Rework Python callback functions. (#6199)
* Define a new callback interface for Python. * Deprecate the old callbacks. * Enable early stopping on dask.
This commit is contained in:
@@ -1,9 +1,17 @@
|
||||
# coding: utf-8
|
||||
# pylint: disable=invalid-name, too-many-statements
|
||||
# pylint: disable=invalid-name, too-many-statements, no-self-use
|
||||
# pylint: disable=too-many-arguments
|
||||
"""Training Library containing training routines."""
|
||||
from abc import ABC
|
||||
import collections
|
||||
import os
|
||||
import pickle
|
||||
from typing import Callable, List
|
||||
import numpy
|
||||
|
||||
from . import rabit
|
||||
from .core import EarlyStopException
|
||||
from .core import EarlyStopException, CallbackEnv
|
||||
from .compat import STRING_TYPES
|
||||
|
||||
|
||||
def _get_callback_context(env):
|
||||
@@ -21,9 +29,9 @@ def _fmt_metric(value, show_stdv=True):
|
||||
return '{0}:{1:.5f}'.format(value[0], value[1])
|
||||
if len(value) == 3:
|
||||
if show_stdv:
|
||||
return '{0}:{1:.5f}+{2:.5f}'.format(value[0], value[1], value[2])
|
||||
return '{0}:{1:.5f}+{2:.5f}'.format(value[0], value[1], value[2])
|
||||
return '{0}:{1:.5f}'.format(value[0], value[1])
|
||||
raise ValueError("wrong metric value")
|
||||
raise ValueError("wrong metric value", value)
|
||||
|
||||
|
||||
def print_evaluation(period=1, show_stdv=True):
|
||||
@@ -253,3 +261,476 @@ def early_stop(stopping_rounds, maximize=False, verbose=True):
|
||||
rabit.tracker_print(msg.format(best_msg))
|
||||
raise EarlyStopException(best_iteration)
|
||||
return callback
|
||||
|
||||
|
||||
# The new implementation of callback functions.
|
||||
# Breaking:
|
||||
# - reset learning rate no longer accepts total boosting rounds
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
class TrainingCallback(ABC):
|
||||
'''Interface for training callback.
|
||||
|
||||
.. versionadded:: 1.3.0
|
||||
|
||||
'''
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def before_training(self, model):
|
||||
'''Run before training starts.'''
|
||||
|
||||
def after_training(self, model):
|
||||
'''Run after training is finished.'''
|
||||
|
||||
def before_iteration(self, model, epoch, evals_log):
|
||||
'''Run before each iteration. Return True when training should stop.'''
|
||||
return False
|
||||
|
||||
def after_iteration(self, model, epoch, evals_log):
|
||||
'''Run after each iteration. Return True when training should stop.'''
|
||||
return False
|
||||
|
||||
|
||||
def _aggcv(rlist):
|
||||
# pylint: disable=invalid-name
|
||||
"""Aggregate cross-validation results.
|
||||
|
||||
"""
|
||||
cvmap = {}
|
||||
idx = rlist[0].split()[0]
|
||||
for line in rlist:
|
||||
arr = line.split()
|
||||
assert idx == arr[0]
|
||||
for metric_idx, it in enumerate(arr[1:]):
|
||||
if not isinstance(it, STRING_TYPES):
|
||||
it = it.decode()
|
||||
k, v = it.split(':')
|
||||
if (metric_idx, k) not in cvmap:
|
||||
cvmap[(metric_idx, k)] = []
|
||||
cvmap[(metric_idx, k)].append(float(v))
|
||||
msg = idx
|
||||
results = []
|
||||
for (metric_idx, k), v in sorted(cvmap.items(), key=lambda x: x[0][0]):
|
||||
v = numpy.array(v)
|
||||
if not isinstance(msg, STRING_TYPES):
|
||||
msg = msg.decode()
|
||||
mean, std = numpy.mean(v), numpy.std(v)
|
||||
results.extend([(k, mean, std)])
|
||||
return results
|
||||
|
||||
|
||||
def _allreduce_metric(score):
|
||||
'''Helper function for computing customized metric in distributed
|
||||
environment. Not strictly correct as many functions don't use mean value
|
||||
as final result.
|
||||
|
||||
'''
|
||||
world = rabit.get_world_size()
|
||||
assert world != 0
|
||||
if world == 1:
|
||||
return score
|
||||
if isinstance(score, tuple): # has mean and stdv
|
||||
raise ValueError(
|
||||
'xgboost.cv function should not be used in distributed environment.')
|
||||
score = numpy.array([score])
|
||||
score = rabit.allreduce(score, rabit.Op.SUM) / world
|
||||
return score[0]
|
||||
|
||||
|
||||
class CallbackContainer:
|
||||
'''A special callback for invoking a list of other callbacks.
|
||||
|
||||
.. versionadded:: 1.3.0
|
||||
|
||||
'''
|
||||
def __init__(self, callbacks: List[TrainingCallback],
|
||||
metric: Callable = None, is_cv: bool = False):
|
||||
self.callbacks = set(callbacks)
|
||||
if metric is not None:
|
||||
msg = 'metric must be callable object for monitoring. For ' + \
|
||||
'builtin metrics, passing them in training parameter' + \
|
||||
' will invoke monitor automatically.'
|
||||
assert callable(metric), msg
|
||||
self.metric = metric
|
||||
self.history = collections.OrderedDict()
|
||||
self.is_cv = is_cv
|
||||
|
||||
if self.is_cv:
|
||||
self.aggregated_cv = None
|
||||
|
||||
def before_training(self, model):
|
||||
'''Function called before training.'''
|
||||
for c in self.callbacks:
|
||||
c.before_training(model=model)
|
||||
|
||||
def after_training(self, model):
|
||||
'''Function called after training.'''
|
||||
for c in self.callbacks:
|
||||
c.after_training(model)
|
||||
|
||||
def before_iteration(self, model, epoch, dtrain, evals):
|
||||
'''Function called before training iteration.'''
|
||||
return any(c.before_iteration(model, epoch, self.history)
|
||||
for c in self.callbacks)
|
||||
|
||||
def _update_history(self, score, epoch):
|
||||
for d in score:
|
||||
name, s = d[0], float(d[1])
|
||||
if self.is_cv:
|
||||
std = float(d[2])
|
||||
s = (s, std)
|
||||
splited_names = name.split('-')
|
||||
data_name = splited_names[0]
|
||||
metric_name = '-'.join(splited_names[1:])
|
||||
s = _allreduce_metric(s)
|
||||
if data_name in self.history:
|
||||
data_history = self.history[data_name]
|
||||
if metric_name in data_history:
|
||||
data_history[metric_name].append(s)
|
||||
else:
|
||||
data_history[metric_name] = [s]
|
||||
else:
|
||||
self.history[data_name] = collections.OrderedDict()
|
||||
self.history[data_name][metric_name] = [s]
|
||||
return False
|
||||
|
||||
def after_iteration(self, model, epoch, dtrain, evals):
|
||||
'''Function called after training iteration.'''
|
||||
if self.is_cv:
|
||||
scores = model.eval(epoch, self.metric)
|
||||
scores = _aggcv(scores)
|
||||
self.aggregated_cv = scores
|
||||
self._update_history(scores, epoch)
|
||||
else:
|
||||
evals = [] if evals is None else evals
|
||||
for _, name in evals:
|
||||
assert name.find('-') == -1, 'Dataset name should not contain `-`'
|
||||
score = model.eval_set(evals, epoch, self.metric)
|
||||
score = score.split()[1:] # into datasets
|
||||
# split up `test-error:0.1234`
|
||||
score = [tuple(s.split(':')) for s in score]
|
||||
self._update_history(score, epoch)
|
||||
ret = any(c.after_iteration(model, epoch, self.history)
|
||||
for c in self.callbacks)
|
||||
return ret
|
||||
|
||||
|
||||
class LearningRateScheduler(TrainingCallback):
|
||||
'''Callback function for scheduling learning rate.
|
||||
|
||||
.. versionadded:: 1.3.0
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
learning_rates : callable/collections.Sequence
|
||||
If it's a callable object, then it should accept an integer parameter
|
||||
`epoch` and returns the corresponding learning rate. Otherwise it
|
||||
should be a sequence like list or tuple with the same size of boosting
|
||||
rounds.
|
||||
|
||||
'''
|
||||
def __init__(self, learning_rates):
|
||||
assert callable(learning_rates) or \
|
||||
isinstance(learning_rates, collections.abc.Sequence)
|
||||
if callable(learning_rates):
|
||||
self.learning_rates = learning_rates
|
||||
else:
|
||||
self.learning_rates = lambda epoch: learning_rates[epoch]
|
||||
super().__init__()
|
||||
|
||||
def after_iteration(self, model, epoch, evals_log):
|
||||
model.set_param('learning_rate', self.learning_rates(epoch))
|
||||
|
||||
|
||||
# pylint: disable=too-many-instance-attributes
|
||||
class EarlyStopping(TrainingCallback):
|
||||
''' Callback function for early stopping
|
||||
|
||||
.. versionadded:: 1.3.0
|
||||
|
||||
Parameters
|
||||
----------
|
||||
rounds : int
|
||||
Early stopping rounds.
|
||||
metric_name : str
|
||||
Name of metric that is used for early stopping.
|
||||
data_name: str
|
||||
Name of dataset that is used for early stopping.
|
||||
maximize : bool
|
||||
Whether to maximize evaluation metric. None means auto (discouraged).
|
||||
save_best : bool
|
||||
Placeholder, the feature is not yet supported.
|
||||
'''
|
||||
def __init__(self,
|
||||
rounds,
|
||||
metric_name=None,
|
||||
data_name=None,
|
||||
maximize=None,
|
||||
save_best=False):
|
||||
self.data = data_name
|
||||
self.metric_name = metric_name
|
||||
self.rounds = rounds
|
||||
self.save_best = save_best
|
||||
# https://github.com/dmlc/xgboost/issues/5531
|
||||
assert self.save_best is False, 'save best is not yet supported.'
|
||||
|
||||
self.maximize = maximize
|
||||
self.stopping_history = {}
|
||||
|
||||
if self.maximize is not None:
|
||||
if self.maximize:
|
||||
self.improve_op = lambda x, y: x > y
|
||||
else:
|
||||
self.improve_op = lambda x, y: x < y
|
||||
|
||||
self.current_rounds = 0
|
||||
self.best_scores = {}
|
||||
super().__init__()
|
||||
|
||||
def _update_rounds(self, score, name, metric, model, epoch):
|
||||
# Just to be compatibility with old behavior before 1.3. We should let
|
||||
# user to decide.
|
||||
if self.maximize is None:
|
||||
maximize_metrics = ('auc', 'aucpr', 'map', 'ndcg', 'auc@',
|
||||
'aucpr@', 'map@', 'ndcg@')
|
||||
if any(metric.startswith(x) for x in maximize_metrics):
|
||||
self.improve_op = lambda x, y: x > y
|
||||
self.maximize = True
|
||||
else:
|
||||
self.improve_op = lambda x, y: x < y
|
||||
self.maximize = False
|
||||
|
||||
if not self.stopping_history: # First round
|
||||
self.current_rounds = 0
|
||||
self.stopping_history[name] = {}
|
||||
self.stopping_history[name][metric] = [score]
|
||||
self.best_scores[name] = {}
|
||||
self.best_scores[name][metric] = [score]
|
||||
model.set_attr(best_score=str(score), best_iteration=str(epoch))
|
||||
elif not self.improve_op(score, self.best_scores[name][metric][-1]):
|
||||
# Not improved
|
||||
self.stopping_history[name][metric].append(score)
|
||||
self.current_rounds += 1
|
||||
else: # Improved
|
||||
self.stopping_history[name][metric].append(score)
|
||||
self.best_scores[name][metric].append(score)
|
||||
record = self.stopping_history[name][metric][-1]
|
||||
model.set_attr(best_score=str(record), best_iteration=str(epoch))
|
||||
self.current_rounds = 0 # reset
|
||||
|
||||
if self.current_rounds >= self.rounds:
|
||||
# Should stop
|
||||
return True
|
||||
return False
|
||||
|
||||
def after_iteration(self, model, epoch, evals_log):
|
||||
msg = 'Must have at least 1 validation dataset for early stopping.'
|
||||
assert len(evals_log.keys()) >= 1, msg
|
||||
data_name = ''
|
||||
if self.data:
|
||||
for d, _ in evals_log.items():
|
||||
if d == self.data:
|
||||
data_name = d
|
||||
if not data_name:
|
||||
raise ValueError('No dataset named:', self.data)
|
||||
else:
|
||||
# Use the last one as default.
|
||||
data_name = list(evals_log.keys())[-1]
|
||||
assert isinstance(data_name, str) and data_name
|
||||
data_log = evals_log[data_name]
|
||||
|
||||
# Filter out scores that can not be used for early stopping.
|
||||
if self.metric_name:
|
||||
metric_name = self.metric_name
|
||||
else:
|
||||
# Use last metric by default.
|
||||
assert isinstance(data_log, collections.OrderedDict)
|
||||
metric_name = list(data_log.keys())[-1]
|
||||
score = data_log[metric_name][-1]
|
||||
return self._update_rounds(score, data_name, metric_name, model, epoch)
|
||||
|
||||
|
||||
class EvaluationMonitor(TrainingCallback):
|
||||
'''Print the evaluation result at each iteration.
|
||||
|
||||
.. versionadded:: 1.3.0
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
metric : callable
|
||||
Extra user defined metric.
|
||||
rank : int
|
||||
Which worker should be used for printing the result.
|
||||
show_stdv : bool
|
||||
Used in cv to show standard deviation. Users should not specify it.
|
||||
'''
|
||||
def __init__(self, rank=0, show_stdv=False):
|
||||
self.printer_rank = rank
|
||||
self.show_stdv = show_stdv
|
||||
super().__init__()
|
||||
|
||||
def _fmt_metric(self, data, metric, score, std):
|
||||
if std is not None and self.show_stdv:
|
||||
msg = '\t{0}:{1:.5f}+{2:.5f}'.format(data + '-' + metric, score, std)
|
||||
else:
|
||||
msg = '\t{0}:{1:.5f}'.format(data + '-' + metric, score)
|
||||
return msg
|
||||
|
||||
def after_iteration(self, model, epoch, evals_log):
|
||||
if not evals_log:
|
||||
return False
|
||||
msg = f'[{epoch}]'
|
||||
if rabit.get_rank() == self.printer_rank:
|
||||
for data, metric in evals_log.items():
|
||||
for metric_name, log in metric.items():
|
||||
if isinstance(log[-1], tuple):
|
||||
score = log[-1][0]
|
||||
stdv = log[-1][1]
|
||||
else:
|
||||
score = log[-1]
|
||||
stdv = None
|
||||
msg += self._fmt_metric(data, metric_name, score, stdv)
|
||||
msg += '\n'
|
||||
rabit.tracker_print(msg)
|
||||
return False
|
||||
|
||||
|
||||
class TrainingCheckPoint(TrainingCallback):
|
||||
'''Checkpointing operation.
|
||||
|
||||
.. versionadded:: 1.3.0
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
directory : os.PathLike
|
||||
Output model directory.
|
||||
name : str
|
||||
pattern of output model file. Models will be saved as name_0.json, name_1.json,
|
||||
name_2.json ....
|
||||
as_pickle : boolean
|
||||
When set to Ture, all training parameters will be saved in pickle format, instead
|
||||
of saving only the model.
|
||||
iterations : int
|
||||
Interval of checkpointing. Checkpointing is slow so setting a larger number can
|
||||
reduce performance hit.
|
||||
|
||||
'''
|
||||
def __init__(self, directory: os.PathLike, name: str = 'model',
|
||||
as_pickle=False, iterations: int = 100):
|
||||
self._path = directory
|
||||
self._name = name
|
||||
self._as_pickle = as_pickle
|
||||
self._iterations = iterations
|
||||
self._epoch = 0
|
||||
super().__init__()
|
||||
|
||||
def after_iteration(self, model, epoch, evals_log):
|
||||
if self._epoch == self._iterations:
|
||||
path = os.path.join(self._path, self._name + '_' + str(epoch) +
|
||||
('.pkl' if self._as_pickle else '.json'))
|
||||
self._epoch = 0
|
||||
if rabit.get_rank() == 0:
|
||||
if self._as_pickle:
|
||||
with open(path, 'wb') as fd:
|
||||
pickle.dump(model, fd)
|
||||
else:
|
||||
model.save_model(path)
|
||||
self._epoch += 1
|
||||
|
||||
|
||||
class LegacyCallbacks:
|
||||
'''Adapter for legacy callback functions.
|
||||
|
||||
.. versionadded:: 1.3.0
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
callbacks : Sequence
|
||||
A sequence of legacy callbacks (callbacks that are not instance of
|
||||
TrainingCallback)
|
||||
start_iteration : int
|
||||
Begining iteration.
|
||||
end_iteration : int
|
||||
End iteration, normally is the number of boosting rounds.
|
||||
evals : Sequence
|
||||
Sequence of evaluation dataset tuples.
|
||||
feval : Custom evaluation metric.
|
||||
'''
|
||||
def __init__(self, callbacks, start_iteration, end_iteration,
|
||||
feval, cvfolds=None):
|
||||
self.callbacks_before_iter = [
|
||||
cb for cb in callbacks
|
||||
if cb.__dict__.get('before_iteration', False)]
|
||||
self.callbacks_after_iter = [
|
||||
cb for cb in callbacks
|
||||
if not cb.__dict__.get('before_iteration', False)]
|
||||
|
||||
self.start_iteration = start_iteration
|
||||
self.end_iteration = end_iteration
|
||||
self.cvfolds = cvfolds
|
||||
|
||||
self.feval = feval
|
||||
assert self.feval is None or callable(self.feval)
|
||||
|
||||
if cvfolds is not None:
|
||||
self.aggregated_cv = None
|
||||
|
||||
super().__init__()
|
||||
|
||||
def before_training(self, model):
|
||||
'''Nothing to do for legacy callbacks'''
|
||||
|
||||
def after_training(self, model):
|
||||
'''Nothing to do for legacy callbacks'''
|
||||
|
||||
def before_iteration(self, model, epoch, dtrain, evals):
|
||||
'''Called before each iteration.'''
|
||||
for cb in self.callbacks_before_iter:
|
||||
rank = rabit.get_rank()
|
||||
cb(CallbackEnv(model=model,
|
||||
cvfolds=self.cvfolds,
|
||||
iteration=epoch,
|
||||
begin_iteration=self.start_iteration,
|
||||
end_iteration=self.end_iteration,
|
||||
rank=rank,
|
||||
evaluation_result_list=None))
|
||||
return False
|
||||
|
||||
def after_iteration(self, model, epoch, dtrain, evals):
|
||||
'''Called after each iteration.'''
|
||||
evaluation_result_list = []
|
||||
if self.cvfolds is not None:
|
||||
scores = model.eval(epoch, self.feval)
|
||||
self.aggregated_cv = _aggcv(scores)
|
||||
evaluation_result_list = self.aggregated_cv
|
||||
|
||||
if evals:
|
||||
# When cv is used, evals are embedded into folds.
|
||||
assert self.cvfolds is None
|
||||
bst_eval_set = model.eval_set(evals, epoch, self.feval)
|
||||
if isinstance(bst_eval_set, STRING_TYPES):
|
||||
msg = bst_eval_set
|
||||
else:
|
||||
msg = bst_eval_set.decode()
|
||||
res = [x.split(':') for x in msg.split()]
|
||||
evaluation_result_list = [(k, float(v)) for k, v in res[1:]]
|
||||
|
||||
try:
|
||||
for cb in self.callbacks_after_iter:
|
||||
rank = rabit.get_rank()
|
||||
cb(CallbackEnv(model=model,
|
||||
cvfolds=self.cvfolds,
|
||||
iteration=epoch,
|
||||
begin_iteration=self.start_iteration,
|
||||
end_iteration=self.end_iteration,
|
||||
rank=rank,
|
||||
evaluation_result_list=evaluation_result_list))
|
||||
except EarlyStopException:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@@ -627,8 +627,8 @@ async def _get_rabit_args(worker_map, client: Client):
|
||||
# evaluation history is instead returned.
|
||||
|
||||
|
||||
async def _train_async(client, params, dtrain: DaskDMatrix,
|
||||
*args, evals=(), **kwargs):
|
||||
async def _train_async(client, params, dtrain: DaskDMatrix, *args, evals=(),
|
||||
early_stopping_rounds=None, **kwargs):
|
||||
_assert_dask_support()
|
||||
client: Client = _xgb_get_client(client)
|
||||
if 'evals_result' in kwargs.keys():
|
||||
@@ -675,6 +675,7 @@ async def _train_async(client, params, dtrain: DaskDMatrix,
|
||||
*args,
|
||||
evals_result=local_history,
|
||||
evals=local_evals,
|
||||
early_stopping_rounds=early_stopping_rounds,
|
||||
**kwargs)
|
||||
ret = {'booster': bst, 'history': local_history}
|
||||
if local_dtrain.num_row() == 0:
|
||||
@@ -694,7 +695,8 @@ async def _train_async(client, params, dtrain: DaskDMatrix,
|
||||
return list(filter(lambda ret: ret is not None, results))[0]
|
||||
|
||||
|
||||
def train(client, params, dtrain, *args, evals=(), **kwargs):
|
||||
def train(client, params, dtrain, *args, evals=(), early_stopping_rounds=None,
|
||||
**kwargs):
|
||||
'''Train XGBoost model.
|
||||
|
||||
.. versionadded:: 1.0.0
|
||||
@@ -724,8 +726,9 @@ def train(client, params, dtrain, *args, evals=(), **kwargs):
|
||||
'''
|
||||
_assert_dask_support()
|
||||
client = _xgb_get_client(client)
|
||||
return client.sync(_train_async, client, params,
|
||||
dtrain=dtrain, *args, evals=evals, **kwargs)
|
||||
return client.sync(
|
||||
_train_async, client, params, dtrain=dtrain, *args, evals=evals,
|
||||
early_stopping_rounds=early_stopping_rounds, **kwargs)
|
||||
|
||||
|
||||
async def _direct_predict_impl(client, data, predict_fn):
|
||||
@@ -1005,6 +1008,7 @@ class DaskScikitLearnBase(XGBModel):
|
||||
base_margin=None,
|
||||
eval_set=None,
|
||||
sample_weight_eval_set=None,
|
||||
early_stopping_rounds=None,
|
||||
verbose=True):
|
||||
'''Fit the regressor.
|
||||
|
||||
@@ -1045,7 +1049,7 @@ class DaskScikitLearnBase(XGBModel):
|
||||
return self.client.sync(_).__await__()
|
||||
|
||||
@property
|
||||
def client(self):
|
||||
def client(self) -> Client:
|
||||
'''The dask client used in this model.'''
|
||||
client = _xgb_get_client(self._client)
|
||||
return client
|
||||
@@ -1059,42 +1063,51 @@ class DaskScikitLearnBase(XGBModel):
|
||||
['estimators', 'model'])
|
||||
class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
|
||||
# pylint: disable=missing-class-docstring
|
||||
async def _fit_async(self,
|
||||
X,
|
||||
y,
|
||||
sample_weights=None,
|
||||
base_margin=None,
|
||||
eval_set=None,
|
||||
sample_weight_eval_set=None,
|
||||
verbose=True):
|
||||
dtrain = await DaskDMatrix(
|
||||
client=self.client, data=X, label=y, weight=sample_weights,
|
||||
base_margin=base_margin, missing=self.missing
|
||||
)
|
||||
async def _fit_async(self, X, y, sample_weights, base_margin, eval_set,
|
||||
sample_weight_eval_set, early_stopping_rounds,
|
||||
verbose):
|
||||
dtrain = await DaskDMatrix(client=self.client,
|
||||
data=X,
|
||||
label=y,
|
||||
weight=sample_weights,
|
||||
base_margin=base_margin,
|
||||
missing=self.missing)
|
||||
params = self.get_xgb_params()
|
||||
evals = await _evaluation_matrices(self.client,
|
||||
eval_set, sample_weight_eval_set,
|
||||
evals = await _evaluation_matrices(self.client, eval_set,
|
||||
sample_weight_eval_set,
|
||||
self.missing)
|
||||
results = await train(client=self.client, params=params, dtrain=dtrain,
|
||||
results = await train(client=self.client,
|
||||
params=params,
|
||||
dtrain=dtrain,
|
||||
num_boost_round=self.get_num_boosting_rounds(),
|
||||
evals=evals, verbose_eval=verbose)
|
||||
evals=evals,
|
||||
verbose_eval=verbose,
|
||||
early_stopping_rounds=early_stopping_rounds)
|
||||
self._Booster = results['booster']
|
||||
# pylint: disable=attribute-defined-outside-init
|
||||
self.evals_result_ = results['history']
|
||||
return self
|
||||
|
||||
# pylint: disable=missing-docstring
|
||||
def fit(self, X, y,
|
||||
def fit(self,
|
||||
X,
|
||||
y,
|
||||
sample_weights=None,
|
||||
base_margin=None,
|
||||
eval_set=None,
|
||||
sample_weight_eval_set=None,
|
||||
early_stopping_rounds=None,
|
||||
verbose=True):
|
||||
_assert_dask_support()
|
||||
return self.client.sync(
|
||||
self._fit_async, X, y, sample_weights, base_margin,
|
||||
eval_set, sample_weight_eval_set, verbose
|
||||
)
|
||||
return self.client.sync(self._fit_async,
|
||||
X=X,
|
||||
y=y,
|
||||
sample_weights=sample_weights,
|
||||
base_margin=base_margin,
|
||||
eval_set=eval_set,
|
||||
sample_weight_eval_set=sample_weight_eval_set,
|
||||
early_stopping_rounds=early_stopping_rounds,
|
||||
verbose=verbose)
|
||||
|
||||
async def _predict_async(
|
||||
self, data, output_margin=False, base_margin=None):
|
||||
@@ -1114,20 +1127,17 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
|
||||
output_margin=output_margin,
|
||||
base_margin=base_margin)
|
||||
|
||||
|
||||
@xgboost_model_doc(
|
||||
'Implementation of the scikit-learn API for XGBoost classification.',
|
||||
['estimators', 'model']
|
||||
)
|
||||
['estimators', 'model'])
|
||||
class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
||||
async def _fit_async(self, X, y,
|
||||
sample_weights=None,
|
||||
base_margin=None,
|
||||
eval_set=None,
|
||||
sample_weight_eval_set=None,
|
||||
verbose=True):
|
||||
async def _fit_async(self, X, y, sample_weights, base_margin, eval_set,
|
||||
sample_weight_eval_set, early_stopping_rounds,
|
||||
verbose):
|
||||
dtrain = await DaskDMatrix(client=self.client,
|
||||
data=X, label=y, weight=sample_weights,
|
||||
data=X,
|
||||
label=y,
|
||||
weight=sample_weights,
|
||||
base_margin=base_margin,
|
||||
missing=self.missing)
|
||||
params = self.get_xgb_params()
|
||||
@@ -1145,28 +1155,40 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
||||
else:
|
||||
params["objective"] = "binary:logistic"
|
||||
|
||||
evals = await _evaluation_matrices(self.client,
|
||||
eval_set, sample_weight_eval_set,
|
||||
evals = await _evaluation_matrices(self.client, eval_set,
|
||||
sample_weight_eval_set,
|
||||
self.missing)
|
||||
results = await train(client=self.client, params=params, dtrain=dtrain,
|
||||
results = await train(client=self.client,
|
||||
params=params,
|
||||
dtrain=dtrain,
|
||||
num_boost_round=self.get_num_boosting_rounds(),
|
||||
evals=evals, verbose_eval=verbose)
|
||||
evals=evals,
|
||||
early_stopping_rounds=early_stopping_rounds,
|
||||
verbose_eval=verbose)
|
||||
self._Booster = results['booster']
|
||||
# pylint: disable=attribute-defined-outside-init
|
||||
self.evals_result_ = results['history']
|
||||
return self
|
||||
|
||||
def fit(self, X, y,
|
||||
def fit(self,
|
||||
X,
|
||||
y,
|
||||
sample_weights=None,
|
||||
base_margin=None,
|
||||
eval_set=None,
|
||||
sample_weight_eval_set=None,
|
||||
early_stopping_rounds=None,
|
||||
verbose=True):
|
||||
_assert_dask_support()
|
||||
return self.client.sync(
|
||||
self._fit_async, X, y, sample_weights, base_margin, eval_set,
|
||||
sample_weight_eval_set, verbose
|
||||
)
|
||||
return self.client.sync(self._fit_async,
|
||||
X=X,
|
||||
y=y,
|
||||
sample_weights=sample_weights,
|
||||
base_margin=base_margin,
|
||||
eval_set=eval_set,
|
||||
sample_weight_eval_set=sample_weight_eval_set,
|
||||
early_stopping_rounds=early_stopping_rounds,
|
||||
verbose=verbose)
|
||||
|
||||
async def _predict_proba_async(self, data, output_margin=False,
|
||||
base_margin=None):
|
||||
|
||||
@@ -2,31 +2,64 @@
|
||||
# pylint: disable=too-many-locals, too-many-arguments, invalid-name
|
||||
# pylint: disable=too-many-branches, too-many-statements
|
||||
"""Training Library containing training routines."""
|
||||
import warnings
|
||||
import numpy as np
|
||||
from .core import Booster, STRING_TYPES, XGBoostError, CallbackEnv
|
||||
from .core import EarlyStopException
|
||||
from .core import Booster, XGBoostError
|
||||
from .compat import (SKLEARN_INSTALLED, XGBStratifiedKFold)
|
||||
from . import rabit
|
||||
from . import callback
|
||||
|
||||
|
||||
def _train_internal(params, dtrain,
|
||||
num_boost_round=10, evals=(),
|
||||
obj=None, feval=None,
|
||||
xgb_model=None, callbacks=None):
|
||||
"""internal training function"""
|
||||
callbacks = [] if callbacks is None else callbacks
|
||||
evals = list(evals)
|
||||
params = params.copy()
|
||||
if isinstance(params, dict) \
|
||||
and 'eval_metric' in params \
|
||||
and isinstance(params['eval_metric'], list):
|
||||
def _configure_deprecated_callbacks(
|
||||
verbose_eval, early_stopping_rounds, maximize, start_iteration,
|
||||
num_boost_round, feval, evals_result, callbacks, show_stdv, cvfolds):
|
||||
link = 'https://xgboost.readthedocs.io/en/latest/python/callbacks.html'
|
||||
warnings.warn(f'Old style callback is deprecated. See: {link}', DeprecationWarning)
|
||||
# Most of legacy advanced options becomes callbacks
|
||||
if early_stopping_rounds is not None:
|
||||
callbacks.append(callback.early_stop(early_stopping_rounds,
|
||||
maximize=maximize,
|
||||
verbose=bool(verbose_eval)))
|
||||
if isinstance(verbose_eval, bool) and verbose_eval:
|
||||
callbacks.append(callback.print_evaluation(show_stdv=show_stdv))
|
||||
else:
|
||||
if isinstance(verbose_eval, int):
|
||||
callbacks.append(callback.print_evaluation(verbose_eval,
|
||||
show_stdv=show_stdv))
|
||||
if evals_result is not None:
|
||||
callbacks.append(callback.record_evaluation(evals_result))
|
||||
callbacks = callback.LegacyCallbacks(
|
||||
callbacks, start_iteration, num_boost_round, feval, cvfolds=cvfolds)
|
||||
return callbacks
|
||||
|
||||
|
||||
def _is_new_callback(callbacks):
|
||||
return any(isinstance(c, callback.TrainingCallback)
|
||||
for c in callbacks) or not callbacks
|
||||
|
||||
|
||||
def _configure_metrics(params):
|
||||
if isinstance(params, dict) and 'eval_metric' in params \
|
||||
and isinstance(params['eval_metric'], list):
|
||||
params = dict((k, v) for k, v in params.items())
|
||||
eval_metrics = params['eval_metric']
|
||||
params.pop("eval_metric", None)
|
||||
params = list(params.items())
|
||||
for eval_metric in eval_metrics:
|
||||
params += [('eval_metric', eval_metric)]
|
||||
return params
|
||||
|
||||
|
||||
def _train_internal(params, dtrain,
|
||||
num_boost_round=10, evals=(),
|
||||
obj=None, feval=None,
|
||||
xgb_model=None, callbacks=None,
|
||||
evals_result=None, maximize=None,
|
||||
verbose_eval=None, early_stopping_rounds=None):
|
||||
"""internal training function"""
|
||||
callbacks = [] if callbacks is None else callbacks
|
||||
evals = list(evals)
|
||||
params = _configure_metrics(params.copy())
|
||||
|
||||
bst = Booster(params, [dtrain] + [d[0] for d in evals])
|
||||
nboost = 0
|
||||
@@ -49,26 +82,29 @@ def _train_internal(params, dtrain,
|
||||
# Distributed code: Load the checkpoint from rabit.
|
||||
version = bst.load_rabit_checkpoint()
|
||||
assert rabit.get_world_size() != 1 or version == 0
|
||||
rank = rabit.get_rank()
|
||||
start_iteration = int(version / 2)
|
||||
nboost += start_iteration
|
||||
|
||||
callbacks_before_iter = [
|
||||
cb for cb in callbacks
|
||||
if cb.__dict__.get('before_iteration', False)]
|
||||
callbacks_after_iter = [
|
||||
cb for cb in callbacks
|
||||
if not cb.__dict__.get('before_iteration', False)]
|
||||
is_new_callback = _is_new_callback(callbacks)
|
||||
if is_new_callback:
|
||||
assert all(isinstance(c, callback.TrainingCallback)
|
||||
for c in callbacks), "You can't mix new and old callback styles."
|
||||
if verbose_eval:
|
||||
callbacks.append(callback.EvaluationMonitor())
|
||||
if early_stopping_rounds:
|
||||
callbacks.append(callback.EarlyStopping(
|
||||
rounds=early_stopping_rounds, maximize=maximize))
|
||||
callbacks = callback.CallbackContainer(callbacks, metric=feval)
|
||||
else:
|
||||
callbacks = _configure_deprecated_callbacks(
|
||||
verbose_eval, early_stopping_rounds, maximize, start_iteration,
|
||||
num_boost_round, feval, evals_result, callbacks,
|
||||
show_stdv=False, cvfolds=None)
|
||||
|
||||
callbacks.before_training(bst)
|
||||
for i in range(start_iteration, num_boost_round):
|
||||
for cb in callbacks_before_iter:
|
||||
cb(CallbackEnv(model=bst,
|
||||
cvfolds=None,
|
||||
iteration=i,
|
||||
begin_iteration=start_iteration,
|
||||
end_iteration=num_boost_round,
|
||||
rank=rank,
|
||||
evaluation_result_list=None))
|
||||
if callbacks.before_iteration(bst, i, dtrain, evals):
|
||||
break
|
||||
# Distributed code: need to resume to this point.
|
||||
# Skip the first update if it is a recovery step.
|
||||
if version % 2 == 0:
|
||||
@@ -79,44 +115,32 @@ def _train_internal(params, dtrain,
|
||||
assert rabit.get_world_size() == 1 or version == rabit.version_number()
|
||||
|
||||
nboost += 1
|
||||
evaluation_result_list = []
|
||||
# check evaluation result.
|
||||
if evals:
|
||||
bst_eval_set = bst.eval_set(evals, i, feval)
|
||||
if isinstance(bst_eval_set, STRING_TYPES):
|
||||
msg = bst_eval_set
|
||||
else:
|
||||
msg = bst_eval_set.decode()
|
||||
res = [x.split(':') for x in msg.split()]
|
||||
evaluation_result_list = [(k, float(v)) for k, v in res[1:]]
|
||||
try:
|
||||
for cb in callbacks_after_iter:
|
||||
cb(CallbackEnv(model=bst,
|
||||
cvfolds=None,
|
||||
iteration=i,
|
||||
begin_iteration=start_iteration,
|
||||
end_iteration=num_boost_round,
|
||||
rank=rank,
|
||||
evaluation_result_list=evaluation_result_list))
|
||||
except EarlyStopException:
|
||||
if callbacks.after_iteration(bst, i, dtrain, evals):
|
||||
break
|
||||
# do checkpoint after evaluation, in case evaluation also updates booster.
|
||||
# do checkpoint after evaluation, in case evaluation also updates
|
||||
# booster.
|
||||
bst.save_rabit_checkpoint()
|
||||
version += 1
|
||||
|
||||
callbacks.after_training(bst)
|
||||
|
||||
if evals_result is not None and is_new_callback:
|
||||
evals_result.update(callbacks.history)
|
||||
|
||||
if bst.attr('best_score') is not None:
|
||||
bst.best_score = float(bst.attr('best_score'))
|
||||
bst.best_iteration = int(bst.attr('best_iteration'))
|
||||
else:
|
||||
bst.best_iteration = nboost - 1
|
||||
bst.best_ntree_limit = (bst.best_iteration + 1) * num_parallel_tree
|
||||
|
||||
# Copy to serialise and unserialise booster to reset state and free training memory
|
||||
# Copy to serialise and unserialise booster to reset state and free
|
||||
# training memory
|
||||
return bst.copy()
|
||||
|
||||
|
||||
def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
|
||||
maximize=False, early_stopping_rounds=None, evals_result=None,
|
||||
maximize=None, early_stopping_rounds=None, evals_result=None,
|
||||
verbose_eval=True, xgb_model=None, callbacks=None):
|
||||
# pylint: disable=too-many-statements,too-many-branches, attribute-defined-outside-init
|
||||
"""Train a booster with given parameters.
|
||||
@@ -189,27 +213,16 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
|
||||
-------
|
||||
Booster : a trained booster model
|
||||
"""
|
||||
callbacks = [] if callbacks is None else callbacks
|
||||
|
||||
# Most of legacy advanced options becomes callbacks
|
||||
if isinstance(verbose_eval, bool) and verbose_eval:
|
||||
callbacks.append(callback.print_evaluation())
|
||||
else:
|
||||
if isinstance(verbose_eval, int):
|
||||
callbacks.append(callback.print_evaluation(verbose_eval))
|
||||
|
||||
if early_stopping_rounds is not None:
|
||||
callbacks.append(callback.early_stop(early_stopping_rounds,
|
||||
maximize=maximize,
|
||||
verbose=bool(verbose_eval)))
|
||||
if evals_result is not None:
|
||||
callbacks.append(callback.record_evaluation(evals_result))
|
||||
|
||||
return _train_internal(params, dtrain,
|
||||
num_boost_round=num_boost_round,
|
||||
evals=evals,
|
||||
obj=obj, feval=feval,
|
||||
xgb_model=xgb_model, callbacks=callbacks)
|
||||
bst = _train_internal(params, dtrain,
|
||||
num_boost_round=num_boost_round,
|
||||
evals=evals,
|
||||
obj=obj, feval=feval,
|
||||
xgb_model=xgb_model, callbacks=callbacks,
|
||||
verbose_eval=verbose_eval,
|
||||
evals_result=evals_result,
|
||||
maximize=maximize,
|
||||
early_stopping_rounds=early_stopping_rounds)
|
||||
return bst
|
||||
|
||||
|
||||
class CVPack(object):
|
||||
@@ -230,6 +243,36 @@ class CVPack(object):
|
||||
return self.bst.eval_set(self.watchlist, iteration, feval)
|
||||
|
||||
|
||||
class _PackedBooster:
|
||||
def __init__(self, cvfolds):
|
||||
self.cvfolds = cvfolds
|
||||
|
||||
def update(self, iteration, obj):
|
||||
'''Iterate through folds for update'''
|
||||
for fold in self.cvfolds:
|
||||
fold.update(iteration, obj)
|
||||
|
||||
def eval(self, iteration, feval):
|
||||
'''Iterate through folds for eval'''
|
||||
result = [f.eval(iteration, feval) for f in self.cvfolds]
|
||||
return result
|
||||
|
||||
def set_attr(self, **kwargs):
|
||||
'''Iterate through folds for setting attributes'''
|
||||
for f in self.cvfolds:
|
||||
f.bst.set_attr(**kwargs)
|
||||
|
||||
def attr(self, key):
|
||||
'''Redirect to booster attr.'''
|
||||
return self.cvfolds[0].bst.attr(key)
|
||||
|
||||
@property
|
||||
def best_iteration(self):
|
||||
'''Get best_iteration'''
|
||||
ret = self.cvfolds[0].bst.attr('best_iteration')
|
||||
return int(ret)
|
||||
|
||||
|
||||
def groups_to_rows(groups, boundaries):
|
||||
"""
|
||||
Given group row boundaries, convert ground indexes to row indexes
|
||||
@@ -334,40 +377,8 @@ def mknfold(dall, nfold, param, seed, evals=(), fpreproc=None, stratified=False,
|
||||
return ret
|
||||
|
||||
|
||||
def aggcv(rlist):
|
||||
# pylint: disable=invalid-name
|
||||
"""
|
||||
Aggregate cross-validation results.
|
||||
|
||||
If verbose_eval is true, progress is displayed in every call. If
|
||||
verbose_eval is an integer, progress will only be displayed every
|
||||
`verbose_eval` trees, tracked via trial.
|
||||
"""
|
||||
cvmap = {}
|
||||
idx = rlist[0].split()[0]
|
||||
for line in rlist:
|
||||
arr = line.split()
|
||||
assert idx == arr[0]
|
||||
for metric_idx, it in enumerate(arr[1:]):
|
||||
if not isinstance(it, STRING_TYPES):
|
||||
it = it.decode()
|
||||
k, v = it.split(':')
|
||||
if (metric_idx, k) not in cvmap:
|
||||
cvmap[(metric_idx, k)] = []
|
||||
cvmap[(metric_idx, k)].append(float(v))
|
||||
msg = idx
|
||||
results = []
|
||||
for (metric_idx, k), v in sorted(cvmap.items(), key=lambda x: x[0][0]):
|
||||
v = np.array(v)
|
||||
if not isinstance(msg, STRING_TYPES):
|
||||
msg = msg.decode()
|
||||
mean, std = np.mean(v), np.std(v)
|
||||
results.extend([(k, mean, std)])
|
||||
return results
|
||||
|
||||
|
||||
def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None,
|
||||
metrics=(), obj=None, feval=None, maximize=False, early_stopping_rounds=None,
|
||||
metrics=(), obj=None, feval=None, maximize=None, early_stopping_rounds=None,
|
||||
fpreproc=None, as_pandas=True, verbose_eval=None, show_stdv=True,
|
||||
seed=0, callbacks=None, shuffle=True):
|
||||
# pylint: disable = invalid-name
|
||||
@@ -467,37 +478,32 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None
|
||||
|
||||
# setup callbacks
|
||||
callbacks = [] if callbacks is None else callbacks
|
||||
if early_stopping_rounds is not None:
|
||||
callbacks.append(callback.early_stop(early_stopping_rounds,
|
||||
maximize=maximize,
|
||||
verbose=False))
|
||||
|
||||
if isinstance(verbose_eval, bool) and verbose_eval:
|
||||
callbacks.append(callback.print_evaluation(show_stdv=show_stdv))
|
||||
is_new_callback = _is_new_callback(callbacks)
|
||||
if is_new_callback:
|
||||
assert all(isinstance(c, callback.TrainingCallback)
|
||||
for c in callbacks), "You can't mix new and old callback styles."
|
||||
if isinstance(verbose_eval, bool) and verbose_eval:
|
||||
callbacks.append(callback.EvaluationMonitor(show_stdv=show_stdv))
|
||||
if early_stopping_rounds:
|
||||
callbacks.append(callback.EarlyStopping(
|
||||
rounds=early_stopping_rounds, maximize=maximize))
|
||||
callbacks = callback.CallbackContainer(callbacks, metric=feval, is_cv=True)
|
||||
else:
|
||||
if isinstance(verbose_eval, int):
|
||||
callbacks.append(callback.print_evaluation(verbose_eval, show_stdv=show_stdv))
|
||||
callbacks = _configure_deprecated_callbacks(
|
||||
verbose_eval, early_stopping_rounds, maximize, 0,
|
||||
num_boost_round, feval, None, callbacks,
|
||||
show_stdv=show_stdv, cvfolds=cvfolds)
|
||||
callbacks.before_training(cvfolds)
|
||||
|
||||
callbacks_before_iter = [
|
||||
cb for cb in callbacks if
|
||||
cb.__dict__.get('before_iteration', False)]
|
||||
callbacks_after_iter = [
|
||||
cb for cb in callbacks if
|
||||
not cb.__dict__.get('before_iteration', False)]
|
||||
booster = _PackedBooster(cvfolds)
|
||||
|
||||
for i in range(num_boost_round):
|
||||
for cb in callbacks_before_iter:
|
||||
cb(CallbackEnv(model=None,
|
||||
cvfolds=cvfolds,
|
||||
iteration=i,
|
||||
begin_iteration=0,
|
||||
end_iteration=num_boost_round,
|
||||
rank=0,
|
||||
evaluation_result_list=None))
|
||||
for fold in cvfolds:
|
||||
fold.update(i, obj)
|
||||
res = aggcv([f.eval(i, feval) for f in cvfolds])
|
||||
if callbacks.before_iteration(booster, i, dtrain, None):
|
||||
break
|
||||
booster.update(i, obj)
|
||||
|
||||
should_break = callbacks.after_iteration(booster, i, dtrain, None)
|
||||
res = callbacks.aggregated_cv
|
||||
for key, mean, std in res:
|
||||
if key + '-mean' not in results:
|
||||
results[key + '-mean'] = []
|
||||
@@ -505,18 +511,10 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None
|
||||
results[key + '-std'] = []
|
||||
results[key + '-mean'].append(mean)
|
||||
results[key + '-std'].append(std)
|
||||
try:
|
||||
for cb in callbacks_after_iter:
|
||||
cb(CallbackEnv(model=None,
|
||||
cvfolds=cvfolds,
|
||||
iteration=i,
|
||||
begin_iteration=0,
|
||||
end_iteration=num_boost_round,
|
||||
rank=0,
|
||||
evaluation_result_list=res))
|
||||
except EarlyStopException as e:
|
||||
|
||||
if should_break:
|
||||
for k in results:
|
||||
results[k] = results[k][:(e.best_iteration + 1)]
|
||||
results[k] = results[k][:(booster.best_iteration + 1)]
|
||||
break
|
||||
if as_pandas:
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user