[PYTHON] Refactor trainnig API to use callback
This commit is contained in:
@@ -35,6 +35,22 @@ class TestBasic(unittest.TestCase):
|
||||
# assert they are the same
|
||||
assert np.sum(np.abs(preds2 - preds)) == 0
|
||||
|
||||
def test_record_results(self):
|
||||
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
|
||||
dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')
|
||||
param = {'max_depth': 2, 'eta': 1, 'silent': 1, 'objective': 'binary:logistic'}
|
||||
# specify validations set to watch performance
|
||||
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
|
||||
num_round = 2
|
||||
result = {}
|
||||
res2 = {}
|
||||
xgb.train(param, dtrain, num_round, watchlist,
|
||||
callbacks=[xgb.callback.record_evaluation(result)])
|
||||
xgb.train(param, dtrain, num_round, watchlist,
|
||||
evals_result=res2)
|
||||
assert result['train']['error'][0] < 0.1
|
||||
assert res2 == result
|
||||
|
||||
def test_multiclass(self):
|
||||
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
|
||||
dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')
|
||||
@@ -189,5 +205,5 @@ class TestBasic(unittest.TestCase):
|
||||
|
||||
# return np.ndarray
|
||||
cv = xgb.cv(params, dm, num_boost_round=10, nfold=10, as_pandas=False)
|
||||
assert isinstance(cv, np.ndarray)
|
||||
assert cv.shape == (10, 4)
|
||||
assert isinstance(cv, dict)
|
||||
assert len(cv) == (4)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import xgboost as xgb
|
||||
import xgboost.testing as tm
|
||||
import testing as tm
|
||||
import numpy as np
|
||||
import unittest
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import xgboost as xgb
|
||||
import xgboost.testing as tm
|
||||
import testing as tm
|
||||
import numpy as np
|
||||
import unittest
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import numpy as np
|
||||
import xgboost as xgb
|
||||
import xgboost.testing as tm
|
||||
import testing as tm
|
||||
import unittest
|
||||
|
||||
try:
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import xgboost as xgb
|
||||
import xgboost.testing as tm
|
||||
import testing as tm
|
||||
import numpy as np
|
||||
import unittest
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import numpy as np
|
||||
import xgboost as xgb
|
||||
import xgboost.testing as tm
|
||||
import testing as tm
|
||||
import unittest
|
||||
|
||||
try:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import numpy as np
|
||||
import random
|
||||
import xgboost as xgb
|
||||
import xgboost.testing as tm
|
||||
import testing as tm
|
||||
|
||||
rng = np.random.RandomState(1994)
|
||||
|
||||
|
||||
22
tests/python/testing.py
Normal file
22
tests/python/testing.py
Normal file
@@ -0,0 +1,22 @@
|
||||
# coding: utf-8
|
||||
|
||||
import nose
|
||||
|
||||
from xgboost.compat import SKLEARN_INSTALLED, PANDAS_INSTALLED
|
||||
|
||||
|
||||
def _skip_if_no_sklearn():
|
||||
if not SKLEARN_INSTALLED:
|
||||
raise nose.SkipTest()
|
||||
|
||||
|
||||
def _skip_if_no_pandas():
|
||||
if not PANDAS_INSTALLED:
|
||||
raise nose.SkipTest()
|
||||
|
||||
|
||||
def _skip_if_no_matplotlib():
|
||||
try:
|
||||
import matplotlib.pyplot as _ # noqa
|
||||
except ImportError:
|
||||
raise nose.SkipTest()
|
||||
Reference in New Issue
Block a user