Add tolerance to early stopping. (#6942)
This commit is contained in:
parent
894e9bc5d4
commit
d245bc891e
@ -487,25 +487,44 @@ class EarlyStopping(TrainingCallback):
|
|||||||
Whether to maximize evaluation metric. None means auto (discouraged).
|
Whether to maximize evaluation metric. None means auto (discouraged).
|
||||||
save_best
|
save_best
|
||||||
Whether training should return the best model or the last model.
|
Whether training should return the best model or the last model.
|
||||||
|
abs_tol
|
||||||
|
Absolute tolerance for early stopping condition.
|
||||||
|
|
||||||
|
.. versionadded:: 1.5.0
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
clf = xgboost.XGBClassifier(tree_method="gpu_hist")
|
||||||
|
es = xgboost.callback.EarlyStopping(
|
||||||
|
rounds=2,
|
||||||
|
abs_tol=1e-3,
|
||||||
|
save_best=True,
|
||||||
|
maximize=False,
|
||||||
|
data_name="validation_0",
|
||||||
|
metric_name="mlogloss",
|
||||||
|
)
|
||||||
|
|
||||||
|
X, y = load_digits(return_X_y=True)
|
||||||
|
clf.fit(X, y, eval_set=[(X, y)], callbacks=[es])
|
||||||
"""
|
"""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
rounds: int,
|
rounds: int,
|
||||||
metric_name: Optional[str] = None,
|
metric_name: Optional[str] = None,
|
||||||
data_name: Optional[str] = None,
|
data_name: Optional[str] = None,
|
||||||
maximize: Optional[bool] = None,
|
maximize: Optional[bool] = None,
|
||||||
save_best: Optional[bool] = False) -> None:
|
save_best: Optional[bool] = False,
|
||||||
|
abs_tol: float = 0) -> None:
|
||||||
self.data = data_name
|
self.data = data_name
|
||||||
self.metric_name = metric_name
|
self.metric_name = metric_name
|
||||||
self.rounds = rounds
|
self.rounds = rounds
|
||||||
self.save_best = save_best
|
self.save_best = save_best
|
||||||
self.maximize = maximize
|
self.maximize = maximize
|
||||||
self.stopping_history: CallbackContainer.EvalsLog = {}
|
self.stopping_history: CallbackContainer.EvalsLog = {}
|
||||||
|
self._tol = abs_tol
|
||||||
|
if self._tol < 0:
|
||||||
|
raise ValueError("tolerance must be greater or equal to 0.")
|
||||||
|
|
||||||
if self.maximize is not None:
|
self.improve_op = None
|
||||||
if self.maximize:
|
|
||||||
self.improve_op = lambda x, y: x > y
|
|
||||||
else:
|
|
||||||
self.improve_op = lambda x, y: x < y
|
|
||||||
|
|
||||||
self.current_rounds: int = 0
|
self.current_rounds: int = 0
|
||||||
self.best_scores: dict = {}
|
self.best_scores: dict = {}
|
||||||
@ -517,18 +536,33 @@ class EarlyStopping(TrainingCallback):
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
def _update_rounds(self, score, name, metric, model, epoch) -> bool:
|
def _update_rounds(self, score, name, metric, model, epoch) -> bool:
|
||||||
# Just to be compatibility with old behavior before 1.3. We should let
|
def get_s(x):
|
||||||
# user to decide.
|
"""get score if it's cross validation history."""
|
||||||
|
return x[0] if isinstance(x, tuple) else x
|
||||||
|
|
||||||
|
def maximize(new, best):
|
||||||
|
return numpy.greater(get_s(new) + self._tol, get_s(best))
|
||||||
|
|
||||||
|
def minimize(new, best):
|
||||||
|
return numpy.greater(get_s(best) + self._tol, get_s(new))
|
||||||
|
|
||||||
if self.maximize is None:
|
if self.maximize is None:
|
||||||
|
# Just to be compatibility with old behavior before 1.3. We should let
|
||||||
|
# user to decide.
|
||||||
maximize_metrics = ('auc', 'aucpr', 'map', 'ndcg', 'auc@',
|
maximize_metrics = ('auc', 'aucpr', 'map', 'ndcg', 'auc@',
|
||||||
'aucpr@', 'map@', 'ndcg@')
|
'aucpr@', 'map@', 'ndcg@')
|
||||||
if any(metric.startswith(x) for x in maximize_metrics):
|
if any(metric.startswith(x) for x in maximize_metrics):
|
||||||
self.improve_op = lambda x, y: x > y
|
|
||||||
self.maximize = True
|
self.maximize = True
|
||||||
else:
|
else:
|
||||||
self.improve_op = lambda x, y: x < y
|
|
||||||
self.maximize = False
|
self.maximize = False
|
||||||
|
|
||||||
|
if self.maximize:
|
||||||
|
self.improve_op = maximize
|
||||||
|
else:
|
||||||
|
self.improve_op = minimize
|
||||||
|
|
||||||
|
assert self.improve_op
|
||||||
|
|
||||||
if not self.stopping_history: # First round
|
if not self.stopping_history: # First round
|
||||||
self.current_rounds = 0
|
self.current_rounds = 0
|
||||||
self.stopping_history[name] = {}
|
self.stopping_history[name] = {}
|
||||||
|
|||||||
@ -126,6 +126,27 @@ class TestCallbacks:
|
|||||||
assert len(dump) - booster.best_iteration == early_stopping_rounds + 1
|
assert len(dump) - booster.best_iteration == early_stopping_rounds + 1
|
||||||
assert len(early_stop.stopping_history['Train']['CustomErr']) == len(dump)
|
assert len(early_stop.stopping_history['Train']['CustomErr']) == len(dump)
|
||||||
|
|
||||||
|
# test tolerance, early stop won't occur with high tolerance.
|
||||||
|
tol = 10
|
||||||
|
rounds = 100
|
||||||
|
early_stop = xgb.callback.EarlyStopping(
|
||||||
|
rounds=early_stopping_rounds,
|
||||||
|
metric_name='CustomErr',
|
||||||
|
data_name='Train',
|
||||||
|
abs_tol=tol
|
||||||
|
)
|
||||||
|
booster = xgb.train(
|
||||||
|
{'objective': 'binary:logistic',
|
||||||
|
'eval_metric': ['error', 'rmse'],
|
||||||
|
'tree_method': 'hist'}, D_train,
|
||||||
|
evals=[(D_train, 'Train'), (D_valid, 'Valid')],
|
||||||
|
feval=tm.eval_error_metric,
|
||||||
|
num_boost_round=rounds,
|
||||||
|
callbacks=[early_stop],
|
||||||
|
verbose_eval=False)
|
||||||
|
# 0 based index
|
||||||
|
assert booster.best_iteration == rounds - 1
|
||||||
|
|
||||||
def test_early_stopping_skl(self):
|
def test_early_stopping_skl(self):
|
||||||
from sklearn.datasets import load_breast_cancer
|
from sklearn.datasets import load_breast_cancer
|
||||||
X, y = load_breast_cancer(return_X_y=True)
|
X, y = load_breast_cancer(return_X_y=True)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user