Remove old callback deprecated in 1.3. (#7280)
This commit is contained in:
@@ -76,23 +76,6 @@ class TestBasic:
|
||||
predt_1 = booster.predict(dtrain)
|
||||
np.testing.assert_allclose(predt_0, predt_1)
|
||||
|
||||
def test_record_results(self):
|
||||
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
|
||||
dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')
|
||||
param = {'max_depth': 2, 'eta': 1, 'verbosity': 0,
|
||||
'objective': 'binary:logistic', 'eval_metric': 'error'}
|
||||
# specify validations set to watch performance
|
||||
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
|
||||
num_round = 2
|
||||
result = {}
|
||||
res2 = {}
|
||||
xgb.train(param, dtrain, num_round, watchlist,
|
||||
callbacks=[xgb.callback.record_evaluation(result)])
|
||||
xgb.train(param, dtrain, num_round, watchlist,
|
||||
evals_result=res2)
|
||||
assert result['train']['error'][0] < 0.1
|
||||
assert res2 == result
|
||||
|
||||
def test_multiclass(self):
|
||||
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
|
||||
dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')
|
||||
@@ -254,8 +237,18 @@ class TestBasic:
|
||||
]
|
||||
|
||||
# Use callback to log the test labels in each fold
|
||||
def cb(cbackenv):
|
||||
print([fold.dtest.get_label() for fold in cbackenv.cvfolds])
|
||||
class Callback(xgb.callback.TrainingCallback):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def after_iteration(
|
||||
self, model,
|
||||
epoch: int,
|
||||
evals_log: xgb.callback.TrainingCallback.EvalsLog
|
||||
):
|
||||
print([fold.dtest.get_label() for fold in model.cvfolds])
|
||||
|
||||
cb = Callback()
|
||||
|
||||
# Run cross validation and capture standard out to test callback result
|
||||
with tm.captured_output() as (out, err):
|
||||
|
||||
Reference in New Issue
Block a user