Cleanup the callback module. (#8702)
- Cleanup pylint markers. - run formatter. - Update examples of using callback.
This commit is contained in:
parent
34eee56256
commit
9fb12b20a4
@ -1,6 +1,3 @@
|
||||
# coding: utf-8
|
||||
# pylint: disable=invalid-name, too-many-statements
|
||||
# pylint: disable=too-many-arguments
|
||||
"""Callback library containing training routines. See :doc:`Callback Functions
|
||||
</python/callbacks>` for a quick introduction.
|
||||
|
||||
@ -34,7 +31,7 @@ __all__ = [
|
||||
"EarlyStopping",
|
||||
"EvaluationMonitor",
|
||||
"TrainingCheckPoint",
|
||||
"CallbackContainer"
|
||||
"CallbackContainer",
|
||||
]
|
||||
|
||||
_Score = Union[float, Tuple[float, float]]
|
||||
@ -45,39 +42,37 @@ _Model = Any # real type is Union[Booster, CVPack]; need more work
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
class TrainingCallback(ABC):
|
||||
'''Interface for training callback.
|
||||
"""Interface for training callback.
|
||||
|
||||
.. versionadded:: 1.3.0
|
||||
|
||||
'''
|
||||
"""
|
||||
|
||||
EvalsLog = Dict[str, Dict[str, _ScoreList]]
|
||||
EvalsLog = Dict[str, Dict[str, _ScoreList]] # pylint: disable=invalid-name
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def before_training(self, model: _Model) -> _Model:
|
||||
'''Run before training starts.'''
|
||||
"""Run before training starts."""
|
||||
return model
|
||||
|
||||
def after_training(self, model: _Model) -> _Model:
|
||||
'''Run after training is finished.'''
|
||||
"""Run after training is finished."""
|
||||
return model
|
||||
|
||||
def before_iteration(self, model: _Model, epoch: int, evals_log: EvalsLog) -> bool:
|
||||
'''Run before each iteration. Return True when training should stop.'''
|
||||
"""Run before each iteration. Return True when training should stop."""
|
||||
return False
|
||||
|
||||
def after_iteration(self, model: _Model, epoch: int, evals_log: EvalsLog) -> bool:
|
||||
'''Run after each iteration. Return True when training should stop.'''
|
||||
"""Run after each iteration. Return True when training should stop."""
|
||||
return False
|
||||
|
||||
|
||||
def _aggcv(rlist: List[str]) -> List[Tuple[str, float, float]]:
|
||||
# pylint: disable=invalid-name, too-many-locals
|
||||
"""Aggregate cross-validation results.
|
||||
|
||||
"""
|
||||
"""Aggregate cross-validation results."""
|
||||
cvmap: Dict[Tuple[int, str], List[float]] = {}
|
||||
idx = rlist[0].split()[0]
|
||||
for line in rlist:
|
||||
@ -86,7 +81,7 @@ def _aggcv(rlist: List[str]) -> List[Tuple[str, float, float]]:
|
||||
for metric_idx, it in enumerate(arr[1:]):
|
||||
if not isinstance(it, str):
|
||||
it = it.decode()
|
||||
k, v = it.split(':')
|
||||
k, v = it.split(":")
|
||||
if (metric_idx, k) not in cvmap:
|
||||
cvmap[(metric_idx, k)] = []
|
||||
cvmap[(metric_idx, k)].append(float(v))
|
||||
@ -106,44 +101,45 @@ _ART = TypeVar("_ART")
|
||||
|
||||
|
||||
def _allreduce_metric(score: _ART) -> _ART:
|
||||
'''Helper function for computing customized metric in distributed
|
||||
"""Helper function for computing customized metric in distributed
|
||||
environment. Not strictly correct as many functions don't use mean value
|
||||
as final result.
|
||||
|
||||
'''
|
||||
"""
|
||||
world = collective.get_world_size()
|
||||
assert world != 0
|
||||
if world == 1:
|
||||
return score
|
||||
if isinstance(score, tuple): # has mean and stdv
|
||||
raise ValueError(
|
||||
'xgboost.cv function should not be used in distributed environment.')
|
||||
"xgboost.cv function should not be used in distributed environment."
|
||||
)
|
||||
arr = numpy.array([score])
|
||||
arr = collective.allreduce(arr, collective.Op.SUM) / world
|
||||
return arr[0]
|
||||
|
||||
|
||||
class CallbackContainer:
|
||||
'''A special internal callback for invoking a list of other callbacks.
|
||||
"""A special internal callback for invoking a list of other callbacks.
|
||||
|
||||
.. versionadded:: 1.3.0
|
||||
|
||||
'''
|
||||
|
||||
EvalsLog = TrainingCallback.EvalsLog
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
callbacks: Sequence[TrainingCallback],
|
||||
metric: Optional[Callable] = None,
|
||||
output_margin: bool = True,
|
||||
is_cv: bool = False
|
||||
is_cv: bool = False,
|
||||
) -> None:
|
||||
self.callbacks = set(callbacks)
|
||||
if metric is not None:
|
||||
msg = 'metric must be callable object for monitoring. For ' + \
|
||||
'builtin metrics, passing them in training parameter' + \
|
||||
' will invoke monitor automatically.'
|
||||
msg = (
|
||||
"metric must be callable object for monitoring. For "
|
||||
+ "builtin metrics, passing them in training parameter"
|
||||
+ " will invoke monitor automatically."
|
||||
)
|
||||
assert callable(metric), msg
|
||||
self.metric = metric
|
||||
self.history: TrainingCallback.EvalsLog = collections.OrderedDict()
|
||||
@ -154,10 +150,10 @@ class CallbackContainer:
|
||||
self.aggregated_cv = None
|
||||
|
||||
def before_training(self, model: _Model) -> _Model:
|
||||
'''Function called before training.'''
|
||||
"""Function called before training."""
|
||||
for c in self.callbacks:
|
||||
model = c.before_training(model=model)
|
||||
msg = 'before_training should return the model'
|
||||
msg = "before_training should return the model"
|
||||
if self.is_cv:
|
||||
assert isinstance(model.cvfolds, list), msg
|
||||
else:
|
||||
@ -165,10 +161,10 @@ class CallbackContainer:
|
||||
return model
|
||||
|
||||
def after_training(self, model: _Model) -> _Model:
|
||||
'''Function called after training.'''
|
||||
"""Function called after training."""
|
||||
for c in self.callbacks:
|
||||
model = c.after_training(model=model)
|
||||
msg = 'after_training should return the model'
|
||||
msg = "after_training should return the model"
|
||||
if self.is_cv:
|
||||
assert isinstance(model.cvfolds, list), msg
|
||||
else:
|
||||
@ -176,9 +172,9 @@ class CallbackContainer:
|
||||
|
||||
if not self.is_cv:
|
||||
num_parallel_tree, _ = _get_booster_layer_trees(model)
|
||||
if model.attr('best_score') is not None:
|
||||
model.best_score = float(cast(str, model.attr('best_score')))
|
||||
model.best_iteration = int(cast(str, model.attr('best_iteration')))
|
||||
if model.attr("best_score") is not None:
|
||||
model.best_score = float(cast(str, model.attr("best_score")))
|
||||
model.best_iteration = int(cast(str, model.attr("best_iteration")))
|
||||
# num_class is handled internally
|
||||
model.set_attr(
|
||||
best_ntree_limit=str((model.best_iteration + 1) * num_parallel_tree)
|
||||
@ -195,16 +191,21 @@ class CallbackContainer:
|
||||
return model
|
||||
|
||||
def before_iteration(
|
||||
self, model: _Model, epoch: int, dtrain: DMatrix, evals: Optional[List[Tuple[DMatrix, str]]]
|
||||
self,
|
||||
model: _Model,
|
||||
epoch: int,
|
||||
dtrain: DMatrix,
|
||||
evals: Optional[List[Tuple[DMatrix, str]]],
|
||||
) -> bool:
|
||||
'''Function called before training iteration.'''
|
||||
return any(c.before_iteration(model, epoch, self.history)
|
||||
for c in self.callbacks)
|
||||
"""Function called before training iteration."""
|
||||
return any(
|
||||
c.before_iteration(model, epoch, self.history) for c in self.callbacks
|
||||
)
|
||||
|
||||
def _update_history(
|
||||
self,
|
||||
score: Union[List[Tuple[str, float]], List[Tuple[str, float, float]]],
|
||||
epoch: int
|
||||
epoch: int,
|
||||
) -> None:
|
||||
for d in score:
|
||||
name: str = d[0]
|
||||
@ -214,9 +215,9 @@ class CallbackContainer:
|
||||
x: _Score = (s, std)
|
||||
else:
|
||||
x = s
|
||||
splited_names = name.split('-')
|
||||
splited_names = name.split("-")
|
||||
data_name = splited_names[0]
|
||||
metric_name = '-'.join(splited_names[1:])
|
||||
metric_name = "-".join(splited_names[1:])
|
||||
x = _allreduce_metric(x)
|
||||
if data_name not in self.history:
|
||||
self.history[data_name] = collections.OrderedDict()
|
||||
@ -238,7 +239,7 @@ class CallbackContainer:
|
||||
dtrain: DMatrix,
|
||||
evals: Optional[List[Tuple[DMatrix, str]]],
|
||||
) -> bool:
|
||||
'''Function called after training iteration.'''
|
||||
"""Function called after training iteration."""
|
||||
if self.is_cv:
|
||||
scores = model.eval(epoch, self.metric, self._output_margin)
|
||||
scores = _aggcv(scores)
|
||||
@ -247,16 +248,15 @@ class CallbackContainer:
|
||||
else:
|
||||
evals = [] if evals is None else evals
|
||||
for _, name in evals:
|
||||
assert name.find('-') == -1, 'Dataset name should not contain `-`'
|
||||
assert name.find("-") == -1, "Dataset name should not contain `-`"
|
||||
score: str = model.eval_set(evals, epoch, self.metric, self._output_margin)
|
||||
splited = score.split()[1:] # into datasets
|
||||
# split up `test-error:0.1234`
|
||||
metric_score_str = [tuple(s.split(':')) for s in splited]
|
||||
metric_score_str = [tuple(s.split(":")) for s in splited]
|
||||
# convert to float
|
||||
metric_score = [(n, float(s)) for n, s in metric_score_str]
|
||||
self._update_history(metric_score, epoch)
|
||||
ret = any(c.after_iteration(model, epoch, self.history)
|
||||
for c in self.callbacks)
|
||||
ret = any(c.after_iteration(model, epoch, self.history) for c in self.callbacks)
|
||||
return ret
|
||||
|
||||
|
||||
@ -320,7 +320,6 @@ class EarlyStopping(TrainingCallback):
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
clf = xgboost.XGBClassifier(tree_method="gpu_hist")
|
||||
es = xgboost.callback.EarlyStopping(
|
||||
rounds=2,
|
||||
abs_tol=1e-3,
|
||||
@ -329,10 +328,13 @@ class EarlyStopping(TrainingCallback):
|
||||
data_name="validation_0",
|
||||
metric_name="mlogloss",
|
||||
)
|
||||
clf = xgboost.XGBClassifier(tree_method="gpu_hist", callbacks=[es])
|
||||
|
||||
X, y = load_digits(return_X_y=True)
|
||||
clf.fit(X, y, eval_set=[(X, y)], callbacks=[es])
|
||||
clf.fit(X, y, eval_set=[(X, y)])
|
||||
"""
|
||||
|
||||
# pylint: disable=too-many-arguments
|
||||
def __init__(
|
||||
self,
|
||||
rounds: int,
|
||||
@ -340,7 +342,7 @@ class EarlyStopping(TrainingCallback):
|
||||
data_name: Optional[str] = None,
|
||||
maximize: Optional[bool] = None,
|
||||
save_best: Optional[bool] = False,
|
||||
min_delta: float = 0.0
|
||||
min_delta: float = 0.0,
|
||||
) -> None:
|
||||
self.data = data_name
|
||||
self.metric_name = metric_name
|
||||
@ -364,9 +366,9 @@ class EarlyStopping(TrainingCallback):
|
||||
def _update_rounds(
|
||||
self, score: _Score, name: str, metric: str, model: _Model, epoch: int
|
||||
) -> bool:
|
||||
def get_s(x: _Score) -> float:
|
||||
def get_s(value: _Score) -> float:
|
||||
"""get score if it's cross validation history."""
|
||||
return x[0] if isinstance(x, tuple) else x
|
||||
return value[0] if isinstance(value, tuple) else value
|
||||
|
||||
def maximize(new: _Score, best: _Score) -> bool:
|
||||
"""New score should be greater than the old one."""
|
||||
@ -379,9 +381,17 @@ class EarlyStopping(TrainingCallback):
|
||||
if self.maximize is None:
|
||||
# Just to be compatibility with old behavior before 1.3. We should let
|
||||
# user to decide.
|
||||
maximize_metrics = ('auc', 'aucpr', 'map', 'ndcg', 'auc@',
|
||||
'aucpr@', 'map@', 'ndcg@')
|
||||
if metric != 'mape' and any(metric.startswith(x) for x in maximize_metrics):
|
||||
maximize_metrics = (
|
||||
"auc",
|
||||
"aucpr",
|
||||
"map",
|
||||
"ndcg",
|
||||
"auc@",
|
||||
"aucpr@",
|
||||
"map@",
|
||||
"ndcg@",
|
||||
)
|
||||
if metric != "mape" and any(metric.startswith(x) for x in maximize_metrics):
|
||||
self.maximize = True
|
||||
else:
|
||||
self.maximize = False
|
||||
@ -414,18 +424,19 @@ class EarlyStopping(TrainingCallback):
|
||||
return True
|
||||
return False
|
||||
|
||||
def after_iteration(self, model: _Model, epoch: int,
|
||||
evals_log: TrainingCallback.EvalsLog) -> bool:
|
||||
def after_iteration(
|
||||
self, model: _Model, epoch: int, evals_log: TrainingCallback.EvalsLog
|
||||
) -> bool:
|
||||
epoch += self.starting_round # training continuation
|
||||
msg = 'Must have at least 1 validation dataset for early stopping.'
|
||||
msg = "Must have at least 1 validation dataset for early stopping."
|
||||
assert len(evals_log.keys()) >= 1, msg
|
||||
data_name = ''
|
||||
data_name = ""
|
||||
if self.data:
|
||||
for d, _ in evals_log.items():
|
||||
if d == self.data:
|
||||
data_name = d
|
||||
if not data_name:
|
||||
raise ValueError('No dataset named:', self.data)
|
||||
raise ValueError("No dataset named:", self.data)
|
||||
else:
|
||||
# Use the last one as default.
|
||||
data_name = list(evals_log.keys())[-1]
|
||||
@ -454,7 +465,7 @@ class EarlyStopping(TrainingCallback):
|
||||
|
||||
|
||||
class EvaluationMonitor(TrainingCallback):
|
||||
'''Print the evaluation result at each iteration.
|
||||
"""Print the evaluation result at each iteration.
|
||||
|
||||
.. versionadded:: 1.3.0
|
||||
|
||||
@ -469,7 +480,8 @@ class EvaluationMonitor(TrainingCallback):
|
||||
How many epoches between printing.
|
||||
show_stdv :
|
||||
Used in cv to show standard deviation. Users should not specify it.
|
||||
'''
|
||||
"""
|
||||
|
||||
def __init__(self, rank: int = 0, period: int = 1, show_stdv: bool = False) -> None:
|
||||
self.printer_rank = rank
|
||||
self.show_stdv = show_stdv
|
||||
@ -488,12 +500,13 @@ class EvaluationMonitor(TrainingCallback):
|
||||
msg = f"\t{data + '-' + metric}:{score:.5f}"
|
||||
return msg
|
||||
|
||||
def after_iteration(self, model: _Model, epoch: int,
|
||||
evals_log: TrainingCallback.EvalsLog) -> bool:
|
||||
def after_iteration(
|
||||
self, model: _Model, epoch: int, evals_log: TrainingCallback.EvalsLog
|
||||
) -> bool:
|
||||
if not evals_log:
|
||||
return False
|
||||
|
||||
msg: str = f'[{epoch}]'
|
||||
msg: str = f"[{epoch}]"
|
||||
if collective.get_rank() == self.printer_rank:
|
||||
for data, metric in evals_log.items():
|
||||
for metric_name, log in metric.items():
|
||||
@ -504,7 +517,7 @@ class EvaluationMonitor(TrainingCallback):
|
||||
else:
|
||||
score = log[-1]
|
||||
msg += self._fmt_metric(data, metric_name, score, stdv)
|
||||
msg += '\n'
|
||||
msg += "\n"
|
||||
|
||||
if (epoch % self.period) == 0 or self.period == 1:
|
||||
collective.communicator_print(msg)
|
||||
@ -521,7 +534,7 @@ class EvaluationMonitor(TrainingCallback):
|
||||
|
||||
|
||||
class TrainingCheckPoint(TrainingCallback):
|
||||
'''Checkpointing operation.
|
||||
"""Checkpointing operation.
|
||||
|
||||
.. versionadded:: 1.3.0
|
||||
|
||||
@ -540,13 +553,14 @@ class TrainingCheckPoint(TrainingCallback):
|
||||
Interval of checkpointing. Checkpointing is slow so setting a larger number can
|
||||
reduce performance hit.
|
||||
|
||||
'''
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
directory: Union[str, os.PathLike],
|
||||
name: str = 'model',
|
||||
name: str = "model",
|
||||
as_pickle: bool = False,
|
||||
iterations: int = 100
|
||||
iterations: int = 100,
|
||||
) -> None:
|
||||
self._path = os.fspath(directory)
|
||||
self._name = name
|
||||
@ -555,15 +569,21 @@ class TrainingCheckPoint(TrainingCallback):
|
||||
self._epoch = 0
|
||||
super().__init__()
|
||||
|
||||
def after_iteration(self, model: _Model, epoch: int,
|
||||
evals_log: TrainingCallback.EvalsLog) -> bool:
|
||||
def after_iteration(
|
||||
self, model: _Model, epoch: int, evals_log: TrainingCallback.EvalsLog
|
||||
) -> bool:
|
||||
if self._epoch == self._iterations:
|
||||
path = os.path.join(self._path, self._name + '_' + str(epoch) +
|
||||
('.pkl' if self._as_pickle else '.json'))
|
||||
path = os.path.join(
|
||||
self._path,
|
||||
self._name
|
||||
+ "_"
|
||||
+ str(epoch)
|
||||
+ (".pkl" if self._as_pickle else ".json"),
|
||||
)
|
||||
self._epoch = 0
|
||||
if collective.get_rank() == 0:
|
||||
if self._as_pickle:
|
||||
with open(path, 'wb') as fd:
|
||||
with open(path, "wb") as fd:
|
||||
pickle.dump(model, fd)
|
||||
else:
|
||||
model.save_model(path)
|
||||
|
||||
@ -288,9 +288,9 @@ __model_doc = f"""
|
||||
|
||||
.. note::
|
||||
|
||||
This parameter replaces `eval_metric` in :py:meth:`fit` method. The old one
|
||||
receives un-transformed prediction regardless of whether custom objective is
|
||||
being used.
|
||||
This parameter replaces `eval_metric` in :py:meth:`fit` method. The old
|
||||
one receives un-transformed prediction regardless of whether custom
|
||||
objective is being used.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@ -340,7 +340,8 @@ __model_doc = f"""
|
||||
for params in parameters_grid:
|
||||
# be sure to (re)initialize the callbacks before each run
|
||||
callbacks = [xgb.callback.LearningRateScheduler(custom_rates)]
|
||||
xgboost.train(params, Xy, callbacks=callbacks)
|
||||
reg = xgboost.XGBRegressor(**params, callbacks=callbacks)
|
||||
reg.fit(X, y)
|
||||
|
||||
kwargs : dict, optional
|
||||
Keyword arguments for XGBoost Booster object. Full documentation of parameters
|
||||
|
||||
@ -28,9 +28,7 @@ from .core import (
|
||||
_CVFolds = Sequence["CVPack"]
|
||||
|
||||
|
||||
def _assert_new_callback(
|
||||
callbacks: Optional[Sequence[TrainingCallback]]
|
||||
) -> None:
|
||||
def _assert_new_callback(callbacks: Optional[Sequence[TrainingCallback]]) -> None:
|
||||
is_new_callback: bool = not callbacks or all(
|
||||
isinstance(c, TrainingCallback) for c in callbacks
|
||||
)
|
||||
@ -45,7 +43,9 @@ def _configure_custom_metric(
|
||||
feval: Optional[Metric], custom_metric: Optional[Metric]
|
||||
) -> Optional[Metric]:
|
||||
if feval is not None:
|
||||
link = "https://xgboost.readthedocs.io/en/latest/tutorials/custom_metric_obj.html"
|
||||
link = (
|
||||
"https://xgboost.readthedocs.io/en/latest/tutorials/custom_metric_obj.html"
|
||||
)
|
||||
warnings.warn(
|
||||
"`feval` is deprecated, use `custom_metric` instead. They have "
|
||||
"different behavior when custom objective is also used."
|
||||
@ -175,9 +175,7 @@ def train(
|
||||
verbose_eval = 1 if verbose_eval is True else verbose_eval
|
||||
callbacks.append(EvaluationMonitor(period=verbose_eval))
|
||||
if early_stopping_rounds:
|
||||
callbacks.append(
|
||||
EarlyStopping(rounds=early_stopping_rounds, maximize=maximize)
|
||||
)
|
||||
callbacks.append(EarlyStopping(rounds=early_stopping_rounds, maximize=maximize))
|
||||
cb_container = CallbackContainer(
|
||||
callbacks,
|
||||
metric=metric_fn,
|
||||
@ -536,13 +534,9 @@ def cv(
|
||||
|
||||
if verbose_eval:
|
||||
verbose_eval = 1 if verbose_eval is True else verbose_eval
|
||||
callbacks.append(
|
||||
EvaluationMonitor(period=verbose_eval, show_stdv=show_stdv)
|
||||
)
|
||||
callbacks.append(EvaluationMonitor(period=verbose_eval, show_stdv=show_stdv))
|
||||
if early_stopping_rounds:
|
||||
callbacks.append(
|
||||
EarlyStopping(rounds=early_stopping_rounds, maximize=maximize)
|
||||
)
|
||||
callbacks.append(EarlyStopping(rounds=early_stopping_rounds, maximize=maximize))
|
||||
callbacks_container = CallbackContainer(
|
||||
callbacks,
|
||||
metric=metric_fn,
|
||||
|
||||
@ -134,6 +134,7 @@ def main(args: argparse.Namespace) -> None:
|
||||
# core
|
||||
"python-package/xgboost/__init__.py",
|
||||
"python-package/xgboost/_typing.py",
|
||||
"python-package/xgboost/callback.py",
|
||||
"python-package/xgboost/compat.py",
|
||||
"python-package/xgboost/config.py",
|
||||
"python-package/xgboost/dask.py",
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user