Support min_delta in early stopping. (#7137)

* Support `min_delta` in early stopping.

* Remove abs_tol.
This commit is contained in:
Jiaming Yuan 2021-08-03 14:29:17 +08:00 committed by GitHub
parent 7bdedacb54
commit e2c406f5c8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 31 additions and 23 deletions

View File

@ -485,8 +485,8 @@ class EarlyStopping(TrainingCallback):
Whether to maximize evaluation metric. None means auto (discouraged). Whether to maximize evaluation metric. None means auto (discouraged).
save_best save_best
Whether training should return the best model or the last model. Whether training should return the best model or the last model.
abs_tol min_delta
Absolute tolerance for early stopping condition. Minimum absolute change in score to be qualified as an improvement.
.. versionadded:: 1.5.0 .. versionadded:: 1.5.0
@ -505,22 +505,24 @@ class EarlyStopping(TrainingCallback):
X, y = load_digits(return_X_y=True) X, y = load_digits(return_X_y=True)
clf.fit(X, y, eval_set=[(X, y)], callbacks=[es]) clf.fit(X, y, eval_set=[(X, y)], callbacks=[es])
""" """
def __init__(self, def __init__(
rounds: int, self,
metric_name: Optional[str] = None, rounds: int,
data_name: Optional[str] = None, metric_name: Optional[str] = None,
maximize: Optional[bool] = None, data_name: Optional[str] = None,
save_best: Optional[bool] = False, maximize: Optional[bool] = None,
abs_tol: float = 0) -> None: save_best: Optional[bool] = False,
min_delta: float = 0.0
) -> None:
self.data = data_name self.data = data_name
self.metric_name = metric_name self.metric_name = metric_name
self.rounds = rounds self.rounds = rounds
self.save_best = save_best self.save_best = save_best
self.maximize = maximize self.maximize = maximize
self.stopping_history: CallbackContainer.EvalsLog = {} self.stopping_history: CallbackContainer.EvalsLog = {}
self._tol = abs_tol self._min_delta = min_delta
if self._tol < 0: if self._min_delta < 0:
raise ValueError("tolerance must be greater or equal to 0.") raise ValueError("min_delta must be greater or equal to 0.")
self.improve_op = None self.improve_op = None
@ -539,10 +541,12 @@ class EarlyStopping(TrainingCallback):
return x[0] if isinstance(x, tuple) else x return x[0] if isinstance(x, tuple) else x
def maximize(new, best): def maximize(new, best):
return numpy.greater(get_s(new) + self._tol, get_s(best)) """New score should be greater than the old one."""
return numpy.greater(get_s(new) - self._min_delta, get_s(best))
def minimize(new, best): def minimize(new, best):
return numpy.greater(get_s(best) + self._tol, get_s(new)) """New score should be smaller than the old one."""
return numpy.greater(get_s(best) - self._min_delta, get_s(new))
if self.maximize is None: if self.maximize is None:
# Just to be compatibility with old behavior before 1.3. We should let # Just to be compatibility with old behavior before 1.3. We should let

View File

@ -126,26 +126,30 @@ class TestCallbacks:
assert len(dump) - booster.best_iteration == early_stopping_rounds + 1 assert len(dump) - booster.best_iteration == early_stopping_rounds + 1
assert len(early_stop.stopping_history['Train']['CustomErr']) == len(dump) assert len(early_stop.stopping_history['Train']['CustomErr']) == len(dump)
# test tolerance, early stop won't occur with high tolerance.
tol = 10
rounds = 100 rounds = 100
early_stop = xgb.callback.EarlyStopping( early_stop = xgb.callback.EarlyStopping(
rounds=early_stopping_rounds, rounds=early_stopping_rounds,
metric_name='CustomErr', metric_name='CustomErr',
data_name='Train', data_name='Train',
abs_tol=tol min_delta=100,
save_best=True,
) )
booster = xgb.train( booster = xgb.train(
{'objective': 'binary:logistic', {
'eval_metric': ['error', 'rmse'], 'objective': 'binary:logistic',
'tree_method': 'hist'}, D_train, 'eval_metric': ['error', 'rmse'],
'tree_method': 'hist'
},
D_train,
evals=[(D_train, 'Train'), (D_valid, 'Valid')], evals=[(D_train, 'Train'), (D_valid, 'Valid')],
feval=tm.eval_error_metric, feval=tm.eval_error_metric,
num_boost_round=rounds, num_boost_round=rounds,
callbacks=[early_stop], callbacks=[early_stop],
verbose_eval=False) verbose_eval=False
# 0 based index )
assert booster.best_iteration == rounds - 1 # No iteration can be made with min_delta == 100
assert booster.best_iteration == 0
assert booster.num_boosted_rounds() == 1
def test_early_stopping_skl(self): def test_early_stopping_skl(self):
from sklearn.datasets import load_breast_cancer from sklearn.datasets import load_breast_cancer