Support early stopping with training continuation, correct num boosted rounds. (#6506)
* Implement early stopping with training continuation. * Add new C API for obtaining boosted rounds. * Fix off by 1 in `save_best`. Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
@@ -43,7 +43,7 @@ class TestCallbacks:
|
||||
# Should print info by each period additionaly to first and latest iteration
|
||||
num_periods = rounds // int(verbose_eval)
|
||||
# Extra information is required for latest iteration
|
||||
is_extra_info_required = num_periods * int(verbose_eval) < (rounds - 1)
|
||||
is_extra_info_required = num_periods * int(verbose_eval) < (rounds - 1)
|
||||
assert len(output.split('\n')) == 1 + num_periods + int(is_extra_info_required)
|
||||
|
||||
def test_evaluation_monitor(self):
|
||||
@@ -63,7 +63,7 @@ class TestCallbacks:
|
||||
self.run_evaluation_monitor(D_train, D_valid, rounds, True)
|
||||
self.run_evaluation_monitor(D_train, D_valid, rounds, 2)
|
||||
self.run_evaluation_monitor(D_train, D_valid, rounds, 4)
|
||||
self.run_evaluation_monitor(D_train, D_valid, rounds, rounds + 1)
|
||||
self.run_evaluation_monitor(D_train, D_valid, rounds, rounds + 1)
|
||||
|
||||
def test_early_stopping(self):
|
||||
D_train = xgb.DMatrix(self.X_train, self.y_train)
|
||||
@@ -81,6 +81,15 @@ class TestCallbacks:
|
||||
dump = booster.get_dump(dump_format='json')
|
||||
assert len(dump) - booster.best_iteration == early_stopping_rounds + 1
|
||||
|
||||
# No early stopping, best_iteration should be set to last epoch
|
||||
booster = xgb.train({'objective': 'binary:logistic',
|
||||
'eval_metric': 'error'}, D_train,
|
||||
evals=[(D_train, 'Train'), (D_valid, 'Valid')],
|
||||
num_boost_round=10,
|
||||
evals_result=evals_result,
|
||||
verbose_eval=True)
|
||||
assert booster.num_boosted_rounds() - 1 == booster.best_iteration
|
||||
|
||||
def test_early_stopping_custom_eval(self):
|
||||
D_train = xgb.DMatrix(self.X_train, self.y_train)
|
||||
D_valid = xgb.DMatrix(self.X_valid, self.y_valid)
|
||||
@@ -153,7 +162,7 @@ class TestCallbacks:
|
||||
eval_metric=tm.eval_error_metric, callbacks=[early_stop])
|
||||
booster = cls.get_booster()
|
||||
dump = booster.get_dump(dump_format='json')
|
||||
assert len(dump) == booster.best_iteration
|
||||
assert len(dump) == booster.best_iteration + 1
|
||||
|
||||
early_stop = xgb.callback.EarlyStopping(rounds=early_stopping_rounds,
|
||||
save_best=True)
|
||||
@@ -170,6 +179,32 @@ class TestCallbacks:
|
||||
eval_metric=tm.eval_error_metric,
|
||||
callbacks=[early_stop])
|
||||
|
||||
def test_early_stopping_continuation(self):
|
||||
from sklearn.datasets import load_breast_cancer
|
||||
X, y = load_breast_cancer(return_X_y=True)
|
||||
cls = xgb.XGBClassifier()
|
||||
early_stopping_rounds = 5
|
||||
early_stop = xgb.callback.EarlyStopping(rounds=early_stopping_rounds,
|
||||
save_best=True)
|
||||
cls.fit(X, y, eval_set=[(X, y)],
|
||||
eval_metric=tm.eval_error_metric,
|
||||
callbacks=[early_stop])
|
||||
booster = cls.get_booster()
|
||||
assert booster.num_boosted_rounds() == booster.best_iteration + 1
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
path = os.path.join(tmpdir, 'model.json')
|
||||
cls.save_model(path)
|
||||
cls = xgb.XGBClassifier()
|
||||
cls.load_model(path)
|
||||
assert cls._Booster is not None
|
||||
early_stopping_rounds = 3
|
||||
cls.fit(X, y, eval_set=[(X, y)], eval_metric=tm.eval_error_metric,
|
||||
early_stopping_rounds=early_stopping_rounds)
|
||||
booster = cls.get_booster()
|
||||
assert booster.num_boosted_rounds() == \
|
||||
booster.best_iteration + early_stopping_rounds + 1
|
||||
|
||||
def run_eta_decay(self, tree_method, deprecated_callback):
|
||||
if deprecated_callback:
|
||||
scheduler = xgb.callback.reset_learning_rate
|
||||
|
||||
Reference in New Issue
Block a user