Remove old callback deprecated in 1.3. (#7280)
This commit is contained in:
parent
578de9f762
commit
69d3b1b8b4
@ -10,262 +10,10 @@ from typing import Callable, List, Optional, Union, Dict, Tuple
|
|||||||
import numpy
|
import numpy
|
||||||
|
|
||||||
from . import rabit
|
from . import rabit
|
||||||
from .core import EarlyStopException, CallbackEnv, Booster, XGBoostError
|
from .core import Booster, XGBoostError
|
||||||
from .compat import STRING_TYPES
|
from .compat import STRING_TYPES
|
||||||
|
|
||||||
|
|
||||||
def _get_callback_context(env):
|
|
||||||
"""return whether the current callback context is cv or train"""
|
|
||||||
if env.model is not None and env.cvfolds is None:
|
|
||||||
context = 'train'
|
|
||||||
elif env.model is None and env.cvfolds is not None:
|
|
||||||
context = 'cv'
|
|
||||||
else:
|
|
||||||
raise ValueError("Unexpected input with both model and cvfolds.")
|
|
||||||
return context
|
|
||||||
|
|
||||||
|
|
||||||
def _fmt_metric(value, show_stdv=True):
|
|
||||||
"""format metric string"""
|
|
||||||
if len(value) == 2:
|
|
||||||
return f"{value[0]}:{value[1]:.5f}"
|
|
||||||
if len(value) == 3:
|
|
||||||
if show_stdv:
|
|
||||||
return f"{value[0]}:{value[1]:.5f}+{value[2]:.5f}"
|
|
||||||
return f"{value[0]}:{value[1]:.5f}"
|
|
||||||
raise ValueError("wrong metric value", value)
|
|
||||||
|
|
||||||
|
|
||||||
def print_evaluation(period=1, show_stdv=True):
|
|
||||||
"""Create a callback that print evaluation result.
|
|
||||||
|
|
||||||
We print the evaluation results every **period** iterations
|
|
||||||
and on the first and the last iterations.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
period : int
|
|
||||||
The period to log the evaluation results
|
|
||||||
|
|
||||||
show_stdv : bool, optional
|
|
||||||
Whether show stdv if provided
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
callback : function
|
|
||||||
A callback that print evaluation every period iterations.
|
|
||||||
"""
|
|
||||||
def callback(env):
|
|
||||||
"""internal function"""
|
|
||||||
if env.rank != 0 or (not env.evaluation_result_list) or period is False or period == 0:
|
|
||||||
return
|
|
||||||
i = env.iteration
|
|
||||||
if i % period == 0 or i + 1 == env.begin_iteration or i + 1 == env.end_iteration:
|
|
||||||
msg = '\t'.join([_fmt_metric(x, show_stdv) for x in env.evaluation_result_list])
|
|
||||||
rabit.tracker_print(f"{i}\t{msg}\n")
|
|
||||||
return callback
|
|
||||||
|
|
||||||
|
|
||||||
def record_evaluation(eval_result):
|
|
||||||
"""Create a call back that records the evaluation history into **eval_result**.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
eval_result : dict
|
|
||||||
A dictionary to store the evaluation results.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
callback : function
|
|
||||||
The requested callback function.
|
|
||||||
"""
|
|
||||||
if not isinstance(eval_result, dict):
|
|
||||||
raise TypeError('eval_result has to be a dictionary')
|
|
||||||
eval_result.clear()
|
|
||||||
|
|
||||||
def init(env):
|
|
||||||
"""internal function"""
|
|
||||||
for k, _ in env.evaluation_result_list:
|
|
||||||
pos = k.index('-')
|
|
||||||
key = k[:pos]
|
|
||||||
metric = k[pos + 1:]
|
|
||||||
if key not in eval_result:
|
|
||||||
eval_result[key] = {}
|
|
||||||
if metric not in eval_result[key]:
|
|
||||||
eval_result[key][metric] = []
|
|
||||||
|
|
||||||
def callback(env):
|
|
||||||
"""internal function"""
|
|
||||||
if not eval_result:
|
|
||||||
init(env)
|
|
||||||
for k, v in env.evaluation_result_list:
|
|
||||||
pos = k.index('-')
|
|
||||||
key = k[:pos]
|
|
||||||
metric = k[pos + 1:]
|
|
||||||
eval_result[key][metric].append(v)
|
|
||||||
return callback
|
|
||||||
|
|
||||||
|
|
||||||
def reset_learning_rate(learning_rates):
|
|
||||||
"""Reset learning rate after iteration 1
|
|
||||||
|
|
||||||
NOTE: the initial learning rate will still take in-effect on first iteration.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
learning_rates: list or function
|
|
||||||
List of learning rate for each boosting round
|
|
||||||
or a customized function that calculates eta in terms of
|
|
||||||
current number of round and the total number of boosting round (e.g.
|
|
||||||
yields learning rate decay)
|
|
||||||
|
|
||||||
* list ``l``: ``eta = l[boosting_round]``
|
|
||||||
* function ``f``: ``eta = f(boosting_round, num_boost_round)``
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
callback : function
|
|
||||||
The requested callback function.
|
|
||||||
"""
|
|
||||||
def get_learning_rate(i, n, learning_rates):
|
|
||||||
"""helper providing the learning rate"""
|
|
||||||
if isinstance(learning_rates, list):
|
|
||||||
if len(learning_rates) != n:
|
|
||||||
raise ValueError("Length of list 'learning_rates' has to equal 'num_boost_round'.")
|
|
||||||
new_learning_rate = learning_rates[i]
|
|
||||||
else:
|
|
||||||
new_learning_rate = learning_rates(i, n)
|
|
||||||
return new_learning_rate
|
|
||||||
|
|
||||||
def callback(env):
|
|
||||||
"""internal function"""
|
|
||||||
context = _get_callback_context(env)
|
|
||||||
|
|
||||||
if context == 'train':
|
|
||||||
bst, i, n = env.model, env.iteration, env.end_iteration
|
|
||||||
bst.set_param(
|
|
||||||
'learning_rate', get_learning_rate(i, n, learning_rates))
|
|
||||||
elif context == 'cv':
|
|
||||||
i, n = env.iteration, env.end_iteration
|
|
||||||
for cvpack in env.cvfolds:
|
|
||||||
bst = cvpack.bst
|
|
||||||
bst.set_param(
|
|
||||||
'learning_rate', get_learning_rate(i, n, learning_rates))
|
|
||||||
|
|
||||||
callback.before_iteration = False
|
|
||||||
return callback
|
|
||||||
|
|
||||||
|
|
||||||
def early_stop(stopping_rounds, maximize=False, verbose=True):
|
|
||||||
"""Create a callback that activates early stoppping.
|
|
||||||
|
|
||||||
Validation error needs to decrease at least
|
|
||||||
every **stopping_rounds** round(s) to continue training.
|
|
||||||
Requires at least one item in **evals**.
|
|
||||||
If there's more than one, will use the last.
|
|
||||||
Returns the model from the last iteration (not the best one).
|
|
||||||
If early stopping occurs, the model will have three additional fields:
|
|
||||||
``bst.best_score``, ``bst.best_iteration``.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
stopping_rounds : int
|
|
||||||
The stopping rounds before the trend occur.
|
|
||||||
|
|
||||||
maximize : bool
|
|
||||||
Whether to maximize evaluation metric.
|
|
||||||
|
|
||||||
verbose : optional, bool
|
|
||||||
Whether to print message about early stopping information.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
callback : function
|
|
||||||
The requested callback function.
|
|
||||||
"""
|
|
||||||
state = {}
|
|
||||||
|
|
||||||
def init(env):
|
|
||||||
"""internal function"""
|
|
||||||
bst = env.model
|
|
||||||
|
|
||||||
if not env.evaluation_result_list:
|
|
||||||
raise ValueError('For early stopping you need at least one set in evals.')
|
|
||||||
if len(env.evaluation_result_list) > 1 and verbose:
|
|
||||||
msg = ("Multiple eval metrics have been passed: "
|
|
||||||
"'{0}' will be used for early stopping.\n\n")
|
|
||||||
rabit.tracker_print(msg.format(env.evaluation_result_list[-1][0]))
|
|
||||||
maximize_metrics = ('auc', 'aucpr', 'map', 'ndcg')
|
|
||||||
maximize_at_n_metrics = ('auc@', 'aucpr@', 'map@', 'ndcg@')
|
|
||||||
maximize_score = maximize
|
|
||||||
metric_label = env.evaluation_result_list[-1][0]
|
|
||||||
metric = metric_label.split('-', 1)[-1]
|
|
||||||
|
|
||||||
if any(metric.startswith(x) for x in maximize_at_n_metrics):
|
|
||||||
maximize_score = True
|
|
||||||
|
|
||||||
if any(metric.split(":")[0] == x for x in maximize_metrics):
|
|
||||||
maximize_score = True
|
|
||||||
|
|
||||||
if verbose and env.rank == 0:
|
|
||||||
msg = "Will train until {} hasn't improved in {} rounds.\n"
|
|
||||||
rabit.tracker_print(msg.format(metric_label, stopping_rounds))
|
|
||||||
|
|
||||||
state['maximize_score'] = maximize_score
|
|
||||||
state['best_iteration'] = 0
|
|
||||||
if maximize_score:
|
|
||||||
state['best_score'] = float('-inf')
|
|
||||||
else:
|
|
||||||
state['best_score'] = float('inf')
|
|
||||||
# pylint: disable=consider-using-f-string
|
|
||||||
msg = '[%d]\t%s' % (
|
|
||||||
env.iteration,
|
|
||||||
'\t'.join([_fmt_metric(x) for x in env.evaluation_result_list])
|
|
||||||
)
|
|
||||||
state['best_msg'] = msg
|
|
||||||
|
|
||||||
if bst is not None:
|
|
||||||
if bst.attr('best_score') is not None:
|
|
||||||
state['best_score'] = float(bst.attr('best_score'))
|
|
||||||
state['best_iteration'] = int(bst.attr('best_iteration'))
|
|
||||||
state['best_msg'] = bst.attr('best_msg')
|
|
||||||
else:
|
|
||||||
bst.set_attr(best_iteration=str(state['best_iteration']))
|
|
||||||
bst.set_attr(best_score=str(state['best_score']))
|
|
||||||
else:
|
|
||||||
assert env.cvfolds is not None
|
|
||||||
|
|
||||||
def callback(env):
|
|
||||||
"""internal function"""
|
|
||||||
if not state:
|
|
||||||
init(env)
|
|
||||||
score = env.evaluation_result_list[-1][1]
|
|
||||||
best_score = state['best_score']
|
|
||||||
best_iteration = state['best_iteration']
|
|
||||||
maximize_score = state['maximize_score']
|
|
||||||
if (maximize_score and score > best_score) or \
|
|
||||||
(not maximize_score and score < best_score):
|
|
||||||
# pylint: disable=consider-using-f-string
|
|
||||||
msg = '[%d]\t%s' % (
|
|
||||||
env.iteration,
|
|
||||||
'\t'.join([_fmt_metric(x) for x in env.evaluation_result_list]))
|
|
||||||
state['best_msg'] = msg
|
|
||||||
state['best_score'] = score
|
|
||||||
state['best_iteration'] = env.iteration
|
|
||||||
# save the property to attributes, so they will occur in checkpoint.
|
|
||||||
if env.model is not None:
|
|
||||||
env.model.set_attr(best_score=str(state['best_score']),
|
|
||||||
best_iteration=str(state['best_iteration']),
|
|
||||||
best_msg=state['best_msg'])
|
|
||||||
elif env.iteration - best_iteration >= stopping_rounds:
|
|
||||||
best_msg = state['best_msg']
|
|
||||||
if verbose and env.rank == 0:
|
|
||||||
msg = "Stopping. Best iteration:\n{}\n\n"
|
|
||||||
rabit.tracker_print(msg.format(best_msg))
|
|
||||||
raise EarlyStopException(best_iteration)
|
|
||||||
return callback
|
|
||||||
|
|
||||||
|
|
||||||
# The new implementation of callback functions.
|
# The new implementation of callback functions.
|
||||||
# Breaking:
|
# Breaking:
|
||||||
# - reset learning rate no longer accepts total boosting rounds
|
# - reset learning rate no longer accepts total boosting rounds
|
||||||
@ -741,100 +489,3 @@ class TrainingCheckPoint(TrainingCallback):
|
|||||||
model.save_model(path)
|
model.save_model(path)
|
||||||
self._epoch += 1
|
self._epoch += 1
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
class LegacyCallbacks:
|
|
||||||
'''Adapter for legacy callback functions.
|
|
||||||
|
|
||||||
.. versionadded:: 1.3.0
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
|
|
||||||
callbacks : Sequence
|
|
||||||
A sequence of legacy callbacks (callbacks that are not instance of
|
|
||||||
TrainingCallback)
|
|
||||||
start_iteration : int
|
|
||||||
Begining iteration.
|
|
||||||
end_iteration : int
|
|
||||||
End iteration, normally is the number of boosting rounds.
|
|
||||||
evals : Sequence
|
|
||||||
Sequence of evaluation dataset tuples.
|
|
||||||
feval : Custom evaluation metric.
|
|
||||||
'''
|
|
||||||
def __init__(self, callbacks, start_iteration, end_iteration,
|
|
||||||
feval, cvfolds=None):
|
|
||||||
self.callbacks_before_iter = [
|
|
||||||
cb for cb in callbacks
|
|
||||||
if cb.__dict__.get('before_iteration', False)]
|
|
||||||
self.callbacks_after_iter = [
|
|
||||||
cb for cb in callbacks
|
|
||||||
if not cb.__dict__.get('before_iteration', False)]
|
|
||||||
|
|
||||||
self.start_iteration = start_iteration
|
|
||||||
self.end_iteration = end_iteration
|
|
||||||
self.cvfolds = cvfolds
|
|
||||||
|
|
||||||
self.feval = feval
|
|
||||||
assert self.feval is None or callable(self.feval)
|
|
||||||
|
|
||||||
if cvfolds is not None:
|
|
||||||
self.aggregated_cv = None
|
|
||||||
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def before_training(self, model):
|
|
||||||
'''Nothing to do for legacy callbacks'''
|
|
||||||
return model
|
|
||||||
|
|
||||||
def after_training(self, model):
|
|
||||||
'''Nothing to do for legacy callbacks'''
|
|
||||||
return model
|
|
||||||
|
|
||||||
def before_iteration(self, model, epoch, dtrain, evals):
|
|
||||||
'''Called before each iteration.'''
|
|
||||||
for cb in self.callbacks_before_iter:
|
|
||||||
rank = rabit.get_rank()
|
|
||||||
cb(CallbackEnv(model=None if self.cvfolds is not None else model,
|
|
||||||
cvfolds=self.cvfolds,
|
|
||||||
iteration=epoch,
|
|
||||||
begin_iteration=self.start_iteration,
|
|
||||||
end_iteration=self.end_iteration,
|
|
||||||
rank=rank,
|
|
||||||
evaluation_result_list=None))
|
|
||||||
return False
|
|
||||||
|
|
||||||
def after_iteration(self, model, epoch, dtrain, evals):
|
|
||||||
'''Called after each iteration.'''
|
|
||||||
evaluation_result_list = []
|
|
||||||
if self.cvfolds is not None:
|
|
||||||
# dtrain is not used here.
|
|
||||||
scores = model.eval(epoch, self.feval)
|
|
||||||
self.aggregated_cv = _aggcv(scores)
|
|
||||||
evaluation_result_list = self.aggregated_cv
|
|
||||||
|
|
||||||
if evals:
|
|
||||||
# When cv is used, evals are embedded into folds.
|
|
||||||
assert self.cvfolds is None
|
|
||||||
bst_eval_set = model.eval_set(evals, epoch, self.feval)
|
|
||||||
if isinstance(bst_eval_set, STRING_TYPES):
|
|
||||||
msg = bst_eval_set
|
|
||||||
else:
|
|
||||||
msg = bst_eval_set.decode()
|
|
||||||
res = [x.split(':') for x in msg.split()]
|
|
||||||
evaluation_result_list = [(k, float(v)) for k, v in res[1:]]
|
|
||||||
|
|
||||||
try:
|
|
||||||
for cb in self.callbacks_after_iter:
|
|
||||||
rank = rabit.get_rank()
|
|
||||||
cb(CallbackEnv(model=None if self.cvfolds is not None else model,
|
|
||||||
cvfolds=self.cvfolds,
|
|
||||||
iteration=epoch,
|
|
||||||
begin_iteration=self.start_iteration,
|
|
||||||
end_iteration=self.end_iteration,
|
|
||||||
rank=rank,
|
|
||||||
evaluation_result_list=evaluation_result_list))
|
|
||||||
except EarlyStopException:
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|||||||
@ -2,7 +2,6 @@
|
|||||||
# pylint: disable=too-many-arguments, too-many-branches, invalid-name
|
# pylint: disable=too-many-arguments, too-many-branches, invalid-name
|
||||||
# pylint: disable=too-many-lines, too-many-locals, no-self-use
|
# pylint: disable=too-many-lines, too-many-locals, no-self-use
|
||||||
"""Core XGBoost Library."""
|
"""Core XGBoost Library."""
|
||||||
import collections
|
|
||||||
# pylint: disable=no-name-in-module,import-error
|
# pylint: disable=no-name-in-module,import-error
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from typing import List, Optional, Any, Union, Dict, TypeVar
|
from typing import List, Optional, Any, Union, Dict, TypeVar
|
||||||
@ -46,18 +45,6 @@ class EarlyStopException(Exception):
|
|||||||
self.best_iteration = best_iteration
|
self.best_iteration = best_iteration
|
||||||
|
|
||||||
|
|
||||||
# Callback environment used by callbacks
|
|
||||||
CallbackEnv = collections.namedtuple(
|
|
||||||
"XGBoostCallbackEnv",
|
|
||||||
["model",
|
|
||||||
"cvfolds",
|
|
||||||
"iteration",
|
|
||||||
"begin_iteration",
|
|
||||||
"end_iteration",
|
|
||||||
"rank",
|
|
||||||
"evaluation_result_list"])
|
|
||||||
|
|
||||||
|
|
||||||
def from_pystr_to_cstr(data: Union[str, List[str]]):
|
def from_pystr_to_cstr(data: Union[str, List[str]]):
|
||||||
"""Convert a Python str or list of Python str to C pointer
|
"""Convert a Python str or list of Python str to C pointer
|
||||||
|
|
||||||
|
|||||||
@ -2,40 +2,24 @@
|
|||||||
# pylint: disable=too-many-locals, too-many-arguments, invalid-name
|
# pylint: disable=too-many-locals, too-many-arguments, invalid-name
|
||||||
# pylint: disable=too-many-branches, too-many-statements
|
# pylint: disable=too-many-branches, too-many-statements
|
||||||
"""Training Library containing training routines."""
|
"""Training Library containing training routines."""
|
||||||
import warnings
|
|
||||||
import copy
|
import copy
|
||||||
|
from typing import Optional, List
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from .core import Booster, XGBoostError, _get_booster_layer_trees
|
from .core import Booster, XGBoostError, _get_booster_layer_trees
|
||||||
from .compat import (SKLEARN_INSTALLED, XGBStratifiedKFold)
|
from .compat import (SKLEARN_INSTALLED, XGBStratifiedKFold)
|
||||||
from . import callback
|
from . import callback
|
||||||
|
|
||||||
|
|
||||||
def _configure_deprecated_callbacks(
|
def _assert_new_callback(callbacks: Optional[List[callback.TrainingCallback]]) -> None:
|
||||||
verbose_eval, early_stopping_rounds, maximize, start_iteration,
|
is_new_callback: bool = not callbacks or all(
|
||||||
num_boost_round, feval, evals_result, callbacks, show_stdv, cvfolds):
|
isinstance(c, callback.TrainingCallback) for c in callbacks
|
||||||
link = 'https://xgboost.readthedocs.io/en/latest/python/callbacks.html'
|
)
|
||||||
warnings.warn(f'Old style callback is deprecated. See: {link}', UserWarning)
|
if not is_new_callback:
|
||||||
# Most of legacy advanced options becomes callbacks
|
link = "https://xgboost.readthedocs.io/en/latest/python/callbacks.html"
|
||||||
if early_stopping_rounds is not None:
|
raise ValueError(
|
||||||
callbacks.append(callback.early_stop(early_stopping_rounds,
|
f"Old style callback was removed in version 1.6. See: {link}."
|
||||||
maximize=maximize,
|
)
|
||||||
verbose=bool(verbose_eval)))
|
|
||||||
if isinstance(verbose_eval, bool) and verbose_eval:
|
|
||||||
callbacks.append(callback.print_evaluation(show_stdv=show_stdv))
|
|
||||||
else:
|
|
||||||
if isinstance(verbose_eval, int):
|
|
||||||
callbacks.append(callback.print_evaluation(verbose_eval,
|
|
||||||
show_stdv=show_stdv))
|
|
||||||
if evals_result is not None:
|
|
||||||
callbacks.append(callback.record_evaluation(evals_result))
|
|
||||||
callbacks = callback.LegacyCallbacks(
|
|
||||||
callbacks, start_iteration, num_boost_round, feval, cvfolds=cvfolds)
|
|
||||||
return callbacks
|
|
||||||
|
|
||||||
|
|
||||||
def _is_new_callback(callbacks):
|
|
||||||
return any(isinstance(c, callback.TrainingCallback)
|
|
||||||
for c in callbacks) or not callbacks
|
|
||||||
|
|
||||||
|
|
||||||
def _train_internal(params, dtrain,
|
def _train_internal(params, dtrain,
|
||||||
@ -56,22 +40,15 @@ def _train_internal(params, dtrain,
|
|||||||
|
|
||||||
start_iteration = 0
|
start_iteration = 0
|
||||||
|
|
||||||
is_new_callback = _is_new_callback(callbacks)
|
_assert_new_callback(callbacks)
|
||||||
if is_new_callback:
|
if verbose_eval:
|
||||||
assert all(isinstance(c, callback.TrainingCallback)
|
verbose_eval = 1 if verbose_eval is True else verbose_eval
|
||||||
for c in callbacks), "You can't mix new and old callback styles."
|
callbacks.append(callback.EvaluationMonitor(period=verbose_eval))
|
||||||
if verbose_eval:
|
if early_stopping_rounds:
|
||||||
verbose_eval = 1 if verbose_eval is True else verbose_eval
|
callbacks.append(
|
||||||
callbacks.append(callback.EvaluationMonitor(period=verbose_eval))
|
callback.EarlyStopping(rounds=early_stopping_rounds, maximize=maximize)
|
||||||
if early_stopping_rounds:
|
)
|
||||||
callbacks.append(callback.EarlyStopping(
|
callbacks = callback.CallbackContainer(callbacks, metric=feval)
|
||||||
rounds=early_stopping_rounds, maximize=maximize))
|
|
||||||
callbacks = callback.CallbackContainer(callbacks, metric=feval)
|
|
||||||
else:
|
|
||||||
callbacks = _configure_deprecated_callbacks(
|
|
||||||
verbose_eval, early_stopping_rounds, maximize, start_iteration,
|
|
||||||
num_boost_round, feval, evals_result, callbacks,
|
|
||||||
show_stdv=False, cvfolds=None)
|
|
||||||
|
|
||||||
bst = callbacks.before_training(bst)
|
bst = callbacks.before_training(bst)
|
||||||
|
|
||||||
@ -84,7 +61,7 @@ def _train_internal(params, dtrain,
|
|||||||
|
|
||||||
bst = callbacks.after_training(bst)
|
bst = callbacks.after_training(bst)
|
||||||
|
|
||||||
if evals_result is not None and is_new_callback:
|
if evals_result is not None:
|
||||||
evals_result.update(callbacks.history)
|
evals_result.update(callbacks.history)
|
||||||
|
|
||||||
# These should be moved into callback functions `after_training`, but until old
|
# These should be moved into callback functions `after_training`, but until old
|
||||||
@ -468,25 +445,19 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None
|
|||||||
|
|
||||||
# setup callbacks
|
# setup callbacks
|
||||||
callbacks = [] if callbacks is None else callbacks
|
callbacks = [] if callbacks is None else callbacks
|
||||||
is_new_callback = _is_new_callback(callbacks)
|
_assert_new_callback(callbacks)
|
||||||
if is_new_callback:
|
|
||||||
assert all(isinstance(c, callback.TrainingCallback)
|
if verbose_eval:
|
||||||
for c in callbacks), "You can't mix new and old callback styles."
|
verbose_eval = 1 if verbose_eval is True else verbose_eval
|
||||||
if verbose_eval:
|
callbacks.append(
|
||||||
verbose_eval = 1 if verbose_eval is True else verbose_eval
|
callback.EvaluationMonitor(period=verbose_eval, show_stdv=show_stdv)
|
||||||
callbacks.append(
|
)
|
||||||
callback.EvaluationMonitor(period=verbose_eval, show_stdv=show_stdv)
|
if early_stopping_rounds:
|
||||||
)
|
callbacks.append(
|
||||||
if early_stopping_rounds:
|
callback.EarlyStopping(rounds=early_stopping_rounds, maximize=maximize)
|
||||||
callbacks.append(
|
)
|
||||||
callback.EarlyStopping(rounds=early_stopping_rounds, maximize=maximize)
|
callbacks = callback.CallbackContainer(callbacks, metric=feval, is_cv=True)
|
||||||
)
|
|
||||||
callbacks = callback.CallbackContainer(callbacks, metric=feval, is_cv=True)
|
|
||||||
else:
|
|
||||||
callbacks = _configure_deprecated_callbacks(
|
|
||||||
verbose_eval, early_stopping_rounds, maximize, 0,
|
|
||||||
num_boost_round, feval, None, callbacks,
|
|
||||||
show_stdv=show_stdv, cvfolds=cvfolds)
|
|
||||||
booster = _PackedBooster(cvfolds)
|
booster = _PackedBooster(cvfolds)
|
||||||
callbacks.before_training(booster)
|
callbacks.before_training(booster)
|
||||||
|
|
||||||
|
|||||||
@ -41,8 +41,7 @@ class TestGPUBasicModels:
|
|||||||
self.cpu_test_bm.run_custom_objective("gpu_hist")
|
self.cpu_test_bm.run_custom_objective("gpu_hist")
|
||||||
|
|
||||||
def test_eta_decay_gpu_hist(self):
|
def test_eta_decay_gpu_hist(self):
|
||||||
self.cpu_test_cb.run_eta_decay('gpu_hist', True)
|
self.cpu_test_cb.run_eta_decay('gpu_hist')
|
||||||
self.cpu_test_cb.run_eta_decay('gpu_hist', False)
|
|
||||||
|
|
||||||
def test_deterministic_gpu_hist(self):
|
def test_deterministic_gpu_hist(self):
|
||||||
kRows = 1000
|
kRows = 1000
|
||||||
|
|||||||
@ -76,23 +76,6 @@ class TestBasic:
|
|||||||
predt_1 = booster.predict(dtrain)
|
predt_1 = booster.predict(dtrain)
|
||||||
np.testing.assert_allclose(predt_0, predt_1)
|
np.testing.assert_allclose(predt_0, predt_1)
|
||||||
|
|
||||||
def test_record_results(self):
|
|
||||||
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
|
|
||||||
dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')
|
|
||||||
param = {'max_depth': 2, 'eta': 1, 'verbosity': 0,
|
|
||||||
'objective': 'binary:logistic', 'eval_metric': 'error'}
|
|
||||||
# specify validations set to watch performance
|
|
||||||
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
|
|
||||||
num_round = 2
|
|
||||||
result = {}
|
|
||||||
res2 = {}
|
|
||||||
xgb.train(param, dtrain, num_round, watchlist,
|
|
||||||
callbacks=[xgb.callback.record_evaluation(result)])
|
|
||||||
xgb.train(param, dtrain, num_round, watchlist,
|
|
||||||
evals_result=res2)
|
|
||||||
assert result['train']['error'][0] < 0.1
|
|
||||||
assert res2 == result
|
|
||||||
|
|
||||||
def test_multiclass(self):
|
def test_multiclass(self):
|
||||||
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
|
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
|
||||||
dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')
|
dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')
|
||||||
@ -254,8 +237,18 @@ class TestBasic:
|
|||||||
]
|
]
|
||||||
|
|
||||||
# Use callback to log the test labels in each fold
|
# Use callback to log the test labels in each fold
|
||||||
def cb(cbackenv):
|
class Callback(xgb.callback.TrainingCallback):
|
||||||
print([fold.dtest.get_label() for fold in cbackenv.cvfolds])
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def after_iteration(
|
||||||
|
self, model,
|
||||||
|
epoch: int,
|
||||||
|
evals_log: xgb.callback.TrainingCallback.EvalsLog
|
||||||
|
):
|
||||||
|
print([fold.dtest.get_label() for fold in model.cvfolds])
|
||||||
|
|
||||||
|
cb = Callback()
|
||||||
|
|
||||||
# Run cross validation and capture standard out to test callback result
|
# Run cross validation and capture standard out to test callback result
|
||||||
with tm.captured_output() as (out, err):
|
with tm.captured_output() as (out, err):
|
||||||
|
|||||||
@ -249,12 +249,9 @@ class TestCallbacks:
|
|||||||
assert booster.num_boosted_rounds() == \
|
assert booster.num_boosted_rounds() == \
|
||||||
booster.best_iteration + early_stopping_rounds + 1
|
booster.best_iteration + early_stopping_rounds + 1
|
||||||
|
|
||||||
def run_eta_decay(self, tree_method, deprecated_callback):
|
def run_eta_decay(self, tree_method):
|
||||||
"""Test learning rate scheduler, used by both CPU and GPU tests."""
|
"""Test learning rate scheduler, used by both CPU and GPU tests."""
|
||||||
if deprecated_callback:
|
scheduler = xgb.callback.LearningRateScheduler
|
||||||
scheduler = xgb.callback.reset_learning_rate
|
|
||||||
else:
|
|
||||||
scheduler = xgb.callback.LearningRateScheduler
|
|
||||||
|
|
||||||
dpath = os.path.join(tm.PROJECT_ROOT, 'demo/data/')
|
dpath = os.path.join(tm.PROJECT_ROOT, 'demo/data/')
|
||||||
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
|
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
|
||||||
@ -262,10 +259,7 @@ class TestCallbacks:
|
|||||||
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
|
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
|
||||||
num_round = 4
|
num_round = 4
|
||||||
|
|
||||||
if deprecated_callback:
|
warning_check = tm.noop_context()
|
||||||
warning_check = pytest.warns(UserWarning)
|
|
||||||
else:
|
|
||||||
warning_check = tm.noop_context()
|
|
||||||
|
|
||||||
# learning_rates as a list
|
# learning_rates as a list
|
||||||
# init eta with 0 to check whether learning_rates work
|
# init eta with 0 to check whether learning_rates work
|
||||||
@ -339,19 +333,9 @@ class TestCallbacks:
|
|||||||
with warning_check:
|
with warning_check:
|
||||||
xgb.cv(param, dtrain, num_round, callbacks=[scheduler(eta_decay)])
|
xgb.cv(param, dtrain, num_round, callbacks=[scheduler(eta_decay)])
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize("tree_method", ["hist", "approx", "exact"])
|
||||||
"tree_method, deprecated_callback",
|
def test_eta_decay(self, tree_method):
|
||||||
[
|
self.run_eta_decay(tree_method)
|
||||||
("hist", True),
|
|
||||||
("hist", False),
|
|
||||||
("approx", True),
|
|
||||||
("approx", False),
|
|
||||||
("exact", True),
|
|
||||||
("exact", False),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_eta_decay(self, tree_method, deprecated_callback):
|
|
||||||
self.run_eta_decay(tree_method, deprecated_callback)
|
|
||||||
|
|
||||||
def test_check_point(self):
|
def test_check_point(self):
|
||||||
from sklearn.datasets import load_breast_cancer
|
from sklearn.datasets import load_breast_cancer
|
||||||
|
|||||||
@ -22,15 +22,25 @@ def test_aft_survival_toy_data():
|
|||||||
# "Accuracy" = the number of data points whose ranged label (y_lower, y_upper) includes
|
# "Accuracy" = the number of data points whose ranged label (y_lower, y_upper) includes
|
||||||
# the corresponding predicted label (y_pred)
|
# the corresponding predicted label (y_pred)
|
||||||
acc_rec = []
|
acc_rec = []
|
||||||
def my_callback(env):
|
|
||||||
y_pred = env.model.predict(dmat)
|
class Callback(xgb.callback.TrainingCallback):
|
||||||
acc = np.sum(np.logical_and(y_pred >= y_lower, y_pred <= y_upper)/len(X))
|
def __init__(self):
|
||||||
acc_rec.append(acc)
|
super().__init__()
|
||||||
|
|
||||||
|
def after_iteration(
|
||||||
|
self, model: xgb.Booster,
|
||||||
|
epoch: int,
|
||||||
|
evals_log: xgb.callback.TrainingCallback.EvalsLog
|
||||||
|
):
|
||||||
|
y_pred = model.predict(dmat)
|
||||||
|
acc = np.sum(np.logical_and(y_pred >= y_lower, y_pred <= y_upper)/len(X))
|
||||||
|
acc_rec.append(acc)
|
||||||
|
return False
|
||||||
|
|
||||||
evals_result = {}
|
evals_result = {}
|
||||||
params = {'max_depth': 3, 'objective':'survival:aft', 'min_child_weight': 0}
|
params = {'max_depth': 3, 'objective': 'survival:aft', 'min_child_weight': 0}
|
||||||
bst = xgb.train(params, dmat, 15, [(dmat, 'train')], evals_result=evals_result,
|
bst = xgb.train(params, dmat, 15, [(dmat, 'train')], evals_result=evals_result,
|
||||||
callbacks=[my_callback])
|
callbacks=[Callback()])
|
||||||
|
|
||||||
nloglik_rec = evals_result['train']['aft-nloglik']
|
nloglik_rec = evals_result['train']['aft-nloglik']
|
||||||
# AFT metric (negative log likelihood) improve monotonically
|
# AFT metric (negative log likelihood) improve monotonically
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user