Remove old callback deprecated in 1.3. (#7280)
This commit is contained in:
parent
578de9f762
commit
69d3b1b8b4
@ -10,262 +10,10 @@ from typing import Callable, List, Optional, Union, Dict, Tuple
|
||||
import numpy
|
||||
|
||||
from . import rabit
|
||||
from .core import EarlyStopException, CallbackEnv, Booster, XGBoostError
|
||||
from .core import Booster, XGBoostError
|
||||
from .compat import STRING_TYPES
|
||||
|
||||
|
||||
def _get_callback_context(env):
|
||||
"""return whether the current callback context is cv or train"""
|
||||
if env.model is not None and env.cvfolds is None:
|
||||
context = 'train'
|
||||
elif env.model is None and env.cvfolds is not None:
|
||||
context = 'cv'
|
||||
else:
|
||||
raise ValueError("Unexpected input with both model and cvfolds.")
|
||||
return context
|
||||
|
||||
|
||||
def _fmt_metric(value, show_stdv=True):
|
||||
"""format metric string"""
|
||||
if len(value) == 2:
|
||||
return f"{value[0]}:{value[1]:.5f}"
|
||||
if len(value) == 3:
|
||||
if show_stdv:
|
||||
return f"{value[0]}:{value[1]:.5f}+{value[2]:.5f}"
|
||||
return f"{value[0]}:{value[1]:.5f}"
|
||||
raise ValueError("wrong metric value", value)
|
||||
|
||||
|
||||
def print_evaluation(period=1, show_stdv=True):
|
||||
"""Create a callback that print evaluation result.
|
||||
|
||||
We print the evaluation results every **period** iterations
|
||||
and on the first and the last iterations.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
period : int
|
||||
The period to log the evaluation results
|
||||
|
||||
show_stdv : bool, optional
|
||||
Whether show stdv if provided
|
||||
|
||||
Returns
|
||||
-------
|
||||
callback : function
|
||||
A callback that print evaluation every period iterations.
|
||||
"""
|
||||
def callback(env):
|
||||
"""internal function"""
|
||||
if env.rank != 0 or (not env.evaluation_result_list) or period is False or period == 0:
|
||||
return
|
||||
i = env.iteration
|
||||
if i % period == 0 or i + 1 == env.begin_iteration or i + 1 == env.end_iteration:
|
||||
msg = '\t'.join([_fmt_metric(x, show_stdv) for x in env.evaluation_result_list])
|
||||
rabit.tracker_print(f"{i}\t{msg}\n")
|
||||
return callback
|
||||
|
||||
|
||||
def record_evaluation(eval_result):
|
||||
"""Create a call back that records the evaluation history into **eval_result**.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
eval_result : dict
|
||||
A dictionary to store the evaluation results.
|
||||
|
||||
Returns
|
||||
-------
|
||||
callback : function
|
||||
The requested callback function.
|
||||
"""
|
||||
if not isinstance(eval_result, dict):
|
||||
raise TypeError('eval_result has to be a dictionary')
|
||||
eval_result.clear()
|
||||
|
||||
def init(env):
|
||||
"""internal function"""
|
||||
for k, _ in env.evaluation_result_list:
|
||||
pos = k.index('-')
|
||||
key = k[:pos]
|
||||
metric = k[pos + 1:]
|
||||
if key not in eval_result:
|
||||
eval_result[key] = {}
|
||||
if metric not in eval_result[key]:
|
||||
eval_result[key][metric] = []
|
||||
|
||||
def callback(env):
|
||||
"""internal function"""
|
||||
if not eval_result:
|
||||
init(env)
|
||||
for k, v in env.evaluation_result_list:
|
||||
pos = k.index('-')
|
||||
key = k[:pos]
|
||||
metric = k[pos + 1:]
|
||||
eval_result[key][metric].append(v)
|
||||
return callback
|
||||
|
||||
|
||||
def reset_learning_rate(learning_rates):
|
||||
"""Reset learning rate after iteration 1
|
||||
|
||||
NOTE: the initial learning rate will still take in-effect on first iteration.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
learning_rates: list or function
|
||||
List of learning rate for each boosting round
|
||||
or a customized function that calculates eta in terms of
|
||||
current number of round and the total number of boosting round (e.g.
|
||||
yields learning rate decay)
|
||||
|
||||
* list ``l``: ``eta = l[boosting_round]``
|
||||
* function ``f``: ``eta = f(boosting_round, num_boost_round)``
|
||||
|
||||
Returns
|
||||
-------
|
||||
callback : function
|
||||
The requested callback function.
|
||||
"""
|
||||
def get_learning_rate(i, n, learning_rates):
|
||||
"""helper providing the learning rate"""
|
||||
if isinstance(learning_rates, list):
|
||||
if len(learning_rates) != n:
|
||||
raise ValueError("Length of list 'learning_rates' has to equal 'num_boost_round'.")
|
||||
new_learning_rate = learning_rates[i]
|
||||
else:
|
||||
new_learning_rate = learning_rates(i, n)
|
||||
return new_learning_rate
|
||||
|
||||
def callback(env):
|
||||
"""internal function"""
|
||||
context = _get_callback_context(env)
|
||||
|
||||
if context == 'train':
|
||||
bst, i, n = env.model, env.iteration, env.end_iteration
|
||||
bst.set_param(
|
||||
'learning_rate', get_learning_rate(i, n, learning_rates))
|
||||
elif context == 'cv':
|
||||
i, n = env.iteration, env.end_iteration
|
||||
for cvpack in env.cvfolds:
|
||||
bst = cvpack.bst
|
||||
bst.set_param(
|
||||
'learning_rate', get_learning_rate(i, n, learning_rates))
|
||||
|
||||
callback.before_iteration = False
|
||||
return callback
|
||||
|
||||
|
||||
def early_stop(stopping_rounds, maximize=False, verbose=True):
|
||||
"""Create a callback that activates early stoppping.
|
||||
|
||||
Validation error needs to decrease at least
|
||||
every **stopping_rounds** round(s) to continue training.
|
||||
Requires at least one item in **evals**.
|
||||
If there's more than one, will use the last.
|
||||
Returns the model from the last iteration (not the best one).
|
||||
If early stopping occurs, the model will have three additional fields:
|
||||
``bst.best_score``, ``bst.best_iteration``.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
stopping_rounds : int
|
||||
The stopping rounds before the trend occur.
|
||||
|
||||
maximize : bool
|
||||
Whether to maximize evaluation metric.
|
||||
|
||||
verbose : optional, bool
|
||||
Whether to print message about early stopping information.
|
||||
|
||||
Returns
|
||||
-------
|
||||
callback : function
|
||||
The requested callback function.
|
||||
"""
|
||||
state = {}
|
||||
|
||||
def init(env):
|
||||
"""internal function"""
|
||||
bst = env.model
|
||||
|
||||
if not env.evaluation_result_list:
|
||||
raise ValueError('For early stopping you need at least one set in evals.')
|
||||
if len(env.evaluation_result_list) > 1 and verbose:
|
||||
msg = ("Multiple eval metrics have been passed: "
|
||||
"'{0}' will be used for early stopping.\n\n")
|
||||
rabit.tracker_print(msg.format(env.evaluation_result_list[-1][0]))
|
||||
maximize_metrics = ('auc', 'aucpr', 'map', 'ndcg')
|
||||
maximize_at_n_metrics = ('auc@', 'aucpr@', 'map@', 'ndcg@')
|
||||
maximize_score = maximize
|
||||
metric_label = env.evaluation_result_list[-1][0]
|
||||
metric = metric_label.split('-', 1)[-1]
|
||||
|
||||
if any(metric.startswith(x) for x in maximize_at_n_metrics):
|
||||
maximize_score = True
|
||||
|
||||
if any(metric.split(":")[0] == x for x in maximize_metrics):
|
||||
maximize_score = True
|
||||
|
||||
if verbose and env.rank == 0:
|
||||
msg = "Will train until {} hasn't improved in {} rounds.\n"
|
||||
rabit.tracker_print(msg.format(metric_label, stopping_rounds))
|
||||
|
||||
state['maximize_score'] = maximize_score
|
||||
state['best_iteration'] = 0
|
||||
if maximize_score:
|
||||
state['best_score'] = float('-inf')
|
||||
else:
|
||||
state['best_score'] = float('inf')
|
||||
# pylint: disable=consider-using-f-string
|
||||
msg = '[%d]\t%s' % (
|
||||
env.iteration,
|
||||
'\t'.join([_fmt_metric(x) for x in env.evaluation_result_list])
|
||||
)
|
||||
state['best_msg'] = msg
|
||||
|
||||
if bst is not None:
|
||||
if bst.attr('best_score') is not None:
|
||||
state['best_score'] = float(bst.attr('best_score'))
|
||||
state['best_iteration'] = int(bst.attr('best_iteration'))
|
||||
state['best_msg'] = bst.attr('best_msg')
|
||||
else:
|
||||
bst.set_attr(best_iteration=str(state['best_iteration']))
|
||||
bst.set_attr(best_score=str(state['best_score']))
|
||||
else:
|
||||
assert env.cvfolds is not None
|
||||
|
||||
def callback(env):
|
||||
"""internal function"""
|
||||
if not state:
|
||||
init(env)
|
||||
score = env.evaluation_result_list[-1][1]
|
||||
best_score = state['best_score']
|
||||
best_iteration = state['best_iteration']
|
||||
maximize_score = state['maximize_score']
|
||||
if (maximize_score and score > best_score) or \
|
||||
(not maximize_score and score < best_score):
|
||||
# pylint: disable=consider-using-f-string
|
||||
msg = '[%d]\t%s' % (
|
||||
env.iteration,
|
||||
'\t'.join([_fmt_metric(x) for x in env.evaluation_result_list]))
|
||||
state['best_msg'] = msg
|
||||
state['best_score'] = score
|
||||
state['best_iteration'] = env.iteration
|
||||
# save the property to attributes, so they will occur in checkpoint.
|
||||
if env.model is not None:
|
||||
env.model.set_attr(best_score=str(state['best_score']),
|
||||
best_iteration=str(state['best_iteration']),
|
||||
best_msg=state['best_msg'])
|
||||
elif env.iteration - best_iteration >= stopping_rounds:
|
||||
best_msg = state['best_msg']
|
||||
if verbose and env.rank == 0:
|
||||
msg = "Stopping. Best iteration:\n{}\n\n"
|
||||
rabit.tracker_print(msg.format(best_msg))
|
||||
raise EarlyStopException(best_iteration)
|
||||
return callback
|
||||
|
||||
|
||||
# The new implementation of callback functions.
|
||||
# Breaking:
|
||||
# - reset learning rate no longer accepts total boosting rounds
|
||||
@ -741,100 +489,3 @@ class TrainingCheckPoint(TrainingCallback):
|
||||
model.save_model(path)
|
||||
self._epoch += 1
|
||||
return False
|
||||
|
||||
|
||||
class LegacyCallbacks:
|
||||
'''Adapter for legacy callback functions.
|
||||
|
||||
.. versionadded:: 1.3.0
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
callbacks : Sequence
|
||||
A sequence of legacy callbacks (callbacks that are not instance of
|
||||
TrainingCallback)
|
||||
start_iteration : int
|
||||
Begining iteration.
|
||||
end_iteration : int
|
||||
End iteration, normally is the number of boosting rounds.
|
||||
evals : Sequence
|
||||
Sequence of evaluation dataset tuples.
|
||||
feval : Custom evaluation metric.
|
||||
'''
|
||||
def __init__(self, callbacks, start_iteration, end_iteration,
|
||||
feval, cvfolds=None):
|
||||
self.callbacks_before_iter = [
|
||||
cb for cb in callbacks
|
||||
if cb.__dict__.get('before_iteration', False)]
|
||||
self.callbacks_after_iter = [
|
||||
cb for cb in callbacks
|
||||
if not cb.__dict__.get('before_iteration', False)]
|
||||
|
||||
self.start_iteration = start_iteration
|
||||
self.end_iteration = end_iteration
|
||||
self.cvfolds = cvfolds
|
||||
|
||||
self.feval = feval
|
||||
assert self.feval is None or callable(self.feval)
|
||||
|
||||
if cvfolds is not None:
|
||||
self.aggregated_cv = None
|
||||
|
||||
super().__init__()
|
||||
|
||||
def before_training(self, model):
|
||||
'''Nothing to do for legacy callbacks'''
|
||||
return model
|
||||
|
||||
def after_training(self, model):
|
||||
'''Nothing to do for legacy callbacks'''
|
||||
return model
|
||||
|
||||
def before_iteration(self, model, epoch, dtrain, evals):
|
||||
'''Called before each iteration.'''
|
||||
for cb in self.callbacks_before_iter:
|
||||
rank = rabit.get_rank()
|
||||
cb(CallbackEnv(model=None if self.cvfolds is not None else model,
|
||||
cvfolds=self.cvfolds,
|
||||
iteration=epoch,
|
||||
begin_iteration=self.start_iteration,
|
||||
end_iteration=self.end_iteration,
|
||||
rank=rank,
|
||||
evaluation_result_list=None))
|
||||
return False
|
||||
|
||||
def after_iteration(self, model, epoch, dtrain, evals):
|
||||
'''Called after each iteration.'''
|
||||
evaluation_result_list = []
|
||||
if self.cvfolds is not None:
|
||||
# dtrain is not used here.
|
||||
scores = model.eval(epoch, self.feval)
|
||||
self.aggregated_cv = _aggcv(scores)
|
||||
evaluation_result_list = self.aggregated_cv
|
||||
|
||||
if evals:
|
||||
# When cv is used, evals are embedded into folds.
|
||||
assert self.cvfolds is None
|
||||
bst_eval_set = model.eval_set(evals, epoch, self.feval)
|
||||
if isinstance(bst_eval_set, STRING_TYPES):
|
||||
msg = bst_eval_set
|
||||
else:
|
||||
msg = bst_eval_set.decode()
|
||||
res = [x.split(':') for x in msg.split()]
|
||||
evaluation_result_list = [(k, float(v)) for k, v in res[1:]]
|
||||
|
||||
try:
|
||||
for cb in self.callbacks_after_iter:
|
||||
rank = rabit.get_rank()
|
||||
cb(CallbackEnv(model=None if self.cvfolds is not None else model,
|
||||
cvfolds=self.cvfolds,
|
||||
iteration=epoch,
|
||||
begin_iteration=self.start_iteration,
|
||||
end_iteration=self.end_iteration,
|
||||
rank=rank,
|
||||
evaluation_result_list=evaluation_result_list))
|
||||
except EarlyStopException:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@ -2,7 +2,6 @@
|
||||
# pylint: disable=too-many-arguments, too-many-branches, invalid-name
|
||||
# pylint: disable=too-many-lines, too-many-locals, no-self-use
|
||||
"""Core XGBoost Library."""
|
||||
import collections
|
||||
# pylint: disable=no-name-in-module,import-error
|
||||
from collections.abc import Mapping
|
||||
from typing import List, Optional, Any, Union, Dict, TypeVar
|
||||
@ -46,18 +45,6 @@ class EarlyStopException(Exception):
|
||||
self.best_iteration = best_iteration
|
||||
|
||||
|
||||
# Callback environment used by callbacks
|
||||
CallbackEnv = collections.namedtuple(
|
||||
"XGBoostCallbackEnv",
|
||||
["model",
|
||||
"cvfolds",
|
||||
"iteration",
|
||||
"begin_iteration",
|
||||
"end_iteration",
|
||||
"rank",
|
||||
"evaluation_result_list"])
|
||||
|
||||
|
||||
def from_pystr_to_cstr(data: Union[str, List[str]]):
|
||||
"""Convert a Python str or list of Python str to C pointer
|
||||
|
||||
|
||||
@ -2,40 +2,24 @@
|
||||
# pylint: disable=too-many-locals, too-many-arguments, invalid-name
|
||||
# pylint: disable=too-many-branches, too-many-statements
|
||||
"""Training Library containing training routines."""
|
||||
import warnings
|
||||
import copy
|
||||
from typing import Optional, List
|
||||
|
||||
import numpy as np
|
||||
from .core import Booster, XGBoostError, _get_booster_layer_trees
|
||||
from .compat import (SKLEARN_INSTALLED, XGBStratifiedKFold)
|
||||
from . import callback
|
||||
|
||||
|
||||
def _configure_deprecated_callbacks(
|
||||
verbose_eval, early_stopping_rounds, maximize, start_iteration,
|
||||
num_boost_round, feval, evals_result, callbacks, show_stdv, cvfolds):
|
||||
link = 'https://xgboost.readthedocs.io/en/latest/python/callbacks.html'
|
||||
warnings.warn(f'Old style callback is deprecated. See: {link}', UserWarning)
|
||||
# Most of legacy advanced options becomes callbacks
|
||||
if early_stopping_rounds is not None:
|
||||
callbacks.append(callback.early_stop(early_stopping_rounds,
|
||||
maximize=maximize,
|
||||
verbose=bool(verbose_eval)))
|
||||
if isinstance(verbose_eval, bool) and verbose_eval:
|
||||
callbacks.append(callback.print_evaluation(show_stdv=show_stdv))
|
||||
else:
|
||||
if isinstance(verbose_eval, int):
|
||||
callbacks.append(callback.print_evaluation(verbose_eval,
|
||||
show_stdv=show_stdv))
|
||||
if evals_result is not None:
|
||||
callbacks.append(callback.record_evaluation(evals_result))
|
||||
callbacks = callback.LegacyCallbacks(
|
||||
callbacks, start_iteration, num_boost_round, feval, cvfolds=cvfolds)
|
||||
return callbacks
|
||||
|
||||
|
||||
def _is_new_callback(callbacks):
|
||||
return any(isinstance(c, callback.TrainingCallback)
|
||||
for c in callbacks) or not callbacks
|
||||
def _assert_new_callback(callbacks: Optional[List[callback.TrainingCallback]]) -> None:
|
||||
is_new_callback: bool = not callbacks or all(
|
||||
isinstance(c, callback.TrainingCallback) for c in callbacks
|
||||
)
|
||||
if not is_new_callback:
|
||||
link = "https://xgboost.readthedocs.io/en/latest/python/callbacks.html"
|
||||
raise ValueError(
|
||||
f"Old style callback was removed in version 1.6. See: {link}."
|
||||
)
|
||||
|
||||
|
||||
def _train_internal(params, dtrain,
|
||||
@ -56,22 +40,15 @@ def _train_internal(params, dtrain,
|
||||
|
||||
start_iteration = 0
|
||||
|
||||
is_new_callback = _is_new_callback(callbacks)
|
||||
if is_new_callback:
|
||||
assert all(isinstance(c, callback.TrainingCallback)
|
||||
for c in callbacks), "You can't mix new and old callback styles."
|
||||
if verbose_eval:
|
||||
verbose_eval = 1 if verbose_eval is True else verbose_eval
|
||||
callbacks.append(callback.EvaluationMonitor(period=verbose_eval))
|
||||
if early_stopping_rounds:
|
||||
callbacks.append(callback.EarlyStopping(
|
||||
rounds=early_stopping_rounds, maximize=maximize))
|
||||
callbacks = callback.CallbackContainer(callbacks, metric=feval)
|
||||
else:
|
||||
callbacks = _configure_deprecated_callbacks(
|
||||
verbose_eval, early_stopping_rounds, maximize, start_iteration,
|
||||
num_boost_round, feval, evals_result, callbacks,
|
||||
show_stdv=False, cvfolds=None)
|
||||
_assert_new_callback(callbacks)
|
||||
if verbose_eval:
|
||||
verbose_eval = 1 if verbose_eval is True else verbose_eval
|
||||
callbacks.append(callback.EvaluationMonitor(period=verbose_eval))
|
||||
if early_stopping_rounds:
|
||||
callbacks.append(
|
||||
callback.EarlyStopping(rounds=early_stopping_rounds, maximize=maximize)
|
||||
)
|
||||
callbacks = callback.CallbackContainer(callbacks, metric=feval)
|
||||
|
||||
bst = callbacks.before_training(bst)
|
||||
|
||||
@ -84,7 +61,7 @@ def _train_internal(params, dtrain,
|
||||
|
||||
bst = callbacks.after_training(bst)
|
||||
|
||||
if evals_result is not None and is_new_callback:
|
||||
if evals_result is not None:
|
||||
evals_result.update(callbacks.history)
|
||||
|
||||
# These should be moved into callback functions `after_training`, but until old
|
||||
@ -468,25 +445,19 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None
|
||||
|
||||
# setup callbacks
|
||||
callbacks = [] if callbacks is None else callbacks
|
||||
is_new_callback = _is_new_callback(callbacks)
|
||||
if is_new_callback:
|
||||
assert all(isinstance(c, callback.TrainingCallback)
|
||||
for c in callbacks), "You can't mix new and old callback styles."
|
||||
if verbose_eval:
|
||||
verbose_eval = 1 if verbose_eval is True else verbose_eval
|
||||
callbacks.append(
|
||||
callback.EvaluationMonitor(period=verbose_eval, show_stdv=show_stdv)
|
||||
)
|
||||
if early_stopping_rounds:
|
||||
callbacks.append(
|
||||
callback.EarlyStopping(rounds=early_stopping_rounds, maximize=maximize)
|
||||
)
|
||||
callbacks = callback.CallbackContainer(callbacks, metric=feval, is_cv=True)
|
||||
else:
|
||||
callbacks = _configure_deprecated_callbacks(
|
||||
verbose_eval, early_stopping_rounds, maximize, 0,
|
||||
num_boost_round, feval, None, callbacks,
|
||||
show_stdv=show_stdv, cvfolds=cvfolds)
|
||||
_assert_new_callback(callbacks)
|
||||
|
||||
if verbose_eval:
|
||||
verbose_eval = 1 if verbose_eval is True else verbose_eval
|
||||
callbacks.append(
|
||||
callback.EvaluationMonitor(period=verbose_eval, show_stdv=show_stdv)
|
||||
)
|
||||
if early_stopping_rounds:
|
||||
callbacks.append(
|
||||
callback.EarlyStopping(rounds=early_stopping_rounds, maximize=maximize)
|
||||
)
|
||||
callbacks = callback.CallbackContainer(callbacks, metric=feval, is_cv=True)
|
||||
|
||||
booster = _PackedBooster(cvfolds)
|
||||
callbacks.before_training(booster)
|
||||
|
||||
|
||||
@ -41,8 +41,7 @@ class TestGPUBasicModels:
|
||||
self.cpu_test_bm.run_custom_objective("gpu_hist")
|
||||
|
||||
def test_eta_decay_gpu_hist(self):
|
||||
self.cpu_test_cb.run_eta_decay('gpu_hist', True)
|
||||
self.cpu_test_cb.run_eta_decay('gpu_hist', False)
|
||||
self.cpu_test_cb.run_eta_decay('gpu_hist')
|
||||
|
||||
def test_deterministic_gpu_hist(self):
|
||||
kRows = 1000
|
||||
|
||||
@ -76,23 +76,6 @@ class TestBasic:
|
||||
predt_1 = booster.predict(dtrain)
|
||||
np.testing.assert_allclose(predt_0, predt_1)
|
||||
|
||||
def test_record_results(self):
|
||||
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
|
||||
dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')
|
||||
param = {'max_depth': 2, 'eta': 1, 'verbosity': 0,
|
||||
'objective': 'binary:logistic', 'eval_metric': 'error'}
|
||||
# specify validations set to watch performance
|
||||
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
|
||||
num_round = 2
|
||||
result = {}
|
||||
res2 = {}
|
||||
xgb.train(param, dtrain, num_round, watchlist,
|
||||
callbacks=[xgb.callback.record_evaluation(result)])
|
||||
xgb.train(param, dtrain, num_round, watchlist,
|
||||
evals_result=res2)
|
||||
assert result['train']['error'][0] < 0.1
|
||||
assert res2 == result
|
||||
|
||||
def test_multiclass(self):
|
||||
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
|
||||
dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')
|
||||
@ -254,8 +237,18 @@ class TestBasic:
|
||||
]
|
||||
|
||||
# Use callback to log the test labels in each fold
|
||||
def cb(cbackenv):
|
||||
print([fold.dtest.get_label() for fold in cbackenv.cvfolds])
|
||||
class Callback(xgb.callback.TrainingCallback):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def after_iteration(
|
||||
self, model,
|
||||
epoch: int,
|
||||
evals_log: xgb.callback.TrainingCallback.EvalsLog
|
||||
):
|
||||
print([fold.dtest.get_label() for fold in model.cvfolds])
|
||||
|
||||
cb = Callback()
|
||||
|
||||
# Run cross validation and capture standard out to test callback result
|
||||
with tm.captured_output() as (out, err):
|
||||
|
||||
@ -249,12 +249,9 @@ class TestCallbacks:
|
||||
assert booster.num_boosted_rounds() == \
|
||||
booster.best_iteration + early_stopping_rounds + 1
|
||||
|
||||
def run_eta_decay(self, tree_method, deprecated_callback):
|
||||
def run_eta_decay(self, tree_method):
|
||||
"""Test learning rate scheduler, used by both CPU and GPU tests."""
|
||||
if deprecated_callback:
|
||||
scheduler = xgb.callback.reset_learning_rate
|
||||
else:
|
||||
scheduler = xgb.callback.LearningRateScheduler
|
||||
scheduler = xgb.callback.LearningRateScheduler
|
||||
|
||||
dpath = os.path.join(tm.PROJECT_ROOT, 'demo/data/')
|
||||
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
|
||||
@ -262,10 +259,7 @@ class TestCallbacks:
|
||||
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
|
||||
num_round = 4
|
||||
|
||||
if deprecated_callback:
|
||||
warning_check = pytest.warns(UserWarning)
|
||||
else:
|
||||
warning_check = tm.noop_context()
|
||||
warning_check = tm.noop_context()
|
||||
|
||||
# learning_rates as a list
|
||||
# init eta with 0 to check whether learning_rates work
|
||||
@ -339,19 +333,9 @@ class TestCallbacks:
|
||||
with warning_check:
|
||||
xgb.cv(param, dtrain, num_round, callbacks=[scheduler(eta_decay)])
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"tree_method, deprecated_callback",
|
||||
[
|
||||
("hist", True),
|
||||
("hist", False),
|
||||
("approx", True),
|
||||
("approx", False),
|
||||
("exact", True),
|
||||
("exact", False),
|
||||
],
|
||||
)
|
||||
def test_eta_decay(self, tree_method, deprecated_callback):
|
||||
self.run_eta_decay(tree_method, deprecated_callback)
|
||||
@pytest.mark.parametrize("tree_method", ["hist", "approx", "exact"])
|
||||
def test_eta_decay(self, tree_method):
|
||||
self.run_eta_decay(tree_method)
|
||||
|
||||
def test_check_point(self):
|
||||
from sklearn.datasets import load_breast_cancer
|
||||
|
||||
@ -22,15 +22,25 @@ def test_aft_survival_toy_data():
|
||||
# "Accuracy" = the number of data points whose ranged label (y_lower, y_upper) includes
|
||||
# the corresponding predicted label (y_pred)
|
||||
acc_rec = []
|
||||
def my_callback(env):
|
||||
y_pred = env.model.predict(dmat)
|
||||
acc = np.sum(np.logical_and(y_pred >= y_lower, y_pred <= y_upper)/len(X))
|
||||
acc_rec.append(acc)
|
||||
|
||||
class Callback(xgb.callback.TrainingCallback):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def after_iteration(
|
||||
self, model: xgb.Booster,
|
||||
epoch: int,
|
||||
evals_log: xgb.callback.TrainingCallback.EvalsLog
|
||||
):
|
||||
y_pred = model.predict(dmat)
|
||||
acc = np.sum(np.logical_and(y_pred >= y_lower, y_pred <= y_upper)/len(X))
|
||||
acc_rec.append(acc)
|
||||
return False
|
||||
|
||||
evals_result = {}
|
||||
params = {'max_depth': 3, 'objective':'survival:aft', 'min_child_weight': 0}
|
||||
params = {'max_depth': 3, 'objective': 'survival:aft', 'min_child_weight': 0}
|
||||
bst = xgb.train(params, dmat, 15, [(dmat, 'train')], evals_result=evals_result,
|
||||
callbacks=[my_callback])
|
||||
callbacks=[Callback()])
|
||||
|
||||
nloglik_rec = evals_result['train']['aft-nloglik']
|
||||
# AFT metric (negative log likelihood) improve monotonically
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user