[PYTHON] Refactor trainnig API to use callback
This commit is contained in:
parent
03996dd4e8
commit
149589c583
7
Makefile
7
Makefile
@ -73,7 +73,7 @@ endif
|
|||||||
|
|
||||||
|
|
||||||
# specify tensor path
|
# specify tensor path
|
||||||
.PHONY: clean all lint clean_all doxygen rcpplint pypack Rpack Rbuild Rcheck java
|
.PHONY: clean all lint clean_all doxygen rcpplint pypack Rpack Rbuild Rcheck java pylint
|
||||||
|
|
||||||
|
|
||||||
all: lib/libxgboost.a $(XGBOOST_DYLIB) xgboost
|
all: lib/libxgboost.a $(XGBOOST_DYLIB) xgboost
|
||||||
@ -131,8 +131,11 @@ rcpplint:
|
|||||||
python2 dmlc-core/scripts/lint.py xgboost ${LINT_LANG} R-package/src
|
python2 dmlc-core/scripts/lint.py xgboost ${LINT_LANG} R-package/src
|
||||||
|
|
||||||
lint: rcpplint
|
lint: rcpplint
|
||||||
python2 dmlc-core/scripts/lint.py xgboost ${LINT_LANG} include src plugin
|
python2 dmlc-core/scripts/lint.py xgboost ${LINT_LANG} include src plugin python-package
|
||||||
|
|
||||||
|
pylint:
|
||||||
|
flake8 --ignore E501 python-package
|
||||||
|
flake8 --ignore E501 tests/python
|
||||||
clean:
|
clean:
|
||||||
$(RM) -rf build build_plugin lib bin *~ */*~ */*/*~ */*/*/*~ */*.o */*/*.o */*/*/*.o xgboost
|
$(RM) -rf build build_plugin lib bin *~ */*~ */*/*~ */*/*/*~ */*.o */*/*.o */*/*/*.o xgboost
|
||||||
|
|
||||||
|
|||||||
@ -12,15 +12,18 @@ print ('running cross validation')
|
|||||||
# [iteration] metric_name:mean_value+std_value
|
# [iteration] metric_name:mean_value+std_value
|
||||||
# std_value is standard deviation of the metric
|
# std_value is standard deviation of the metric
|
||||||
xgb.cv(param, dtrain, num_round, nfold=5,
|
xgb.cv(param, dtrain, num_round, nfold=5,
|
||||||
metrics={'error'}, seed = 0)
|
metrics={'error'}, seed = 0,
|
||||||
|
callbacks=[xgb.callback.print_evaluation(show_stdv=True)])
|
||||||
|
|
||||||
print ('running cross validation, disable standard deviation display')
|
print ('running cross validation, disable standard deviation display')
|
||||||
# do cross validation, this will print result out as
|
# do cross validation, this will print result out as
|
||||||
# [iteration] metric_name:mean_value+std_value
|
# [iteration] metric_name:mean_value+std_value
|
||||||
# std_value is standard deviation of the metric
|
# std_value is standard deviation of the metric
|
||||||
xgb.cv(param, dtrain, num_round, nfold=5,
|
res = xgb.cv(param, dtrain, num_boost_round=10, nfold=5,
|
||||||
metrics={'error'}, seed = 0, show_stdv = False)
|
metrics={'error'}, seed = 0,
|
||||||
|
callbacks=[xgb.callback.print_evaluation(show_stdv=False),
|
||||||
|
xgb.callback.early_stop(3)])
|
||||||
|
print (res)
|
||||||
print ('running cross validation, with preprocessing function')
|
print ('running cross validation, with preprocessing function')
|
||||||
# define the preprocessing function
|
# define the preprocessing function
|
||||||
# used to return the preprocessed training, test data, and parameter
|
# used to return the preprocessed training, test data, and parameter
|
||||||
@ -58,4 +61,3 @@ param = {'max_depth':2, 'eta':1, 'silent':1}
|
|||||||
# train with customized objective
|
# train with customized objective
|
||||||
xgb.cv(param, dtrain, num_round, nfold = 5, seed = 0,
|
xgb.cv(param, dtrain, num_round, nfold = 5, seed = 0,
|
||||||
obj = logregobj, feval=evalerror)
|
obj = logregobj, feval=evalerror)
|
||||||
|
|
||||||
|
|||||||
@ -2,8 +2,8 @@
|
|||||||
|
|
||||||
ignore=tests
|
ignore=tests
|
||||||
|
|
||||||
unexpected-special-method-signature,too-many-nested-blocks
|
disiable=unexpected-special-method-signature,too-many-nested-blocks
|
||||||
|
|
||||||
dummy-variables-rgx=(unused|)_.*
|
dummy-variables-rgx=(unused|)_.*
|
||||||
|
|
||||||
reports=no
|
reports=no
|
||||||
|
|||||||
217
python-package/xgboost/callback.py
Normal file
217
python-package/xgboost/callback.py
Normal file
@ -0,0 +1,217 @@
|
|||||||
|
# coding: utf-8
|
||||||
|
# pylint: disable= invalid-name
|
||||||
|
"""Training Library containing training routines."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
|
||||||
|
from . import rabit
|
||||||
|
from .core import EarlyStopException
|
||||||
|
|
||||||
|
|
||||||
|
def _fmt_metric(value, show_stdv=True):
|
||||||
|
"""format metric string"""
|
||||||
|
if len(value) == 2:
|
||||||
|
return '%s:%g' % (value[0], value[1])
|
||||||
|
elif len(value) == 3:
|
||||||
|
if show_stdv:
|
||||||
|
return '%s:%g+%g' % (value[0], value[1], value[2])
|
||||||
|
else:
|
||||||
|
return '%s:%g' % (value[0], value[1])
|
||||||
|
else:
|
||||||
|
raise ValueError("wrong metric value")
|
||||||
|
|
||||||
|
|
||||||
|
def print_evaluation(period=1, show_stdv=True):
|
||||||
|
"""Create a callback that print evaluation result.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
period : int
|
||||||
|
The period to log the evaluation results
|
||||||
|
|
||||||
|
show_stdv : bool, optional
|
||||||
|
Whether show stdv if provided
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
callback : function
|
||||||
|
A callback that print evaluation every period iterations.
|
||||||
|
"""
|
||||||
|
def callback(env):
|
||||||
|
"""internal function"""
|
||||||
|
if env.rank != 0 or len(env.evaluation_result_list) == 0:
|
||||||
|
return
|
||||||
|
i = env.iteration
|
||||||
|
if (i % period == 0 or i + 1 == env.begin_iteration):
|
||||||
|
msg = '\t'.join([_fmt_metric(x, show_stdv) for x in env.evaluation_result_list])
|
||||||
|
rabit.tracker_print('[%d]\t%s\n' % (i, msg))
|
||||||
|
return callback
|
||||||
|
|
||||||
|
|
||||||
|
def record_evaluation(eval_result):
|
||||||
|
"""Create a call back that records the evaluation history into eval_result.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
eval_result : dict
|
||||||
|
A dictionary to store the evaluation results.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
callback : function
|
||||||
|
The requested callback function.
|
||||||
|
"""
|
||||||
|
if not isinstance(eval_result, dict):
|
||||||
|
raise TypeError('eval_result has to be a dictionary')
|
||||||
|
eval_result.clear()
|
||||||
|
|
||||||
|
def init(env):
|
||||||
|
"""internal function"""
|
||||||
|
for k, _ in env.evaluation_result_list:
|
||||||
|
key, metric = k.split('-')
|
||||||
|
if key not in eval_result:
|
||||||
|
eval_result[key] = {}
|
||||||
|
if metric not in eval_result[key]:
|
||||||
|
eval_result[key][metric] = []
|
||||||
|
|
||||||
|
def callback(env):
|
||||||
|
"""internal function"""
|
||||||
|
if len(eval_result) == 0:
|
||||||
|
init(env)
|
||||||
|
for k, v in env.evaluation_result_list:
|
||||||
|
key, metric = k.split('-')
|
||||||
|
eval_result[key][metric].append(v)
|
||||||
|
return callback
|
||||||
|
|
||||||
|
|
||||||
|
def reset_learning_rate(learning_rates):
|
||||||
|
"""Reset learning rate after iteration 1
|
||||||
|
|
||||||
|
NOTE: the initial learning rate will still take in-effect on first iteration.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
learning_rates: list or function
|
||||||
|
List of learning rate for each boosting round
|
||||||
|
or a customized function that calculates eta in terms of
|
||||||
|
current number of round and the total number of boosting round (e.g. yields
|
||||||
|
learning rate decay)
|
||||||
|
- list l: eta = l[boosting round]
|
||||||
|
- function f: eta = f(boosting round, num_boost_round)
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
callback : function
|
||||||
|
The requested callback function.
|
||||||
|
"""
|
||||||
|
def callback(env):
|
||||||
|
"""internal function"""
|
||||||
|
bst = env.model
|
||||||
|
i = env.iteration
|
||||||
|
if isinstance(learning_rates, list):
|
||||||
|
if len(learning_rates) != env.end_iteration:
|
||||||
|
raise ValueError("Length of list 'learning_rates' has to equal 'num_boost_round'.")
|
||||||
|
bst.set_param('learning_rate', learning_rates[i])
|
||||||
|
else:
|
||||||
|
bst.set_param('learning_rate', learning_rates(i, env.end_iteration))
|
||||||
|
callback.before_iteration = True
|
||||||
|
return callback
|
||||||
|
|
||||||
|
|
||||||
|
def early_stop(stopping_rounds, maximize=False, verbose=True):
|
||||||
|
"""Create a callback that activates early stoppping.
|
||||||
|
|
||||||
|
Validation error needs to decrease at least
|
||||||
|
every <stopping_rounds> round(s) to continue training.
|
||||||
|
Requires at least one item in evals.
|
||||||
|
If there's more than one, will use the last.
|
||||||
|
Returns the model from the last iteration (not the best one).
|
||||||
|
If early stopping occurs, the model will have three additional fields:
|
||||||
|
bst.best_score, bst.best_iteration and bst.best_ntree_limit.
|
||||||
|
(Use bst.best_ntree_limit to get the correct value if num_parallel_tree
|
||||||
|
and/or num_class appears in the parameters)
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
stopp_rounds : int
|
||||||
|
The stopping rounds before the trend occur.
|
||||||
|
|
||||||
|
maximize : bool
|
||||||
|
Whether to maximize evaluation metric.
|
||||||
|
|
||||||
|
verbose : optional, bool
|
||||||
|
Whether to print message about early stopping information.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
callback : function
|
||||||
|
The requested callback function.
|
||||||
|
"""
|
||||||
|
state = {}
|
||||||
|
|
||||||
|
def init(env):
|
||||||
|
"""internal function"""
|
||||||
|
bst = env.model
|
||||||
|
|
||||||
|
if len(env.evaluation_result_list) == 0:
|
||||||
|
raise ValueError('For early stopping you need at least one set in evals.')
|
||||||
|
if len(env.evaluation_result_list) > 1 and verbose:
|
||||||
|
msg = ("Multiple eval metrics have been passed: "
|
||||||
|
"'{0}' will be used for early stopping.\n\n")
|
||||||
|
rabit.tracker_print(msg.format(env.evaluation_result_list[-1][0]))
|
||||||
|
maximize_metrics = ('auc', 'map', 'ndcg')
|
||||||
|
maximize_score = maximize
|
||||||
|
metric = env.evaluation_result_list[-1][0]
|
||||||
|
if any(env.evaluation_result_list[-1][0].split('-')[1].startswith(x)
|
||||||
|
for x in maximize_metrics):
|
||||||
|
maximize_score = True
|
||||||
|
|
||||||
|
if verbose and env.rank == 0:
|
||||||
|
msg = "Will train until {} hasn't improved in {} rounds.\n"
|
||||||
|
rabit.tracker_print(msg.format(metric, stopping_rounds))
|
||||||
|
|
||||||
|
state['maximize_score'] = maximize_score
|
||||||
|
state['best_iteration'] = 0
|
||||||
|
if maximize_score:
|
||||||
|
state['best_score'] = float('-inf')
|
||||||
|
else:
|
||||||
|
state['best_score'] = float('inf')
|
||||||
|
|
||||||
|
if bst is not None:
|
||||||
|
if bst.attr('best_score') is not None:
|
||||||
|
state['best_score'] = float(bst.attr('best_score'))
|
||||||
|
state['best_iteration'] = int(bst.attr('best_iteration'))
|
||||||
|
state['best_msg'] = bst.attr('best_msg')
|
||||||
|
else:
|
||||||
|
bst.set_attr(best_iteration=str(state['best_iteration']))
|
||||||
|
bst.set_attr(best_score=str(state['best_score']))
|
||||||
|
else:
|
||||||
|
assert env.cvfolds is not None
|
||||||
|
|
||||||
|
def callback(env):
|
||||||
|
"""internal function"""
|
||||||
|
score = env.evaluation_result_list[-1][1]
|
||||||
|
if len(state) == 0:
|
||||||
|
init(env)
|
||||||
|
best_score = state['best_score']
|
||||||
|
best_iteration = state['best_iteration']
|
||||||
|
maximize_score = state['maximize_score']
|
||||||
|
if (maximize_score and score > best_score) or \
|
||||||
|
(not maximize_score and score < best_score):
|
||||||
|
msg = '[%d]\t%s' % (
|
||||||
|
env.iteration,
|
||||||
|
'\t'.join([_fmt_metric(x) for x in env.evaluation_result_list]))
|
||||||
|
state['best_msg'] = msg
|
||||||
|
state['best_score'] = score
|
||||||
|
state['best_iteration'] = env.iteration
|
||||||
|
# save the property to attributes, so they will occur in checkpoint.
|
||||||
|
if env.model is not None:
|
||||||
|
env.model.set_attr(best_score=str(state['best_score']),
|
||||||
|
best_iteration=str(state['best_iteration']),
|
||||||
|
best_msg=state['best_msg'])
|
||||||
|
elif env.iteration - best_iteration >= stopping_rounds:
|
||||||
|
best_msg = state['best_msg']
|
||||||
|
if verbose and env.rank == 0:
|
||||||
|
msg = "Stopping. Best iteration:\n{}\n\n"
|
||||||
|
rabit.tracker_print(msg.format(best_msg))
|
||||||
|
raise EarlyStopException(best_iteration)
|
||||||
|
return callback
|
||||||
@ -1,5 +1,5 @@
|
|||||||
# coding: utf-8
|
# coding: utf-8
|
||||||
# pylint: disable=unused-import, invalid-name, wrong-import-position
|
# pylint: disable= invalid-name, unused-import
|
||||||
"""For compatibility"""
|
"""For compatibility"""
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
@ -14,12 +14,14 @@ if PY3:
|
|||||||
STRING_TYPES = str,
|
STRING_TYPES = str,
|
||||||
|
|
||||||
def py_str(x):
|
def py_str(x):
|
||||||
|
"""convert c string back to python string"""
|
||||||
return x.decode('utf-8')
|
return x.decode('utf-8')
|
||||||
else:
|
else:
|
||||||
# pylint: disable=invalid-name
|
# pylint: disable=invalid-name
|
||||||
STRING_TYPES = basestring,
|
STRING_TYPES = basestring,
|
||||||
|
|
||||||
def py_str(x):
|
def py_str(x):
|
||||||
|
"""convert c string back to python string"""
|
||||||
return x
|
return x
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
# coding: utf-8
|
# coding: utf-8
|
||||||
# pylint: disable=too-many-arguments, too-many-branches
|
# pylint: disable=too-many-arguments, too-many-branches, invalid-name
|
||||||
|
# pylint: disable=too-many-branches, too-many-lines, W0141
|
||||||
"""Core XGBoost Library."""
|
"""Core XGBoost Library."""
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
|
|
||||||
@ -22,6 +23,31 @@ class XGBoostError(Exception):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class EarlyStopException(Exception):
|
||||||
|
"""Exception to signal early stopping.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
best_iteration : int
|
||||||
|
The best iteration stopped.
|
||||||
|
"""
|
||||||
|
def __init__(self, best_iteration):
|
||||||
|
super(EarlyStopException, self).__init__()
|
||||||
|
self.best_iteration = best_iteration
|
||||||
|
|
||||||
|
|
||||||
|
# Callback environment used by callbacks
|
||||||
|
CallbackEnv = collections.namedtuple(
|
||||||
|
"XGBoostCallbackEnv",
|
||||||
|
["model",
|
||||||
|
"cvfolds",
|
||||||
|
"iteration",
|
||||||
|
"begin_iteration",
|
||||||
|
"end_iteration",
|
||||||
|
"rank",
|
||||||
|
"evaluation_result_list"])
|
||||||
|
|
||||||
|
|
||||||
def from_pystr_to_cstr(data):
|
def from_pystr_to_cstr(data):
|
||||||
"""Convert a list of Python str to C pointer
|
"""Convert a list of Python str to C pointer
|
||||||
|
|
||||||
@ -657,7 +683,7 @@ class Booster(object):
|
|||||||
def __copy__(self):
|
def __copy__(self):
|
||||||
return self.__deepcopy__(None)
|
return self.__deepcopy__(None)
|
||||||
|
|
||||||
def __deepcopy__(self, memo):
|
def __deepcopy__(self, _):
|
||||||
return Booster(model_file=self.save_raw())
|
return Booster(model_file=self.save_raw())
|
||||||
|
|
||||||
def copy(self):
|
def copy(self):
|
||||||
@ -975,7 +1001,6 @@ class Booster(object):
|
|||||||
_check_call(_LIB.XGBoosterLoadModelFromBuffer(self.handle, ptr, length))
|
_check_call(_LIB.XGBoosterLoadModelFromBuffer(self.handle, ptr, length))
|
||||||
|
|
||||||
def dump_model(self, fout, fmap='', with_stats=False):
|
def dump_model(self, fout, fmap='', with_stats=False):
|
||||||
# pylint: disable=consider-using-enumerate
|
|
||||||
"""
|
"""
|
||||||
Dump model into a text file.
|
Dump model into a text file.
|
||||||
|
|
||||||
@ -1143,10 +1168,12 @@ class Booster(object):
|
|||||||
msg = 'feature_names mismatch: {0} {1}'
|
msg = 'feature_names mismatch: {0} {1}'
|
||||||
|
|
||||||
if dat_missing:
|
if dat_missing:
|
||||||
msg += '\nexpected ' + ', '.join(str(s) for s in dat_missing) + ' in input data'
|
msg += ('\nexpected ' + ', '.join(str(s) for s in dat_missing) +
|
||||||
|
' in input data')
|
||||||
|
|
||||||
if my_missing:
|
if my_missing:
|
||||||
msg += '\ntraining data did not have the following fields: ' + ', '.join(str(s) for s in my_missing)
|
msg += ('\ntraining data did not have the following fields: ' +
|
||||||
|
', '.join(str(s) for s in my_missing))
|
||||||
|
|
||||||
raise ValueError(msg.format(self.feature_names,
|
raise ValueError(msg.format(self.feature_names,
|
||||||
data.feature_names))
|
data.feature_names))
|
||||||
@ -1161,23 +1188,25 @@ class Booster(object):
|
|||||||
The name of feature map file.
|
The name of feature map file.
|
||||||
bin: int, default None
|
bin: int, default None
|
||||||
The maximum number of bins.
|
The maximum number of bins.
|
||||||
Number of bins equals number of unique split values n_unique, if bins == None or bins > n_unique.
|
Number of bins equals number of unique split values n_unique,
|
||||||
|
if bins == None or bins > n_unique.
|
||||||
as_pandas : bool, default True
|
as_pandas : bool, default True
|
||||||
Return pd.DataFrame when pandas is installed.
|
Return pd.DataFrame when pandas is installed.
|
||||||
If False or pandas is not installed, return numpy ndarray.
|
If False or pandas is not installed, return numpy ndarray.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
a histogram of used splitting values for the specified feature either as numpy array or pandas DataFrame.
|
a histogram of used splitting values for the specified feature
|
||||||
|
either as numpy array or pandas DataFrame.
|
||||||
"""
|
"""
|
||||||
xgdump = self.get_dump(fmap=fmap)
|
xgdump = self.get_dump(fmap=fmap)
|
||||||
values = []
|
values = []
|
||||||
regexp = re.compile("\[{0}<([\d.Ee+-]+)\]".format(feature))
|
regexp = re.compile(r"\[{0}<([\d.Ee+-]+)\]".format(feature))
|
||||||
for i in range(len(xgdump)):
|
for i in range(len(xgdump)):
|
||||||
m = re.findall(regexp, xgdump[i])
|
m = re.findall(regexp, xgdump[i])
|
||||||
values.extend(map(float, m))
|
values.extend(map(float, m))
|
||||||
|
|
||||||
n_unique = np.unique(values).shape[0]
|
n_unique = len(np.unique(values))
|
||||||
bins = max(min(n_unique, bins) if bins is not None else n_unique, 1)
|
bins = max(min(n_unique, bins) if bins is not None else n_unique, 1)
|
||||||
|
|
||||||
nph = np.histogram(values, bins=bins)
|
nph = np.histogram(values, bins=bins)
|
||||||
@ -1187,7 +1216,8 @@ class Booster(object):
|
|||||||
if as_pandas and PANDAS_INSTALLED:
|
if as_pandas and PANDAS_INSTALLED:
|
||||||
return DataFrame(nph, columns=['SplitValue', 'Count'])
|
return DataFrame(nph, columns=['SplitValue', 'Count'])
|
||||||
elif as_pandas and not PANDAS_INSTALLED:
|
elif as_pandas and not PANDAS_INSTALLED:
|
||||||
sys.stderr.write("Returning histogram as ndarray (as_pandas == True, but pandas is not installed).")
|
sys.stderr.write(
|
||||||
|
"Returning histogram as ndarray (as_pandas == True, but pandas is not installed).")
|
||||||
return nph
|
return nph
|
||||||
else:
|
else:
|
||||||
return nph
|
return nph
|
||||||
|
|||||||
@ -1,3 +1,6 @@
|
|||||||
|
# coding: utf-8
|
||||||
|
# pylint: disable= invalid-name
|
||||||
|
|
||||||
"""Distributed XGBoost Rabit related API."""
|
"""Distributed XGBoost Rabit related API."""
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
import sys
|
import sys
|
||||||
@ -179,7 +182,7 @@ def allreduce(data, op, prepare_fun=None):
|
|||||||
else:
|
else:
|
||||||
func_ptr = ctypes.CFUNCTYPE(None, ctypes.c_void_p)
|
func_ptr = ctypes.CFUNCTYPE(None, ctypes.c_void_p)
|
||||||
|
|
||||||
def pfunc(args):
|
def pfunc(_):
|
||||||
"""prepare function."""
|
"""prepare function."""
|
||||||
prepare_fun(data)
|
prepare_fun(data)
|
||||||
_LIB.RabitAllreduce(buf.ctypes.data_as(ctypes.c_void_p),
|
_LIB.RabitAllreduce(buf.ctypes.data_as(ctypes.c_void_p),
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
# coding: utf-8
|
# coding: utf-8
|
||||||
# pylint: disable=too-many-arguments, too-many-locals, invalid-name, fixme
|
# pylint: disable=too-many-arguments, too-many-locals, invalid-name, fixme, E0012, R0912
|
||||||
"""Scikit-Learn Wrapper interface for XGBoost."""
|
"""Scikit-Learn Wrapper interface for XGBoost."""
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
|
|
||||||
@ -42,6 +42,7 @@ def _objective_decorator(func):
|
|||||||
``dmatrix.get_label()``
|
``dmatrix.get_label()``
|
||||||
"""
|
"""
|
||||||
def inner(preds, dmatrix):
|
def inner(preds, dmatrix):
|
||||||
|
"""internal function"""
|
||||||
labels = dmatrix.get_label()
|
labels = dmatrix.get_label()
|
||||||
return func(labels, preds)
|
return func(labels, preds)
|
||||||
return inner
|
return inner
|
||||||
@ -183,7 +184,7 @@ class XGBModel(XGBModelBase):
|
|||||||
|
|
||||||
def fit(self, X, y, eval_set=None, eval_metric=None,
|
def fit(self, X, y, eval_set=None, eval_metric=None,
|
||||||
early_stopping_rounds=None, verbose=True):
|
early_stopping_rounds=None, verbose=True):
|
||||||
# pylint: disable=missing-docstring,invalid-name,attribute-defined-outside-init, redefined-variable-type
|
# pylint: disable=missing-docstring,invalid-name,attribute-defined-outside-init
|
||||||
"""
|
"""
|
||||||
Fit the gradient boosting model
|
Fit the gradient boosting model
|
||||||
|
|
||||||
@ -351,7 +352,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
|||||||
|
|
||||||
def fit(self, X, y, sample_weight=None, eval_set=None, eval_metric=None,
|
def fit(self, X, y, sample_weight=None, eval_set=None, eval_metric=None,
|
||||||
early_stopping_rounds=None, verbose=True):
|
early_stopping_rounds=None, verbose=True):
|
||||||
# pylint: disable = attribute-defined-outside-init,arguments-differ, redefined-variable-type
|
# pylint: disable = attribute-defined-outside-init,arguments-differ
|
||||||
"""
|
"""
|
||||||
Fit gradient boosting classifier
|
Fit gradient boosting classifier
|
||||||
|
|
||||||
|
|||||||
@ -1,20 +1,122 @@
|
|||||||
# coding: utf-8
|
# coding: utf-8
|
||||||
# pylint: disable=too-many-locals, too-many-arguments, invalid-name
|
# pylint: disable=too-many-locals, too-many-arguments, invalid-name
|
||||||
# pylint: disable=too-many-branches
|
# pylint: disable=too-many-branches, too-many-statements
|
||||||
"""Training Library containing training routines."""
|
"""Training Library containing training routines."""
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
|
|
||||||
import sys
|
|
||||||
import re
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from .core import Booster, STRING_TYPES, XGBoostError
|
from .core import Booster, STRING_TYPES, XGBoostError, CallbackEnv, EarlyStopException
|
||||||
from .compat import (SKLEARN_INSTALLED, XGBStratifiedKFold)
|
from .compat import (SKLEARN_INSTALLED, XGBStratifiedKFold)
|
||||||
from . import rabit
|
from . import rabit
|
||||||
|
from . import callback
|
||||||
|
|
||||||
|
|
||||||
|
def _train_internal(params, dtrain,
|
||||||
|
num_boost_round=10, evals=(),
|
||||||
|
obj=None, feval=None,
|
||||||
|
xgb_model=None, callbacks=None):
|
||||||
|
"""internal training function"""
|
||||||
|
callbacks = [] if callbacks is None else callbacks
|
||||||
|
evals = list(evals)
|
||||||
|
if isinstance(params, dict) \
|
||||||
|
and 'eval_metric' in params \
|
||||||
|
and isinstance(params['eval_metric'], list):
|
||||||
|
params = dict((k, v) for k, v in params.items())
|
||||||
|
eval_metrics = params['eval_metric']
|
||||||
|
params.pop("eval_metric", None)
|
||||||
|
params = list(params.items())
|
||||||
|
for eval_metric in eval_metrics:
|
||||||
|
params += [('eval_metric', eval_metric)]
|
||||||
|
|
||||||
|
bst = Booster(params, [dtrain] + [d[0] for d in evals])
|
||||||
|
nboost = 0
|
||||||
|
num_parallel_tree = 1
|
||||||
|
|
||||||
|
if xgb_model is not None:
|
||||||
|
if not isinstance(xgb_model, STRING_TYPES):
|
||||||
|
xgb_model = xgb_model.save_raw()
|
||||||
|
bst = Booster(params, [dtrain] + [d[0] for d in evals], model_file=xgb_model)
|
||||||
|
nboost = len(bst.get_dump())
|
||||||
|
else:
|
||||||
|
bst = Booster(params, [dtrain] + [d[0] for d in evals])
|
||||||
|
|
||||||
|
_params = dict(params) if isinstance(params, list) else params
|
||||||
|
|
||||||
|
if 'num_parallel_tree' in _params:
|
||||||
|
num_parallel_tree = _params['num_parallel_tree']
|
||||||
|
nboost //= num_parallel_tree
|
||||||
|
if 'num_class' in _params:
|
||||||
|
nboost //= _params['num_class']
|
||||||
|
|
||||||
|
# Distributed code: Load the checkpoint from rabit.
|
||||||
|
version = bst.load_rabit_checkpoint()
|
||||||
|
assert(rabit.get_world_size() != 1 or version == 0)
|
||||||
|
rank = rabit.get_rank()
|
||||||
|
start_iteration = int(version / 2)
|
||||||
|
nboost += start_iteration
|
||||||
|
|
||||||
|
callbacks_before_iter = [
|
||||||
|
cb for cb in callbacks if cb.__dict__.get('before_iteration', False)]
|
||||||
|
callbacks_after_iter = [
|
||||||
|
cb for cb in callbacks if not cb.__dict__.get('before_iteration', False)]
|
||||||
|
|
||||||
|
for i in range(start_iteration, num_boost_round):
|
||||||
|
for cb in callbacks_before_iter:
|
||||||
|
cb(CallbackEnv(model=bst,
|
||||||
|
cvfolds=None,
|
||||||
|
iteration=i,
|
||||||
|
begin_iteration=start_iteration,
|
||||||
|
end_iteration=num_boost_round,
|
||||||
|
rank=rank,
|
||||||
|
evaluation_result_list=None))
|
||||||
|
# Distributed code: need to resume to this point.
|
||||||
|
# Skip the first update if it is a recovery step.
|
||||||
|
if version % 2 == 0:
|
||||||
|
bst.update(dtrain, i, obj)
|
||||||
|
bst.save_rabit_checkpoint()
|
||||||
|
version += 1
|
||||||
|
|
||||||
|
assert(rabit.get_world_size() == 1 or version == rabit.version_number())
|
||||||
|
|
||||||
|
nboost += 1
|
||||||
|
evaluation_result_list = []
|
||||||
|
# check evaluation result.
|
||||||
|
if len(evals) != 0:
|
||||||
|
bst_eval_set = bst.eval_set(evals, i, feval)
|
||||||
|
if isinstance(bst_eval_set, STRING_TYPES):
|
||||||
|
msg = bst_eval_set
|
||||||
|
else:
|
||||||
|
msg = bst_eval_set.decode()
|
||||||
|
res = [x.split(':') for x in msg.split()]
|
||||||
|
evaluation_result_list = [(k, float(v)) for k, v in res[1:]]
|
||||||
|
try:
|
||||||
|
for cb in callbacks_after_iter:
|
||||||
|
cb(CallbackEnv(model=bst,
|
||||||
|
cvfolds=None,
|
||||||
|
iteration=i,
|
||||||
|
begin_iteration=start_iteration,
|
||||||
|
end_iteration=num_boost_round,
|
||||||
|
rank=rank,
|
||||||
|
evaluation_result_list=evaluation_result_list))
|
||||||
|
except EarlyStopException:
|
||||||
|
break
|
||||||
|
# do checkpoint after evaluation, in case evaluation also updates booster.
|
||||||
|
bst.save_rabit_checkpoint()
|
||||||
|
version += 1
|
||||||
|
|
||||||
|
if bst.attr('best_score') is not None:
|
||||||
|
bst.best_score = float(bst.attr('best_score'))
|
||||||
|
bst.best_iteration = int(bst.attr('best_iteration'))
|
||||||
|
else:
|
||||||
|
bst.best_iteration = nboost - 1
|
||||||
|
bst.best_ntree_limit = (bst.best_iteration + 1) * num_parallel_tree
|
||||||
|
return bst
|
||||||
|
|
||||||
|
|
||||||
def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
|
def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
|
||||||
maximize=False, early_stopping_rounds=None, evals_result=None,
|
maximize=False, early_stopping_rounds=None, evals_result=None,
|
||||||
verbose_eval=True, learning_rates=None, xgb_model=None):
|
verbose_eval=True, learning_rates=None, xgb_model=None, callbacks=None):
|
||||||
# pylint: disable=too-many-statements,too-many-branches, attribute-defined-outside-init
|
# pylint: disable=too-many-statements,too-many-branches, attribute-defined-outside-init
|
||||||
"""Train a booster with given parameters.
|
"""Train a booster with given parameters.
|
||||||
|
|
||||||
@ -70,176 +172,37 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
|
|||||||
xgb_model : file name of stored xgb model or 'Booster' instance
|
xgb_model : file name of stored xgb model or 'Booster' instance
|
||||||
Xgb model to be loaded before training (allows training continuation).
|
Xgb model to be loaded before training (allows training continuation).
|
||||||
|
|
||||||
|
callbacks : list of callback functions
|
||||||
|
List of callback functions that are applied at end of each iteration.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
booster : a trained booster model
|
booster : a trained booster model
|
||||||
"""
|
"""
|
||||||
evals = list(evals)
|
callbacks = [] if callbacks is None else callbacks
|
||||||
if isinstance(params, dict) \
|
|
||||||
and 'eval_metric' in params \
|
|
||||||
and isinstance(params['eval_metric'], list):
|
|
||||||
params = dict((k, v) for k, v in params.items())
|
|
||||||
eval_metrics = params['eval_metric']
|
|
||||||
params.pop("eval_metric", None)
|
|
||||||
params = list(params.items())
|
|
||||||
for eval_metric in eval_metrics:
|
|
||||||
params += [('eval_metric', eval_metric)]
|
|
||||||
|
|
||||||
bst = Booster(params, [dtrain] + [d[0] for d in evals])
|
# Most of legacy advanced options becomes callbacks
|
||||||
nboost = 0
|
if isinstance(verbose_eval, bool) and verbose_eval:
|
||||||
num_parallel_tree = 1
|
callbacks.append(callback.print_evaluation())
|
||||||
|
|
||||||
if isinstance(verbose_eval, bool):
|
|
||||||
verbose_eval_every_line = False
|
|
||||||
else:
|
else:
|
||||||
if isinstance(verbose_eval, int):
|
if isinstance(verbose_eval, int):
|
||||||
verbose_eval_every_line = verbose_eval
|
callbacks.append(callback.print_evaluation(verbose_eval))
|
||||||
verbose_eval = True if verbose_eval_every_line > 0 else False
|
|
||||||
|
|
||||||
if rabit.get_rank() != 0:
|
if early_stopping_rounds is not None:
|
||||||
verbose_eval = False
|
callbacks.append(callback.early_stop(early_stopping_rounds,
|
||||||
|
maximize=maximize,
|
||||||
if xgb_model is not None:
|
verbose=bool(verbose_eval)))
|
||||||
if not isinstance(xgb_model, STRING_TYPES):
|
if learning_rates is not None:
|
||||||
xgb_model = xgb_model.save_raw()
|
callbacks.append(callback.reset_learning_rate(learning_rates))
|
||||||
bst = Booster(params, [dtrain] + [d[0] for d in evals], model_file=xgb_model)
|
|
||||||
nboost = len(bst.get_dump())
|
|
||||||
else:
|
|
||||||
bst = Booster(params, [dtrain] + [d[0] for d in evals])
|
|
||||||
|
|
||||||
_params = dict(params) if isinstance(params, list) else params
|
|
||||||
_eta_param_name = 'eta' if 'eta' in _params else 'learning_rate'
|
|
||||||
if 'num_parallel_tree' in _params:
|
|
||||||
num_parallel_tree = _params['num_parallel_tree']
|
|
||||||
nboost //= num_parallel_tree
|
|
||||||
if 'num_class' in _params:
|
|
||||||
nboost //= _params['num_class']
|
|
||||||
|
|
||||||
if evals_result is not None:
|
if evals_result is not None:
|
||||||
if not isinstance(evals_result, dict):
|
callbacks.append(callback.record_evaluation(evals_result))
|
||||||
raise TypeError('evals_result has to be a dictionary')
|
|
||||||
else:
|
|
||||||
evals_name = [d[1] for d in evals]
|
|
||||||
evals_result.clear()
|
|
||||||
evals_result.update(dict([(key, {}) for key in evals_name]))
|
|
||||||
|
|
||||||
# early stopping
|
return _train_internal(params, dtrain,
|
||||||
if early_stopping_rounds is not None:
|
num_boost_round=num_boost_round,
|
||||||
if len(evals) < 1:
|
evals=evals,
|
||||||
raise ValueError('For early stopping you need at least one set in evals.')
|
obj=obj, feval=feval,
|
||||||
|
xgb_model=xgb_model, callbacks=callbacks)
|
||||||
if verbose_eval:
|
|
||||||
rabit.tracker_print("Will train until {} error hasn't decreased in {} rounds.\n".format(
|
|
||||||
evals[-1][1], early_stopping_rounds))
|
|
||||||
|
|
||||||
# is params a list of tuples? are we using multiple eval metrics?
|
|
||||||
if isinstance(params, list):
|
|
||||||
if len(params) != len(dict(params).items()):
|
|
||||||
params = dict(params)
|
|
||||||
msg = ("Multiple eval metrics have been passed: "
|
|
||||||
"'{0}' will be used for early stopping.\n\n")
|
|
||||||
rabit.tracker_print(msg.format(params['eval_metric']))
|
|
||||||
else:
|
|
||||||
params = dict(params)
|
|
||||||
|
|
||||||
# either minimize loss or maximize AUC/MAP/NDCG
|
|
||||||
maximize_score = False
|
|
||||||
if 'eval_metric' in params:
|
|
||||||
maximize_metrics = ('auc', 'map', 'ndcg')
|
|
||||||
if any(params['eval_metric'].startswith(x) for x in maximize_metrics):
|
|
||||||
maximize_score = True
|
|
||||||
if feval is not None:
|
|
||||||
maximize_score = maximize
|
|
||||||
|
|
||||||
if maximize_score:
|
|
||||||
bst.set_attr(best_score='0.0')
|
|
||||||
else:
|
|
||||||
bst.set_attr(best_score='inf')
|
|
||||||
bst.set_attr(best_iteration='0')
|
|
||||||
|
|
||||||
if isinstance(learning_rates, list) and len(learning_rates) != num_boost_round:
|
|
||||||
raise ValueError("Length of list 'learning_rates' has to equal 'num_boost_round'.")
|
|
||||||
|
|
||||||
# Distributed code: Load the checkpoint from rabit.
|
|
||||||
version = bst.load_rabit_checkpoint()
|
|
||||||
assert(rabit.get_world_size() != 1 or version == 0)
|
|
||||||
start_iteration = int(version / 2)
|
|
||||||
nboost += start_iteration
|
|
||||||
|
|
||||||
for i in range(start_iteration, num_boost_round):
|
|
||||||
if learning_rates is not None:
|
|
||||||
if isinstance(learning_rates, list):
|
|
||||||
bst.set_param(_eta_param_name, learning_rates[i])
|
|
||||||
else:
|
|
||||||
bst.set_param(_eta_param_name, learning_rates(i, num_boost_round))
|
|
||||||
|
|
||||||
# Distributed code: need to resume to this point.
|
|
||||||
# Skip the first update if it is a recovery step.
|
|
||||||
if version % 2 == 0:
|
|
||||||
bst.update(dtrain, i, obj)
|
|
||||||
bst.save_rabit_checkpoint()
|
|
||||||
version += 1
|
|
||||||
|
|
||||||
assert(rabit.get_world_size() == 1 or version == rabit.version_number())
|
|
||||||
|
|
||||||
nboost += 1
|
|
||||||
# check evaluation result.
|
|
||||||
if len(evals) != 0:
|
|
||||||
bst_eval_set = bst.eval_set(evals, i, feval)
|
|
||||||
|
|
||||||
if isinstance(bst_eval_set, STRING_TYPES):
|
|
||||||
msg = bst_eval_set
|
|
||||||
else:
|
|
||||||
msg = bst_eval_set.decode()
|
|
||||||
|
|
||||||
if verbose_eval:
|
|
||||||
if verbose_eval_every_line:
|
|
||||||
if i % verbose_eval_every_line == 0 or i == num_boost_round - 1:
|
|
||||||
rabit.tracker_print(msg + '\n')
|
|
||||||
else:
|
|
||||||
rabit.tracker_print(msg + '\n')
|
|
||||||
|
|
||||||
if evals_result is not None:
|
|
||||||
res = re.findall("([0-9a-zA-Z@]+[-]*):-?([0-9.]+).", msg)
|
|
||||||
for key in evals_name:
|
|
||||||
evals_idx = evals_name.index(key)
|
|
||||||
res_per_eval = len(res) // len(evals_name)
|
|
||||||
for r in range(res_per_eval):
|
|
||||||
res_item = res[(evals_idx * res_per_eval) + r]
|
|
||||||
res_key = res_item[0]
|
|
||||||
res_val = res_item[1]
|
|
||||||
if res_key in evals_result[key]:
|
|
||||||
evals_result[key][res_key].append(res_val)
|
|
||||||
else:
|
|
||||||
evals_result[key][res_key] = [res_val]
|
|
||||||
|
|
||||||
if early_stopping_rounds:
|
|
||||||
score = float(msg.rsplit(':', 1)[1])
|
|
||||||
best_score = float(bst.attr('best_score'))
|
|
||||||
best_iteration = int(bst.attr('best_iteration'))
|
|
||||||
if (maximize_score and score > best_score) or \
|
|
||||||
(not maximize_score and score < best_score):
|
|
||||||
# save the property to attributes, so they will occur in checkpoint.
|
|
||||||
bst.set_attr(best_score=str(score),
|
|
||||||
best_iteration=str(nboost - 1),
|
|
||||||
best_msg=msg)
|
|
||||||
elif i - best_iteration >= early_stopping_rounds:
|
|
||||||
best_msg = bst.attr('best_msg')
|
|
||||||
if verbose_eval:
|
|
||||||
msg = "Stopping. Best iteration:\n{}\n\n"
|
|
||||||
rabit.tracker_print(msg.format(best_msg))
|
|
||||||
break
|
|
||||||
# do checkpoint after evaluation, in case evaluation also updates booster.
|
|
||||||
bst.save_rabit_checkpoint()
|
|
||||||
version += 1
|
|
||||||
|
|
||||||
if early_stopping_rounds:
|
|
||||||
bst.best_score = float(bst.attr('best_score'))
|
|
||||||
bst.best_iteration = int(bst.attr('best_iteration'))
|
|
||||||
else:
|
|
||||||
bst.best_iteration = nboost - 1
|
|
||||||
bst.best_ntree_limit = (bst.best_iteration + 1) * num_parallel_tree
|
|
||||||
return bst
|
|
||||||
|
|
||||||
|
|
||||||
class CVPack(object):
|
class CVPack(object):
|
||||||
@ -294,7 +257,7 @@ def mknfold(dall, nfold, param, seed, evals=(), fpreproc=None, stratified=False,
|
|||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
def aggcv(rlist, show_stdv=True, verbose_eval=None, as_pandas=True, trial=0):
|
def aggcv(rlist):
|
||||||
# pylint: disable=invalid-name
|
# pylint: disable=invalid-name
|
||||||
"""
|
"""
|
||||||
Aggregate cross-validation results.
|
Aggregate cross-validation results.
|
||||||
@ -315,50 +278,21 @@ def aggcv(rlist, show_stdv=True, verbose_eval=None, as_pandas=True, trial=0):
|
|||||||
if k not in cvmap:
|
if k not in cvmap:
|
||||||
cvmap[k] = []
|
cvmap[k] = []
|
||||||
cvmap[k].append(float(v))
|
cvmap[k].append(float(v))
|
||||||
|
|
||||||
msg = idx
|
msg = idx
|
||||||
|
|
||||||
if show_stdv:
|
|
||||||
fmt = '\tcv-{0}:{1}+{2}'
|
|
||||||
else:
|
|
||||||
fmt = '\tcv-{0}:{1}'
|
|
||||||
|
|
||||||
index = []
|
|
||||||
results = []
|
results = []
|
||||||
for k, v in sorted(cvmap.items(), key=lambda x: x[0]):
|
for k, v in sorted(cvmap.items(), key=lambda x: (x[0].startswith('test'), x[0])):
|
||||||
v = np.array(v)
|
v = np.array(v)
|
||||||
if not isinstance(msg, STRING_TYPES):
|
if not isinstance(msg, STRING_TYPES):
|
||||||
msg = msg.decode()
|
msg = msg.decode()
|
||||||
mean, std = np.mean(v), np.std(v)
|
mean, std = np.mean(v), np.std(v)
|
||||||
msg += fmt.format(k, mean, std)
|
results.extend([(k, mean, std)])
|
||||||
|
|
||||||
index.extend([k + '-mean', k + '-std'])
|
|
||||||
results.extend([mean, std])
|
|
||||||
|
|
||||||
if as_pandas:
|
|
||||||
try:
|
|
||||||
import pandas as pd
|
|
||||||
results = pd.Series(results, index=index)
|
|
||||||
except ImportError:
|
|
||||||
if verbose_eval is None:
|
|
||||||
verbose_eval = True
|
|
||||||
else:
|
|
||||||
# if verbose_eval is default (None),
|
|
||||||
# result will be np.ndarray as it can't hold column name
|
|
||||||
if verbose_eval is None:
|
|
||||||
verbose_eval = True
|
|
||||||
|
|
||||||
if (isinstance(verbose_eval, int) and verbose_eval > 0 and trial % verbose_eval == 0) or \
|
|
||||||
(isinstance(verbose_eval, bool) and verbose_eval):
|
|
||||||
sys.stderr.write(msg + '\n')
|
|
||||||
sys.stderr.flush()
|
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None,
|
def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None,
|
||||||
metrics=(), obj=None, feval=None, maximize=False, early_stopping_rounds=None,
|
metrics=(), obj=None, feval=None, maximize=False, early_stopping_rounds=None,
|
||||||
fpreproc=None, as_pandas=True, verbose_eval=None, show_stdv=True, seed=0):
|
fpreproc=None, as_pandas=True, verbose_eval=None, show_stdv=True, seed=0,
|
||||||
|
callbacks=None):
|
||||||
# pylint: disable = invalid-name
|
# pylint: disable = invalid-name
|
||||||
"""Cross-validation with given paramaters.
|
"""Cross-validation with given paramaters.
|
||||||
|
|
||||||
@ -404,6 +338,8 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None
|
|||||||
Results are not affected, and always contains std.
|
Results are not affected, and always contains std.
|
||||||
seed : int
|
seed : int
|
||||||
Seed used to generate the folds (passed to numpy.random.seed).
|
Seed used to generate the folds (passed to numpy.random.seed).
|
||||||
|
callbacks : list of callback functions
|
||||||
|
List of callback functions that are applied at end of each iteration.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
@ -431,59 +367,63 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None
|
|||||||
|
|
||||||
params.pop("eval_metric", None)
|
params.pop("eval_metric", None)
|
||||||
|
|
||||||
if early_stopping_rounds is not None:
|
results = {}
|
||||||
|
|
||||||
if len(metrics) > 1:
|
|
||||||
msg = ('Check your params. '
|
|
||||||
'Early stopping works with single eval metric only.')
|
|
||||||
raise ValueError(msg)
|
|
||||||
if verbose_eval:
|
|
||||||
msg = "Will train until cv error hasn't decreased in {} rounds.\n"
|
|
||||||
sys.stderr.write(msg.format(early_stopping_rounds))
|
|
||||||
|
|
||||||
maximize_score = False
|
|
||||||
if len(metrics) == 1:
|
|
||||||
maximize_metrics = ('auc', 'map', 'ndcg')
|
|
||||||
if any(metrics[0].startswith(x) for x in maximize_metrics):
|
|
||||||
maximize_score = True
|
|
||||||
if feval is not None:
|
|
||||||
maximize_score = maximize
|
|
||||||
|
|
||||||
if maximize_score:
|
|
||||||
best_score = 0.0
|
|
||||||
else:
|
|
||||||
best_score = float('inf')
|
|
||||||
|
|
||||||
best_score_i = 0
|
|
||||||
results = []
|
|
||||||
cvfolds = mknfold(dtrain, nfold, params, seed, metrics, fpreproc, stratified, folds)
|
cvfolds = mknfold(dtrain, nfold, params, seed, metrics, fpreproc, stratified, folds)
|
||||||
|
|
||||||
|
# setup callbacks
|
||||||
|
callbacks = [] if callbacks is None else callbacks
|
||||||
|
if early_stopping_rounds is not None:
|
||||||
|
callbacks.append(callback.early_stop(early_stopping_rounds,
|
||||||
|
maximize=maximize,
|
||||||
|
verbose=False))
|
||||||
|
if isinstance(verbose_eval, bool) and verbose_eval:
|
||||||
|
callbacks.append(callback.print_evaluation(show_stdv=show_stdv))
|
||||||
|
else:
|
||||||
|
if isinstance(verbose_eval, int):
|
||||||
|
callbacks.append(callback.print_evaluation(verbose_eval, show_stdv=show_stdv))
|
||||||
|
|
||||||
|
callbacks_before_iter = [
|
||||||
|
cb for cb in callbacks if cb.__dict__.get('before_iteration', False)]
|
||||||
|
callbacks_after_iter = [
|
||||||
|
cb for cb in callbacks if not cb.__dict__.get('before_iteration', False)]
|
||||||
|
|
||||||
for i in range(num_boost_round):
|
for i in range(num_boost_round):
|
||||||
|
for cb in callbacks_before_iter:
|
||||||
|
cb(CallbackEnv(model=None,
|
||||||
|
cvfolds=cvfolds,
|
||||||
|
iteration=i,
|
||||||
|
begin_iteration=0,
|
||||||
|
end_iteration=num_boost_round,
|
||||||
|
rank=0,
|
||||||
|
evaluation_result_list=None))
|
||||||
for fold in cvfolds:
|
for fold in cvfolds:
|
||||||
fold.update(i, obj)
|
fold.update(i, obj)
|
||||||
res = aggcv([f.eval(i, feval) for f in cvfolds],
|
res = aggcv([f.eval(i, feval) for f in cvfolds])
|
||||||
show_stdv=show_stdv, verbose_eval=verbose_eval,
|
|
||||||
as_pandas=as_pandas, trial=i)
|
|
||||||
results.append(res)
|
|
||||||
|
|
||||||
if early_stopping_rounds is not None:
|
for key, mean, std in res:
|
||||||
score = res[0]
|
if key + '-mean' not in results:
|
||||||
if (maximize_score and score > best_score) or \
|
results[key + '-mean'] = []
|
||||||
(not maximize_score and score < best_score):
|
if key + '-std' not in results:
|
||||||
best_score = score
|
results[key + '-std'] = []
|
||||||
best_score_i = i
|
results[key + '-mean'].append(mean)
|
||||||
elif i - best_score_i >= early_stopping_rounds:
|
results[key + '-std'].append(std)
|
||||||
results = results[:best_score_i + 1]
|
try:
|
||||||
if verbose_eval:
|
for cb in callbacks_after_iter:
|
||||||
msg = "Stopping. Best iteration:\n[{}] cv-mean:{}\tcv-std:{}\n"
|
cb(CallbackEnv(model=None,
|
||||||
sys.stderr.write(msg.format(best_score_i, results[-1][0], results[-1][1]))
|
cvfolds=cvfolds,
|
||||||
break
|
iteration=i,
|
||||||
|
begin_iteration=0,
|
||||||
|
end_iteration=num_boost_round,
|
||||||
|
rank=0,
|
||||||
|
evaluation_result_list=res))
|
||||||
|
except EarlyStopException as e:
|
||||||
|
for k in results.keys():
|
||||||
|
results[k] = results[k][:(e.best_iteration + 1)]
|
||||||
|
break
|
||||||
if as_pandas:
|
if as_pandas:
|
||||||
try:
|
try:
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
results = pd.DataFrame(results)
|
results = pd.DataFrame.from_dict(results)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
results = np.array(results)
|
pass
|
||||||
else:
|
|
||||||
results = np.array(results)
|
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|||||||
2
rabit
2
rabit
@ -1 +1 @@
|
|||||||
Subproject commit e19fced5cbd4e41b10099facae7caa5cd3e6ada3
|
Subproject commit 8f61535b83e650331459d7f33a1615fa7d27b7bd
|
||||||
@ -35,6 +35,22 @@ class TestBasic(unittest.TestCase):
|
|||||||
# assert they are the same
|
# assert they are the same
|
||||||
assert np.sum(np.abs(preds2 - preds)) == 0
|
assert np.sum(np.abs(preds2 - preds)) == 0
|
||||||
|
|
||||||
|
def test_record_results(self):
|
||||||
|
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
|
||||||
|
dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')
|
||||||
|
param = {'max_depth': 2, 'eta': 1, 'silent': 1, 'objective': 'binary:logistic'}
|
||||||
|
# specify validations set to watch performance
|
||||||
|
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
|
||||||
|
num_round = 2
|
||||||
|
result = {}
|
||||||
|
res2 = {}
|
||||||
|
xgb.train(param, dtrain, num_round, watchlist,
|
||||||
|
callbacks=[xgb.callback.record_evaluation(result)])
|
||||||
|
xgb.train(param, dtrain, num_round, watchlist,
|
||||||
|
evals_result=res2)
|
||||||
|
assert result['train']['error'][0] < 0.1
|
||||||
|
assert res2 == result
|
||||||
|
|
||||||
def test_multiclass(self):
|
def test_multiclass(self):
|
||||||
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
|
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
|
||||||
dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')
|
dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')
|
||||||
@ -189,5 +205,5 @@ class TestBasic(unittest.TestCase):
|
|||||||
|
|
||||||
# return np.ndarray
|
# return np.ndarray
|
||||||
cv = xgb.cv(params, dm, num_boost_round=10, nfold=10, as_pandas=False)
|
cv = xgb.cv(params, dm, num_boost_round=10, nfold=10, as_pandas=False)
|
||||||
assert isinstance(cv, np.ndarray)
|
assert isinstance(cv, dict)
|
||||||
assert cv.shape == (10, 4)
|
assert len(cv) == (4)
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
import xgboost as xgb
|
import xgboost as xgb
|
||||||
import xgboost.testing as tm
|
import testing as tm
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
import xgboost as xgb
|
import xgboost as xgb
|
||||||
import xgboost.testing as tm
|
import testing as tm
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import xgboost as xgb
|
import xgboost as xgb
|
||||||
import xgboost.testing as tm
|
import testing as tm
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
import xgboost as xgb
|
import xgboost as xgb
|
||||||
import xgboost.testing as tm
|
import testing as tm
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import xgboost as xgb
|
import xgboost as xgb
|
||||||
import xgboost.testing as tm
|
import testing as tm
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import random
|
import random
|
||||||
import xgboost as xgb
|
import xgboost as xgb
|
||||||
import xgboost.testing as tm
|
import testing as tm
|
||||||
|
|
||||||
rng = np.random.RandomState(1994)
|
rng = np.random.RandomState(1994)
|
||||||
|
|
||||||
|
|||||||
@ -17,6 +17,6 @@ def _skip_if_no_pandas():
|
|||||||
|
|
||||||
def _skip_if_no_matplotlib():
|
def _skip_if_no_matplotlib():
|
||||||
try:
|
try:
|
||||||
import matplotlib.pyplot as plt # noqa
|
import matplotlib.pyplot as _ # noqa
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise nose.SkipTest()
|
raise nose.SkipTest()
|
||||||
Loading…
x
Reference in New Issue
Block a user