Rework Python callback functions. (#6199)
* Define a new callback interface for Python. * Deprecate the old callbacks. * Enable early stopping on dask.
This commit is contained in:
parent
b5b24354b8
commit
ab5b35134f
130
demo/guide-python/callbacks.py
Normal file
130
demo/guide-python/callbacks.py
Normal file
@ -0,0 +1,130 @@
|
|||||||
|
'''
|
||||||
|
Demo for using and defining callback functions.
|
||||||
|
|
||||||
|
.. versionadded:: 1.3.0
|
||||||
|
'''
|
||||||
|
import xgboost as xgb
|
||||||
|
import tempfile
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
from sklearn.datasets import load_breast_cancer
|
||||||
|
from sklearn.model_selection import train_test_split
|
||||||
|
from matplotlib import pyplot as plt
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
|
||||||
|
class Plotting(xgb.callback.TrainingCallback):
|
||||||
|
'''Plot evaluation result during training. Only for demonstration purpose as it's quite
|
||||||
|
slow to draw.
|
||||||
|
|
||||||
|
'''
|
||||||
|
def __init__(self, rounds):
|
||||||
|
self.fig = plt.figure()
|
||||||
|
self.ax = self.fig.add_subplot(111)
|
||||||
|
self.rounds = rounds
|
||||||
|
self.lines = {}
|
||||||
|
self.fig.show()
|
||||||
|
self.x = np.linspace(0, self.rounds, self.rounds)
|
||||||
|
plt.ion()
|
||||||
|
|
||||||
|
def _get_key(self, data, metric):
|
||||||
|
return f'{data}-{metric}'
|
||||||
|
|
||||||
|
def after_iteration(self, model, epoch, evals_log):
|
||||||
|
'''Update the plot.'''
|
||||||
|
if not self.lines:
|
||||||
|
for data, metric in evals_log.items():
|
||||||
|
for metric_name, log in metric.items():
|
||||||
|
key = self._get_key(data, metric_name)
|
||||||
|
expanded = log + [0] * (self.rounds - len(log))
|
||||||
|
self.lines[key], = self.ax.plot(self.x, expanded, label=key)
|
||||||
|
self.ax.legend()
|
||||||
|
else:
|
||||||
|
# https://pythonspot.com/matplotlib-update-plot/
|
||||||
|
for data, metric in evals_log.items():
|
||||||
|
for metric_name, log in metric.items():
|
||||||
|
key = self._get_key(data, metric_name)
|
||||||
|
expanded = log + [0] * (self.rounds - len(log))
|
||||||
|
self.lines[key].set_ydata(expanded)
|
||||||
|
self.fig.canvas.draw()
|
||||||
|
# False to indicate training should not stop.
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def custom_callback():
|
||||||
|
'''Demo for defining a custom callback function that plots evaluation result during
|
||||||
|
training.'''
|
||||||
|
X, y = load_breast_cancer(return_X_y=True)
|
||||||
|
X_train, X_valid, y_train, y_valid = train_test_split(X, y, random_state=0)
|
||||||
|
|
||||||
|
D_train = xgb.DMatrix(X_train, y_train)
|
||||||
|
D_valid = xgb.DMatrix(X_valid, y_valid)
|
||||||
|
|
||||||
|
num_boost_round = 100
|
||||||
|
plotting = Plotting(num_boost_round)
|
||||||
|
|
||||||
|
# Pass it to the `callbacks` parameter as a list.
|
||||||
|
xgb.train(
|
||||||
|
{
|
||||||
|
'objective': 'binary:logistic',
|
||||||
|
'eval_metric': ['error', 'rmse'],
|
||||||
|
'tree_method': 'gpu_hist'
|
||||||
|
},
|
||||||
|
D_train,
|
||||||
|
evals=[(D_train, 'Train'), (D_valid, 'Valid')],
|
||||||
|
num_boost_round=num_boost_round,
|
||||||
|
callbacks=[plotting])
|
||||||
|
|
||||||
|
|
||||||
|
def check_point_callback():
|
||||||
|
# only for demo, set a larger value (like 100) in practice as checkpointing is quite
|
||||||
|
# slow.
|
||||||
|
rounds = 2
|
||||||
|
|
||||||
|
def check(as_pickle):
|
||||||
|
for i in range(0, 10, rounds):
|
||||||
|
if i == 0:
|
||||||
|
continue
|
||||||
|
if as_pickle:
|
||||||
|
path = os.path.join(tmpdir, 'model_' + str(i) + '.pkl')
|
||||||
|
else:
|
||||||
|
path = os.path.join(tmpdir, 'model_' + str(i) + '.json')
|
||||||
|
assert(os.path.exists(path))
|
||||||
|
|
||||||
|
X, y = load_breast_cancer(return_X_y=True)
|
||||||
|
m = xgb.DMatrix(X, y)
|
||||||
|
# Check point to a temporary directory for demo
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
# Use callback class from xgboost.callback
|
||||||
|
# Feel free to subclass/customize it to suit your need.
|
||||||
|
check_point = xgb.callback.TrainingCheckPoint(directory=tmpdir,
|
||||||
|
iterations=rounds,
|
||||||
|
name='model')
|
||||||
|
xgb.train({'objective': 'binary:logistic'}, m,
|
||||||
|
num_boost_round=10,
|
||||||
|
verbose_eval=False,
|
||||||
|
callbacks=[check_point])
|
||||||
|
check(False)
|
||||||
|
|
||||||
|
# This version of checkpoint saves everything including parameters and
|
||||||
|
# model. See: doc/tutorials/saving_model.rst
|
||||||
|
check_point = xgb.callback.TrainingCheckPoint(directory=tmpdir,
|
||||||
|
iterations=rounds,
|
||||||
|
as_pickle=True,
|
||||||
|
name='model')
|
||||||
|
xgb.train({'objective': 'binary:logistic'}, m,
|
||||||
|
num_boost_round=10,
|
||||||
|
verbose_eval=False,
|
||||||
|
callbacks=[check_point])
|
||||||
|
check(True)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--plot', default=1, type=int)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
check_point_callback()
|
||||||
|
|
||||||
|
if args.plot:
|
||||||
|
custom_callback()
|
||||||
@ -1,5 +1,7 @@
|
|||||||
'''A demo for defining data iterator.
|
'''A demo for defining data iterator.
|
||||||
|
|
||||||
|
.. versionadded:: 1.2.0
|
||||||
|
|
||||||
The demo that defines a customized iterator for passing batches of data into
|
The demo that defines a customized iterator for passing batches of data into
|
||||||
`xgboost.DeviceQuantileDMatrix` and use this `DeviceQuantileDMatrix` for
|
`xgboost.DeviceQuantileDMatrix` and use this `DeviceQuantileDMatrix` for
|
||||||
training. The feature is used primarily designed to reduce the required GPU
|
training. The feature is used primarily designed to reduce the required GPU
|
||||||
|
|||||||
59
doc/python/callbacks.rst
Normal file
59
doc/python/callbacks.rst
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
##################
|
||||||
|
Callback Functions
|
||||||
|
##################
|
||||||
|
|
||||||
|
This document gives a basic walkthrough of callback function used in XGBoost Python
|
||||||
|
package. In XGBoost 1.3, a new callback interface is designed for Python package, which
|
||||||
|
provides the flexiblity of designing various extension for training. Also, XGBoost has a
|
||||||
|
number of pre-defined callbacks for supporting early stopping, checkpoints etc.
|
||||||
|
|
||||||
|
#######################
|
||||||
|
Using builtin callbacks
|
||||||
|
#######################
|
||||||
|
|
||||||
|
By default, training methods in XGBoost have parameters like ``early_stopping_rounds`` and
|
||||||
|
``verbose``/``verbose_eval``, when specified the training procedure will define the
|
||||||
|
corresponding callbacks internally. For example, when ``early_stopping_rounds`` is
|
||||||
|
specified, ``EarlyStopping`` callback is invoked inside iteration loop. You can also pass
|
||||||
|
this callback function directly into XGBoost:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
D_train = xgb.DMatrix(X_train, y_train)
|
||||||
|
D_valid = xgb.DMatrix(X_valid, y_valid)
|
||||||
|
|
||||||
|
# Define a custom evaluation metric used for early stopping.
|
||||||
|
def eval_error_metric(predt, dtrain: xgb.DMatrix):
|
||||||
|
label = dtrain.get_label()
|
||||||
|
r = np.zeros(predt.shape)
|
||||||
|
gt = predt > 0.5
|
||||||
|
r[gt] = 1 - label[gt]
|
||||||
|
le = predt <= 0.5
|
||||||
|
r[le] = label[le]
|
||||||
|
return 'CustomErr', np.sum(r)
|
||||||
|
|
||||||
|
# Specify which dataset and which metric should be used for early stopping.
|
||||||
|
early_stop = xgb.callback.EarlyStopping(rounds=early_stopping_rounds,
|
||||||
|
metric_name='CustomErr',
|
||||||
|
data_name='Train')
|
||||||
|
|
||||||
|
booster = xgb.train(
|
||||||
|
{'objective': 'binary:logistic',
|
||||||
|
'eval_metric': ['error', 'rmse'],
|
||||||
|
'tree_method': 'hist'}, D_train,
|
||||||
|
evals=[(D_train, 'Train'), (D_valid, 'Valid')],
|
||||||
|
feval=eval_error_metric,
|
||||||
|
num_boost_round=1000,
|
||||||
|
callbacks=[early_stop],
|
||||||
|
verbose_eval=False)
|
||||||
|
|
||||||
|
dump = booster.get_dump(dump_format='json')
|
||||||
|
assert len(early_stop.stopping_history['Valid']['CustomErr']) == len(dump)
|
||||||
|
|
||||||
|
##########################
|
||||||
|
Defining your own callback
|
||||||
|
##########################
|
||||||
|
|
||||||
|
XGBoost provides an callback interface class: ``xgboost.callback.TrainingCallback``, user
|
||||||
|
defined callbacks should inherit this class and override corresponding methods. There's a
|
||||||
|
working example in `demo/guide-python/callbacks.py <https://github.com/dmlc/xgboost/tree/master/demo/guide-python/callbacks.py>`_
|
||||||
@ -11,4 +11,5 @@ Contents
|
|||||||
.. toctree::
|
.. toctree::
|
||||||
python_intro
|
python_intro
|
||||||
python_api
|
python_api
|
||||||
|
callbacks
|
||||||
Python examples <https://github.com/dmlc/xgboost/tree/master/demo/guide-python>
|
Python examples <https://github.com/dmlc/xgboost/tree/master/demo/guide-python>
|
||||||
|
|||||||
@ -1,9 +1,17 @@
|
|||||||
# coding: utf-8
|
# coding: utf-8
|
||||||
# pylint: disable=invalid-name, too-many-statements
|
# pylint: disable=invalid-name, too-many-statements, no-self-use
|
||||||
|
# pylint: disable=too-many-arguments
|
||||||
"""Training Library containing training routines."""
|
"""Training Library containing training routines."""
|
||||||
|
from abc import ABC
|
||||||
|
import collections
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
from typing import Callable, List
|
||||||
|
import numpy
|
||||||
|
|
||||||
from . import rabit
|
from . import rabit
|
||||||
from .core import EarlyStopException
|
from .core import EarlyStopException, CallbackEnv
|
||||||
|
from .compat import STRING_TYPES
|
||||||
|
|
||||||
|
|
||||||
def _get_callback_context(env):
|
def _get_callback_context(env):
|
||||||
@ -21,9 +29,9 @@ def _fmt_metric(value, show_stdv=True):
|
|||||||
return '{0}:{1:.5f}'.format(value[0], value[1])
|
return '{0}:{1:.5f}'.format(value[0], value[1])
|
||||||
if len(value) == 3:
|
if len(value) == 3:
|
||||||
if show_stdv:
|
if show_stdv:
|
||||||
return '{0}:{1:.5f}+{2:.5f}'.format(value[0], value[1], value[2])
|
return '{0}:{1:.5f}+{2:.5f}'.format(value[0], value[1], value[2])
|
||||||
return '{0}:{1:.5f}'.format(value[0], value[1])
|
return '{0}:{1:.5f}'.format(value[0], value[1])
|
||||||
raise ValueError("wrong metric value")
|
raise ValueError("wrong metric value", value)
|
||||||
|
|
||||||
|
|
||||||
def print_evaluation(period=1, show_stdv=True):
|
def print_evaluation(period=1, show_stdv=True):
|
||||||
@ -253,3 +261,476 @@ def early_stop(stopping_rounds, maximize=False, verbose=True):
|
|||||||
rabit.tracker_print(msg.format(best_msg))
|
rabit.tracker_print(msg.format(best_msg))
|
||||||
raise EarlyStopException(best_iteration)
|
raise EarlyStopException(best_iteration)
|
||||||
return callback
|
return callback
|
||||||
|
|
||||||
|
|
||||||
|
# The new implementation of callback functions.
|
||||||
|
# Breaking:
|
||||||
|
# - reset learning rate no longer accepts total boosting rounds
|
||||||
|
|
||||||
|
# pylint: disable=unused-argument
|
||||||
|
class TrainingCallback(ABC):
|
||||||
|
'''Interface for training callback.
|
||||||
|
|
||||||
|
.. versionadded:: 1.3.0
|
||||||
|
|
||||||
|
'''
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def before_training(self, model):
|
||||||
|
'''Run before training starts.'''
|
||||||
|
|
||||||
|
def after_training(self, model):
|
||||||
|
'''Run after training is finished.'''
|
||||||
|
|
||||||
|
def before_iteration(self, model, epoch, evals_log):
|
||||||
|
'''Run before each iteration. Return True when training should stop.'''
|
||||||
|
return False
|
||||||
|
|
||||||
|
def after_iteration(self, model, epoch, evals_log):
|
||||||
|
'''Run after each iteration. Return True when training should stop.'''
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _aggcv(rlist):
|
||||||
|
# pylint: disable=invalid-name
|
||||||
|
"""Aggregate cross-validation results.
|
||||||
|
|
||||||
|
"""
|
||||||
|
cvmap = {}
|
||||||
|
idx = rlist[0].split()[0]
|
||||||
|
for line in rlist:
|
||||||
|
arr = line.split()
|
||||||
|
assert idx == arr[0]
|
||||||
|
for metric_idx, it in enumerate(arr[1:]):
|
||||||
|
if not isinstance(it, STRING_TYPES):
|
||||||
|
it = it.decode()
|
||||||
|
k, v = it.split(':')
|
||||||
|
if (metric_idx, k) not in cvmap:
|
||||||
|
cvmap[(metric_idx, k)] = []
|
||||||
|
cvmap[(metric_idx, k)].append(float(v))
|
||||||
|
msg = idx
|
||||||
|
results = []
|
||||||
|
for (metric_idx, k), v in sorted(cvmap.items(), key=lambda x: x[0][0]):
|
||||||
|
v = numpy.array(v)
|
||||||
|
if not isinstance(msg, STRING_TYPES):
|
||||||
|
msg = msg.decode()
|
||||||
|
mean, std = numpy.mean(v), numpy.std(v)
|
||||||
|
results.extend([(k, mean, std)])
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def _allreduce_metric(score):
|
||||||
|
'''Helper function for computing customized metric in distributed
|
||||||
|
environment. Not strictly correct as many functions don't use mean value
|
||||||
|
as final result.
|
||||||
|
|
||||||
|
'''
|
||||||
|
world = rabit.get_world_size()
|
||||||
|
assert world != 0
|
||||||
|
if world == 1:
|
||||||
|
return score
|
||||||
|
if isinstance(score, tuple): # has mean and stdv
|
||||||
|
raise ValueError(
|
||||||
|
'xgboost.cv function should not be used in distributed environment.')
|
||||||
|
score = numpy.array([score])
|
||||||
|
score = rabit.allreduce(score, rabit.Op.SUM) / world
|
||||||
|
return score[0]
|
||||||
|
|
||||||
|
|
||||||
|
class CallbackContainer:
|
||||||
|
'''A special callback for invoking a list of other callbacks.
|
||||||
|
|
||||||
|
.. versionadded:: 1.3.0
|
||||||
|
|
||||||
|
'''
|
||||||
|
def __init__(self, callbacks: List[TrainingCallback],
|
||||||
|
metric: Callable = None, is_cv: bool = False):
|
||||||
|
self.callbacks = set(callbacks)
|
||||||
|
if metric is not None:
|
||||||
|
msg = 'metric must be callable object for monitoring. For ' + \
|
||||||
|
'builtin metrics, passing them in training parameter' + \
|
||||||
|
' will invoke monitor automatically.'
|
||||||
|
assert callable(metric), msg
|
||||||
|
self.metric = metric
|
||||||
|
self.history = collections.OrderedDict()
|
||||||
|
self.is_cv = is_cv
|
||||||
|
|
||||||
|
if self.is_cv:
|
||||||
|
self.aggregated_cv = None
|
||||||
|
|
||||||
|
def before_training(self, model):
|
||||||
|
'''Function called before training.'''
|
||||||
|
for c in self.callbacks:
|
||||||
|
c.before_training(model=model)
|
||||||
|
|
||||||
|
def after_training(self, model):
|
||||||
|
'''Function called after training.'''
|
||||||
|
for c in self.callbacks:
|
||||||
|
c.after_training(model)
|
||||||
|
|
||||||
|
def before_iteration(self, model, epoch, dtrain, evals):
|
||||||
|
'''Function called before training iteration.'''
|
||||||
|
return any(c.before_iteration(model, epoch, self.history)
|
||||||
|
for c in self.callbacks)
|
||||||
|
|
||||||
|
def _update_history(self, score, epoch):
|
||||||
|
for d in score:
|
||||||
|
name, s = d[0], float(d[1])
|
||||||
|
if self.is_cv:
|
||||||
|
std = float(d[2])
|
||||||
|
s = (s, std)
|
||||||
|
splited_names = name.split('-')
|
||||||
|
data_name = splited_names[0]
|
||||||
|
metric_name = '-'.join(splited_names[1:])
|
||||||
|
s = _allreduce_metric(s)
|
||||||
|
if data_name in self.history:
|
||||||
|
data_history = self.history[data_name]
|
||||||
|
if metric_name in data_history:
|
||||||
|
data_history[metric_name].append(s)
|
||||||
|
else:
|
||||||
|
data_history[metric_name] = [s]
|
||||||
|
else:
|
||||||
|
self.history[data_name] = collections.OrderedDict()
|
||||||
|
self.history[data_name][metric_name] = [s]
|
||||||
|
return False
|
||||||
|
|
||||||
|
def after_iteration(self, model, epoch, dtrain, evals):
|
||||||
|
'''Function called after training iteration.'''
|
||||||
|
if self.is_cv:
|
||||||
|
scores = model.eval(epoch, self.metric)
|
||||||
|
scores = _aggcv(scores)
|
||||||
|
self.aggregated_cv = scores
|
||||||
|
self._update_history(scores, epoch)
|
||||||
|
else:
|
||||||
|
evals = [] if evals is None else evals
|
||||||
|
for _, name in evals:
|
||||||
|
assert name.find('-') == -1, 'Dataset name should not contain `-`'
|
||||||
|
score = model.eval_set(evals, epoch, self.metric)
|
||||||
|
score = score.split()[1:] # into datasets
|
||||||
|
# split up `test-error:0.1234`
|
||||||
|
score = [tuple(s.split(':')) for s in score]
|
||||||
|
self._update_history(score, epoch)
|
||||||
|
ret = any(c.after_iteration(model, epoch, self.history)
|
||||||
|
for c in self.callbacks)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
class LearningRateScheduler(TrainingCallback):
|
||||||
|
'''Callback function for scheduling learning rate.
|
||||||
|
|
||||||
|
.. versionadded:: 1.3.0
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
|
||||||
|
learning_rates : callable/collections.Sequence
|
||||||
|
If it's a callable object, then it should accept an integer parameter
|
||||||
|
`epoch` and returns the corresponding learning rate. Otherwise it
|
||||||
|
should be a sequence like list or tuple with the same size of boosting
|
||||||
|
rounds.
|
||||||
|
|
||||||
|
'''
|
||||||
|
def __init__(self, learning_rates):
|
||||||
|
assert callable(learning_rates) or \
|
||||||
|
isinstance(learning_rates, collections.abc.Sequence)
|
||||||
|
if callable(learning_rates):
|
||||||
|
self.learning_rates = learning_rates
|
||||||
|
else:
|
||||||
|
self.learning_rates = lambda epoch: learning_rates[epoch]
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def after_iteration(self, model, epoch, evals_log):
|
||||||
|
model.set_param('learning_rate', self.learning_rates(epoch))
|
||||||
|
|
||||||
|
|
||||||
|
# pylint: disable=too-many-instance-attributes
|
||||||
|
class EarlyStopping(TrainingCallback):
|
||||||
|
''' Callback function for early stopping
|
||||||
|
|
||||||
|
.. versionadded:: 1.3.0
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
rounds : int
|
||||||
|
Early stopping rounds.
|
||||||
|
metric_name : str
|
||||||
|
Name of metric that is used for early stopping.
|
||||||
|
data_name: str
|
||||||
|
Name of dataset that is used for early stopping.
|
||||||
|
maximize : bool
|
||||||
|
Whether to maximize evaluation metric. None means auto (discouraged).
|
||||||
|
save_best : bool
|
||||||
|
Placeholder, the feature is not yet supported.
|
||||||
|
'''
|
||||||
|
def __init__(self,
|
||||||
|
rounds,
|
||||||
|
metric_name=None,
|
||||||
|
data_name=None,
|
||||||
|
maximize=None,
|
||||||
|
save_best=False):
|
||||||
|
self.data = data_name
|
||||||
|
self.metric_name = metric_name
|
||||||
|
self.rounds = rounds
|
||||||
|
self.save_best = save_best
|
||||||
|
# https://github.com/dmlc/xgboost/issues/5531
|
||||||
|
assert self.save_best is False, 'save best is not yet supported.'
|
||||||
|
|
||||||
|
self.maximize = maximize
|
||||||
|
self.stopping_history = {}
|
||||||
|
|
||||||
|
if self.maximize is not None:
|
||||||
|
if self.maximize:
|
||||||
|
self.improve_op = lambda x, y: x > y
|
||||||
|
else:
|
||||||
|
self.improve_op = lambda x, y: x < y
|
||||||
|
|
||||||
|
self.current_rounds = 0
|
||||||
|
self.best_scores = {}
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def _update_rounds(self, score, name, metric, model, epoch):
|
||||||
|
# Just to be compatibility with old behavior before 1.3. We should let
|
||||||
|
# user to decide.
|
||||||
|
if self.maximize is None:
|
||||||
|
maximize_metrics = ('auc', 'aucpr', 'map', 'ndcg', 'auc@',
|
||||||
|
'aucpr@', 'map@', 'ndcg@')
|
||||||
|
if any(metric.startswith(x) for x in maximize_metrics):
|
||||||
|
self.improve_op = lambda x, y: x > y
|
||||||
|
self.maximize = True
|
||||||
|
else:
|
||||||
|
self.improve_op = lambda x, y: x < y
|
||||||
|
self.maximize = False
|
||||||
|
|
||||||
|
if not self.stopping_history: # First round
|
||||||
|
self.current_rounds = 0
|
||||||
|
self.stopping_history[name] = {}
|
||||||
|
self.stopping_history[name][metric] = [score]
|
||||||
|
self.best_scores[name] = {}
|
||||||
|
self.best_scores[name][metric] = [score]
|
||||||
|
model.set_attr(best_score=str(score), best_iteration=str(epoch))
|
||||||
|
elif not self.improve_op(score, self.best_scores[name][metric][-1]):
|
||||||
|
# Not improved
|
||||||
|
self.stopping_history[name][metric].append(score)
|
||||||
|
self.current_rounds += 1
|
||||||
|
else: # Improved
|
||||||
|
self.stopping_history[name][metric].append(score)
|
||||||
|
self.best_scores[name][metric].append(score)
|
||||||
|
record = self.stopping_history[name][metric][-1]
|
||||||
|
model.set_attr(best_score=str(record), best_iteration=str(epoch))
|
||||||
|
self.current_rounds = 0 # reset
|
||||||
|
|
||||||
|
if self.current_rounds >= self.rounds:
|
||||||
|
# Should stop
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def after_iteration(self, model, epoch, evals_log):
|
||||||
|
msg = 'Must have at least 1 validation dataset for early stopping.'
|
||||||
|
assert len(evals_log.keys()) >= 1, msg
|
||||||
|
data_name = ''
|
||||||
|
if self.data:
|
||||||
|
for d, _ in evals_log.items():
|
||||||
|
if d == self.data:
|
||||||
|
data_name = d
|
||||||
|
if not data_name:
|
||||||
|
raise ValueError('No dataset named:', self.data)
|
||||||
|
else:
|
||||||
|
# Use the last one as default.
|
||||||
|
data_name = list(evals_log.keys())[-1]
|
||||||
|
assert isinstance(data_name, str) and data_name
|
||||||
|
data_log = evals_log[data_name]
|
||||||
|
|
||||||
|
# Filter out scores that can not be used for early stopping.
|
||||||
|
if self.metric_name:
|
||||||
|
metric_name = self.metric_name
|
||||||
|
else:
|
||||||
|
# Use last metric by default.
|
||||||
|
assert isinstance(data_log, collections.OrderedDict)
|
||||||
|
metric_name = list(data_log.keys())[-1]
|
||||||
|
score = data_log[metric_name][-1]
|
||||||
|
return self._update_rounds(score, data_name, metric_name, model, epoch)
|
||||||
|
|
||||||
|
|
||||||
|
class EvaluationMonitor(TrainingCallback):
|
||||||
|
'''Print the evaluation result at each iteration.
|
||||||
|
|
||||||
|
.. versionadded:: 1.3.0
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
|
||||||
|
metric : callable
|
||||||
|
Extra user defined metric.
|
||||||
|
rank : int
|
||||||
|
Which worker should be used for printing the result.
|
||||||
|
show_stdv : bool
|
||||||
|
Used in cv to show standard deviation. Users should not specify it.
|
||||||
|
'''
|
||||||
|
def __init__(self, rank=0, show_stdv=False):
|
||||||
|
self.printer_rank = rank
|
||||||
|
self.show_stdv = show_stdv
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def _fmt_metric(self, data, metric, score, std):
|
||||||
|
if std is not None and self.show_stdv:
|
||||||
|
msg = '\t{0}:{1:.5f}+{2:.5f}'.format(data + '-' + metric, score, std)
|
||||||
|
else:
|
||||||
|
msg = '\t{0}:{1:.5f}'.format(data + '-' + metric, score)
|
||||||
|
return msg
|
||||||
|
|
||||||
|
def after_iteration(self, model, epoch, evals_log):
|
||||||
|
if not evals_log:
|
||||||
|
return False
|
||||||
|
msg = f'[{epoch}]'
|
||||||
|
if rabit.get_rank() == self.printer_rank:
|
||||||
|
for data, metric in evals_log.items():
|
||||||
|
for metric_name, log in metric.items():
|
||||||
|
if isinstance(log[-1], tuple):
|
||||||
|
score = log[-1][0]
|
||||||
|
stdv = log[-1][1]
|
||||||
|
else:
|
||||||
|
score = log[-1]
|
||||||
|
stdv = None
|
||||||
|
msg += self._fmt_metric(data, metric_name, score, stdv)
|
||||||
|
msg += '\n'
|
||||||
|
rabit.tracker_print(msg)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class TrainingCheckPoint(TrainingCallback):
|
||||||
|
'''Checkpointing operation.
|
||||||
|
|
||||||
|
.. versionadded:: 1.3.0
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
|
||||||
|
directory : os.PathLike
|
||||||
|
Output model directory.
|
||||||
|
name : str
|
||||||
|
pattern of output model file. Models will be saved as name_0.json, name_1.json,
|
||||||
|
name_2.json ....
|
||||||
|
as_pickle : boolean
|
||||||
|
When set to Ture, all training parameters will be saved in pickle format, instead
|
||||||
|
of saving only the model.
|
||||||
|
iterations : int
|
||||||
|
Interval of checkpointing. Checkpointing is slow so setting a larger number can
|
||||||
|
reduce performance hit.
|
||||||
|
|
||||||
|
'''
|
||||||
|
def __init__(self, directory: os.PathLike, name: str = 'model',
|
||||||
|
as_pickle=False, iterations: int = 100):
|
||||||
|
self._path = directory
|
||||||
|
self._name = name
|
||||||
|
self._as_pickle = as_pickle
|
||||||
|
self._iterations = iterations
|
||||||
|
self._epoch = 0
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def after_iteration(self, model, epoch, evals_log):
|
||||||
|
if self._epoch == self._iterations:
|
||||||
|
path = os.path.join(self._path, self._name + '_' + str(epoch) +
|
||||||
|
('.pkl' if self._as_pickle else '.json'))
|
||||||
|
self._epoch = 0
|
||||||
|
if rabit.get_rank() == 0:
|
||||||
|
if self._as_pickle:
|
||||||
|
with open(path, 'wb') as fd:
|
||||||
|
pickle.dump(model, fd)
|
||||||
|
else:
|
||||||
|
model.save_model(path)
|
||||||
|
self._epoch += 1
|
||||||
|
|
||||||
|
|
||||||
|
class LegacyCallbacks:
|
||||||
|
'''Adapter for legacy callback functions.
|
||||||
|
|
||||||
|
.. versionadded:: 1.3.0
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
|
||||||
|
callbacks : Sequence
|
||||||
|
A sequence of legacy callbacks (callbacks that are not instance of
|
||||||
|
TrainingCallback)
|
||||||
|
start_iteration : int
|
||||||
|
Begining iteration.
|
||||||
|
end_iteration : int
|
||||||
|
End iteration, normally is the number of boosting rounds.
|
||||||
|
evals : Sequence
|
||||||
|
Sequence of evaluation dataset tuples.
|
||||||
|
feval : Custom evaluation metric.
|
||||||
|
'''
|
||||||
|
def __init__(self, callbacks, start_iteration, end_iteration,
|
||||||
|
feval, cvfolds=None):
|
||||||
|
self.callbacks_before_iter = [
|
||||||
|
cb for cb in callbacks
|
||||||
|
if cb.__dict__.get('before_iteration', False)]
|
||||||
|
self.callbacks_after_iter = [
|
||||||
|
cb for cb in callbacks
|
||||||
|
if not cb.__dict__.get('before_iteration', False)]
|
||||||
|
|
||||||
|
self.start_iteration = start_iteration
|
||||||
|
self.end_iteration = end_iteration
|
||||||
|
self.cvfolds = cvfolds
|
||||||
|
|
||||||
|
self.feval = feval
|
||||||
|
assert self.feval is None or callable(self.feval)
|
||||||
|
|
||||||
|
if cvfolds is not None:
|
||||||
|
self.aggregated_cv = None
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def before_training(self, model):
|
||||||
|
'''Nothing to do for legacy callbacks'''
|
||||||
|
|
||||||
|
def after_training(self, model):
|
||||||
|
'''Nothing to do for legacy callbacks'''
|
||||||
|
|
||||||
|
def before_iteration(self, model, epoch, dtrain, evals):
|
||||||
|
'''Called before each iteration.'''
|
||||||
|
for cb in self.callbacks_before_iter:
|
||||||
|
rank = rabit.get_rank()
|
||||||
|
cb(CallbackEnv(model=model,
|
||||||
|
cvfolds=self.cvfolds,
|
||||||
|
iteration=epoch,
|
||||||
|
begin_iteration=self.start_iteration,
|
||||||
|
end_iteration=self.end_iteration,
|
||||||
|
rank=rank,
|
||||||
|
evaluation_result_list=None))
|
||||||
|
return False
|
||||||
|
|
||||||
|
def after_iteration(self, model, epoch, dtrain, evals):
|
||||||
|
'''Called after each iteration.'''
|
||||||
|
evaluation_result_list = []
|
||||||
|
if self.cvfolds is not None:
|
||||||
|
scores = model.eval(epoch, self.feval)
|
||||||
|
self.aggregated_cv = _aggcv(scores)
|
||||||
|
evaluation_result_list = self.aggregated_cv
|
||||||
|
|
||||||
|
if evals:
|
||||||
|
# When cv is used, evals are embedded into folds.
|
||||||
|
assert self.cvfolds is None
|
||||||
|
bst_eval_set = model.eval_set(evals, epoch, self.feval)
|
||||||
|
if isinstance(bst_eval_set, STRING_TYPES):
|
||||||
|
msg = bst_eval_set
|
||||||
|
else:
|
||||||
|
msg = bst_eval_set.decode()
|
||||||
|
res = [x.split(':') for x in msg.split()]
|
||||||
|
evaluation_result_list = [(k, float(v)) for k, v in res[1:]]
|
||||||
|
|
||||||
|
try:
|
||||||
|
for cb in self.callbacks_after_iter:
|
||||||
|
rank = rabit.get_rank()
|
||||||
|
cb(CallbackEnv(model=model,
|
||||||
|
cvfolds=self.cvfolds,
|
||||||
|
iteration=epoch,
|
||||||
|
begin_iteration=self.start_iteration,
|
||||||
|
end_iteration=self.end_iteration,
|
||||||
|
rank=rank,
|
||||||
|
evaluation_result_list=evaluation_result_list))
|
||||||
|
except EarlyStopException:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|||||||
@ -627,8 +627,8 @@ async def _get_rabit_args(worker_map, client: Client):
|
|||||||
# evaluation history is instead returned.
|
# evaluation history is instead returned.
|
||||||
|
|
||||||
|
|
||||||
async def _train_async(client, params, dtrain: DaskDMatrix,
|
async def _train_async(client, params, dtrain: DaskDMatrix, *args, evals=(),
|
||||||
*args, evals=(), **kwargs):
|
early_stopping_rounds=None, **kwargs):
|
||||||
_assert_dask_support()
|
_assert_dask_support()
|
||||||
client: Client = _xgb_get_client(client)
|
client: Client = _xgb_get_client(client)
|
||||||
if 'evals_result' in kwargs.keys():
|
if 'evals_result' in kwargs.keys():
|
||||||
@ -675,6 +675,7 @@ async def _train_async(client, params, dtrain: DaskDMatrix,
|
|||||||
*args,
|
*args,
|
||||||
evals_result=local_history,
|
evals_result=local_history,
|
||||||
evals=local_evals,
|
evals=local_evals,
|
||||||
|
early_stopping_rounds=early_stopping_rounds,
|
||||||
**kwargs)
|
**kwargs)
|
||||||
ret = {'booster': bst, 'history': local_history}
|
ret = {'booster': bst, 'history': local_history}
|
||||||
if local_dtrain.num_row() == 0:
|
if local_dtrain.num_row() == 0:
|
||||||
@ -694,7 +695,8 @@ async def _train_async(client, params, dtrain: DaskDMatrix,
|
|||||||
return list(filter(lambda ret: ret is not None, results))[0]
|
return list(filter(lambda ret: ret is not None, results))[0]
|
||||||
|
|
||||||
|
|
||||||
def train(client, params, dtrain, *args, evals=(), **kwargs):
|
def train(client, params, dtrain, *args, evals=(), early_stopping_rounds=None,
|
||||||
|
**kwargs):
|
||||||
'''Train XGBoost model.
|
'''Train XGBoost model.
|
||||||
|
|
||||||
.. versionadded:: 1.0.0
|
.. versionadded:: 1.0.0
|
||||||
@ -724,8 +726,9 @@ def train(client, params, dtrain, *args, evals=(), **kwargs):
|
|||||||
'''
|
'''
|
||||||
_assert_dask_support()
|
_assert_dask_support()
|
||||||
client = _xgb_get_client(client)
|
client = _xgb_get_client(client)
|
||||||
return client.sync(_train_async, client, params,
|
return client.sync(
|
||||||
dtrain=dtrain, *args, evals=evals, **kwargs)
|
_train_async, client, params, dtrain=dtrain, *args, evals=evals,
|
||||||
|
early_stopping_rounds=early_stopping_rounds, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
async def _direct_predict_impl(client, data, predict_fn):
|
async def _direct_predict_impl(client, data, predict_fn):
|
||||||
@ -1005,6 +1008,7 @@ class DaskScikitLearnBase(XGBModel):
|
|||||||
base_margin=None,
|
base_margin=None,
|
||||||
eval_set=None,
|
eval_set=None,
|
||||||
sample_weight_eval_set=None,
|
sample_weight_eval_set=None,
|
||||||
|
early_stopping_rounds=None,
|
||||||
verbose=True):
|
verbose=True):
|
||||||
'''Fit the regressor.
|
'''Fit the regressor.
|
||||||
|
|
||||||
@ -1045,7 +1049,7 @@ class DaskScikitLearnBase(XGBModel):
|
|||||||
return self.client.sync(_).__await__()
|
return self.client.sync(_).__await__()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def client(self):
|
def client(self) -> Client:
|
||||||
'''The dask client used in this model.'''
|
'''The dask client used in this model.'''
|
||||||
client = _xgb_get_client(self._client)
|
client = _xgb_get_client(self._client)
|
||||||
return client
|
return client
|
||||||
@ -1059,42 +1063,51 @@ class DaskScikitLearnBase(XGBModel):
|
|||||||
['estimators', 'model'])
|
['estimators', 'model'])
|
||||||
class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
|
class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
|
||||||
# pylint: disable=missing-class-docstring
|
# pylint: disable=missing-class-docstring
|
||||||
async def _fit_async(self,
|
async def _fit_async(self, X, y, sample_weights, base_margin, eval_set,
|
||||||
X,
|
sample_weight_eval_set, early_stopping_rounds,
|
||||||
y,
|
verbose):
|
||||||
sample_weights=None,
|
dtrain = await DaskDMatrix(client=self.client,
|
||||||
base_margin=None,
|
data=X,
|
||||||
eval_set=None,
|
label=y,
|
||||||
sample_weight_eval_set=None,
|
weight=sample_weights,
|
||||||
verbose=True):
|
base_margin=base_margin,
|
||||||
dtrain = await DaskDMatrix(
|
missing=self.missing)
|
||||||
client=self.client, data=X, label=y, weight=sample_weights,
|
|
||||||
base_margin=base_margin, missing=self.missing
|
|
||||||
)
|
|
||||||
params = self.get_xgb_params()
|
params = self.get_xgb_params()
|
||||||
evals = await _evaluation_matrices(self.client,
|
evals = await _evaluation_matrices(self.client, eval_set,
|
||||||
eval_set, sample_weight_eval_set,
|
sample_weight_eval_set,
|
||||||
self.missing)
|
self.missing)
|
||||||
results = await train(client=self.client, params=params, dtrain=dtrain,
|
results = await train(client=self.client,
|
||||||
|
params=params,
|
||||||
|
dtrain=dtrain,
|
||||||
num_boost_round=self.get_num_boosting_rounds(),
|
num_boost_round=self.get_num_boosting_rounds(),
|
||||||
evals=evals, verbose_eval=verbose)
|
evals=evals,
|
||||||
|
verbose_eval=verbose,
|
||||||
|
early_stopping_rounds=early_stopping_rounds)
|
||||||
self._Booster = results['booster']
|
self._Booster = results['booster']
|
||||||
# pylint: disable=attribute-defined-outside-init
|
# pylint: disable=attribute-defined-outside-init
|
||||||
self.evals_result_ = results['history']
|
self.evals_result_ = results['history']
|
||||||
return self
|
return self
|
||||||
|
|
||||||
# pylint: disable=missing-docstring
|
# pylint: disable=missing-docstring
|
||||||
def fit(self, X, y,
|
def fit(self,
|
||||||
|
X,
|
||||||
|
y,
|
||||||
sample_weights=None,
|
sample_weights=None,
|
||||||
base_margin=None,
|
base_margin=None,
|
||||||
eval_set=None,
|
eval_set=None,
|
||||||
sample_weight_eval_set=None,
|
sample_weight_eval_set=None,
|
||||||
|
early_stopping_rounds=None,
|
||||||
verbose=True):
|
verbose=True):
|
||||||
_assert_dask_support()
|
_assert_dask_support()
|
||||||
return self.client.sync(
|
return self.client.sync(self._fit_async,
|
||||||
self._fit_async, X, y, sample_weights, base_margin,
|
X=X,
|
||||||
eval_set, sample_weight_eval_set, verbose
|
y=y,
|
||||||
)
|
sample_weights=sample_weights,
|
||||||
|
base_margin=base_margin,
|
||||||
|
eval_set=eval_set,
|
||||||
|
sample_weight_eval_set=sample_weight_eval_set,
|
||||||
|
early_stopping_rounds=early_stopping_rounds,
|
||||||
|
verbose=verbose)
|
||||||
|
|
||||||
async def _predict_async(
|
async def _predict_async(
|
||||||
self, data, output_margin=False, base_margin=None):
|
self, data, output_margin=False, base_margin=None):
|
||||||
@ -1114,20 +1127,17 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
|
|||||||
output_margin=output_margin,
|
output_margin=output_margin,
|
||||||
base_margin=base_margin)
|
base_margin=base_margin)
|
||||||
|
|
||||||
|
|
||||||
@xgboost_model_doc(
|
@xgboost_model_doc(
|
||||||
'Implementation of the scikit-learn API for XGBoost classification.',
|
'Implementation of the scikit-learn API for XGBoost classification.',
|
||||||
['estimators', 'model']
|
['estimators', 'model'])
|
||||||
)
|
|
||||||
class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
||||||
async def _fit_async(self, X, y,
|
async def _fit_async(self, X, y, sample_weights, base_margin, eval_set,
|
||||||
sample_weights=None,
|
sample_weight_eval_set, early_stopping_rounds,
|
||||||
base_margin=None,
|
verbose):
|
||||||
eval_set=None,
|
|
||||||
sample_weight_eval_set=None,
|
|
||||||
verbose=True):
|
|
||||||
dtrain = await DaskDMatrix(client=self.client,
|
dtrain = await DaskDMatrix(client=self.client,
|
||||||
data=X, label=y, weight=sample_weights,
|
data=X,
|
||||||
|
label=y,
|
||||||
|
weight=sample_weights,
|
||||||
base_margin=base_margin,
|
base_margin=base_margin,
|
||||||
missing=self.missing)
|
missing=self.missing)
|
||||||
params = self.get_xgb_params()
|
params = self.get_xgb_params()
|
||||||
@ -1145,28 +1155,40 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
|||||||
else:
|
else:
|
||||||
params["objective"] = "binary:logistic"
|
params["objective"] = "binary:logistic"
|
||||||
|
|
||||||
evals = await _evaluation_matrices(self.client,
|
evals = await _evaluation_matrices(self.client, eval_set,
|
||||||
eval_set, sample_weight_eval_set,
|
sample_weight_eval_set,
|
||||||
self.missing)
|
self.missing)
|
||||||
results = await train(client=self.client, params=params, dtrain=dtrain,
|
results = await train(client=self.client,
|
||||||
|
params=params,
|
||||||
|
dtrain=dtrain,
|
||||||
num_boost_round=self.get_num_boosting_rounds(),
|
num_boost_round=self.get_num_boosting_rounds(),
|
||||||
evals=evals, verbose_eval=verbose)
|
evals=evals,
|
||||||
|
early_stopping_rounds=early_stopping_rounds,
|
||||||
|
verbose_eval=verbose)
|
||||||
self._Booster = results['booster']
|
self._Booster = results['booster']
|
||||||
# pylint: disable=attribute-defined-outside-init
|
# pylint: disable=attribute-defined-outside-init
|
||||||
self.evals_result_ = results['history']
|
self.evals_result_ = results['history']
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def fit(self, X, y,
|
def fit(self,
|
||||||
|
X,
|
||||||
|
y,
|
||||||
sample_weights=None,
|
sample_weights=None,
|
||||||
base_margin=None,
|
base_margin=None,
|
||||||
eval_set=None,
|
eval_set=None,
|
||||||
sample_weight_eval_set=None,
|
sample_weight_eval_set=None,
|
||||||
|
early_stopping_rounds=None,
|
||||||
verbose=True):
|
verbose=True):
|
||||||
_assert_dask_support()
|
_assert_dask_support()
|
||||||
return self.client.sync(
|
return self.client.sync(self._fit_async,
|
||||||
self._fit_async, X, y, sample_weights, base_margin, eval_set,
|
X=X,
|
||||||
sample_weight_eval_set, verbose
|
y=y,
|
||||||
)
|
sample_weights=sample_weights,
|
||||||
|
base_margin=base_margin,
|
||||||
|
eval_set=eval_set,
|
||||||
|
sample_weight_eval_set=sample_weight_eval_set,
|
||||||
|
early_stopping_rounds=early_stopping_rounds,
|
||||||
|
verbose=verbose)
|
||||||
|
|
||||||
async def _predict_proba_async(self, data, output_margin=False,
|
async def _predict_proba_async(self, data, output_margin=False,
|
||||||
base_margin=None):
|
base_margin=None):
|
||||||
|
|||||||
@ -2,31 +2,64 @@
|
|||||||
# pylint: disable=too-many-locals, too-many-arguments, invalid-name
|
# pylint: disable=too-many-locals, too-many-arguments, invalid-name
|
||||||
# pylint: disable=too-many-branches, too-many-statements
|
# pylint: disable=too-many-branches, too-many-statements
|
||||||
"""Training Library containing training routines."""
|
"""Training Library containing training routines."""
|
||||||
|
import warnings
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from .core import Booster, STRING_TYPES, XGBoostError, CallbackEnv
|
from .core import Booster, XGBoostError
|
||||||
from .core import EarlyStopException
|
|
||||||
from .compat import (SKLEARN_INSTALLED, XGBStratifiedKFold)
|
from .compat import (SKLEARN_INSTALLED, XGBStratifiedKFold)
|
||||||
from . import rabit
|
from . import rabit
|
||||||
from . import callback
|
from . import callback
|
||||||
|
|
||||||
|
|
||||||
def _train_internal(params, dtrain,
|
def _configure_deprecated_callbacks(
|
||||||
num_boost_round=10, evals=(),
|
verbose_eval, early_stopping_rounds, maximize, start_iteration,
|
||||||
obj=None, feval=None,
|
num_boost_round, feval, evals_result, callbacks, show_stdv, cvfolds):
|
||||||
xgb_model=None, callbacks=None):
|
link = 'https://xgboost.readthedocs.io/en/latest/python/callbacks.html'
|
||||||
"""internal training function"""
|
warnings.warn(f'Old style callback is deprecated. See: {link}', DeprecationWarning)
|
||||||
callbacks = [] if callbacks is None else callbacks
|
# Most of legacy advanced options becomes callbacks
|
||||||
evals = list(evals)
|
if early_stopping_rounds is not None:
|
||||||
params = params.copy()
|
callbacks.append(callback.early_stop(early_stopping_rounds,
|
||||||
if isinstance(params, dict) \
|
maximize=maximize,
|
||||||
and 'eval_metric' in params \
|
verbose=bool(verbose_eval)))
|
||||||
and isinstance(params['eval_metric'], list):
|
if isinstance(verbose_eval, bool) and verbose_eval:
|
||||||
|
callbacks.append(callback.print_evaluation(show_stdv=show_stdv))
|
||||||
|
else:
|
||||||
|
if isinstance(verbose_eval, int):
|
||||||
|
callbacks.append(callback.print_evaluation(verbose_eval,
|
||||||
|
show_stdv=show_stdv))
|
||||||
|
if evals_result is not None:
|
||||||
|
callbacks.append(callback.record_evaluation(evals_result))
|
||||||
|
callbacks = callback.LegacyCallbacks(
|
||||||
|
callbacks, start_iteration, num_boost_round, feval, cvfolds=cvfolds)
|
||||||
|
return callbacks
|
||||||
|
|
||||||
|
|
||||||
|
def _is_new_callback(callbacks):
|
||||||
|
return any(isinstance(c, callback.TrainingCallback)
|
||||||
|
for c in callbacks) or not callbacks
|
||||||
|
|
||||||
|
|
||||||
|
def _configure_metrics(params):
|
||||||
|
if isinstance(params, dict) and 'eval_metric' in params \
|
||||||
|
and isinstance(params['eval_metric'], list):
|
||||||
params = dict((k, v) for k, v in params.items())
|
params = dict((k, v) for k, v in params.items())
|
||||||
eval_metrics = params['eval_metric']
|
eval_metrics = params['eval_metric']
|
||||||
params.pop("eval_metric", None)
|
params.pop("eval_metric", None)
|
||||||
params = list(params.items())
|
params = list(params.items())
|
||||||
for eval_metric in eval_metrics:
|
for eval_metric in eval_metrics:
|
||||||
params += [('eval_metric', eval_metric)]
|
params += [('eval_metric', eval_metric)]
|
||||||
|
return params
|
||||||
|
|
||||||
|
|
||||||
|
def _train_internal(params, dtrain,
|
||||||
|
num_boost_round=10, evals=(),
|
||||||
|
obj=None, feval=None,
|
||||||
|
xgb_model=None, callbacks=None,
|
||||||
|
evals_result=None, maximize=None,
|
||||||
|
verbose_eval=None, early_stopping_rounds=None):
|
||||||
|
"""internal training function"""
|
||||||
|
callbacks = [] if callbacks is None else callbacks
|
||||||
|
evals = list(evals)
|
||||||
|
params = _configure_metrics(params.copy())
|
||||||
|
|
||||||
bst = Booster(params, [dtrain] + [d[0] for d in evals])
|
bst = Booster(params, [dtrain] + [d[0] for d in evals])
|
||||||
nboost = 0
|
nboost = 0
|
||||||
@ -49,26 +82,29 @@ def _train_internal(params, dtrain,
|
|||||||
# Distributed code: Load the checkpoint from rabit.
|
# Distributed code: Load the checkpoint from rabit.
|
||||||
version = bst.load_rabit_checkpoint()
|
version = bst.load_rabit_checkpoint()
|
||||||
assert rabit.get_world_size() != 1 or version == 0
|
assert rabit.get_world_size() != 1 or version == 0
|
||||||
rank = rabit.get_rank()
|
|
||||||
start_iteration = int(version / 2)
|
start_iteration = int(version / 2)
|
||||||
nboost += start_iteration
|
nboost += start_iteration
|
||||||
|
|
||||||
callbacks_before_iter = [
|
is_new_callback = _is_new_callback(callbacks)
|
||||||
cb for cb in callbacks
|
if is_new_callback:
|
||||||
if cb.__dict__.get('before_iteration', False)]
|
assert all(isinstance(c, callback.TrainingCallback)
|
||||||
callbacks_after_iter = [
|
for c in callbacks), "You can't mix new and old callback styles."
|
||||||
cb for cb in callbacks
|
if verbose_eval:
|
||||||
if not cb.__dict__.get('before_iteration', False)]
|
callbacks.append(callback.EvaluationMonitor())
|
||||||
|
if early_stopping_rounds:
|
||||||
|
callbacks.append(callback.EarlyStopping(
|
||||||
|
rounds=early_stopping_rounds, maximize=maximize))
|
||||||
|
callbacks = callback.CallbackContainer(callbacks, metric=feval)
|
||||||
|
else:
|
||||||
|
callbacks = _configure_deprecated_callbacks(
|
||||||
|
verbose_eval, early_stopping_rounds, maximize, start_iteration,
|
||||||
|
num_boost_round, feval, evals_result, callbacks,
|
||||||
|
show_stdv=False, cvfolds=None)
|
||||||
|
|
||||||
|
callbacks.before_training(bst)
|
||||||
for i in range(start_iteration, num_boost_round):
|
for i in range(start_iteration, num_boost_round):
|
||||||
for cb in callbacks_before_iter:
|
if callbacks.before_iteration(bst, i, dtrain, evals):
|
||||||
cb(CallbackEnv(model=bst,
|
break
|
||||||
cvfolds=None,
|
|
||||||
iteration=i,
|
|
||||||
begin_iteration=start_iteration,
|
|
||||||
end_iteration=num_boost_round,
|
|
||||||
rank=rank,
|
|
||||||
evaluation_result_list=None))
|
|
||||||
# Distributed code: need to resume to this point.
|
# Distributed code: need to resume to this point.
|
||||||
# Skip the first update if it is a recovery step.
|
# Skip the first update if it is a recovery step.
|
||||||
if version % 2 == 0:
|
if version % 2 == 0:
|
||||||
@ -79,44 +115,32 @@ def _train_internal(params, dtrain,
|
|||||||
assert rabit.get_world_size() == 1 or version == rabit.version_number()
|
assert rabit.get_world_size() == 1 or version == rabit.version_number()
|
||||||
|
|
||||||
nboost += 1
|
nboost += 1
|
||||||
evaluation_result_list = []
|
|
||||||
# check evaluation result.
|
# check evaluation result.
|
||||||
if evals:
|
if callbacks.after_iteration(bst, i, dtrain, evals):
|
||||||
bst_eval_set = bst.eval_set(evals, i, feval)
|
|
||||||
if isinstance(bst_eval_set, STRING_TYPES):
|
|
||||||
msg = bst_eval_set
|
|
||||||
else:
|
|
||||||
msg = bst_eval_set.decode()
|
|
||||||
res = [x.split(':') for x in msg.split()]
|
|
||||||
evaluation_result_list = [(k, float(v)) for k, v in res[1:]]
|
|
||||||
try:
|
|
||||||
for cb in callbacks_after_iter:
|
|
||||||
cb(CallbackEnv(model=bst,
|
|
||||||
cvfolds=None,
|
|
||||||
iteration=i,
|
|
||||||
begin_iteration=start_iteration,
|
|
||||||
end_iteration=num_boost_round,
|
|
||||||
rank=rank,
|
|
||||||
evaluation_result_list=evaluation_result_list))
|
|
||||||
except EarlyStopException:
|
|
||||||
break
|
break
|
||||||
# do checkpoint after evaluation, in case evaluation also updates booster.
|
# do checkpoint after evaluation, in case evaluation also updates
|
||||||
|
# booster.
|
||||||
bst.save_rabit_checkpoint()
|
bst.save_rabit_checkpoint()
|
||||||
version += 1
|
version += 1
|
||||||
|
|
||||||
|
callbacks.after_training(bst)
|
||||||
|
|
||||||
|
if evals_result is not None and is_new_callback:
|
||||||
|
evals_result.update(callbacks.history)
|
||||||
|
|
||||||
if bst.attr('best_score') is not None:
|
if bst.attr('best_score') is not None:
|
||||||
bst.best_score = float(bst.attr('best_score'))
|
bst.best_score = float(bst.attr('best_score'))
|
||||||
bst.best_iteration = int(bst.attr('best_iteration'))
|
bst.best_iteration = int(bst.attr('best_iteration'))
|
||||||
else:
|
else:
|
||||||
bst.best_iteration = nboost - 1
|
bst.best_iteration = nboost - 1
|
||||||
bst.best_ntree_limit = (bst.best_iteration + 1) * num_parallel_tree
|
bst.best_ntree_limit = (bst.best_iteration + 1) * num_parallel_tree
|
||||||
|
# Copy to serialise and unserialise booster to reset state and free
|
||||||
# Copy to serialise and unserialise booster to reset state and free training memory
|
# training memory
|
||||||
return bst.copy()
|
return bst.copy()
|
||||||
|
|
||||||
|
|
||||||
def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
|
def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
|
||||||
maximize=False, early_stopping_rounds=None, evals_result=None,
|
maximize=None, early_stopping_rounds=None, evals_result=None,
|
||||||
verbose_eval=True, xgb_model=None, callbacks=None):
|
verbose_eval=True, xgb_model=None, callbacks=None):
|
||||||
# pylint: disable=too-many-statements,too-many-branches, attribute-defined-outside-init
|
# pylint: disable=too-many-statements,too-many-branches, attribute-defined-outside-init
|
||||||
"""Train a booster with given parameters.
|
"""Train a booster with given parameters.
|
||||||
@ -189,27 +213,16 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
|
|||||||
-------
|
-------
|
||||||
Booster : a trained booster model
|
Booster : a trained booster model
|
||||||
"""
|
"""
|
||||||
callbacks = [] if callbacks is None else callbacks
|
bst = _train_internal(params, dtrain,
|
||||||
|
num_boost_round=num_boost_round,
|
||||||
# Most of legacy advanced options becomes callbacks
|
evals=evals,
|
||||||
if isinstance(verbose_eval, bool) and verbose_eval:
|
obj=obj, feval=feval,
|
||||||
callbacks.append(callback.print_evaluation())
|
xgb_model=xgb_model, callbacks=callbacks,
|
||||||
else:
|
verbose_eval=verbose_eval,
|
||||||
if isinstance(verbose_eval, int):
|
evals_result=evals_result,
|
||||||
callbacks.append(callback.print_evaluation(verbose_eval))
|
maximize=maximize,
|
||||||
|
early_stopping_rounds=early_stopping_rounds)
|
||||||
if early_stopping_rounds is not None:
|
return bst
|
||||||
callbacks.append(callback.early_stop(early_stopping_rounds,
|
|
||||||
maximize=maximize,
|
|
||||||
verbose=bool(verbose_eval)))
|
|
||||||
if evals_result is not None:
|
|
||||||
callbacks.append(callback.record_evaluation(evals_result))
|
|
||||||
|
|
||||||
return _train_internal(params, dtrain,
|
|
||||||
num_boost_round=num_boost_round,
|
|
||||||
evals=evals,
|
|
||||||
obj=obj, feval=feval,
|
|
||||||
xgb_model=xgb_model, callbacks=callbacks)
|
|
||||||
|
|
||||||
|
|
||||||
class CVPack(object):
|
class CVPack(object):
|
||||||
@ -230,6 +243,36 @@ class CVPack(object):
|
|||||||
return self.bst.eval_set(self.watchlist, iteration, feval)
|
return self.bst.eval_set(self.watchlist, iteration, feval)
|
||||||
|
|
||||||
|
|
||||||
|
class _PackedBooster:
|
||||||
|
def __init__(self, cvfolds):
|
||||||
|
self.cvfolds = cvfolds
|
||||||
|
|
||||||
|
def update(self, iteration, obj):
|
||||||
|
'''Iterate through folds for update'''
|
||||||
|
for fold in self.cvfolds:
|
||||||
|
fold.update(iteration, obj)
|
||||||
|
|
||||||
|
def eval(self, iteration, feval):
|
||||||
|
'''Iterate through folds for eval'''
|
||||||
|
result = [f.eval(iteration, feval) for f in self.cvfolds]
|
||||||
|
return result
|
||||||
|
|
||||||
|
def set_attr(self, **kwargs):
|
||||||
|
'''Iterate through folds for setting attributes'''
|
||||||
|
for f in self.cvfolds:
|
||||||
|
f.bst.set_attr(**kwargs)
|
||||||
|
|
||||||
|
def attr(self, key):
|
||||||
|
'''Redirect to booster attr.'''
|
||||||
|
return self.cvfolds[0].bst.attr(key)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def best_iteration(self):
|
||||||
|
'''Get best_iteration'''
|
||||||
|
ret = self.cvfolds[0].bst.attr('best_iteration')
|
||||||
|
return int(ret)
|
||||||
|
|
||||||
|
|
||||||
def groups_to_rows(groups, boundaries):
|
def groups_to_rows(groups, boundaries):
|
||||||
"""
|
"""
|
||||||
Given group row boundaries, convert ground indexes to row indexes
|
Given group row boundaries, convert ground indexes to row indexes
|
||||||
@ -334,40 +377,8 @@ def mknfold(dall, nfold, param, seed, evals=(), fpreproc=None, stratified=False,
|
|||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
def aggcv(rlist):
|
|
||||||
# pylint: disable=invalid-name
|
|
||||||
"""
|
|
||||||
Aggregate cross-validation results.
|
|
||||||
|
|
||||||
If verbose_eval is true, progress is displayed in every call. If
|
|
||||||
verbose_eval is an integer, progress will only be displayed every
|
|
||||||
`verbose_eval` trees, tracked via trial.
|
|
||||||
"""
|
|
||||||
cvmap = {}
|
|
||||||
idx = rlist[0].split()[0]
|
|
||||||
for line in rlist:
|
|
||||||
arr = line.split()
|
|
||||||
assert idx == arr[0]
|
|
||||||
for metric_idx, it in enumerate(arr[1:]):
|
|
||||||
if not isinstance(it, STRING_TYPES):
|
|
||||||
it = it.decode()
|
|
||||||
k, v = it.split(':')
|
|
||||||
if (metric_idx, k) not in cvmap:
|
|
||||||
cvmap[(metric_idx, k)] = []
|
|
||||||
cvmap[(metric_idx, k)].append(float(v))
|
|
||||||
msg = idx
|
|
||||||
results = []
|
|
||||||
for (metric_idx, k), v in sorted(cvmap.items(), key=lambda x: x[0][0]):
|
|
||||||
v = np.array(v)
|
|
||||||
if not isinstance(msg, STRING_TYPES):
|
|
||||||
msg = msg.decode()
|
|
||||||
mean, std = np.mean(v), np.std(v)
|
|
||||||
results.extend([(k, mean, std)])
|
|
||||||
return results
|
|
||||||
|
|
||||||
|
|
||||||
def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None,
|
def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None,
|
||||||
metrics=(), obj=None, feval=None, maximize=False, early_stopping_rounds=None,
|
metrics=(), obj=None, feval=None, maximize=None, early_stopping_rounds=None,
|
||||||
fpreproc=None, as_pandas=True, verbose_eval=None, show_stdv=True,
|
fpreproc=None, as_pandas=True, verbose_eval=None, show_stdv=True,
|
||||||
seed=0, callbacks=None, shuffle=True):
|
seed=0, callbacks=None, shuffle=True):
|
||||||
# pylint: disable = invalid-name
|
# pylint: disable = invalid-name
|
||||||
@ -467,37 +478,32 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None
|
|||||||
|
|
||||||
# setup callbacks
|
# setup callbacks
|
||||||
callbacks = [] if callbacks is None else callbacks
|
callbacks = [] if callbacks is None else callbacks
|
||||||
if early_stopping_rounds is not None:
|
is_new_callback = _is_new_callback(callbacks)
|
||||||
callbacks.append(callback.early_stop(early_stopping_rounds,
|
if is_new_callback:
|
||||||
maximize=maximize,
|
assert all(isinstance(c, callback.TrainingCallback)
|
||||||
verbose=False))
|
for c in callbacks), "You can't mix new and old callback styles."
|
||||||
|
if isinstance(verbose_eval, bool) and verbose_eval:
|
||||||
if isinstance(verbose_eval, bool) and verbose_eval:
|
callbacks.append(callback.EvaluationMonitor(show_stdv=show_stdv))
|
||||||
callbacks.append(callback.print_evaluation(show_stdv=show_stdv))
|
if early_stopping_rounds:
|
||||||
|
callbacks.append(callback.EarlyStopping(
|
||||||
|
rounds=early_stopping_rounds, maximize=maximize))
|
||||||
|
callbacks = callback.CallbackContainer(callbacks, metric=feval, is_cv=True)
|
||||||
else:
|
else:
|
||||||
if isinstance(verbose_eval, int):
|
callbacks = _configure_deprecated_callbacks(
|
||||||
callbacks.append(callback.print_evaluation(verbose_eval, show_stdv=show_stdv))
|
verbose_eval, early_stopping_rounds, maximize, 0,
|
||||||
|
num_boost_round, feval, None, callbacks,
|
||||||
|
show_stdv=show_stdv, cvfolds=cvfolds)
|
||||||
|
callbacks.before_training(cvfolds)
|
||||||
|
|
||||||
callbacks_before_iter = [
|
booster = _PackedBooster(cvfolds)
|
||||||
cb for cb in callbacks if
|
|
||||||
cb.__dict__.get('before_iteration', False)]
|
|
||||||
callbacks_after_iter = [
|
|
||||||
cb for cb in callbacks if
|
|
||||||
not cb.__dict__.get('before_iteration', False)]
|
|
||||||
|
|
||||||
for i in range(num_boost_round):
|
for i in range(num_boost_round):
|
||||||
for cb in callbacks_before_iter:
|
if callbacks.before_iteration(booster, i, dtrain, None):
|
||||||
cb(CallbackEnv(model=None,
|
break
|
||||||
cvfolds=cvfolds,
|
booster.update(i, obj)
|
||||||
iteration=i,
|
|
||||||
begin_iteration=0,
|
|
||||||
end_iteration=num_boost_round,
|
|
||||||
rank=0,
|
|
||||||
evaluation_result_list=None))
|
|
||||||
for fold in cvfolds:
|
|
||||||
fold.update(i, obj)
|
|
||||||
res = aggcv([f.eval(i, feval) for f in cvfolds])
|
|
||||||
|
|
||||||
|
should_break = callbacks.after_iteration(booster, i, dtrain, None)
|
||||||
|
res = callbacks.aggregated_cv
|
||||||
for key, mean, std in res:
|
for key, mean, std in res:
|
||||||
if key + '-mean' not in results:
|
if key + '-mean' not in results:
|
||||||
results[key + '-mean'] = []
|
results[key + '-mean'] = []
|
||||||
@ -505,18 +511,10 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None
|
|||||||
results[key + '-std'] = []
|
results[key + '-std'] = []
|
||||||
results[key + '-mean'].append(mean)
|
results[key + '-mean'].append(mean)
|
||||||
results[key + '-std'].append(std)
|
results[key + '-std'].append(std)
|
||||||
try:
|
|
||||||
for cb in callbacks_after_iter:
|
if should_break:
|
||||||
cb(CallbackEnv(model=None,
|
|
||||||
cvfolds=cvfolds,
|
|
||||||
iteration=i,
|
|
||||||
begin_iteration=0,
|
|
||||||
end_iteration=num_boost_round,
|
|
||||||
rank=0,
|
|
||||||
evaluation_result_list=res))
|
|
||||||
except EarlyStopException as e:
|
|
||||||
for k in results:
|
for k in results:
|
||||||
results[k] = results[k][:(e.best_iteration + 1)]
|
results[k] = results[k][:(booster.best_iteration + 1)]
|
||||||
break
|
break
|
||||||
if as_pandas:
|
if as_pandas:
|
||||||
try:
|
try:
|
||||||
|
|||||||
@ -5,12 +5,12 @@ import numpy as np
|
|||||||
import xgboost as xgb
|
import xgboost as xgb
|
||||||
sys.path.append("tests/python")
|
sys.path.append("tests/python")
|
||||||
# Don't import the test class, otherwise they will run twice.
|
# Don't import the test class, otherwise they will run twice.
|
||||||
import test_basic_models as test_bm # noqa
|
import test_callback as test_cb # noqa
|
||||||
rng = np.random.RandomState(1994)
|
rng = np.random.RandomState(1994)
|
||||||
|
|
||||||
|
|
||||||
class TestGPUBasicModels(unittest.TestCase):
|
class TestGPUBasicModels(unittest.TestCase):
|
||||||
cputest = test_bm.TestModels()
|
cputest = test_cb.TestCallbacks()
|
||||||
|
|
||||||
def run_cls(self, X, y, deterministic):
|
def run_cls(self, X, y, deterministic):
|
||||||
cls = xgb.XGBClassifier(tree_method='gpu_hist',
|
cls = xgb.XGBClassifier(tree_method='gpu_hist',
|
||||||
@ -36,7 +36,8 @@ class TestGPUBasicModels(unittest.TestCase):
|
|||||||
return hash(model_0), hash(model_1)
|
return hash(model_0), hash(model_1)
|
||||||
|
|
||||||
def test_eta_decay_gpu_hist(self):
|
def test_eta_decay_gpu_hist(self):
|
||||||
self.cputest.run_eta_decay('gpu_hist')
|
self.cputest.run_eta_decay('gpu_hist', True)
|
||||||
|
self.cputest.run_eta_decay('gpu_hist', False)
|
||||||
|
|
||||||
def test_deterministic_gpu_hist(self):
|
def test_deterministic_gpu_hist(self):
|
||||||
kRows = 1000
|
kRows = 1000
|
||||||
|
|||||||
@ -8,7 +8,7 @@ import pytest
|
|||||||
import locale
|
import locale
|
||||||
import tempfile
|
import tempfile
|
||||||
|
|
||||||
dpath = 'demo/data/'
|
dpath = os.path.join(tm.PROJECT_ROOT, 'demo/data/')
|
||||||
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
|
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
|
||||||
dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')
|
dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')
|
||||||
|
|
||||||
@ -110,84 +110,6 @@ class TestModels(unittest.TestCase):
|
|||||||
for jj in range(ii + 1, len(preds_list)):
|
for jj in range(ii + 1, len(preds_list)):
|
||||||
assert np.sum(np.abs(preds_list[ii] - preds_list[jj])) > 0
|
assert np.sum(np.abs(preds_list[ii] - preds_list[jj])) > 0
|
||||||
|
|
||||||
def run_eta_decay(self, tree_method):
|
|
||||||
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
|
|
||||||
num_round = 4
|
|
||||||
|
|
||||||
# learning_rates as a list
|
|
||||||
# init eta with 0 to check whether learning_rates work
|
|
||||||
param = {'max_depth': 2, 'eta': 0, 'verbosity': 0,
|
|
||||||
'objective': 'binary:logistic', 'eval_metric': 'error',
|
|
||||||
'tree_method': tree_method}
|
|
||||||
evals_result = {}
|
|
||||||
bst = xgb.train(param, dtrain, num_round, watchlist,
|
|
||||||
callbacks=[xgb.callback.reset_learning_rate([
|
|
||||||
0.8, 0.7, 0.6, 0.5
|
|
||||||
])],
|
|
||||||
evals_result=evals_result)
|
|
||||||
eval_errors_0 = list(map(float, evals_result['eval']['error']))
|
|
||||||
assert isinstance(bst, xgb.core.Booster)
|
|
||||||
# validation error should decrease, if eta > 0
|
|
||||||
assert eval_errors_0[0] > eval_errors_0[-1]
|
|
||||||
|
|
||||||
# init learning_rate with 0 to check whether learning_rates work
|
|
||||||
param = {'max_depth': 2, 'learning_rate': 0, 'verbosity': 0,
|
|
||||||
'objective': 'binary:logistic', 'eval_metric': 'error',
|
|
||||||
'tree_method': tree_method}
|
|
||||||
evals_result = {}
|
|
||||||
bst = xgb.train(param, dtrain, num_round, watchlist,
|
|
||||||
callbacks=[xgb.callback.reset_learning_rate(
|
|
||||||
[0.8, 0.7, 0.6, 0.5])],
|
|
||||||
evals_result=evals_result)
|
|
||||||
eval_errors_1 = list(map(float, evals_result['eval']['error']))
|
|
||||||
assert isinstance(bst, xgb.core.Booster)
|
|
||||||
# validation error should decrease, if learning_rate > 0
|
|
||||||
assert eval_errors_1[0] > eval_errors_1[-1]
|
|
||||||
|
|
||||||
# check if learning_rates override default value of eta/learning_rate
|
|
||||||
param = {
|
|
||||||
'max_depth': 2, 'verbosity': 0, 'objective': 'binary:logistic',
|
|
||||||
'eval_metric': 'error', 'tree_method': tree_method
|
|
||||||
}
|
|
||||||
evals_result = {}
|
|
||||||
bst = xgb.train(param, dtrain, num_round, watchlist,
|
|
||||||
callbacks=[xgb.callback.reset_learning_rate(
|
|
||||||
[0, 0, 0, 0]
|
|
||||||
)],
|
|
||||||
evals_result=evals_result)
|
|
||||||
eval_errors_2 = list(map(float, evals_result['eval']['error']))
|
|
||||||
assert isinstance(bst, xgb.core.Booster)
|
|
||||||
# validation error should not decrease, if eta/learning_rate = 0
|
|
||||||
assert eval_errors_2[0] == eval_errors_2[-1]
|
|
||||||
|
|
||||||
# learning_rates as a customized decay function
|
|
||||||
def eta_decay(ithround, num_boost_round):
|
|
||||||
return num_boost_round / (ithround + 1)
|
|
||||||
|
|
||||||
evals_result = {}
|
|
||||||
bst = xgb.train(param, dtrain, num_round, watchlist,
|
|
||||||
callbacks=[
|
|
||||||
xgb.callback.reset_learning_rate(eta_decay)
|
|
||||||
],
|
|
||||||
evals_result=evals_result)
|
|
||||||
eval_errors_3 = list(map(float, evals_result['eval']['error']))
|
|
||||||
|
|
||||||
assert isinstance(bst, xgb.core.Booster)
|
|
||||||
|
|
||||||
assert eval_errors_3[0] == eval_errors_2[0]
|
|
||||||
|
|
||||||
for i in range(1, len(eval_errors_0)):
|
|
||||||
assert eval_errors_3[i] != eval_errors_2[i]
|
|
||||||
|
|
||||||
def test_eta_decay_hist(self):
|
|
||||||
self.run_eta_decay('hist')
|
|
||||||
|
|
||||||
def test_eta_decay_approx(self):
|
|
||||||
self.run_eta_decay('approx')
|
|
||||||
|
|
||||||
def test_eta_decay_exact(self):
|
|
||||||
self.run_eta_decay('exact')
|
|
||||||
|
|
||||||
def test_boost_from_prediction(self):
|
def test_boost_from_prediction(self):
|
||||||
# Re-construct dtrain here to avoid modification
|
# Re-construct dtrain here to avoid modification
|
||||||
margined = xgb.DMatrix(dpath + 'agaricus.txt.train')
|
margined = xgb.DMatrix(dpath + 'agaricus.txt.train')
|
||||||
|
|||||||
234
tests/python/test_callback.py
Normal file
234
tests/python/test_callback.py
Normal file
@ -0,0 +1,234 @@
|
|||||||
|
import xgboost as xgb
|
||||||
|
import unittest
|
||||||
|
import pytest
|
||||||
|
import os
|
||||||
|
import testing as tm
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
# We use the dataset for tests.
|
||||||
|
pytestmark = pytest.mark.skipif(**tm.no_sklearn())
|
||||||
|
|
||||||
|
|
||||||
|
class TestCallbacks(unittest.TestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
from sklearn.datasets import load_breast_cancer
|
||||||
|
X, y = load_breast_cancer(return_X_y=True)
|
||||||
|
cls.X = X
|
||||||
|
cls.y = y
|
||||||
|
|
||||||
|
split = int(X.shape[0]*0.8)
|
||||||
|
cls.X_train = X[: split, ...]
|
||||||
|
cls.y_train = y[: split, ...]
|
||||||
|
cls.X_valid = X[split:, ...]
|
||||||
|
cls.y_valid = y[split:, ...]
|
||||||
|
|
||||||
|
def test_evaluation_monitor(self):
|
||||||
|
D_train = xgb.DMatrix(self.X_train, self.y_train)
|
||||||
|
D_valid = xgb.DMatrix(self.X_valid, self.y_valid)
|
||||||
|
evals_result = {}
|
||||||
|
rounds = 10
|
||||||
|
xgb.train({'objective': 'binary:logistic',
|
||||||
|
'eval_metric': 'error'}, D_train,
|
||||||
|
evals=[(D_train, 'Train'), (D_valid, 'Valid')],
|
||||||
|
num_boost_round=rounds,
|
||||||
|
evals_result=evals_result,
|
||||||
|
verbose_eval=True)
|
||||||
|
print('evals_result:', evals_result)
|
||||||
|
assert len(evals_result['Train']['error']) == rounds
|
||||||
|
assert len(evals_result['Valid']['error']) == rounds
|
||||||
|
|
||||||
|
def test_early_stopping(self):
|
||||||
|
D_train = xgb.DMatrix(self.X_train, self.y_train)
|
||||||
|
D_valid = xgb.DMatrix(self.X_valid, self.y_valid)
|
||||||
|
evals_result = {}
|
||||||
|
rounds = 30
|
||||||
|
early_stopping_rounds = 5
|
||||||
|
booster = xgb.train({'objective': 'binary:logistic',
|
||||||
|
'eval_metric': 'error'}, D_train,
|
||||||
|
evals=[(D_train, 'Train'), (D_valid, 'Valid')],
|
||||||
|
num_boost_round=rounds,
|
||||||
|
evals_result=evals_result,
|
||||||
|
verbose_eval=True,
|
||||||
|
early_stopping_rounds=early_stopping_rounds)
|
||||||
|
dump = booster.get_dump(dump_format='json')
|
||||||
|
assert len(dump) - booster.best_iteration == early_stopping_rounds + 1
|
||||||
|
|
||||||
|
def test_early_stopping_custom_eval(self):
|
||||||
|
D_train = xgb.DMatrix(self.X_train, self.y_train)
|
||||||
|
D_valid = xgb.DMatrix(self.X_valid, self.y_valid)
|
||||||
|
early_stopping_rounds = 5
|
||||||
|
booster = xgb.train({'objective': 'binary:logistic',
|
||||||
|
'eval_metric': 'error',
|
||||||
|
'tree_method': 'hist'}, D_train,
|
||||||
|
evals=[(D_train, 'Train'), (D_valid, 'Valid')],
|
||||||
|
feval=tm.eval_error_metric,
|
||||||
|
num_boost_round=1000,
|
||||||
|
early_stopping_rounds=early_stopping_rounds,
|
||||||
|
verbose_eval=False)
|
||||||
|
dump = booster.get_dump(dump_format='json')
|
||||||
|
assert len(dump) - booster.best_iteration == early_stopping_rounds + 1
|
||||||
|
|
||||||
|
def test_early_stopping_customize(self):
|
||||||
|
D_train = xgb.DMatrix(self.X_train, self.y_train)
|
||||||
|
D_valid = xgb.DMatrix(self.X_valid, self.y_valid)
|
||||||
|
early_stopping_rounds = 5
|
||||||
|
early_stop = xgb.callback.EarlyStopping(rounds=early_stopping_rounds,
|
||||||
|
metric_name='CustomErr',
|
||||||
|
data_name='Train')
|
||||||
|
# Specify which dataset and which metric should be used for early stopping.
|
||||||
|
booster = xgb.train(
|
||||||
|
{'objective': 'binary:logistic',
|
||||||
|
'eval_metric': ['error', 'rmse'],
|
||||||
|
'tree_method': 'hist'}, D_train,
|
||||||
|
evals=[(D_train, 'Train'), (D_valid, 'Valid')],
|
||||||
|
feval=tm.eval_error_metric,
|
||||||
|
num_boost_round=1000,
|
||||||
|
callbacks=[early_stop],
|
||||||
|
verbose_eval=False)
|
||||||
|
dump = booster.get_dump(dump_format='json')
|
||||||
|
assert len(dump) - booster.best_iteration == early_stopping_rounds + 1
|
||||||
|
assert len(early_stop.stopping_history['Train']['CustomErr']) == len(dump)
|
||||||
|
|
||||||
|
def test_early_stopping_skl(self):
|
||||||
|
from sklearn.datasets import load_breast_cancer
|
||||||
|
X, y = load_breast_cancer(return_X_y=True)
|
||||||
|
cls = xgb.XGBClassifier()
|
||||||
|
early_stopping_rounds = 5
|
||||||
|
cls.fit(X, y, eval_set=[(X, y)],
|
||||||
|
early_stopping_rounds=early_stopping_rounds, eval_metric='error')
|
||||||
|
booster = cls.get_booster()
|
||||||
|
dump = booster.get_dump(dump_format='json')
|
||||||
|
assert len(dump) - booster.best_iteration == early_stopping_rounds + 1
|
||||||
|
|
||||||
|
def test_early_stopping_custom_eval_skl(self):
|
||||||
|
from sklearn.datasets import load_breast_cancer
|
||||||
|
X, y = load_breast_cancer(return_X_y=True)
|
||||||
|
cls = xgb.XGBClassifier()
|
||||||
|
early_stopping_rounds = 5
|
||||||
|
cls.fit(X, y, eval_set=[(X, y)],
|
||||||
|
early_stopping_rounds=early_stopping_rounds,
|
||||||
|
eval_metric=tm.eval_error_metric)
|
||||||
|
booster = cls.get_booster()
|
||||||
|
dump = booster.get_dump(dump_format='json')
|
||||||
|
assert len(dump) - booster.best_iteration == early_stopping_rounds + 1
|
||||||
|
|
||||||
|
def run_eta_decay(self, tree_method, deprecated_callback):
|
||||||
|
if deprecated_callback:
|
||||||
|
scheduler = xgb.callback.reset_learning_rate
|
||||||
|
else:
|
||||||
|
scheduler = xgb.callback.LearningRateScheduler
|
||||||
|
|
||||||
|
dpath = os.path.join(tm.PROJECT_ROOT, 'demo/data/')
|
||||||
|
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
|
||||||
|
dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')
|
||||||
|
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
|
||||||
|
num_round = 4
|
||||||
|
|
||||||
|
# learning_rates as a list
|
||||||
|
# init eta with 0 to check whether learning_rates work
|
||||||
|
param = {'max_depth': 2, 'eta': 0, 'verbosity': 0,
|
||||||
|
'objective': 'binary:logistic', 'eval_metric': 'error',
|
||||||
|
'tree_method': tree_method}
|
||||||
|
evals_result = {}
|
||||||
|
bst = xgb.train(param, dtrain, num_round, watchlist,
|
||||||
|
callbacks=[scheduler([
|
||||||
|
0.8, 0.7, 0.6, 0.5
|
||||||
|
])],
|
||||||
|
evals_result=evals_result)
|
||||||
|
eval_errors_0 = list(map(float, evals_result['eval']['error']))
|
||||||
|
assert isinstance(bst, xgb.core.Booster)
|
||||||
|
# validation error should decrease, if eta > 0
|
||||||
|
assert eval_errors_0[0] > eval_errors_0[-1]
|
||||||
|
|
||||||
|
# init learning_rate with 0 to check whether learning_rates work
|
||||||
|
param = {'max_depth': 2, 'learning_rate': 0, 'verbosity': 0,
|
||||||
|
'objective': 'binary:logistic', 'eval_metric': 'error',
|
||||||
|
'tree_method': tree_method}
|
||||||
|
evals_result = {}
|
||||||
|
bst = xgb.train(param, dtrain, num_round, watchlist,
|
||||||
|
callbacks=[scheduler(
|
||||||
|
[0.8, 0.7, 0.6, 0.5])],
|
||||||
|
evals_result=evals_result)
|
||||||
|
eval_errors_1 = list(map(float, evals_result['eval']['error']))
|
||||||
|
assert isinstance(bst, xgb.core.Booster)
|
||||||
|
# validation error should decrease, if learning_rate > 0
|
||||||
|
assert eval_errors_1[0] > eval_errors_1[-1]
|
||||||
|
|
||||||
|
# check if learning_rates override default value of eta/learning_rate
|
||||||
|
param = {
|
||||||
|
'max_depth': 2, 'verbosity': 0, 'objective': 'binary:logistic',
|
||||||
|
'eval_metric': 'error', 'tree_method': tree_method
|
||||||
|
}
|
||||||
|
evals_result = {}
|
||||||
|
bst = xgb.train(param, dtrain, num_round, watchlist,
|
||||||
|
callbacks=[scheduler(
|
||||||
|
[0, 0, 0, 0]
|
||||||
|
)],
|
||||||
|
evals_result=evals_result)
|
||||||
|
eval_errors_2 = list(map(float, evals_result['eval']['error']))
|
||||||
|
assert isinstance(bst, xgb.core.Booster)
|
||||||
|
# validation error should not decrease, if eta/learning_rate = 0
|
||||||
|
assert eval_errors_2[0] == eval_errors_2[-1]
|
||||||
|
|
||||||
|
# learning_rates as a customized decay function
|
||||||
|
def eta_decay(ithround, num_boost_round=num_round):
|
||||||
|
return num_boost_round / (ithround + 1)
|
||||||
|
|
||||||
|
evals_result = {}
|
||||||
|
bst = xgb.train(param, dtrain, num_round, watchlist,
|
||||||
|
callbacks=[
|
||||||
|
scheduler(eta_decay)
|
||||||
|
],
|
||||||
|
evals_result=evals_result)
|
||||||
|
eval_errors_3 = list(map(float, evals_result['eval']['error']))
|
||||||
|
|
||||||
|
assert isinstance(bst, xgb.core.Booster)
|
||||||
|
|
||||||
|
assert eval_errors_3[0] == eval_errors_2[0]
|
||||||
|
|
||||||
|
for i in range(1, len(eval_errors_0)):
|
||||||
|
assert eval_errors_3[i] != eval_errors_2[i]
|
||||||
|
|
||||||
|
def test_eta_decay_hist(self):
|
||||||
|
with pytest.deprecated_call():
|
||||||
|
self.run_eta_decay('hist', True)
|
||||||
|
self.run_eta_decay('hist', False)
|
||||||
|
|
||||||
|
def test_eta_decay_approx(self):
|
||||||
|
with pytest.deprecated_call():
|
||||||
|
self.run_eta_decay('approx', True)
|
||||||
|
self.run_eta_decay('approx', False)
|
||||||
|
|
||||||
|
def test_eta_decay_exact(self):
|
||||||
|
with pytest.deprecated_call():
|
||||||
|
self.run_eta_decay('exact', True)
|
||||||
|
self.run_eta_decay('exact', False)
|
||||||
|
|
||||||
|
def test_check_point(self):
|
||||||
|
from sklearn.datasets import load_breast_cancer
|
||||||
|
X, y = load_breast_cancer(return_X_y=True)
|
||||||
|
m = xgb.DMatrix(X, y)
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
check_point = xgb.callback.TrainingCheckPoint(directory=tmpdir,
|
||||||
|
iterations=1,
|
||||||
|
name='model')
|
||||||
|
xgb.train({'objective': 'binary:logistic'}, m,
|
||||||
|
num_boost_round=10,
|
||||||
|
verbose_eval=False,
|
||||||
|
callbacks=[check_point])
|
||||||
|
for i in range(1, 10):
|
||||||
|
assert os.path.exists(
|
||||||
|
os.path.join(tmpdir, 'model_' + str(i) + '.json'))
|
||||||
|
|
||||||
|
check_point = xgb.callback.TrainingCheckPoint(directory=tmpdir,
|
||||||
|
iterations=1,
|
||||||
|
as_pickle=True,
|
||||||
|
name='model')
|
||||||
|
xgb.train({'objective': 'binary:logistic'}, m,
|
||||||
|
num_boost_round=10,
|
||||||
|
verbose_eval=False,
|
||||||
|
callbacks=[check_point])
|
||||||
|
for i in range(1, 10):
|
||||||
|
assert os.path.exists(
|
||||||
|
os.path.join(tmpdir, 'model_' + str(i) + '.pkl'))
|
||||||
@ -119,6 +119,12 @@ def test_aft_demo():
|
|||||||
os.remove('aft_model.json')
|
os.remove('aft_model.json')
|
||||||
|
|
||||||
|
|
||||||
|
def test_callbacks_demo():
|
||||||
|
script = os.path.join(PYTHON_DEMO_DIR, 'callbacks.py')
|
||||||
|
cmd = ['python', script, '--plot=0']
|
||||||
|
subprocess.check_call(cmd)
|
||||||
|
|
||||||
|
|
||||||
# gpu_acceleration is not tested due to covertype dataset is being too huge.
|
# gpu_acceleration is not tested due to covertype dataset is being too huge.
|
||||||
# gamma regression is not tested as it requires running a R script first.
|
# gamma regression is not tested as it requires running a R script first.
|
||||||
# aft viz is not tested due to ploting is not controled
|
# aft viz is not tested due to ploting is not controled
|
||||||
|
|||||||
@ -328,7 +328,7 @@ def test_sklearn_grid_search():
|
|||||||
reg.client = client
|
reg.client = client
|
||||||
model = GridSearchCV(reg, {'max_depth': [2, 4],
|
model = GridSearchCV(reg, {'max_depth': [2, 4],
|
||||||
'n_estimators': [5, 10]},
|
'n_estimators': [5, 10]},
|
||||||
cv=2, verbose=1, iid=True)
|
cv=2, verbose=1)
|
||||||
model.fit(X, y)
|
model.fit(X, y)
|
||||||
# Expect unique results for each parameter value This confirms
|
# Expect unique results for each parameter value This confirms
|
||||||
# sklearn is able to successfully update the parameter
|
# sklearn is able to successfully update the parameter
|
||||||
@ -705,3 +705,42 @@ class TestWithDask:
|
|||||||
@pytest.mark.gtest
|
@pytest.mark.gtest
|
||||||
def test_quantile_same_on_all_workers(self):
|
def test_quantile_same_on_all_workers(self):
|
||||||
self.run_quantile('SameOnAllWorkers')
|
self.run_quantile('SameOnAllWorkers')
|
||||||
|
|
||||||
|
|
||||||
|
class TestDaskCallbacks:
|
||||||
|
@pytest.mark.skipif(**tm.no_sklearn())
|
||||||
|
def test_early_stopping(self, client):
|
||||||
|
from sklearn.datasets import load_breast_cancer
|
||||||
|
X, y = load_breast_cancer(return_X_y=True)
|
||||||
|
X, y = da.from_array(X), da.from_array(y)
|
||||||
|
m = xgb.dask.DaskDMatrix(client, X, y)
|
||||||
|
early_stopping_rounds = 5
|
||||||
|
booster = xgb.dask.train(client, {'objective': 'binary:logistic',
|
||||||
|
'eval_metric': 'error',
|
||||||
|
'tree_method': 'hist'}, m,
|
||||||
|
evals=[(m, 'Train')],
|
||||||
|
num_boost_round=1000,
|
||||||
|
early_stopping_rounds=early_stopping_rounds)['booster']
|
||||||
|
assert hasattr(booster, 'best_score')
|
||||||
|
assert booster.best_iteration == 10
|
||||||
|
dump = booster.get_dump(dump_format='json')
|
||||||
|
assert len(dump) - booster.best_iteration == early_stopping_rounds + 1
|
||||||
|
|
||||||
|
@pytest.mark.skipif(**tm.no_sklearn())
|
||||||
|
def test_early_stopping_custom_eval(self, client):
|
||||||
|
from sklearn.datasets import load_breast_cancer
|
||||||
|
X, y = load_breast_cancer(return_X_y=True)
|
||||||
|
X, y = da.from_array(X), da.from_array(y)
|
||||||
|
m = xgb.dask.DaskDMatrix(client, X, y)
|
||||||
|
early_stopping_rounds = 5
|
||||||
|
booster = xgb.dask.train(
|
||||||
|
client, {'objective': 'binary:logistic',
|
||||||
|
'eval_metric': 'error',
|
||||||
|
'tree_method': 'hist'}, m,
|
||||||
|
evals=[(m, 'Train')],
|
||||||
|
feval=tm.eval_error_metric,
|
||||||
|
num_boost_round=1000,
|
||||||
|
early_stopping_rounds=early_stopping_rounds)['booster']
|
||||||
|
assert hasattr(booster, 'best_score')
|
||||||
|
dump = booster.get_dump(dump_format='json')
|
||||||
|
assert len(dump) - booster.best_iteration == early_stopping_rounds + 1
|
||||||
|
|||||||
@ -240,6 +240,16 @@ def non_increasing(L, tolerance=1e-4):
|
|||||||
return all((y - x) < tolerance for x, y in zip(L, L[1:]))
|
return all((y - x) < tolerance for x, y in zip(L, L[1:]))
|
||||||
|
|
||||||
|
|
||||||
|
def eval_error_metric(predt, dtrain: xgb.DMatrix):
|
||||||
|
label = dtrain.get_label()
|
||||||
|
r = np.zeros(predt.shape)
|
||||||
|
gt = predt > 0.5
|
||||||
|
r[gt] = 1 - label[gt]
|
||||||
|
le = predt <= 0.5
|
||||||
|
r[le] = label[le]
|
||||||
|
return 'CustomErr', np.sum(r)
|
||||||
|
|
||||||
|
|
||||||
CURDIR = os.path.normpath(os.path.abspath(os.path.dirname(__file__)))
|
CURDIR = os.path.normpath(os.path.abspath(os.path.dirname(__file__)))
|
||||||
PROJECT_ROOT = os.path.normpath(
|
PROJECT_ROOT = os.path.normpath(
|
||||||
os.path.join(CURDIR, os.path.pardir, os.path.pardir))
|
os.path.join(CURDIR, os.path.pardir, os.path.pardir))
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user