[PYTHON] Refactor trainnig API to use callback
This commit is contained in:
@@ -35,6 +35,22 @@ class TestBasic(unittest.TestCase):
|
||||
# assert they are the same
|
||||
assert np.sum(np.abs(preds2 - preds)) == 0
|
||||
|
||||
def test_record_results(self):
|
||||
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
|
||||
dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')
|
||||
param = {'max_depth': 2, 'eta': 1, 'silent': 1, 'objective': 'binary:logistic'}
|
||||
# specify validations set to watch performance
|
||||
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
|
||||
num_round = 2
|
||||
result = {}
|
||||
res2 = {}
|
||||
xgb.train(param, dtrain, num_round, watchlist,
|
||||
callbacks=[xgb.callback.record_evaluation(result)])
|
||||
xgb.train(param, dtrain, num_round, watchlist,
|
||||
evals_result=res2)
|
||||
assert result['train']['error'][0] < 0.1
|
||||
assert res2 == result
|
||||
|
||||
def test_multiclass(self):
|
||||
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
|
||||
dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')
|
||||
@@ -189,5 +205,5 @@ class TestBasic(unittest.TestCase):
|
||||
|
||||
# return np.ndarray
|
||||
cv = xgb.cv(params, dm, num_boost_round=10, nfold=10, as_pandas=False)
|
||||
assert isinstance(cv, np.ndarray)
|
||||
assert cv.shape == (10, 4)
|
||||
assert isinstance(cv, dict)
|
||||
assert len(cv) == (4)
|
||||
|
||||
Reference in New Issue
Block a user