Cleanup the callback module. (#8702)

- Cleanup pylint markers.
- run formatter.
- Update examples of using callback.
This commit is contained in:
Jiaming Yuan 2023-01-22 00:13:49 +08:00 committed by GitHub
parent 34eee56256
commit 9fb12b20a4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 106 additions and 90 deletions

View File

@ -1,6 +1,3 @@
# coding: utf-8
# pylint: disable=invalid-name, too-many-statements
# pylint: disable=too-many-arguments
"""Callback library containing training routines. See :doc:`Callback Functions """Callback library containing training routines. See :doc:`Callback Functions
</python/callbacks>` for a quick introduction. </python/callbacks>` for a quick introduction.
@ -34,7 +31,7 @@ __all__ = [
"EarlyStopping", "EarlyStopping",
"EvaluationMonitor", "EvaluationMonitor",
"TrainingCheckPoint", "TrainingCheckPoint",
"CallbackContainer" "CallbackContainer",
] ]
_Score = Union[float, Tuple[float, float]] _Score = Union[float, Tuple[float, float]]
@ -45,39 +42,37 @@ _Model = Any # real type is Union[Booster, CVPack]; need more work
# pylint: disable=unused-argument # pylint: disable=unused-argument
class TrainingCallback(ABC): class TrainingCallback(ABC):
'''Interface for training callback. """Interface for training callback.
.. versionadded:: 1.3.0 .. versionadded:: 1.3.0
''' """
EvalsLog = Dict[str, Dict[str, _ScoreList]] EvalsLog = Dict[str, Dict[str, _ScoreList]] # pylint: disable=invalid-name
def __init__(self) -> None: def __init__(self) -> None:
pass pass
def before_training(self, model: _Model) -> _Model: def before_training(self, model: _Model) -> _Model:
'''Run before training starts.''' """Run before training starts."""
return model return model
def after_training(self, model: _Model) -> _Model: def after_training(self, model: _Model) -> _Model:
'''Run after training is finished.''' """Run after training is finished."""
return model return model
def before_iteration(self, model: _Model, epoch: int, evals_log: EvalsLog) -> bool: def before_iteration(self, model: _Model, epoch: int, evals_log: EvalsLog) -> bool:
'''Run before each iteration. Return True when training should stop.''' """Run before each iteration. Return True when training should stop."""
return False return False
def after_iteration(self, model: _Model, epoch: int, evals_log: EvalsLog) -> bool: def after_iteration(self, model: _Model, epoch: int, evals_log: EvalsLog) -> bool:
'''Run after each iteration. Return True when training should stop.''' """Run after each iteration. Return True when training should stop."""
return False return False
def _aggcv(rlist: List[str]) -> List[Tuple[str, float, float]]: def _aggcv(rlist: List[str]) -> List[Tuple[str, float, float]]:
# pylint: disable=invalid-name, too-many-locals # pylint: disable=invalid-name, too-many-locals
"""Aggregate cross-validation results. """Aggregate cross-validation results."""
"""
cvmap: Dict[Tuple[int, str], List[float]] = {} cvmap: Dict[Tuple[int, str], List[float]] = {}
idx = rlist[0].split()[0] idx = rlist[0].split()[0]
for line in rlist: for line in rlist:
@ -86,7 +81,7 @@ def _aggcv(rlist: List[str]) -> List[Tuple[str, float, float]]:
for metric_idx, it in enumerate(arr[1:]): for metric_idx, it in enumerate(arr[1:]):
if not isinstance(it, str): if not isinstance(it, str):
it = it.decode() it = it.decode()
k, v = it.split(':') k, v = it.split(":")
if (metric_idx, k) not in cvmap: if (metric_idx, k) not in cvmap:
cvmap[(metric_idx, k)] = [] cvmap[(metric_idx, k)] = []
cvmap[(metric_idx, k)].append(float(v)) cvmap[(metric_idx, k)].append(float(v))
@ -106,44 +101,45 @@ _ART = TypeVar("_ART")
def _allreduce_metric(score: _ART) -> _ART: def _allreduce_metric(score: _ART) -> _ART:
'''Helper function for computing customized metric in distributed """Helper function for computing customized metric in distributed
environment. Not strictly correct as many functions don't use mean value environment. Not strictly correct as many functions don't use mean value
as final result. as final result.
''' """
world = collective.get_world_size() world = collective.get_world_size()
assert world != 0 assert world != 0
if world == 1: if world == 1:
return score return score
if isinstance(score, tuple): # has mean and stdv if isinstance(score, tuple): # has mean and stdv
raise ValueError( raise ValueError(
'xgboost.cv function should not be used in distributed environment.') "xgboost.cv function should not be used in distributed environment."
)
arr = numpy.array([score]) arr = numpy.array([score])
arr = collective.allreduce(arr, collective.Op.SUM) / world arr = collective.allreduce(arr, collective.Op.SUM) / world
return arr[0] return arr[0]
class CallbackContainer: class CallbackContainer:
'''A special internal callback for invoking a list of other callbacks. """A special internal callback for invoking a list of other callbacks.
.. versionadded:: 1.3.0 .. versionadded:: 1.3.0
''' """
EvalsLog = TrainingCallback.EvalsLog
def __init__( def __init__(
self, self,
callbacks: Sequence[TrainingCallback], callbacks: Sequence[TrainingCallback],
metric: Optional[Callable] = None, metric: Optional[Callable] = None,
output_margin: bool = True, output_margin: bool = True,
is_cv: bool = False is_cv: bool = False,
) -> None: ) -> None:
self.callbacks = set(callbacks) self.callbacks = set(callbacks)
if metric is not None: if metric is not None:
msg = 'metric must be callable object for monitoring. For ' + \ msg = (
'builtin metrics, passing them in training parameter' + \ "metric must be callable object for monitoring. For "
' will invoke monitor automatically.' + "builtin metrics, passing them in training parameter"
+ " will invoke monitor automatically."
)
assert callable(metric), msg assert callable(metric), msg
self.metric = metric self.metric = metric
self.history: TrainingCallback.EvalsLog = collections.OrderedDict() self.history: TrainingCallback.EvalsLog = collections.OrderedDict()
@ -154,10 +150,10 @@ class CallbackContainer:
self.aggregated_cv = None self.aggregated_cv = None
def before_training(self, model: _Model) -> _Model: def before_training(self, model: _Model) -> _Model:
'''Function called before training.''' """Function called before training."""
for c in self.callbacks: for c in self.callbacks:
model = c.before_training(model=model) model = c.before_training(model=model)
msg = 'before_training should return the model' msg = "before_training should return the model"
if self.is_cv: if self.is_cv:
assert isinstance(model.cvfolds, list), msg assert isinstance(model.cvfolds, list), msg
else: else:
@ -165,10 +161,10 @@ class CallbackContainer:
return model return model
def after_training(self, model: _Model) -> _Model: def after_training(self, model: _Model) -> _Model:
'''Function called after training.''' """Function called after training."""
for c in self.callbacks: for c in self.callbacks:
model = c.after_training(model=model) model = c.after_training(model=model)
msg = 'after_training should return the model' msg = "after_training should return the model"
if self.is_cv: if self.is_cv:
assert isinstance(model.cvfolds, list), msg assert isinstance(model.cvfolds, list), msg
else: else:
@ -176,9 +172,9 @@ class CallbackContainer:
if not self.is_cv: if not self.is_cv:
num_parallel_tree, _ = _get_booster_layer_trees(model) num_parallel_tree, _ = _get_booster_layer_trees(model)
if model.attr('best_score') is not None: if model.attr("best_score") is not None:
model.best_score = float(cast(str, model.attr('best_score'))) model.best_score = float(cast(str, model.attr("best_score")))
model.best_iteration = int(cast(str, model.attr('best_iteration'))) model.best_iteration = int(cast(str, model.attr("best_iteration")))
# num_class is handled internally # num_class is handled internally
model.set_attr( model.set_attr(
best_ntree_limit=str((model.best_iteration + 1) * num_parallel_tree) best_ntree_limit=str((model.best_iteration + 1) * num_parallel_tree)
@ -195,16 +191,21 @@ class CallbackContainer:
return model return model
def before_iteration( def before_iteration(
self, model: _Model, epoch: int, dtrain: DMatrix, evals: Optional[List[Tuple[DMatrix, str]]] self,
model: _Model,
epoch: int,
dtrain: DMatrix,
evals: Optional[List[Tuple[DMatrix, str]]],
) -> bool: ) -> bool:
'''Function called before training iteration.''' """Function called before training iteration."""
return any(c.before_iteration(model, epoch, self.history) return any(
for c in self.callbacks) c.before_iteration(model, epoch, self.history) for c in self.callbacks
)
def _update_history( def _update_history(
self, self,
score: Union[List[Tuple[str, float]], List[Tuple[str, float, float]]], score: Union[List[Tuple[str, float]], List[Tuple[str, float, float]]],
epoch: int epoch: int,
) -> None: ) -> None:
for d in score: for d in score:
name: str = d[0] name: str = d[0]
@ -214,9 +215,9 @@ class CallbackContainer:
x: _Score = (s, std) x: _Score = (s, std)
else: else:
x = s x = s
splited_names = name.split('-') splited_names = name.split("-")
data_name = splited_names[0] data_name = splited_names[0]
metric_name = '-'.join(splited_names[1:]) metric_name = "-".join(splited_names[1:])
x = _allreduce_metric(x) x = _allreduce_metric(x)
if data_name not in self.history: if data_name not in self.history:
self.history[data_name] = collections.OrderedDict() self.history[data_name] = collections.OrderedDict()
@ -238,7 +239,7 @@ class CallbackContainer:
dtrain: DMatrix, dtrain: DMatrix,
evals: Optional[List[Tuple[DMatrix, str]]], evals: Optional[List[Tuple[DMatrix, str]]],
) -> bool: ) -> bool:
'''Function called after training iteration.''' """Function called after training iteration."""
if self.is_cv: if self.is_cv:
scores = model.eval(epoch, self.metric, self._output_margin) scores = model.eval(epoch, self.metric, self._output_margin)
scores = _aggcv(scores) scores = _aggcv(scores)
@ -247,16 +248,15 @@ class CallbackContainer:
else: else:
evals = [] if evals is None else evals evals = [] if evals is None else evals
for _, name in evals: for _, name in evals:
assert name.find('-') == -1, 'Dataset name should not contain `-`' assert name.find("-") == -1, "Dataset name should not contain `-`"
score: str = model.eval_set(evals, epoch, self.metric, self._output_margin) score: str = model.eval_set(evals, epoch, self.metric, self._output_margin)
splited = score.split()[1:] # into datasets splited = score.split()[1:] # into datasets
# split up `test-error:0.1234` # split up `test-error:0.1234`
metric_score_str = [tuple(s.split(':')) for s in splited] metric_score_str = [tuple(s.split(":")) for s in splited]
# convert to float # convert to float
metric_score = [(n, float(s)) for n, s in metric_score_str] metric_score = [(n, float(s)) for n, s in metric_score_str]
self._update_history(metric_score, epoch) self._update_history(metric_score, epoch)
ret = any(c.after_iteration(model, epoch, self.history) ret = any(c.after_iteration(model, epoch, self.history) for c in self.callbacks)
for c in self.callbacks)
return ret return ret
@ -320,7 +320,6 @@ class EarlyStopping(TrainingCallback):
.. code-block:: python .. code-block:: python
clf = xgboost.XGBClassifier(tree_method="gpu_hist")
es = xgboost.callback.EarlyStopping( es = xgboost.callback.EarlyStopping(
rounds=2, rounds=2,
abs_tol=1e-3, abs_tol=1e-3,
@ -329,10 +328,13 @@ class EarlyStopping(TrainingCallback):
data_name="validation_0", data_name="validation_0",
metric_name="mlogloss", metric_name="mlogloss",
) )
clf = xgboost.XGBClassifier(tree_method="gpu_hist", callbacks=[es])
X, y = load_digits(return_X_y=True) X, y = load_digits(return_X_y=True)
clf.fit(X, y, eval_set=[(X, y)], callbacks=[es]) clf.fit(X, y, eval_set=[(X, y)])
""" """
# pylint: disable=too-many-arguments
def __init__( def __init__(
self, self,
rounds: int, rounds: int,
@ -340,7 +342,7 @@ class EarlyStopping(TrainingCallback):
data_name: Optional[str] = None, data_name: Optional[str] = None,
maximize: Optional[bool] = None, maximize: Optional[bool] = None,
save_best: Optional[bool] = False, save_best: Optional[bool] = False,
min_delta: float = 0.0 min_delta: float = 0.0,
) -> None: ) -> None:
self.data = data_name self.data = data_name
self.metric_name = metric_name self.metric_name = metric_name
@ -364,9 +366,9 @@ class EarlyStopping(TrainingCallback):
def _update_rounds( def _update_rounds(
self, score: _Score, name: str, metric: str, model: _Model, epoch: int self, score: _Score, name: str, metric: str, model: _Model, epoch: int
) -> bool: ) -> bool:
def get_s(x: _Score) -> float: def get_s(value: _Score) -> float:
"""get score if it's cross validation history.""" """get score if it's cross validation history."""
return x[0] if isinstance(x, tuple) else x return value[0] if isinstance(value, tuple) else value
def maximize(new: _Score, best: _Score) -> bool: def maximize(new: _Score, best: _Score) -> bool:
"""New score should be greater than the old one.""" """New score should be greater than the old one."""
@ -379,9 +381,17 @@ class EarlyStopping(TrainingCallback):
if self.maximize is None: if self.maximize is None:
# Just to be compatibility with old behavior before 1.3. We should let # Just to be compatibility with old behavior before 1.3. We should let
# user to decide. # user to decide.
maximize_metrics = ('auc', 'aucpr', 'map', 'ndcg', 'auc@', maximize_metrics = (
'aucpr@', 'map@', 'ndcg@') "auc",
if metric != 'mape' and any(metric.startswith(x) for x in maximize_metrics): "aucpr",
"map",
"ndcg",
"auc@",
"aucpr@",
"map@",
"ndcg@",
)
if metric != "mape" and any(metric.startswith(x) for x in maximize_metrics):
self.maximize = True self.maximize = True
else: else:
self.maximize = False self.maximize = False
@ -414,18 +424,19 @@ class EarlyStopping(TrainingCallback):
return True return True
return False return False
def after_iteration(self, model: _Model, epoch: int, def after_iteration(
evals_log: TrainingCallback.EvalsLog) -> bool: self, model: _Model, epoch: int, evals_log: TrainingCallback.EvalsLog
) -> bool:
epoch += self.starting_round # training continuation epoch += self.starting_round # training continuation
msg = 'Must have at least 1 validation dataset for early stopping.' msg = "Must have at least 1 validation dataset for early stopping."
assert len(evals_log.keys()) >= 1, msg assert len(evals_log.keys()) >= 1, msg
data_name = '' data_name = ""
if self.data: if self.data:
for d, _ in evals_log.items(): for d, _ in evals_log.items():
if d == self.data: if d == self.data:
data_name = d data_name = d
if not data_name: if not data_name:
raise ValueError('No dataset named:', self.data) raise ValueError("No dataset named:", self.data)
else: else:
# Use the last one as default. # Use the last one as default.
data_name = list(evals_log.keys())[-1] data_name = list(evals_log.keys())[-1]
@ -454,7 +465,7 @@ class EarlyStopping(TrainingCallback):
class EvaluationMonitor(TrainingCallback): class EvaluationMonitor(TrainingCallback):
'''Print the evaluation result at each iteration. """Print the evaluation result at each iteration.
.. versionadded:: 1.3.0 .. versionadded:: 1.3.0
@ -469,7 +480,8 @@ class EvaluationMonitor(TrainingCallback):
How many epoches between printing. How many epoches between printing.
show_stdv : show_stdv :
Used in cv to show standard deviation. Users should not specify it. Used in cv to show standard deviation. Users should not specify it.
''' """
def __init__(self, rank: int = 0, period: int = 1, show_stdv: bool = False) -> None: def __init__(self, rank: int = 0, period: int = 1, show_stdv: bool = False) -> None:
self.printer_rank = rank self.printer_rank = rank
self.show_stdv = show_stdv self.show_stdv = show_stdv
@ -488,12 +500,13 @@ class EvaluationMonitor(TrainingCallback):
msg = f"\t{data + '-' + metric}:{score:.5f}" msg = f"\t{data + '-' + metric}:{score:.5f}"
return msg return msg
def after_iteration(self, model: _Model, epoch: int, def after_iteration(
evals_log: TrainingCallback.EvalsLog) -> bool: self, model: _Model, epoch: int, evals_log: TrainingCallback.EvalsLog
) -> bool:
if not evals_log: if not evals_log:
return False return False
msg: str = f'[{epoch}]' msg: str = f"[{epoch}]"
if collective.get_rank() == self.printer_rank: if collective.get_rank() == self.printer_rank:
for data, metric in evals_log.items(): for data, metric in evals_log.items():
for metric_name, log in metric.items(): for metric_name, log in metric.items():
@ -504,7 +517,7 @@ class EvaluationMonitor(TrainingCallback):
else: else:
score = log[-1] score = log[-1]
msg += self._fmt_metric(data, metric_name, score, stdv) msg += self._fmt_metric(data, metric_name, score, stdv)
msg += '\n' msg += "\n"
if (epoch % self.period) == 0 or self.period == 1: if (epoch % self.period) == 0 or self.period == 1:
collective.communicator_print(msg) collective.communicator_print(msg)
@ -521,7 +534,7 @@ class EvaluationMonitor(TrainingCallback):
class TrainingCheckPoint(TrainingCallback): class TrainingCheckPoint(TrainingCallback):
'''Checkpointing operation. """Checkpointing operation.
.. versionadded:: 1.3.0 .. versionadded:: 1.3.0
@ -540,13 +553,14 @@ class TrainingCheckPoint(TrainingCallback):
Interval of checkpointing. Checkpointing is slow so setting a larger number can Interval of checkpointing. Checkpointing is slow so setting a larger number can
reduce performance hit. reduce performance hit.
''' """
def __init__( def __init__(
self, self,
directory: Union[str, os.PathLike], directory: Union[str, os.PathLike],
name: str = 'model', name: str = "model",
as_pickle: bool = False, as_pickle: bool = False,
iterations: int = 100 iterations: int = 100,
) -> None: ) -> None:
self._path = os.fspath(directory) self._path = os.fspath(directory)
self._name = name self._name = name
@ -555,15 +569,21 @@ class TrainingCheckPoint(TrainingCallback):
self._epoch = 0 self._epoch = 0
super().__init__() super().__init__()
def after_iteration(self, model: _Model, epoch: int, def after_iteration(
evals_log: TrainingCallback.EvalsLog) -> bool: self, model: _Model, epoch: int, evals_log: TrainingCallback.EvalsLog
) -> bool:
if self._epoch == self._iterations: if self._epoch == self._iterations:
path = os.path.join(self._path, self._name + '_' + str(epoch) + path = os.path.join(
('.pkl' if self._as_pickle else '.json')) self._path,
self._name
+ "_"
+ str(epoch)
+ (".pkl" if self._as_pickle else ".json"),
)
self._epoch = 0 self._epoch = 0
if collective.get_rank() == 0: if collective.get_rank() == 0:
if self._as_pickle: if self._as_pickle:
with open(path, 'wb') as fd: with open(path, "wb") as fd:
pickle.dump(model, fd) pickle.dump(model, fd)
else: else:
model.save_model(path) model.save_model(path)

View File

@ -288,9 +288,9 @@ __model_doc = f"""
.. note:: .. note::
This parameter replaces `eval_metric` in :py:meth:`fit` method. The old one This parameter replaces `eval_metric` in :py:meth:`fit` method. The old
receives un-transformed prediction regardless of whether custom objective is one receives un-transformed prediction regardless of whether custom
being used. objective is being used.
.. code-block:: python .. code-block:: python
@ -340,7 +340,8 @@ __model_doc = f"""
for params in parameters_grid: for params in parameters_grid:
# be sure to (re)initialize the callbacks before each run # be sure to (re)initialize the callbacks before each run
callbacks = [xgb.callback.LearningRateScheduler(custom_rates)] callbacks = [xgb.callback.LearningRateScheduler(custom_rates)]
xgboost.train(params, Xy, callbacks=callbacks) reg = xgboost.XGBRegressor(**params, callbacks=callbacks)
reg.fit(X, y)
kwargs : dict, optional kwargs : dict, optional
Keyword arguments for XGBoost Booster object. Full documentation of parameters Keyword arguments for XGBoost Booster object. Full documentation of parameters

View File

@ -28,9 +28,7 @@ from .core import (
_CVFolds = Sequence["CVPack"] _CVFolds = Sequence["CVPack"]
def _assert_new_callback( def _assert_new_callback(callbacks: Optional[Sequence[TrainingCallback]]) -> None:
callbacks: Optional[Sequence[TrainingCallback]]
) -> None:
is_new_callback: bool = not callbacks or all( is_new_callback: bool = not callbacks or all(
isinstance(c, TrainingCallback) for c in callbacks isinstance(c, TrainingCallback) for c in callbacks
) )
@ -45,7 +43,9 @@ def _configure_custom_metric(
feval: Optional[Metric], custom_metric: Optional[Metric] feval: Optional[Metric], custom_metric: Optional[Metric]
) -> Optional[Metric]: ) -> Optional[Metric]:
if feval is not None: if feval is not None:
link = "https://xgboost.readthedocs.io/en/latest/tutorials/custom_metric_obj.html" link = (
"https://xgboost.readthedocs.io/en/latest/tutorials/custom_metric_obj.html"
)
warnings.warn( warnings.warn(
"`feval` is deprecated, use `custom_metric` instead. They have " "`feval` is deprecated, use `custom_metric` instead. They have "
"different behavior when custom objective is also used." "different behavior when custom objective is also used."
@ -175,9 +175,7 @@ def train(
verbose_eval = 1 if verbose_eval is True else verbose_eval verbose_eval = 1 if verbose_eval is True else verbose_eval
callbacks.append(EvaluationMonitor(period=verbose_eval)) callbacks.append(EvaluationMonitor(period=verbose_eval))
if early_stopping_rounds: if early_stopping_rounds:
callbacks.append( callbacks.append(EarlyStopping(rounds=early_stopping_rounds, maximize=maximize))
EarlyStopping(rounds=early_stopping_rounds, maximize=maximize)
)
cb_container = CallbackContainer( cb_container = CallbackContainer(
callbacks, callbacks,
metric=metric_fn, metric=metric_fn,
@ -536,13 +534,9 @@ def cv(
if verbose_eval: if verbose_eval:
verbose_eval = 1 if verbose_eval is True else verbose_eval verbose_eval = 1 if verbose_eval is True else verbose_eval
callbacks.append( callbacks.append(EvaluationMonitor(period=verbose_eval, show_stdv=show_stdv))
EvaluationMonitor(period=verbose_eval, show_stdv=show_stdv)
)
if early_stopping_rounds: if early_stopping_rounds:
callbacks.append( callbacks.append(EarlyStopping(rounds=early_stopping_rounds, maximize=maximize))
EarlyStopping(rounds=early_stopping_rounds, maximize=maximize)
)
callbacks_container = CallbackContainer( callbacks_container = CallbackContainer(
callbacks, callbacks,
metric=metric_fn, metric=metric_fn,

View File

@ -134,6 +134,7 @@ def main(args: argparse.Namespace) -> None:
# core # core
"python-package/xgboost/__init__.py", "python-package/xgboost/__init__.py",
"python-package/xgboost/_typing.py", "python-package/xgboost/_typing.py",
"python-package/xgboost/callback.py",
"python-package/xgboost/compat.py", "python-package/xgboost/compat.py",
"python-package/xgboost/config.py", "python-package/xgboost/config.py",
"python-package/xgboost/dask.py", "python-package/xgboost/dask.py",