Support early stopping with training continuation, correct num boosted rounds. (#6506)

* Implement early stopping with training continuation.

* Add new C API for obtaining boosted rounds.

* Fix off by 1 in `save_best`.

Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
Jiaming Yuan
2020-12-17 19:59:19 +08:00
committed by GitHub
parent 125b3c0f2d
commit ca3da55de4
16 changed files with 210 additions and 118 deletions

View File

@@ -124,6 +124,20 @@ class TestModels:
predt_2 = bst.predict(dtrain)
assert np.all(np.abs(predt_2 - predt_1) < 1e-6)
def test_boost_from_existing_model(self):
X = xgb.DMatrix(dpath + 'agaricus.txt.train')
booster = xgb.train({'tree_method': 'hist'}, X, num_boost_round=4)
assert booster.num_boosted_rounds() == 4
booster = xgb.train({'tree_method': 'hist'}, X, num_boost_round=4,
xgb_model=booster)
assert booster.num_boosted_rounds() == 8
booster = xgb.train({'updater': 'prune', 'process_type': 'update'}, X,
num_boost_round=4, xgb_model=booster)
# Trees are moved for update, the rounds is reduced. This test is
# written for being compatible with current code (1.0.0). If the
# behaviour is considered sub-optimal, feel free to change.
assert booster.num_boosted_rounds() == 4
def test_custom_objective(self):
param = {'max_depth': 2, 'eta': 1, 'objective': 'reg:logistic'}
watchlist = [(dtest, 'eval'), (dtrain, 'train')]

View File

@@ -43,7 +43,7 @@ class TestCallbacks:
# Should print info by each period additionaly to first and latest iteration
num_periods = rounds // int(verbose_eval)
# Extra information is required for latest iteration
is_extra_info_required = num_periods * int(verbose_eval) < (rounds - 1)
is_extra_info_required = num_periods * int(verbose_eval) < (rounds - 1)
assert len(output.split('\n')) == 1 + num_periods + int(is_extra_info_required)
def test_evaluation_monitor(self):
@@ -63,7 +63,7 @@ class TestCallbacks:
self.run_evaluation_monitor(D_train, D_valid, rounds, True)
self.run_evaluation_monitor(D_train, D_valid, rounds, 2)
self.run_evaluation_monitor(D_train, D_valid, rounds, 4)
self.run_evaluation_monitor(D_train, D_valid, rounds, rounds + 1)
self.run_evaluation_monitor(D_train, D_valid, rounds, rounds + 1)
def test_early_stopping(self):
D_train = xgb.DMatrix(self.X_train, self.y_train)
@@ -81,6 +81,15 @@ class TestCallbacks:
dump = booster.get_dump(dump_format='json')
assert len(dump) - booster.best_iteration == early_stopping_rounds + 1
# No early stopping, best_iteration should be set to last epoch
booster = xgb.train({'objective': 'binary:logistic',
'eval_metric': 'error'}, D_train,
evals=[(D_train, 'Train'), (D_valid, 'Valid')],
num_boost_round=10,
evals_result=evals_result,
verbose_eval=True)
assert booster.num_boosted_rounds() - 1 == booster.best_iteration
def test_early_stopping_custom_eval(self):
D_train = xgb.DMatrix(self.X_train, self.y_train)
D_valid = xgb.DMatrix(self.X_valid, self.y_valid)
@@ -153,7 +162,7 @@ class TestCallbacks:
eval_metric=tm.eval_error_metric, callbacks=[early_stop])
booster = cls.get_booster()
dump = booster.get_dump(dump_format='json')
assert len(dump) == booster.best_iteration
assert len(dump) == booster.best_iteration + 1
early_stop = xgb.callback.EarlyStopping(rounds=early_stopping_rounds,
save_best=True)
@@ -170,6 +179,32 @@ class TestCallbacks:
eval_metric=tm.eval_error_metric,
callbacks=[early_stop])
def test_early_stopping_continuation(self):
from sklearn.datasets import load_breast_cancer
X, y = load_breast_cancer(return_X_y=True)
cls = xgb.XGBClassifier()
early_stopping_rounds = 5
early_stop = xgb.callback.EarlyStopping(rounds=early_stopping_rounds,
save_best=True)
cls.fit(X, y, eval_set=[(X, y)],
eval_metric=tm.eval_error_metric,
callbacks=[early_stop])
booster = cls.get_booster()
assert booster.num_boosted_rounds() == booster.best_iteration + 1
with tempfile.TemporaryDirectory() as tmpdir:
path = os.path.join(tmpdir, 'model.json')
cls.save_model(path)
cls = xgb.XGBClassifier()
cls.load_model(path)
assert cls._Booster is not None
early_stopping_rounds = 3
cls.fit(X, y, eval_set=[(X, y)], eval_metric=tm.eval_error_metric,
early_stopping_rounds=early_stopping_rounds)
booster = cls.get_booster()
assert booster.num_boosted_rounds() == \
booster.best_iteration + early_stopping_rounds + 1
def run_eta_decay(self, tree_method, deprecated_callback):
if deprecated_callback:
scheduler = xgb.callback.reset_learning_rate

View File

@@ -46,7 +46,7 @@ class TestEarlyStopping:
@staticmethod
def assert_metrics_length(cv, expected_length):
for key, value in cv.items():
assert len(value) == expected_length
assert len(value) == expected_length
@pytest.mark.skipif(**tm.no_sklearn())
def test_cv_early_stopping(self):

View File

@@ -62,6 +62,8 @@ def test_multiclass_classification():
kf = KFold(n_splits=2, shuffle=True, random_state=rng)
for train_index, test_index in kf.split(X, y):
xgb_model = xgb.XGBClassifier().fit(X[train_index], y[train_index])
assert (xgb_model.get_booster().num_boosted_rounds() ==
xgb_model.n_estimators)
preds = xgb_model.predict(X[test_index])
# test other params in XGBClassifier().fit
preds2 = xgb_model.predict(X[test_index], output_margin=True,