[PYTHON] Refactor trainnig API to use callback

This commit is contained in:
tqchen 2016-05-19 17:47:11 -07:00
parent 03996dd4e8
commit 149589c583
18 changed files with 492 additions and 278 deletions

View File

@ -73,7 +73,7 @@ endif
# specify tensor path # specify tensor path
.PHONY: clean all lint clean_all doxygen rcpplint pypack Rpack Rbuild Rcheck java .PHONY: clean all lint clean_all doxygen rcpplint pypack Rpack Rbuild Rcheck java pylint
all: lib/libxgboost.a $(XGBOOST_DYLIB) xgboost all: lib/libxgboost.a $(XGBOOST_DYLIB) xgboost
@ -131,8 +131,11 @@ rcpplint:
python2 dmlc-core/scripts/lint.py xgboost ${LINT_LANG} R-package/src python2 dmlc-core/scripts/lint.py xgboost ${LINT_LANG} R-package/src
lint: rcpplint lint: rcpplint
python2 dmlc-core/scripts/lint.py xgboost ${LINT_LANG} include src plugin python2 dmlc-core/scripts/lint.py xgboost ${LINT_LANG} include src plugin python-package
pylint:
flake8 --ignore E501 python-package
flake8 --ignore E501 tests/python
clean: clean:
$(RM) -rf build build_plugin lib bin *~ */*~ */*/*~ */*/*/*~ */*.o */*/*.o */*/*/*.o xgboost $(RM) -rf build build_plugin lib bin *~ */*~ */*/*~ */*/*/*~ */*.o */*/*.o */*/*/*.o xgboost

View File

@ -12,15 +12,18 @@ print ('running cross validation')
# [iteration] metric_name:mean_value+std_value # [iteration] metric_name:mean_value+std_value
# std_value is standard deviation of the metric # std_value is standard deviation of the metric
xgb.cv(param, dtrain, num_round, nfold=5, xgb.cv(param, dtrain, num_round, nfold=5,
metrics={'error'}, seed = 0) metrics={'error'}, seed = 0,
callbacks=[xgb.callback.print_evaluation(show_stdv=True)])
print ('running cross validation, disable standard deviation display') print ('running cross validation, disable standard deviation display')
# do cross validation, this will print result out as # do cross validation, this will print result out as
# [iteration] metric_name:mean_value+std_value # [iteration] metric_name:mean_value+std_value
# std_value is standard deviation of the metric # std_value is standard deviation of the metric
xgb.cv(param, dtrain, num_round, nfold=5, res = xgb.cv(param, dtrain, num_boost_round=10, nfold=5,
metrics={'error'}, seed = 0, show_stdv = False) metrics={'error'}, seed = 0,
callbacks=[xgb.callback.print_evaluation(show_stdv=False),
xgb.callback.early_stop(3)])
print (res)
print ('running cross validation, with preprocessing function') print ('running cross validation, with preprocessing function')
# define the preprocessing function # define the preprocessing function
# used to return the preprocessed training, test data, and parameter # used to return the preprocessed training, test data, and parameter
@ -58,4 +61,3 @@ param = {'max_depth':2, 'eta':1, 'silent':1}
# train with customized objective # train with customized objective
xgb.cv(param, dtrain, num_round, nfold = 5, seed = 0, xgb.cv(param, dtrain, num_round, nfold = 5, seed = 0,
obj = logregobj, feval=evalerror) obj = logregobj, feval=evalerror)

View File

@ -2,7 +2,7 @@
ignore=tests ignore=tests
unexpected-special-method-signature,too-many-nested-blocks disiable=unexpected-special-method-signature,too-many-nested-blocks
dummy-variables-rgx=(unused|)_.* dummy-variables-rgx=(unused|)_.*

View File

@ -0,0 +1,217 @@
# coding: utf-8
# pylint: disable= invalid-name
"""Training Library containing training routines."""
from __future__ import absolute_import
from . import rabit
from .core import EarlyStopException
def _fmt_metric(value, show_stdv=True):
"""format metric string"""
if len(value) == 2:
return '%s:%g' % (value[0], value[1])
elif len(value) == 3:
if show_stdv:
return '%s:%g+%g' % (value[0], value[1], value[2])
else:
return '%s:%g' % (value[0], value[1])
else:
raise ValueError("wrong metric value")
def print_evaluation(period=1, show_stdv=True):
"""Create a callback that print evaluation result.
Parameters
----------
period : int
The period to log the evaluation results
show_stdv : bool, optional
Whether show stdv if provided
Returns
-------
callback : function
A callback that print evaluation every period iterations.
"""
def callback(env):
"""internal function"""
if env.rank != 0 or len(env.evaluation_result_list) == 0:
return
i = env.iteration
if (i % period == 0 or i + 1 == env.begin_iteration):
msg = '\t'.join([_fmt_metric(x, show_stdv) for x in env.evaluation_result_list])
rabit.tracker_print('[%d]\t%s\n' % (i, msg))
return callback
def record_evaluation(eval_result):
"""Create a call back that records the evaluation history into eval_result.
Parameters
----------
eval_result : dict
A dictionary to store the evaluation results.
Returns
-------
callback : function
The requested callback function.
"""
if not isinstance(eval_result, dict):
raise TypeError('eval_result has to be a dictionary')
eval_result.clear()
def init(env):
"""internal function"""
for k, _ in env.evaluation_result_list:
key, metric = k.split('-')
if key not in eval_result:
eval_result[key] = {}
if metric not in eval_result[key]:
eval_result[key][metric] = []
def callback(env):
"""internal function"""
if len(eval_result) == 0:
init(env)
for k, v in env.evaluation_result_list:
key, metric = k.split('-')
eval_result[key][metric].append(v)
return callback
def reset_learning_rate(learning_rates):
"""Reset learning rate after iteration 1
NOTE: the initial learning rate will still take in-effect on first iteration.
Parameters
----------
learning_rates: list or function
List of learning rate for each boosting round
or a customized function that calculates eta in terms of
current number of round and the total number of boosting round (e.g. yields
learning rate decay)
- list l: eta = l[boosting round]
- function f: eta = f(boosting round, num_boost_round)
Returns
-------
callback : function
The requested callback function.
"""
def callback(env):
"""internal function"""
bst = env.model
i = env.iteration
if isinstance(learning_rates, list):
if len(learning_rates) != env.end_iteration:
raise ValueError("Length of list 'learning_rates' has to equal 'num_boost_round'.")
bst.set_param('learning_rate', learning_rates[i])
else:
bst.set_param('learning_rate', learning_rates(i, env.end_iteration))
callback.before_iteration = True
return callback
def early_stop(stopping_rounds, maximize=False, verbose=True):
"""Create a callback that activates early stoppping.
Validation error needs to decrease at least
every <stopping_rounds> round(s) to continue training.
Requires at least one item in evals.
If there's more than one, will use the last.
Returns the model from the last iteration (not the best one).
If early stopping occurs, the model will have three additional fields:
bst.best_score, bst.best_iteration and bst.best_ntree_limit.
(Use bst.best_ntree_limit to get the correct value if num_parallel_tree
and/or num_class appears in the parameters)
Parameters
----------
stopp_rounds : int
The stopping rounds before the trend occur.
maximize : bool
Whether to maximize evaluation metric.
verbose : optional, bool
Whether to print message about early stopping information.
Returns
-------
callback : function
The requested callback function.
"""
state = {}
def init(env):
"""internal function"""
bst = env.model
if len(env.evaluation_result_list) == 0:
raise ValueError('For early stopping you need at least one set in evals.')
if len(env.evaluation_result_list) > 1 and verbose:
msg = ("Multiple eval metrics have been passed: "
"'{0}' will be used for early stopping.\n\n")
rabit.tracker_print(msg.format(env.evaluation_result_list[-1][0]))
maximize_metrics = ('auc', 'map', 'ndcg')
maximize_score = maximize
metric = env.evaluation_result_list[-1][0]
if any(env.evaluation_result_list[-1][0].split('-')[1].startswith(x)
for x in maximize_metrics):
maximize_score = True
if verbose and env.rank == 0:
msg = "Will train until {} hasn't improved in {} rounds.\n"
rabit.tracker_print(msg.format(metric, stopping_rounds))
state['maximize_score'] = maximize_score
state['best_iteration'] = 0
if maximize_score:
state['best_score'] = float('-inf')
else:
state['best_score'] = float('inf')
if bst is not None:
if bst.attr('best_score') is not None:
state['best_score'] = float(bst.attr('best_score'))
state['best_iteration'] = int(bst.attr('best_iteration'))
state['best_msg'] = bst.attr('best_msg')
else:
bst.set_attr(best_iteration=str(state['best_iteration']))
bst.set_attr(best_score=str(state['best_score']))
else:
assert env.cvfolds is not None
def callback(env):
"""internal function"""
score = env.evaluation_result_list[-1][1]
if len(state) == 0:
init(env)
best_score = state['best_score']
best_iteration = state['best_iteration']
maximize_score = state['maximize_score']
if (maximize_score and score > best_score) or \
(not maximize_score and score < best_score):
msg = '[%d]\t%s' % (
env.iteration,
'\t'.join([_fmt_metric(x) for x in env.evaluation_result_list]))
state['best_msg'] = msg
state['best_score'] = score
state['best_iteration'] = env.iteration
# save the property to attributes, so they will occur in checkpoint.
if env.model is not None:
env.model.set_attr(best_score=str(state['best_score']),
best_iteration=str(state['best_iteration']),
best_msg=state['best_msg'])
elif env.iteration - best_iteration >= stopping_rounds:
best_msg = state['best_msg']
if verbose and env.rank == 0:
msg = "Stopping. Best iteration:\n{}\n\n"
rabit.tracker_print(msg.format(best_msg))
raise EarlyStopException(best_iteration)
return callback

View File

@ -1,5 +1,5 @@
# coding: utf-8 # coding: utf-8
# pylint: disable=unused-import, invalid-name, wrong-import-position # pylint: disable= invalid-name, unused-import
"""For compatibility""" """For compatibility"""
from __future__ import absolute_import from __future__ import absolute_import
@ -14,12 +14,14 @@ if PY3:
STRING_TYPES = str, STRING_TYPES = str,
def py_str(x): def py_str(x):
"""convert c string back to python string"""
return x.decode('utf-8') return x.decode('utf-8')
else: else:
# pylint: disable=invalid-name # pylint: disable=invalid-name
STRING_TYPES = basestring, STRING_TYPES = basestring,
def py_str(x): def py_str(x):
"""convert c string back to python string"""
return x return x
try: try:

View File

@ -1,5 +1,6 @@
# coding: utf-8 # coding: utf-8
# pylint: disable=too-many-arguments, too-many-branches # pylint: disable=too-many-arguments, too-many-branches, invalid-name
# pylint: disable=too-many-branches, too-many-lines, W0141
"""Core XGBoost Library.""" """Core XGBoost Library."""
from __future__ import absolute_import from __future__ import absolute_import
@ -22,6 +23,31 @@ class XGBoostError(Exception):
pass pass
class EarlyStopException(Exception):
"""Exception to signal early stopping.
Parameters
----------
best_iteration : int
The best iteration stopped.
"""
def __init__(self, best_iteration):
super(EarlyStopException, self).__init__()
self.best_iteration = best_iteration
# Callback environment used by callbacks
CallbackEnv = collections.namedtuple(
"XGBoostCallbackEnv",
["model",
"cvfolds",
"iteration",
"begin_iteration",
"end_iteration",
"rank",
"evaluation_result_list"])
def from_pystr_to_cstr(data): def from_pystr_to_cstr(data):
"""Convert a list of Python str to C pointer """Convert a list of Python str to C pointer
@ -657,7 +683,7 @@ class Booster(object):
def __copy__(self): def __copy__(self):
return self.__deepcopy__(None) return self.__deepcopy__(None)
def __deepcopy__(self, memo): def __deepcopy__(self, _):
return Booster(model_file=self.save_raw()) return Booster(model_file=self.save_raw())
def copy(self): def copy(self):
@ -975,7 +1001,6 @@ class Booster(object):
_check_call(_LIB.XGBoosterLoadModelFromBuffer(self.handle, ptr, length)) _check_call(_LIB.XGBoosterLoadModelFromBuffer(self.handle, ptr, length))
def dump_model(self, fout, fmap='', with_stats=False): def dump_model(self, fout, fmap='', with_stats=False):
# pylint: disable=consider-using-enumerate
""" """
Dump model into a text file. Dump model into a text file.
@ -1143,10 +1168,12 @@ class Booster(object):
msg = 'feature_names mismatch: {0} {1}' msg = 'feature_names mismatch: {0} {1}'
if dat_missing: if dat_missing:
msg += '\nexpected ' + ', '.join(str(s) for s in dat_missing) + ' in input data' msg += ('\nexpected ' + ', '.join(str(s) for s in dat_missing) +
' in input data')
if my_missing: if my_missing:
msg += '\ntraining data did not have the following fields: ' + ', '.join(str(s) for s in my_missing) msg += ('\ntraining data did not have the following fields: ' +
', '.join(str(s) for s in my_missing))
raise ValueError(msg.format(self.feature_names, raise ValueError(msg.format(self.feature_names,
data.feature_names)) data.feature_names))
@ -1161,23 +1188,25 @@ class Booster(object):
The name of feature map file. The name of feature map file.
bin: int, default None bin: int, default None
The maximum number of bins. The maximum number of bins.
Number of bins equals number of unique split values n_unique, if bins == None or bins > n_unique. Number of bins equals number of unique split values n_unique,
if bins == None or bins > n_unique.
as_pandas : bool, default True as_pandas : bool, default True
Return pd.DataFrame when pandas is installed. Return pd.DataFrame when pandas is installed.
If False or pandas is not installed, return numpy ndarray. If False or pandas is not installed, return numpy ndarray.
Returns Returns
------- -------
a histogram of used splitting values for the specified feature either as numpy array or pandas DataFrame. a histogram of used splitting values for the specified feature
either as numpy array or pandas DataFrame.
""" """
xgdump = self.get_dump(fmap=fmap) xgdump = self.get_dump(fmap=fmap)
values = [] values = []
regexp = re.compile("\[{0}<([\d.Ee+-]+)\]".format(feature)) regexp = re.compile(r"\[{0}<([\d.Ee+-]+)\]".format(feature))
for i in range(len(xgdump)): for i in range(len(xgdump)):
m = re.findall(regexp, xgdump[i]) m = re.findall(regexp, xgdump[i])
values.extend(map(float, m)) values.extend(map(float, m))
n_unique = np.unique(values).shape[0] n_unique = len(np.unique(values))
bins = max(min(n_unique, bins) if bins is not None else n_unique, 1) bins = max(min(n_unique, bins) if bins is not None else n_unique, 1)
nph = np.histogram(values, bins=bins) nph = np.histogram(values, bins=bins)
@ -1187,7 +1216,8 @@ class Booster(object):
if as_pandas and PANDAS_INSTALLED: if as_pandas and PANDAS_INSTALLED:
return DataFrame(nph, columns=['SplitValue', 'Count']) return DataFrame(nph, columns=['SplitValue', 'Count'])
elif as_pandas and not PANDAS_INSTALLED: elif as_pandas and not PANDAS_INSTALLED:
sys.stderr.write("Returning histogram as ndarray (as_pandas == True, but pandas is not installed).") sys.stderr.write(
"Returning histogram as ndarray (as_pandas == True, but pandas is not installed).")
return nph return nph
else: else:
return nph return nph

View File

@ -1,3 +1,6 @@
# coding: utf-8
# pylint: disable= invalid-name
"""Distributed XGBoost Rabit related API.""" """Distributed XGBoost Rabit related API."""
from __future__ import absolute_import from __future__ import absolute_import
import sys import sys
@ -179,7 +182,7 @@ def allreduce(data, op, prepare_fun=None):
else: else:
func_ptr = ctypes.CFUNCTYPE(None, ctypes.c_void_p) func_ptr = ctypes.CFUNCTYPE(None, ctypes.c_void_p)
def pfunc(args): def pfunc(_):
"""prepare function.""" """prepare function."""
prepare_fun(data) prepare_fun(data)
_LIB.RabitAllreduce(buf.ctypes.data_as(ctypes.c_void_p), _LIB.RabitAllreduce(buf.ctypes.data_as(ctypes.c_void_p),

View File

@ -1,5 +1,5 @@
# coding: utf-8 # coding: utf-8
# pylint: disable=too-many-arguments, too-many-locals, invalid-name, fixme # pylint: disable=too-many-arguments, too-many-locals, invalid-name, fixme, E0012, R0912
"""Scikit-Learn Wrapper interface for XGBoost.""" """Scikit-Learn Wrapper interface for XGBoost."""
from __future__ import absolute_import from __future__ import absolute_import
@ -42,6 +42,7 @@ def _objective_decorator(func):
``dmatrix.get_label()`` ``dmatrix.get_label()``
""" """
def inner(preds, dmatrix): def inner(preds, dmatrix):
"""internal function"""
labels = dmatrix.get_label() labels = dmatrix.get_label()
return func(labels, preds) return func(labels, preds)
return inner return inner
@ -183,7 +184,7 @@ class XGBModel(XGBModelBase):
def fit(self, X, y, eval_set=None, eval_metric=None, def fit(self, X, y, eval_set=None, eval_metric=None,
early_stopping_rounds=None, verbose=True): early_stopping_rounds=None, verbose=True):
# pylint: disable=missing-docstring,invalid-name,attribute-defined-outside-init, redefined-variable-type # pylint: disable=missing-docstring,invalid-name,attribute-defined-outside-init
""" """
Fit the gradient boosting model Fit the gradient boosting model
@ -351,7 +352,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
def fit(self, X, y, sample_weight=None, eval_set=None, eval_metric=None, def fit(self, X, y, sample_weight=None, eval_set=None, eval_metric=None,
early_stopping_rounds=None, verbose=True): early_stopping_rounds=None, verbose=True):
# pylint: disable = attribute-defined-outside-init,arguments-differ, redefined-variable-type # pylint: disable = attribute-defined-outside-init,arguments-differ
""" """
Fit gradient boosting classifier Fit gradient boosting classifier

View File

@ -1,20 +1,122 @@
# coding: utf-8 # coding: utf-8
# pylint: disable=too-many-locals, too-many-arguments, invalid-name # pylint: disable=too-many-locals, too-many-arguments, invalid-name
# pylint: disable=too-many-branches # pylint: disable=too-many-branches, too-many-statements
"""Training Library containing training routines.""" """Training Library containing training routines."""
from __future__ import absolute_import from __future__ import absolute_import
import sys
import re
import numpy as np import numpy as np
from .core import Booster, STRING_TYPES, XGBoostError from .core import Booster, STRING_TYPES, XGBoostError, CallbackEnv, EarlyStopException
from .compat import (SKLEARN_INSTALLED, XGBStratifiedKFold) from .compat import (SKLEARN_INSTALLED, XGBStratifiedKFold)
from . import rabit from . import rabit
from . import callback
def _train_internal(params, dtrain,
num_boost_round=10, evals=(),
obj=None, feval=None,
xgb_model=None, callbacks=None):
"""internal training function"""
callbacks = [] if callbacks is None else callbacks
evals = list(evals)
if isinstance(params, dict) \
and 'eval_metric' in params \
and isinstance(params['eval_metric'], list):
params = dict((k, v) for k, v in params.items())
eval_metrics = params['eval_metric']
params.pop("eval_metric", None)
params = list(params.items())
for eval_metric in eval_metrics:
params += [('eval_metric', eval_metric)]
bst = Booster(params, [dtrain] + [d[0] for d in evals])
nboost = 0
num_parallel_tree = 1
if xgb_model is not None:
if not isinstance(xgb_model, STRING_TYPES):
xgb_model = xgb_model.save_raw()
bst = Booster(params, [dtrain] + [d[0] for d in evals], model_file=xgb_model)
nboost = len(bst.get_dump())
else:
bst = Booster(params, [dtrain] + [d[0] for d in evals])
_params = dict(params) if isinstance(params, list) else params
if 'num_parallel_tree' in _params:
num_parallel_tree = _params['num_parallel_tree']
nboost //= num_parallel_tree
if 'num_class' in _params:
nboost //= _params['num_class']
# Distributed code: Load the checkpoint from rabit.
version = bst.load_rabit_checkpoint()
assert(rabit.get_world_size() != 1 or version == 0)
rank = rabit.get_rank()
start_iteration = int(version / 2)
nboost += start_iteration
callbacks_before_iter = [
cb for cb in callbacks if cb.__dict__.get('before_iteration', False)]
callbacks_after_iter = [
cb for cb in callbacks if not cb.__dict__.get('before_iteration', False)]
for i in range(start_iteration, num_boost_round):
for cb in callbacks_before_iter:
cb(CallbackEnv(model=bst,
cvfolds=None,
iteration=i,
begin_iteration=start_iteration,
end_iteration=num_boost_round,
rank=rank,
evaluation_result_list=None))
# Distributed code: need to resume to this point.
# Skip the first update if it is a recovery step.
if version % 2 == 0:
bst.update(dtrain, i, obj)
bst.save_rabit_checkpoint()
version += 1
assert(rabit.get_world_size() == 1 or version == rabit.version_number())
nboost += 1
evaluation_result_list = []
# check evaluation result.
if len(evals) != 0:
bst_eval_set = bst.eval_set(evals, i, feval)
if isinstance(bst_eval_set, STRING_TYPES):
msg = bst_eval_set
else:
msg = bst_eval_set.decode()
res = [x.split(':') for x in msg.split()]
evaluation_result_list = [(k, float(v)) for k, v in res[1:]]
try:
for cb in callbacks_after_iter:
cb(CallbackEnv(model=bst,
cvfolds=None,
iteration=i,
begin_iteration=start_iteration,
end_iteration=num_boost_round,
rank=rank,
evaluation_result_list=evaluation_result_list))
except EarlyStopException:
break
# do checkpoint after evaluation, in case evaluation also updates booster.
bst.save_rabit_checkpoint()
version += 1
if bst.attr('best_score') is not None:
bst.best_score = float(bst.attr('best_score'))
bst.best_iteration = int(bst.attr('best_iteration'))
else:
bst.best_iteration = nboost - 1
bst.best_ntree_limit = (bst.best_iteration + 1) * num_parallel_tree
return bst
def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None, def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
maximize=False, early_stopping_rounds=None, evals_result=None, maximize=False, early_stopping_rounds=None, evals_result=None,
verbose_eval=True, learning_rates=None, xgb_model=None): verbose_eval=True, learning_rates=None, xgb_model=None, callbacks=None):
# pylint: disable=too-many-statements,too-many-branches, attribute-defined-outside-init # pylint: disable=too-many-statements,too-many-branches, attribute-defined-outside-init
"""Train a booster with given parameters. """Train a booster with given parameters.
@ -70,176 +172,37 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
xgb_model : file name of stored xgb model or 'Booster' instance xgb_model : file name of stored xgb model or 'Booster' instance
Xgb model to be loaded before training (allows training continuation). Xgb model to be loaded before training (allows training continuation).
callbacks : list of callback functions
List of callback functions that are applied at end of each iteration.
Returns Returns
------- -------
booster : a trained booster model booster : a trained booster model
""" """
evals = list(evals) callbacks = [] if callbacks is None else callbacks
if isinstance(params, dict) \
and 'eval_metric' in params \
and isinstance(params['eval_metric'], list):
params = dict((k, v) for k, v in params.items())
eval_metrics = params['eval_metric']
params.pop("eval_metric", None)
params = list(params.items())
for eval_metric in eval_metrics:
params += [('eval_metric', eval_metric)]
bst = Booster(params, [dtrain] + [d[0] for d in evals]) # Most of legacy advanced options becomes callbacks
nboost = 0 if isinstance(verbose_eval, bool) and verbose_eval:
num_parallel_tree = 1 callbacks.append(callback.print_evaluation())
if isinstance(verbose_eval, bool):
verbose_eval_every_line = False
else: else:
if isinstance(verbose_eval, int): if isinstance(verbose_eval, int):
verbose_eval_every_line = verbose_eval callbacks.append(callback.print_evaluation(verbose_eval))
verbose_eval = True if verbose_eval_every_line > 0 else False
if rabit.get_rank() != 0: if early_stopping_rounds is not None:
verbose_eval = False callbacks.append(callback.early_stop(early_stopping_rounds,
maximize=maximize,
if xgb_model is not None: verbose=bool(verbose_eval)))
if not isinstance(xgb_model, STRING_TYPES): if learning_rates is not None:
xgb_model = xgb_model.save_raw() callbacks.append(callback.reset_learning_rate(learning_rates))
bst = Booster(params, [dtrain] + [d[0] for d in evals], model_file=xgb_model)
nboost = len(bst.get_dump())
else:
bst = Booster(params, [dtrain] + [d[0] for d in evals])
_params = dict(params) if isinstance(params, list) else params
_eta_param_name = 'eta' if 'eta' in _params else 'learning_rate'
if 'num_parallel_tree' in _params:
num_parallel_tree = _params['num_parallel_tree']
nboost //= num_parallel_tree
if 'num_class' in _params:
nboost //= _params['num_class']
if evals_result is not None: if evals_result is not None:
if not isinstance(evals_result, dict): callbacks.append(callback.record_evaluation(evals_result))
raise TypeError('evals_result has to be a dictionary')
else:
evals_name = [d[1] for d in evals]
evals_result.clear()
evals_result.update(dict([(key, {}) for key in evals_name]))
# early stopping return _train_internal(params, dtrain,
if early_stopping_rounds is not None: num_boost_round=num_boost_round,
if len(evals) < 1: evals=evals,
raise ValueError('For early stopping you need at least one set in evals.') obj=obj, feval=feval,
xgb_model=xgb_model, callbacks=callbacks)
if verbose_eval:
rabit.tracker_print("Will train until {} error hasn't decreased in {} rounds.\n".format(
evals[-1][1], early_stopping_rounds))
# is params a list of tuples? are we using multiple eval metrics?
if isinstance(params, list):
if len(params) != len(dict(params).items()):
params = dict(params)
msg = ("Multiple eval metrics have been passed: "
"'{0}' will be used for early stopping.\n\n")
rabit.tracker_print(msg.format(params['eval_metric']))
else:
params = dict(params)
# either minimize loss or maximize AUC/MAP/NDCG
maximize_score = False
if 'eval_metric' in params:
maximize_metrics = ('auc', 'map', 'ndcg')
if any(params['eval_metric'].startswith(x) for x in maximize_metrics):
maximize_score = True
if feval is not None:
maximize_score = maximize
if maximize_score:
bst.set_attr(best_score='0.0')
else:
bst.set_attr(best_score='inf')
bst.set_attr(best_iteration='0')
if isinstance(learning_rates, list) and len(learning_rates) != num_boost_round:
raise ValueError("Length of list 'learning_rates' has to equal 'num_boost_round'.")
# Distributed code: Load the checkpoint from rabit.
version = bst.load_rabit_checkpoint()
assert(rabit.get_world_size() != 1 or version == 0)
start_iteration = int(version / 2)
nboost += start_iteration
for i in range(start_iteration, num_boost_round):
if learning_rates is not None:
if isinstance(learning_rates, list):
bst.set_param(_eta_param_name, learning_rates[i])
else:
bst.set_param(_eta_param_name, learning_rates(i, num_boost_round))
# Distributed code: need to resume to this point.
# Skip the first update if it is a recovery step.
if version % 2 == 0:
bst.update(dtrain, i, obj)
bst.save_rabit_checkpoint()
version += 1
assert(rabit.get_world_size() == 1 or version == rabit.version_number())
nboost += 1
# check evaluation result.
if len(evals) != 0:
bst_eval_set = bst.eval_set(evals, i, feval)
if isinstance(bst_eval_set, STRING_TYPES):
msg = bst_eval_set
else:
msg = bst_eval_set.decode()
if verbose_eval:
if verbose_eval_every_line:
if i % verbose_eval_every_line == 0 or i == num_boost_round - 1:
rabit.tracker_print(msg + '\n')
else:
rabit.tracker_print(msg + '\n')
if evals_result is not None:
res = re.findall("([0-9a-zA-Z@]+[-]*):-?([0-9.]+).", msg)
for key in evals_name:
evals_idx = evals_name.index(key)
res_per_eval = len(res) // len(evals_name)
for r in range(res_per_eval):
res_item = res[(evals_idx * res_per_eval) + r]
res_key = res_item[0]
res_val = res_item[1]
if res_key in evals_result[key]:
evals_result[key][res_key].append(res_val)
else:
evals_result[key][res_key] = [res_val]
if early_stopping_rounds:
score = float(msg.rsplit(':', 1)[1])
best_score = float(bst.attr('best_score'))
best_iteration = int(bst.attr('best_iteration'))
if (maximize_score and score > best_score) or \
(not maximize_score and score < best_score):
# save the property to attributes, so they will occur in checkpoint.
bst.set_attr(best_score=str(score),
best_iteration=str(nboost - 1),
best_msg=msg)
elif i - best_iteration >= early_stopping_rounds:
best_msg = bst.attr('best_msg')
if verbose_eval:
msg = "Stopping. Best iteration:\n{}\n\n"
rabit.tracker_print(msg.format(best_msg))
break
# do checkpoint after evaluation, in case evaluation also updates booster.
bst.save_rabit_checkpoint()
version += 1
if early_stopping_rounds:
bst.best_score = float(bst.attr('best_score'))
bst.best_iteration = int(bst.attr('best_iteration'))
else:
bst.best_iteration = nboost - 1
bst.best_ntree_limit = (bst.best_iteration + 1) * num_parallel_tree
return bst
class CVPack(object): class CVPack(object):
@ -294,7 +257,7 @@ def mknfold(dall, nfold, param, seed, evals=(), fpreproc=None, stratified=False,
return ret return ret
def aggcv(rlist, show_stdv=True, verbose_eval=None, as_pandas=True, trial=0): def aggcv(rlist):
# pylint: disable=invalid-name # pylint: disable=invalid-name
""" """
Aggregate cross-validation results. Aggregate cross-validation results.
@ -315,50 +278,21 @@ def aggcv(rlist, show_stdv=True, verbose_eval=None, as_pandas=True, trial=0):
if k not in cvmap: if k not in cvmap:
cvmap[k] = [] cvmap[k] = []
cvmap[k].append(float(v)) cvmap[k].append(float(v))
msg = idx msg = idx
if show_stdv:
fmt = '\tcv-{0}:{1}+{2}'
else:
fmt = '\tcv-{0}:{1}'
index = []
results = [] results = []
for k, v in sorted(cvmap.items(), key=lambda x: x[0]): for k, v in sorted(cvmap.items(), key=lambda x: (x[0].startswith('test'), x[0])):
v = np.array(v) v = np.array(v)
if not isinstance(msg, STRING_TYPES): if not isinstance(msg, STRING_TYPES):
msg = msg.decode() msg = msg.decode()
mean, std = np.mean(v), np.std(v) mean, std = np.mean(v), np.std(v)
msg += fmt.format(k, mean, std) results.extend([(k, mean, std)])
index.extend([k + '-mean', k + '-std'])
results.extend([mean, std])
if as_pandas:
try:
import pandas as pd
results = pd.Series(results, index=index)
except ImportError:
if verbose_eval is None:
verbose_eval = True
else:
# if verbose_eval is default (None),
# result will be np.ndarray as it can't hold column name
if verbose_eval is None:
verbose_eval = True
if (isinstance(verbose_eval, int) and verbose_eval > 0 and trial % verbose_eval == 0) or \
(isinstance(verbose_eval, bool) and verbose_eval):
sys.stderr.write(msg + '\n')
sys.stderr.flush()
return results return results
def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None, def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None,
metrics=(), obj=None, feval=None, maximize=False, early_stopping_rounds=None, metrics=(), obj=None, feval=None, maximize=False, early_stopping_rounds=None,
fpreproc=None, as_pandas=True, verbose_eval=None, show_stdv=True, seed=0): fpreproc=None, as_pandas=True, verbose_eval=None, show_stdv=True, seed=0,
callbacks=None):
# pylint: disable = invalid-name # pylint: disable = invalid-name
"""Cross-validation with given paramaters. """Cross-validation with given paramaters.
@ -404,6 +338,8 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None
Results are not affected, and always contains std. Results are not affected, and always contains std.
seed : int seed : int
Seed used to generate the folds (passed to numpy.random.seed). Seed used to generate the folds (passed to numpy.random.seed).
callbacks : list of callback functions
List of callback functions that are applied at end of each iteration.
Returns Returns
------- -------
@ -431,59 +367,63 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None
params.pop("eval_metric", None) params.pop("eval_metric", None)
if early_stopping_rounds is not None: results = {}
if len(metrics) > 1:
msg = ('Check your params. '
'Early stopping works with single eval metric only.')
raise ValueError(msg)
if verbose_eval:
msg = "Will train until cv error hasn't decreased in {} rounds.\n"
sys.stderr.write(msg.format(early_stopping_rounds))
maximize_score = False
if len(metrics) == 1:
maximize_metrics = ('auc', 'map', 'ndcg')
if any(metrics[0].startswith(x) for x in maximize_metrics):
maximize_score = True
if feval is not None:
maximize_score = maximize
if maximize_score:
best_score = 0.0
else:
best_score = float('inf')
best_score_i = 0
results = []
cvfolds = mknfold(dtrain, nfold, params, seed, metrics, fpreproc, stratified, folds) cvfolds = mknfold(dtrain, nfold, params, seed, metrics, fpreproc, stratified, folds)
# setup callbacks
callbacks = [] if callbacks is None else callbacks
if early_stopping_rounds is not None:
callbacks.append(callback.early_stop(early_stopping_rounds,
maximize=maximize,
verbose=False))
if isinstance(verbose_eval, bool) and verbose_eval:
callbacks.append(callback.print_evaluation(show_stdv=show_stdv))
else:
if isinstance(verbose_eval, int):
callbacks.append(callback.print_evaluation(verbose_eval, show_stdv=show_stdv))
callbacks_before_iter = [
cb for cb in callbacks if cb.__dict__.get('before_iteration', False)]
callbacks_after_iter = [
cb for cb in callbacks if not cb.__dict__.get('before_iteration', False)]
for i in range(num_boost_round): for i in range(num_boost_round):
for cb in callbacks_before_iter:
cb(CallbackEnv(model=None,
cvfolds=cvfolds,
iteration=i,
begin_iteration=0,
end_iteration=num_boost_round,
rank=0,
evaluation_result_list=None))
for fold in cvfolds: for fold in cvfolds:
fold.update(i, obj) fold.update(i, obj)
res = aggcv([f.eval(i, feval) for f in cvfolds], res = aggcv([f.eval(i, feval) for f in cvfolds])
show_stdv=show_stdv, verbose_eval=verbose_eval,
as_pandas=as_pandas, trial=i)
results.append(res)
if early_stopping_rounds is not None: for key, mean, std in res:
score = res[0] if key + '-mean' not in results:
if (maximize_score and score > best_score) or \ results[key + '-mean'] = []
(not maximize_score and score < best_score): if key + '-std' not in results:
best_score = score results[key + '-std'] = []
best_score_i = i results[key + '-mean'].append(mean)
elif i - best_score_i >= early_stopping_rounds: results[key + '-std'].append(std)
results = results[:best_score_i + 1] try:
if verbose_eval: for cb in callbacks_after_iter:
msg = "Stopping. Best iteration:\n[{}] cv-mean:{}\tcv-std:{}\n" cb(CallbackEnv(model=None,
sys.stderr.write(msg.format(best_score_i, results[-1][0], results[-1][1])) cvfolds=cvfolds,
break iteration=i,
begin_iteration=0,
end_iteration=num_boost_round,
rank=0,
evaluation_result_list=res))
except EarlyStopException as e:
for k in results.keys():
results[k] = results[k][:(e.best_iteration + 1)]
break
if as_pandas: if as_pandas:
try: try:
import pandas as pd import pandas as pd
results = pd.DataFrame(results) results = pd.DataFrame.from_dict(results)
except ImportError: except ImportError:
results = np.array(results) pass
else:
results = np.array(results)
return results return results

2
rabit

@ -1 +1 @@
Subproject commit e19fced5cbd4e41b10099facae7caa5cd3e6ada3 Subproject commit 8f61535b83e650331459d7f33a1615fa7d27b7bd

View File

@ -35,6 +35,22 @@ class TestBasic(unittest.TestCase):
# assert they are the same # assert they are the same
assert np.sum(np.abs(preds2 - preds)) == 0 assert np.sum(np.abs(preds2 - preds)) == 0
def test_record_results(self):
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')
param = {'max_depth': 2, 'eta': 1, 'silent': 1, 'objective': 'binary:logistic'}
# specify validations set to watch performance
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
num_round = 2
result = {}
res2 = {}
xgb.train(param, dtrain, num_round, watchlist,
callbacks=[xgb.callback.record_evaluation(result)])
xgb.train(param, dtrain, num_round, watchlist,
evals_result=res2)
assert result['train']['error'][0] < 0.1
assert res2 == result
def test_multiclass(self): def test_multiclass(self):
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train') dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
dtest = xgb.DMatrix(dpath + 'agaricus.txt.test') dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')
@ -189,5 +205,5 @@ class TestBasic(unittest.TestCase):
# return np.ndarray # return np.ndarray
cv = xgb.cv(params, dm, num_boost_round=10, nfold=10, as_pandas=False) cv = xgb.cv(params, dm, num_boost_round=10, nfold=10, as_pandas=False)
assert isinstance(cv, np.ndarray) assert isinstance(cv, dict)
assert cv.shape == (10, 4) assert len(cv) == (4)

View File

@ -1,5 +1,5 @@
import xgboost as xgb import xgboost as xgb
import xgboost.testing as tm import testing as tm
import numpy as np import numpy as np
import unittest import unittest

View File

@ -1,5 +1,5 @@
import xgboost as xgb import xgboost as xgb
import xgboost.testing as tm import testing as tm
import numpy as np import numpy as np
import unittest import unittest

View File

@ -1,7 +1,7 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import numpy as np import numpy as np
import xgboost as xgb import xgboost as xgb
import xgboost.testing as tm import testing as tm
import unittest import unittest
try: try:

View File

@ -1,5 +1,5 @@
import xgboost as xgb import xgboost as xgb
import xgboost.testing as tm import testing as tm
import numpy as np import numpy as np
import unittest import unittest

View File

@ -1,7 +1,7 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import numpy as np import numpy as np
import xgboost as xgb import xgboost as xgb
import xgboost.testing as tm import testing as tm
import unittest import unittest
try: try:

View File

@ -1,7 +1,7 @@
import numpy as np import numpy as np
import random import random
import xgboost as xgb import xgboost as xgb
import xgboost.testing as tm import testing as tm
rng = np.random.RandomState(1994) rng = np.random.RandomState(1994)

View File

@ -17,6 +17,6 @@ def _skip_if_no_pandas():
def _skip_if_no_matplotlib(): def _skip_if_no_matplotlib():
try: try:
import matplotlib.pyplot as plt # noqa import matplotlib.pyplot as _ # noqa
except ImportError: except ImportError:
raise nose.SkipTest() raise nose.SkipTest()