Rework Python callback functions. (#6199)
* Define a new callback interface for Python. * Deprecate the old callbacks. * Enable early stopping on dask.
This commit is contained in:
@@ -8,7 +8,7 @@ import pytest
|
||||
import locale
|
||||
import tempfile
|
||||
|
||||
dpath = 'demo/data/'
|
||||
dpath = os.path.join(tm.PROJECT_ROOT, 'demo/data/')
|
||||
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
|
||||
dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')
|
||||
|
||||
@@ -110,84 +110,6 @@ class TestModels(unittest.TestCase):
|
||||
for jj in range(ii + 1, len(preds_list)):
|
||||
assert np.sum(np.abs(preds_list[ii] - preds_list[jj])) > 0
|
||||
|
||||
def run_eta_decay(self, tree_method):
|
||||
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
|
||||
num_round = 4
|
||||
|
||||
# learning_rates as a list
|
||||
# init eta with 0 to check whether learning_rates work
|
||||
param = {'max_depth': 2, 'eta': 0, 'verbosity': 0,
|
||||
'objective': 'binary:logistic', 'eval_metric': 'error',
|
||||
'tree_method': tree_method}
|
||||
evals_result = {}
|
||||
bst = xgb.train(param, dtrain, num_round, watchlist,
|
||||
callbacks=[xgb.callback.reset_learning_rate([
|
||||
0.8, 0.7, 0.6, 0.5
|
||||
])],
|
||||
evals_result=evals_result)
|
||||
eval_errors_0 = list(map(float, evals_result['eval']['error']))
|
||||
assert isinstance(bst, xgb.core.Booster)
|
||||
# validation error should decrease, if eta > 0
|
||||
assert eval_errors_0[0] > eval_errors_0[-1]
|
||||
|
||||
# init learning_rate with 0 to check whether learning_rates work
|
||||
param = {'max_depth': 2, 'learning_rate': 0, 'verbosity': 0,
|
||||
'objective': 'binary:logistic', 'eval_metric': 'error',
|
||||
'tree_method': tree_method}
|
||||
evals_result = {}
|
||||
bst = xgb.train(param, dtrain, num_round, watchlist,
|
||||
callbacks=[xgb.callback.reset_learning_rate(
|
||||
[0.8, 0.7, 0.6, 0.5])],
|
||||
evals_result=evals_result)
|
||||
eval_errors_1 = list(map(float, evals_result['eval']['error']))
|
||||
assert isinstance(bst, xgb.core.Booster)
|
||||
# validation error should decrease, if learning_rate > 0
|
||||
assert eval_errors_1[0] > eval_errors_1[-1]
|
||||
|
||||
# check if learning_rates override default value of eta/learning_rate
|
||||
param = {
|
||||
'max_depth': 2, 'verbosity': 0, 'objective': 'binary:logistic',
|
||||
'eval_metric': 'error', 'tree_method': tree_method
|
||||
}
|
||||
evals_result = {}
|
||||
bst = xgb.train(param, dtrain, num_round, watchlist,
|
||||
callbacks=[xgb.callback.reset_learning_rate(
|
||||
[0, 0, 0, 0]
|
||||
)],
|
||||
evals_result=evals_result)
|
||||
eval_errors_2 = list(map(float, evals_result['eval']['error']))
|
||||
assert isinstance(bst, xgb.core.Booster)
|
||||
# validation error should not decrease, if eta/learning_rate = 0
|
||||
assert eval_errors_2[0] == eval_errors_2[-1]
|
||||
|
||||
# learning_rates as a customized decay function
|
||||
def eta_decay(ithround, num_boost_round):
|
||||
return num_boost_round / (ithround + 1)
|
||||
|
||||
evals_result = {}
|
||||
bst = xgb.train(param, dtrain, num_round, watchlist,
|
||||
callbacks=[
|
||||
xgb.callback.reset_learning_rate(eta_decay)
|
||||
],
|
||||
evals_result=evals_result)
|
||||
eval_errors_3 = list(map(float, evals_result['eval']['error']))
|
||||
|
||||
assert isinstance(bst, xgb.core.Booster)
|
||||
|
||||
assert eval_errors_3[0] == eval_errors_2[0]
|
||||
|
||||
for i in range(1, len(eval_errors_0)):
|
||||
assert eval_errors_3[i] != eval_errors_2[i]
|
||||
|
||||
def test_eta_decay_hist(self):
|
||||
self.run_eta_decay('hist')
|
||||
|
||||
def test_eta_decay_approx(self):
|
||||
self.run_eta_decay('approx')
|
||||
|
||||
def test_eta_decay_exact(self):
|
||||
self.run_eta_decay('exact')
|
||||
|
||||
def test_boost_from_prediction(self):
|
||||
# Re-construct dtrain here to avoid modification
|
||||
margined = xgb.DMatrix(dpath + 'agaricus.txt.train')
|
||||
|
||||
Reference in New Issue
Block a user