Cleanup the callback module. (#8702)
- Cleanup pylint markers. - run formatter. - Update examples of using callback.
This commit is contained in:
parent
34eee56256
commit
9fb12b20a4
@ -1,6 +1,3 @@
|
|||||||
# coding: utf-8
|
|
||||||
# pylint: disable=invalid-name, too-many-statements
|
|
||||||
# pylint: disable=too-many-arguments
|
|
||||||
"""Callback library containing training routines. See :doc:`Callback Functions
|
"""Callback library containing training routines. See :doc:`Callback Functions
|
||||||
</python/callbacks>` for a quick introduction.
|
</python/callbacks>` for a quick introduction.
|
||||||
|
|
||||||
@ -34,7 +31,7 @@ __all__ = [
|
|||||||
"EarlyStopping",
|
"EarlyStopping",
|
||||||
"EvaluationMonitor",
|
"EvaluationMonitor",
|
||||||
"TrainingCheckPoint",
|
"TrainingCheckPoint",
|
||||||
"CallbackContainer"
|
"CallbackContainer",
|
||||||
]
|
]
|
||||||
|
|
||||||
_Score = Union[float, Tuple[float, float]]
|
_Score = Union[float, Tuple[float, float]]
|
||||||
@ -45,39 +42,37 @@ _Model = Any # real type is Union[Booster, CVPack]; need more work
|
|||||||
|
|
||||||
# pylint: disable=unused-argument
|
# pylint: disable=unused-argument
|
||||||
class TrainingCallback(ABC):
|
class TrainingCallback(ABC):
|
||||||
'''Interface for training callback.
|
"""Interface for training callback.
|
||||||
|
|
||||||
.. versionadded:: 1.3.0
|
.. versionadded:: 1.3.0
|
||||||
|
|
||||||
'''
|
"""
|
||||||
|
|
||||||
EvalsLog = Dict[str, Dict[str, _ScoreList]]
|
EvalsLog = Dict[str, Dict[str, _ScoreList]] # pylint: disable=invalid-name
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def before_training(self, model: _Model) -> _Model:
|
def before_training(self, model: _Model) -> _Model:
|
||||||
'''Run before training starts.'''
|
"""Run before training starts."""
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def after_training(self, model: _Model) -> _Model:
|
def after_training(self, model: _Model) -> _Model:
|
||||||
'''Run after training is finished.'''
|
"""Run after training is finished."""
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def before_iteration(self, model: _Model, epoch: int, evals_log: EvalsLog) -> bool:
|
def before_iteration(self, model: _Model, epoch: int, evals_log: EvalsLog) -> bool:
|
||||||
'''Run before each iteration. Return True when training should stop.'''
|
"""Run before each iteration. Return True when training should stop."""
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def after_iteration(self, model: _Model, epoch: int, evals_log: EvalsLog) -> bool:
|
def after_iteration(self, model: _Model, epoch: int, evals_log: EvalsLog) -> bool:
|
||||||
'''Run after each iteration. Return True when training should stop.'''
|
"""Run after each iteration. Return True when training should stop."""
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def _aggcv(rlist: List[str]) -> List[Tuple[str, float, float]]:
|
def _aggcv(rlist: List[str]) -> List[Tuple[str, float, float]]:
|
||||||
# pylint: disable=invalid-name, too-many-locals
|
# pylint: disable=invalid-name, too-many-locals
|
||||||
"""Aggregate cross-validation results.
|
"""Aggregate cross-validation results."""
|
||||||
|
|
||||||
"""
|
|
||||||
cvmap: Dict[Tuple[int, str], List[float]] = {}
|
cvmap: Dict[Tuple[int, str], List[float]] = {}
|
||||||
idx = rlist[0].split()[0]
|
idx = rlist[0].split()[0]
|
||||||
for line in rlist:
|
for line in rlist:
|
||||||
@ -86,7 +81,7 @@ def _aggcv(rlist: List[str]) -> List[Tuple[str, float, float]]:
|
|||||||
for metric_idx, it in enumerate(arr[1:]):
|
for metric_idx, it in enumerate(arr[1:]):
|
||||||
if not isinstance(it, str):
|
if not isinstance(it, str):
|
||||||
it = it.decode()
|
it = it.decode()
|
||||||
k, v = it.split(':')
|
k, v = it.split(":")
|
||||||
if (metric_idx, k) not in cvmap:
|
if (metric_idx, k) not in cvmap:
|
||||||
cvmap[(metric_idx, k)] = []
|
cvmap[(metric_idx, k)] = []
|
||||||
cvmap[(metric_idx, k)].append(float(v))
|
cvmap[(metric_idx, k)].append(float(v))
|
||||||
@ -106,44 +101,45 @@ _ART = TypeVar("_ART")
|
|||||||
|
|
||||||
|
|
||||||
def _allreduce_metric(score: _ART) -> _ART:
|
def _allreduce_metric(score: _ART) -> _ART:
|
||||||
'''Helper function for computing customized metric in distributed
|
"""Helper function for computing customized metric in distributed
|
||||||
environment. Not strictly correct as many functions don't use mean value
|
environment. Not strictly correct as many functions don't use mean value
|
||||||
as final result.
|
as final result.
|
||||||
|
|
||||||
'''
|
"""
|
||||||
world = collective.get_world_size()
|
world = collective.get_world_size()
|
||||||
assert world != 0
|
assert world != 0
|
||||||
if world == 1:
|
if world == 1:
|
||||||
return score
|
return score
|
||||||
if isinstance(score, tuple): # has mean and stdv
|
if isinstance(score, tuple): # has mean and stdv
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'xgboost.cv function should not be used in distributed environment.')
|
"xgboost.cv function should not be used in distributed environment."
|
||||||
|
)
|
||||||
arr = numpy.array([score])
|
arr = numpy.array([score])
|
||||||
arr = collective.allreduce(arr, collective.Op.SUM) / world
|
arr = collective.allreduce(arr, collective.Op.SUM) / world
|
||||||
return arr[0]
|
return arr[0]
|
||||||
|
|
||||||
|
|
||||||
class CallbackContainer:
|
class CallbackContainer:
|
||||||
'''A special internal callback for invoking a list of other callbacks.
|
"""A special internal callback for invoking a list of other callbacks.
|
||||||
|
|
||||||
.. versionadded:: 1.3.0
|
.. versionadded:: 1.3.0
|
||||||
|
|
||||||
'''
|
"""
|
||||||
|
|
||||||
EvalsLog = TrainingCallback.EvalsLog
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
callbacks: Sequence[TrainingCallback],
|
callbacks: Sequence[TrainingCallback],
|
||||||
metric: Optional[Callable] = None,
|
metric: Optional[Callable] = None,
|
||||||
output_margin: bool = True,
|
output_margin: bool = True,
|
||||||
is_cv: bool = False
|
is_cv: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.callbacks = set(callbacks)
|
self.callbacks = set(callbacks)
|
||||||
if metric is not None:
|
if metric is not None:
|
||||||
msg = 'metric must be callable object for monitoring. For ' + \
|
msg = (
|
||||||
'builtin metrics, passing them in training parameter' + \
|
"metric must be callable object for monitoring. For "
|
||||||
' will invoke monitor automatically.'
|
+ "builtin metrics, passing them in training parameter"
|
||||||
|
+ " will invoke monitor automatically."
|
||||||
|
)
|
||||||
assert callable(metric), msg
|
assert callable(metric), msg
|
||||||
self.metric = metric
|
self.metric = metric
|
||||||
self.history: TrainingCallback.EvalsLog = collections.OrderedDict()
|
self.history: TrainingCallback.EvalsLog = collections.OrderedDict()
|
||||||
@ -154,10 +150,10 @@ class CallbackContainer:
|
|||||||
self.aggregated_cv = None
|
self.aggregated_cv = None
|
||||||
|
|
||||||
def before_training(self, model: _Model) -> _Model:
|
def before_training(self, model: _Model) -> _Model:
|
||||||
'''Function called before training.'''
|
"""Function called before training."""
|
||||||
for c in self.callbacks:
|
for c in self.callbacks:
|
||||||
model = c.before_training(model=model)
|
model = c.before_training(model=model)
|
||||||
msg = 'before_training should return the model'
|
msg = "before_training should return the model"
|
||||||
if self.is_cv:
|
if self.is_cv:
|
||||||
assert isinstance(model.cvfolds, list), msg
|
assert isinstance(model.cvfolds, list), msg
|
||||||
else:
|
else:
|
||||||
@ -165,10 +161,10 @@ class CallbackContainer:
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
def after_training(self, model: _Model) -> _Model:
|
def after_training(self, model: _Model) -> _Model:
|
||||||
'''Function called after training.'''
|
"""Function called after training."""
|
||||||
for c in self.callbacks:
|
for c in self.callbacks:
|
||||||
model = c.after_training(model=model)
|
model = c.after_training(model=model)
|
||||||
msg = 'after_training should return the model'
|
msg = "after_training should return the model"
|
||||||
if self.is_cv:
|
if self.is_cv:
|
||||||
assert isinstance(model.cvfolds, list), msg
|
assert isinstance(model.cvfolds, list), msg
|
||||||
else:
|
else:
|
||||||
@ -176,9 +172,9 @@ class CallbackContainer:
|
|||||||
|
|
||||||
if not self.is_cv:
|
if not self.is_cv:
|
||||||
num_parallel_tree, _ = _get_booster_layer_trees(model)
|
num_parallel_tree, _ = _get_booster_layer_trees(model)
|
||||||
if model.attr('best_score') is not None:
|
if model.attr("best_score") is not None:
|
||||||
model.best_score = float(cast(str, model.attr('best_score')))
|
model.best_score = float(cast(str, model.attr("best_score")))
|
||||||
model.best_iteration = int(cast(str, model.attr('best_iteration')))
|
model.best_iteration = int(cast(str, model.attr("best_iteration")))
|
||||||
# num_class is handled internally
|
# num_class is handled internally
|
||||||
model.set_attr(
|
model.set_attr(
|
||||||
best_ntree_limit=str((model.best_iteration + 1) * num_parallel_tree)
|
best_ntree_limit=str((model.best_iteration + 1) * num_parallel_tree)
|
||||||
@ -195,16 +191,21 @@ class CallbackContainer:
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
def before_iteration(
|
def before_iteration(
|
||||||
self, model: _Model, epoch: int, dtrain: DMatrix, evals: Optional[List[Tuple[DMatrix, str]]]
|
self,
|
||||||
|
model: _Model,
|
||||||
|
epoch: int,
|
||||||
|
dtrain: DMatrix,
|
||||||
|
evals: Optional[List[Tuple[DMatrix, str]]],
|
||||||
) -> bool:
|
) -> bool:
|
||||||
'''Function called before training iteration.'''
|
"""Function called before training iteration."""
|
||||||
return any(c.before_iteration(model, epoch, self.history)
|
return any(
|
||||||
for c in self.callbacks)
|
c.before_iteration(model, epoch, self.history) for c in self.callbacks
|
||||||
|
)
|
||||||
|
|
||||||
def _update_history(
|
def _update_history(
|
||||||
self,
|
self,
|
||||||
score: Union[List[Tuple[str, float]], List[Tuple[str, float, float]]],
|
score: Union[List[Tuple[str, float]], List[Tuple[str, float, float]]],
|
||||||
epoch: int
|
epoch: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
for d in score:
|
for d in score:
|
||||||
name: str = d[0]
|
name: str = d[0]
|
||||||
@ -214,9 +215,9 @@ class CallbackContainer:
|
|||||||
x: _Score = (s, std)
|
x: _Score = (s, std)
|
||||||
else:
|
else:
|
||||||
x = s
|
x = s
|
||||||
splited_names = name.split('-')
|
splited_names = name.split("-")
|
||||||
data_name = splited_names[0]
|
data_name = splited_names[0]
|
||||||
metric_name = '-'.join(splited_names[1:])
|
metric_name = "-".join(splited_names[1:])
|
||||||
x = _allreduce_metric(x)
|
x = _allreduce_metric(x)
|
||||||
if data_name not in self.history:
|
if data_name not in self.history:
|
||||||
self.history[data_name] = collections.OrderedDict()
|
self.history[data_name] = collections.OrderedDict()
|
||||||
@ -238,7 +239,7 @@ class CallbackContainer:
|
|||||||
dtrain: DMatrix,
|
dtrain: DMatrix,
|
||||||
evals: Optional[List[Tuple[DMatrix, str]]],
|
evals: Optional[List[Tuple[DMatrix, str]]],
|
||||||
) -> bool:
|
) -> bool:
|
||||||
'''Function called after training iteration.'''
|
"""Function called after training iteration."""
|
||||||
if self.is_cv:
|
if self.is_cv:
|
||||||
scores = model.eval(epoch, self.metric, self._output_margin)
|
scores = model.eval(epoch, self.metric, self._output_margin)
|
||||||
scores = _aggcv(scores)
|
scores = _aggcv(scores)
|
||||||
@ -247,16 +248,15 @@ class CallbackContainer:
|
|||||||
else:
|
else:
|
||||||
evals = [] if evals is None else evals
|
evals = [] if evals is None else evals
|
||||||
for _, name in evals:
|
for _, name in evals:
|
||||||
assert name.find('-') == -1, 'Dataset name should not contain `-`'
|
assert name.find("-") == -1, "Dataset name should not contain `-`"
|
||||||
score: str = model.eval_set(evals, epoch, self.metric, self._output_margin)
|
score: str = model.eval_set(evals, epoch, self.metric, self._output_margin)
|
||||||
splited = score.split()[1:] # into datasets
|
splited = score.split()[1:] # into datasets
|
||||||
# split up `test-error:0.1234`
|
# split up `test-error:0.1234`
|
||||||
metric_score_str = [tuple(s.split(':')) for s in splited]
|
metric_score_str = [tuple(s.split(":")) for s in splited]
|
||||||
# convert to float
|
# convert to float
|
||||||
metric_score = [(n, float(s)) for n, s in metric_score_str]
|
metric_score = [(n, float(s)) for n, s in metric_score_str]
|
||||||
self._update_history(metric_score, epoch)
|
self._update_history(metric_score, epoch)
|
||||||
ret = any(c.after_iteration(model, epoch, self.history)
|
ret = any(c.after_iteration(model, epoch, self.history) for c in self.callbacks)
|
||||||
for c in self.callbacks)
|
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
@ -320,7 +320,6 @@ class EarlyStopping(TrainingCallback):
|
|||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
clf = xgboost.XGBClassifier(tree_method="gpu_hist")
|
|
||||||
es = xgboost.callback.EarlyStopping(
|
es = xgboost.callback.EarlyStopping(
|
||||||
rounds=2,
|
rounds=2,
|
||||||
abs_tol=1e-3,
|
abs_tol=1e-3,
|
||||||
@ -329,10 +328,13 @@ class EarlyStopping(TrainingCallback):
|
|||||||
data_name="validation_0",
|
data_name="validation_0",
|
||||||
metric_name="mlogloss",
|
metric_name="mlogloss",
|
||||||
)
|
)
|
||||||
|
clf = xgboost.XGBClassifier(tree_method="gpu_hist", callbacks=[es])
|
||||||
|
|
||||||
X, y = load_digits(return_X_y=True)
|
X, y = load_digits(return_X_y=True)
|
||||||
clf.fit(X, y, eval_set=[(X, y)], callbacks=[es])
|
clf.fit(X, y, eval_set=[(X, y)])
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# pylint: disable=too-many-arguments
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
rounds: int,
|
rounds: int,
|
||||||
@ -340,7 +342,7 @@ class EarlyStopping(TrainingCallback):
|
|||||||
data_name: Optional[str] = None,
|
data_name: Optional[str] = None,
|
||||||
maximize: Optional[bool] = None,
|
maximize: Optional[bool] = None,
|
||||||
save_best: Optional[bool] = False,
|
save_best: Optional[bool] = False,
|
||||||
min_delta: float = 0.0
|
min_delta: float = 0.0,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.data = data_name
|
self.data = data_name
|
||||||
self.metric_name = metric_name
|
self.metric_name = metric_name
|
||||||
@ -364,9 +366,9 @@ class EarlyStopping(TrainingCallback):
|
|||||||
def _update_rounds(
|
def _update_rounds(
|
||||||
self, score: _Score, name: str, metric: str, model: _Model, epoch: int
|
self, score: _Score, name: str, metric: str, model: _Model, epoch: int
|
||||||
) -> bool:
|
) -> bool:
|
||||||
def get_s(x: _Score) -> float:
|
def get_s(value: _Score) -> float:
|
||||||
"""get score if it's cross validation history."""
|
"""get score if it's cross validation history."""
|
||||||
return x[0] if isinstance(x, tuple) else x
|
return value[0] if isinstance(value, tuple) else value
|
||||||
|
|
||||||
def maximize(new: _Score, best: _Score) -> bool:
|
def maximize(new: _Score, best: _Score) -> bool:
|
||||||
"""New score should be greater than the old one."""
|
"""New score should be greater than the old one."""
|
||||||
@ -379,9 +381,17 @@ class EarlyStopping(TrainingCallback):
|
|||||||
if self.maximize is None:
|
if self.maximize is None:
|
||||||
# Just to be compatibility with old behavior before 1.3. We should let
|
# Just to be compatibility with old behavior before 1.3. We should let
|
||||||
# user to decide.
|
# user to decide.
|
||||||
maximize_metrics = ('auc', 'aucpr', 'map', 'ndcg', 'auc@',
|
maximize_metrics = (
|
||||||
'aucpr@', 'map@', 'ndcg@')
|
"auc",
|
||||||
if metric != 'mape' and any(metric.startswith(x) for x in maximize_metrics):
|
"aucpr",
|
||||||
|
"map",
|
||||||
|
"ndcg",
|
||||||
|
"auc@",
|
||||||
|
"aucpr@",
|
||||||
|
"map@",
|
||||||
|
"ndcg@",
|
||||||
|
)
|
||||||
|
if metric != "mape" and any(metric.startswith(x) for x in maximize_metrics):
|
||||||
self.maximize = True
|
self.maximize = True
|
||||||
else:
|
else:
|
||||||
self.maximize = False
|
self.maximize = False
|
||||||
@ -414,18 +424,19 @@ class EarlyStopping(TrainingCallback):
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def after_iteration(self, model: _Model, epoch: int,
|
def after_iteration(
|
||||||
evals_log: TrainingCallback.EvalsLog) -> bool:
|
self, model: _Model, epoch: int, evals_log: TrainingCallback.EvalsLog
|
||||||
|
) -> bool:
|
||||||
epoch += self.starting_round # training continuation
|
epoch += self.starting_round # training continuation
|
||||||
msg = 'Must have at least 1 validation dataset for early stopping.'
|
msg = "Must have at least 1 validation dataset for early stopping."
|
||||||
assert len(evals_log.keys()) >= 1, msg
|
assert len(evals_log.keys()) >= 1, msg
|
||||||
data_name = ''
|
data_name = ""
|
||||||
if self.data:
|
if self.data:
|
||||||
for d, _ in evals_log.items():
|
for d, _ in evals_log.items():
|
||||||
if d == self.data:
|
if d == self.data:
|
||||||
data_name = d
|
data_name = d
|
||||||
if not data_name:
|
if not data_name:
|
||||||
raise ValueError('No dataset named:', self.data)
|
raise ValueError("No dataset named:", self.data)
|
||||||
else:
|
else:
|
||||||
# Use the last one as default.
|
# Use the last one as default.
|
||||||
data_name = list(evals_log.keys())[-1]
|
data_name = list(evals_log.keys())[-1]
|
||||||
@ -454,7 +465,7 @@ class EarlyStopping(TrainingCallback):
|
|||||||
|
|
||||||
|
|
||||||
class EvaluationMonitor(TrainingCallback):
|
class EvaluationMonitor(TrainingCallback):
|
||||||
'''Print the evaluation result at each iteration.
|
"""Print the evaluation result at each iteration.
|
||||||
|
|
||||||
.. versionadded:: 1.3.0
|
.. versionadded:: 1.3.0
|
||||||
|
|
||||||
@ -469,7 +480,8 @@ class EvaluationMonitor(TrainingCallback):
|
|||||||
How many epoches between printing.
|
How many epoches between printing.
|
||||||
show_stdv :
|
show_stdv :
|
||||||
Used in cv to show standard deviation. Users should not specify it.
|
Used in cv to show standard deviation. Users should not specify it.
|
||||||
'''
|
"""
|
||||||
|
|
||||||
def __init__(self, rank: int = 0, period: int = 1, show_stdv: bool = False) -> None:
|
def __init__(self, rank: int = 0, period: int = 1, show_stdv: bool = False) -> None:
|
||||||
self.printer_rank = rank
|
self.printer_rank = rank
|
||||||
self.show_stdv = show_stdv
|
self.show_stdv = show_stdv
|
||||||
@ -488,12 +500,13 @@ class EvaluationMonitor(TrainingCallback):
|
|||||||
msg = f"\t{data + '-' + metric}:{score:.5f}"
|
msg = f"\t{data + '-' + metric}:{score:.5f}"
|
||||||
return msg
|
return msg
|
||||||
|
|
||||||
def after_iteration(self, model: _Model, epoch: int,
|
def after_iteration(
|
||||||
evals_log: TrainingCallback.EvalsLog) -> bool:
|
self, model: _Model, epoch: int, evals_log: TrainingCallback.EvalsLog
|
||||||
|
) -> bool:
|
||||||
if not evals_log:
|
if not evals_log:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
msg: str = f'[{epoch}]'
|
msg: str = f"[{epoch}]"
|
||||||
if collective.get_rank() == self.printer_rank:
|
if collective.get_rank() == self.printer_rank:
|
||||||
for data, metric in evals_log.items():
|
for data, metric in evals_log.items():
|
||||||
for metric_name, log in metric.items():
|
for metric_name, log in metric.items():
|
||||||
@ -504,7 +517,7 @@ class EvaluationMonitor(TrainingCallback):
|
|||||||
else:
|
else:
|
||||||
score = log[-1]
|
score = log[-1]
|
||||||
msg += self._fmt_metric(data, metric_name, score, stdv)
|
msg += self._fmt_metric(data, metric_name, score, stdv)
|
||||||
msg += '\n'
|
msg += "\n"
|
||||||
|
|
||||||
if (epoch % self.period) == 0 or self.period == 1:
|
if (epoch % self.period) == 0 or self.period == 1:
|
||||||
collective.communicator_print(msg)
|
collective.communicator_print(msg)
|
||||||
@ -521,7 +534,7 @@ class EvaluationMonitor(TrainingCallback):
|
|||||||
|
|
||||||
|
|
||||||
class TrainingCheckPoint(TrainingCallback):
|
class TrainingCheckPoint(TrainingCallback):
|
||||||
'''Checkpointing operation.
|
"""Checkpointing operation.
|
||||||
|
|
||||||
.. versionadded:: 1.3.0
|
.. versionadded:: 1.3.0
|
||||||
|
|
||||||
@ -540,13 +553,14 @@ class TrainingCheckPoint(TrainingCallback):
|
|||||||
Interval of checkpointing. Checkpointing is slow so setting a larger number can
|
Interval of checkpointing. Checkpointing is slow so setting a larger number can
|
||||||
reduce performance hit.
|
reduce performance hit.
|
||||||
|
|
||||||
'''
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
directory: Union[str, os.PathLike],
|
directory: Union[str, os.PathLike],
|
||||||
name: str = 'model',
|
name: str = "model",
|
||||||
as_pickle: bool = False,
|
as_pickle: bool = False,
|
||||||
iterations: int = 100
|
iterations: int = 100,
|
||||||
) -> None:
|
) -> None:
|
||||||
self._path = os.fspath(directory)
|
self._path = os.fspath(directory)
|
||||||
self._name = name
|
self._name = name
|
||||||
@ -555,15 +569,21 @@ class TrainingCheckPoint(TrainingCallback):
|
|||||||
self._epoch = 0
|
self._epoch = 0
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def after_iteration(self, model: _Model, epoch: int,
|
def after_iteration(
|
||||||
evals_log: TrainingCallback.EvalsLog) -> bool:
|
self, model: _Model, epoch: int, evals_log: TrainingCallback.EvalsLog
|
||||||
|
) -> bool:
|
||||||
if self._epoch == self._iterations:
|
if self._epoch == self._iterations:
|
||||||
path = os.path.join(self._path, self._name + '_' + str(epoch) +
|
path = os.path.join(
|
||||||
('.pkl' if self._as_pickle else '.json'))
|
self._path,
|
||||||
|
self._name
|
||||||
|
+ "_"
|
||||||
|
+ str(epoch)
|
||||||
|
+ (".pkl" if self._as_pickle else ".json"),
|
||||||
|
)
|
||||||
self._epoch = 0
|
self._epoch = 0
|
||||||
if collective.get_rank() == 0:
|
if collective.get_rank() == 0:
|
||||||
if self._as_pickle:
|
if self._as_pickle:
|
||||||
with open(path, 'wb') as fd:
|
with open(path, "wb") as fd:
|
||||||
pickle.dump(model, fd)
|
pickle.dump(model, fd)
|
||||||
else:
|
else:
|
||||||
model.save_model(path)
|
model.save_model(path)
|
||||||
|
|||||||
@ -288,9 +288,9 @@ __model_doc = f"""
|
|||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
|
|
||||||
This parameter replaces `eval_metric` in :py:meth:`fit` method. The old one
|
This parameter replaces `eval_metric` in :py:meth:`fit` method. The old
|
||||||
receives un-transformed prediction regardless of whether custom objective is
|
one receives un-transformed prediction regardless of whether custom
|
||||||
being used.
|
objective is being used.
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
@ -340,7 +340,8 @@ __model_doc = f"""
|
|||||||
for params in parameters_grid:
|
for params in parameters_grid:
|
||||||
# be sure to (re)initialize the callbacks before each run
|
# be sure to (re)initialize the callbacks before each run
|
||||||
callbacks = [xgb.callback.LearningRateScheduler(custom_rates)]
|
callbacks = [xgb.callback.LearningRateScheduler(custom_rates)]
|
||||||
xgboost.train(params, Xy, callbacks=callbacks)
|
reg = xgboost.XGBRegressor(**params, callbacks=callbacks)
|
||||||
|
reg.fit(X, y)
|
||||||
|
|
||||||
kwargs : dict, optional
|
kwargs : dict, optional
|
||||||
Keyword arguments for XGBoost Booster object. Full documentation of parameters
|
Keyword arguments for XGBoost Booster object. Full documentation of parameters
|
||||||
|
|||||||
@ -28,9 +28,7 @@ from .core import (
|
|||||||
_CVFolds = Sequence["CVPack"]
|
_CVFolds = Sequence["CVPack"]
|
||||||
|
|
||||||
|
|
||||||
def _assert_new_callback(
|
def _assert_new_callback(callbacks: Optional[Sequence[TrainingCallback]]) -> None:
|
||||||
callbacks: Optional[Sequence[TrainingCallback]]
|
|
||||||
) -> None:
|
|
||||||
is_new_callback: bool = not callbacks or all(
|
is_new_callback: bool = not callbacks or all(
|
||||||
isinstance(c, TrainingCallback) for c in callbacks
|
isinstance(c, TrainingCallback) for c in callbacks
|
||||||
)
|
)
|
||||||
@ -45,7 +43,9 @@ def _configure_custom_metric(
|
|||||||
feval: Optional[Metric], custom_metric: Optional[Metric]
|
feval: Optional[Metric], custom_metric: Optional[Metric]
|
||||||
) -> Optional[Metric]:
|
) -> Optional[Metric]:
|
||||||
if feval is not None:
|
if feval is not None:
|
||||||
link = "https://xgboost.readthedocs.io/en/latest/tutorials/custom_metric_obj.html"
|
link = (
|
||||||
|
"https://xgboost.readthedocs.io/en/latest/tutorials/custom_metric_obj.html"
|
||||||
|
)
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"`feval` is deprecated, use `custom_metric` instead. They have "
|
"`feval` is deprecated, use `custom_metric` instead. They have "
|
||||||
"different behavior when custom objective is also used."
|
"different behavior when custom objective is also used."
|
||||||
@ -175,9 +175,7 @@ def train(
|
|||||||
verbose_eval = 1 if verbose_eval is True else verbose_eval
|
verbose_eval = 1 if verbose_eval is True else verbose_eval
|
||||||
callbacks.append(EvaluationMonitor(period=verbose_eval))
|
callbacks.append(EvaluationMonitor(period=verbose_eval))
|
||||||
if early_stopping_rounds:
|
if early_stopping_rounds:
|
||||||
callbacks.append(
|
callbacks.append(EarlyStopping(rounds=early_stopping_rounds, maximize=maximize))
|
||||||
EarlyStopping(rounds=early_stopping_rounds, maximize=maximize)
|
|
||||||
)
|
|
||||||
cb_container = CallbackContainer(
|
cb_container = CallbackContainer(
|
||||||
callbacks,
|
callbacks,
|
||||||
metric=metric_fn,
|
metric=metric_fn,
|
||||||
@ -536,13 +534,9 @@ def cv(
|
|||||||
|
|
||||||
if verbose_eval:
|
if verbose_eval:
|
||||||
verbose_eval = 1 if verbose_eval is True else verbose_eval
|
verbose_eval = 1 if verbose_eval is True else verbose_eval
|
||||||
callbacks.append(
|
callbacks.append(EvaluationMonitor(period=verbose_eval, show_stdv=show_stdv))
|
||||||
EvaluationMonitor(period=verbose_eval, show_stdv=show_stdv)
|
|
||||||
)
|
|
||||||
if early_stopping_rounds:
|
if early_stopping_rounds:
|
||||||
callbacks.append(
|
callbacks.append(EarlyStopping(rounds=early_stopping_rounds, maximize=maximize))
|
||||||
EarlyStopping(rounds=early_stopping_rounds, maximize=maximize)
|
|
||||||
)
|
|
||||||
callbacks_container = CallbackContainer(
|
callbacks_container = CallbackContainer(
|
||||||
callbacks,
|
callbacks,
|
||||||
metric=metric_fn,
|
metric=metric_fn,
|
||||||
|
|||||||
@ -134,6 +134,7 @@ def main(args: argparse.Namespace) -> None:
|
|||||||
# core
|
# core
|
||||||
"python-package/xgboost/__init__.py",
|
"python-package/xgboost/__init__.py",
|
||||||
"python-package/xgboost/_typing.py",
|
"python-package/xgboost/_typing.py",
|
||||||
|
"python-package/xgboost/callback.py",
|
||||||
"python-package/xgboost/compat.py",
|
"python-package/xgboost/compat.py",
|
||||||
"python-package/xgboost/config.py",
|
"python-package/xgboost/config.py",
|
||||||
"python-package/xgboost/dask.py",
|
"python-package/xgboost/dask.py",
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user