[PYTHON] Refactor trainnig API to use callback
This commit is contained in:
@@ -1,20 +1,122 @@
|
||||
# coding: utf-8
|
||||
# pylint: disable=too-many-locals, too-many-arguments, invalid-name
|
||||
# pylint: disable=too-many-branches
|
||||
# pylint: disable=too-many-branches, too-many-statements
|
||||
"""Training Library containing training routines."""
|
||||
from __future__ import absolute_import
|
||||
|
||||
import sys
|
||||
import re
|
||||
|
||||
import numpy as np
|
||||
from .core import Booster, STRING_TYPES, XGBoostError
|
||||
from .core import Booster, STRING_TYPES, XGBoostError, CallbackEnv, EarlyStopException
|
||||
from .compat import (SKLEARN_INSTALLED, XGBStratifiedKFold)
|
||||
from . import rabit
|
||||
from . import callback
|
||||
|
||||
|
||||
def _train_internal(params, dtrain,
|
||||
num_boost_round=10, evals=(),
|
||||
obj=None, feval=None,
|
||||
xgb_model=None, callbacks=None):
|
||||
"""internal training function"""
|
||||
callbacks = [] if callbacks is None else callbacks
|
||||
evals = list(evals)
|
||||
if isinstance(params, dict) \
|
||||
and 'eval_metric' in params \
|
||||
and isinstance(params['eval_metric'], list):
|
||||
params = dict((k, v) for k, v in params.items())
|
||||
eval_metrics = params['eval_metric']
|
||||
params.pop("eval_metric", None)
|
||||
params = list(params.items())
|
||||
for eval_metric in eval_metrics:
|
||||
params += [('eval_metric', eval_metric)]
|
||||
|
||||
bst = Booster(params, [dtrain] + [d[0] for d in evals])
|
||||
nboost = 0
|
||||
num_parallel_tree = 1
|
||||
|
||||
if xgb_model is not None:
|
||||
if not isinstance(xgb_model, STRING_TYPES):
|
||||
xgb_model = xgb_model.save_raw()
|
||||
bst = Booster(params, [dtrain] + [d[0] for d in evals], model_file=xgb_model)
|
||||
nboost = len(bst.get_dump())
|
||||
else:
|
||||
bst = Booster(params, [dtrain] + [d[0] for d in evals])
|
||||
|
||||
_params = dict(params) if isinstance(params, list) else params
|
||||
|
||||
if 'num_parallel_tree' in _params:
|
||||
num_parallel_tree = _params['num_parallel_tree']
|
||||
nboost //= num_parallel_tree
|
||||
if 'num_class' in _params:
|
||||
nboost //= _params['num_class']
|
||||
|
||||
# Distributed code: Load the checkpoint from rabit.
|
||||
version = bst.load_rabit_checkpoint()
|
||||
assert(rabit.get_world_size() != 1 or version == 0)
|
||||
rank = rabit.get_rank()
|
||||
start_iteration = int(version / 2)
|
||||
nboost += start_iteration
|
||||
|
||||
callbacks_before_iter = [
|
||||
cb for cb in callbacks if cb.__dict__.get('before_iteration', False)]
|
||||
callbacks_after_iter = [
|
||||
cb for cb in callbacks if not cb.__dict__.get('before_iteration', False)]
|
||||
|
||||
for i in range(start_iteration, num_boost_round):
|
||||
for cb in callbacks_before_iter:
|
||||
cb(CallbackEnv(model=bst,
|
||||
cvfolds=None,
|
||||
iteration=i,
|
||||
begin_iteration=start_iteration,
|
||||
end_iteration=num_boost_round,
|
||||
rank=rank,
|
||||
evaluation_result_list=None))
|
||||
# Distributed code: need to resume to this point.
|
||||
# Skip the first update if it is a recovery step.
|
||||
if version % 2 == 0:
|
||||
bst.update(dtrain, i, obj)
|
||||
bst.save_rabit_checkpoint()
|
||||
version += 1
|
||||
|
||||
assert(rabit.get_world_size() == 1 or version == rabit.version_number())
|
||||
|
||||
nboost += 1
|
||||
evaluation_result_list = []
|
||||
# check evaluation result.
|
||||
if len(evals) != 0:
|
||||
bst_eval_set = bst.eval_set(evals, i, feval)
|
||||
if isinstance(bst_eval_set, STRING_TYPES):
|
||||
msg = bst_eval_set
|
||||
else:
|
||||
msg = bst_eval_set.decode()
|
||||
res = [x.split(':') for x in msg.split()]
|
||||
evaluation_result_list = [(k, float(v)) for k, v in res[1:]]
|
||||
try:
|
||||
for cb in callbacks_after_iter:
|
||||
cb(CallbackEnv(model=bst,
|
||||
cvfolds=None,
|
||||
iteration=i,
|
||||
begin_iteration=start_iteration,
|
||||
end_iteration=num_boost_round,
|
||||
rank=rank,
|
||||
evaluation_result_list=evaluation_result_list))
|
||||
except EarlyStopException:
|
||||
break
|
||||
# do checkpoint after evaluation, in case evaluation also updates booster.
|
||||
bst.save_rabit_checkpoint()
|
||||
version += 1
|
||||
|
||||
if bst.attr('best_score') is not None:
|
||||
bst.best_score = float(bst.attr('best_score'))
|
||||
bst.best_iteration = int(bst.attr('best_iteration'))
|
||||
else:
|
||||
bst.best_iteration = nboost - 1
|
||||
bst.best_ntree_limit = (bst.best_iteration + 1) * num_parallel_tree
|
||||
return bst
|
||||
|
||||
|
||||
def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
|
||||
maximize=False, early_stopping_rounds=None, evals_result=None,
|
||||
verbose_eval=True, learning_rates=None, xgb_model=None):
|
||||
verbose_eval=True, learning_rates=None, xgb_model=None, callbacks=None):
|
||||
# pylint: disable=too-many-statements,too-many-branches, attribute-defined-outside-init
|
||||
"""Train a booster with given parameters.
|
||||
|
||||
@@ -70,176 +172,37 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
|
||||
xgb_model : file name of stored xgb model or 'Booster' instance
|
||||
Xgb model to be loaded before training (allows training continuation).
|
||||
|
||||
callbacks : list of callback functions
|
||||
List of callback functions that are applied at end of each iteration.
|
||||
|
||||
Returns
|
||||
-------
|
||||
booster : a trained booster model
|
||||
"""
|
||||
evals = list(evals)
|
||||
if isinstance(params, dict) \
|
||||
and 'eval_metric' in params \
|
||||
and isinstance(params['eval_metric'], list):
|
||||
params = dict((k, v) for k, v in params.items())
|
||||
eval_metrics = params['eval_metric']
|
||||
params.pop("eval_metric", None)
|
||||
params = list(params.items())
|
||||
for eval_metric in eval_metrics:
|
||||
params += [('eval_metric', eval_metric)]
|
||||
callbacks = [] if callbacks is None else callbacks
|
||||
|
||||
bst = Booster(params, [dtrain] + [d[0] for d in evals])
|
||||
nboost = 0
|
||||
num_parallel_tree = 1
|
||||
|
||||
if isinstance(verbose_eval, bool):
|
||||
verbose_eval_every_line = False
|
||||
# Most of legacy advanced options becomes callbacks
|
||||
if isinstance(verbose_eval, bool) and verbose_eval:
|
||||
callbacks.append(callback.print_evaluation())
|
||||
else:
|
||||
if isinstance(verbose_eval, int):
|
||||
verbose_eval_every_line = verbose_eval
|
||||
verbose_eval = True if verbose_eval_every_line > 0 else False
|
||||
callbacks.append(callback.print_evaluation(verbose_eval))
|
||||
|
||||
if rabit.get_rank() != 0:
|
||||
verbose_eval = False
|
||||
|
||||
if xgb_model is not None:
|
||||
if not isinstance(xgb_model, STRING_TYPES):
|
||||
xgb_model = xgb_model.save_raw()
|
||||
bst = Booster(params, [dtrain] + [d[0] for d in evals], model_file=xgb_model)
|
||||
nboost = len(bst.get_dump())
|
||||
else:
|
||||
bst = Booster(params, [dtrain] + [d[0] for d in evals])
|
||||
|
||||
_params = dict(params) if isinstance(params, list) else params
|
||||
_eta_param_name = 'eta' if 'eta' in _params else 'learning_rate'
|
||||
if 'num_parallel_tree' in _params:
|
||||
num_parallel_tree = _params['num_parallel_tree']
|
||||
nboost //= num_parallel_tree
|
||||
if 'num_class' in _params:
|
||||
nboost //= _params['num_class']
|
||||
if early_stopping_rounds is not None:
|
||||
callbacks.append(callback.early_stop(early_stopping_rounds,
|
||||
maximize=maximize,
|
||||
verbose=bool(verbose_eval)))
|
||||
if learning_rates is not None:
|
||||
callbacks.append(callback.reset_learning_rate(learning_rates))
|
||||
|
||||
if evals_result is not None:
|
||||
if not isinstance(evals_result, dict):
|
||||
raise TypeError('evals_result has to be a dictionary')
|
||||
else:
|
||||
evals_name = [d[1] for d in evals]
|
||||
evals_result.clear()
|
||||
evals_result.update(dict([(key, {}) for key in evals_name]))
|
||||
callbacks.append(callback.record_evaluation(evals_result))
|
||||
|
||||
# early stopping
|
||||
if early_stopping_rounds is not None:
|
||||
if len(evals) < 1:
|
||||
raise ValueError('For early stopping you need at least one set in evals.')
|
||||
|
||||
if verbose_eval:
|
||||
rabit.tracker_print("Will train until {} error hasn't decreased in {} rounds.\n".format(
|
||||
evals[-1][1], early_stopping_rounds))
|
||||
|
||||
# is params a list of tuples? are we using multiple eval metrics?
|
||||
if isinstance(params, list):
|
||||
if len(params) != len(dict(params).items()):
|
||||
params = dict(params)
|
||||
msg = ("Multiple eval metrics have been passed: "
|
||||
"'{0}' will be used for early stopping.\n\n")
|
||||
rabit.tracker_print(msg.format(params['eval_metric']))
|
||||
else:
|
||||
params = dict(params)
|
||||
|
||||
# either minimize loss or maximize AUC/MAP/NDCG
|
||||
maximize_score = False
|
||||
if 'eval_metric' in params:
|
||||
maximize_metrics = ('auc', 'map', 'ndcg')
|
||||
if any(params['eval_metric'].startswith(x) for x in maximize_metrics):
|
||||
maximize_score = True
|
||||
if feval is not None:
|
||||
maximize_score = maximize
|
||||
|
||||
if maximize_score:
|
||||
bst.set_attr(best_score='0.0')
|
||||
else:
|
||||
bst.set_attr(best_score='inf')
|
||||
bst.set_attr(best_iteration='0')
|
||||
|
||||
if isinstance(learning_rates, list) and len(learning_rates) != num_boost_round:
|
||||
raise ValueError("Length of list 'learning_rates' has to equal 'num_boost_round'.")
|
||||
|
||||
# Distributed code: Load the checkpoint from rabit.
|
||||
version = bst.load_rabit_checkpoint()
|
||||
assert(rabit.get_world_size() != 1 or version == 0)
|
||||
start_iteration = int(version / 2)
|
||||
nboost += start_iteration
|
||||
|
||||
for i in range(start_iteration, num_boost_round):
|
||||
if learning_rates is not None:
|
||||
if isinstance(learning_rates, list):
|
||||
bst.set_param(_eta_param_name, learning_rates[i])
|
||||
else:
|
||||
bst.set_param(_eta_param_name, learning_rates(i, num_boost_round))
|
||||
|
||||
# Distributed code: need to resume to this point.
|
||||
# Skip the first update if it is a recovery step.
|
||||
if version % 2 == 0:
|
||||
bst.update(dtrain, i, obj)
|
||||
bst.save_rabit_checkpoint()
|
||||
version += 1
|
||||
|
||||
assert(rabit.get_world_size() == 1 or version == rabit.version_number())
|
||||
|
||||
nboost += 1
|
||||
# check evaluation result.
|
||||
if len(evals) != 0:
|
||||
bst_eval_set = bst.eval_set(evals, i, feval)
|
||||
|
||||
if isinstance(bst_eval_set, STRING_TYPES):
|
||||
msg = bst_eval_set
|
||||
else:
|
||||
msg = bst_eval_set.decode()
|
||||
|
||||
if verbose_eval:
|
||||
if verbose_eval_every_line:
|
||||
if i % verbose_eval_every_line == 0 or i == num_boost_round - 1:
|
||||
rabit.tracker_print(msg + '\n')
|
||||
else:
|
||||
rabit.tracker_print(msg + '\n')
|
||||
|
||||
if evals_result is not None:
|
||||
res = re.findall("([0-9a-zA-Z@]+[-]*):-?([0-9.]+).", msg)
|
||||
for key in evals_name:
|
||||
evals_idx = evals_name.index(key)
|
||||
res_per_eval = len(res) // len(evals_name)
|
||||
for r in range(res_per_eval):
|
||||
res_item = res[(evals_idx * res_per_eval) + r]
|
||||
res_key = res_item[0]
|
||||
res_val = res_item[1]
|
||||
if res_key in evals_result[key]:
|
||||
evals_result[key][res_key].append(res_val)
|
||||
else:
|
||||
evals_result[key][res_key] = [res_val]
|
||||
|
||||
if early_stopping_rounds:
|
||||
score = float(msg.rsplit(':', 1)[1])
|
||||
best_score = float(bst.attr('best_score'))
|
||||
best_iteration = int(bst.attr('best_iteration'))
|
||||
if (maximize_score and score > best_score) or \
|
||||
(not maximize_score and score < best_score):
|
||||
# save the property to attributes, so they will occur in checkpoint.
|
||||
bst.set_attr(best_score=str(score),
|
||||
best_iteration=str(nboost - 1),
|
||||
best_msg=msg)
|
||||
elif i - best_iteration >= early_stopping_rounds:
|
||||
best_msg = bst.attr('best_msg')
|
||||
if verbose_eval:
|
||||
msg = "Stopping. Best iteration:\n{}\n\n"
|
||||
rabit.tracker_print(msg.format(best_msg))
|
||||
break
|
||||
# do checkpoint after evaluation, in case evaluation also updates booster.
|
||||
bst.save_rabit_checkpoint()
|
||||
version += 1
|
||||
|
||||
if early_stopping_rounds:
|
||||
bst.best_score = float(bst.attr('best_score'))
|
||||
bst.best_iteration = int(bst.attr('best_iteration'))
|
||||
else:
|
||||
bst.best_iteration = nboost - 1
|
||||
bst.best_ntree_limit = (bst.best_iteration + 1) * num_parallel_tree
|
||||
return bst
|
||||
return _train_internal(params, dtrain,
|
||||
num_boost_round=num_boost_round,
|
||||
evals=evals,
|
||||
obj=obj, feval=feval,
|
||||
xgb_model=xgb_model, callbacks=callbacks)
|
||||
|
||||
|
||||
class CVPack(object):
|
||||
@@ -294,7 +257,7 @@ def mknfold(dall, nfold, param, seed, evals=(), fpreproc=None, stratified=False,
|
||||
return ret
|
||||
|
||||
|
||||
def aggcv(rlist, show_stdv=True, verbose_eval=None, as_pandas=True, trial=0):
|
||||
def aggcv(rlist):
|
||||
# pylint: disable=invalid-name
|
||||
"""
|
||||
Aggregate cross-validation results.
|
||||
@@ -315,50 +278,21 @@ def aggcv(rlist, show_stdv=True, verbose_eval=None, as_pandas=True, trial=0):
|
||||
if k not in cvmap:
|
||||
cvmap[k] = []
|
||||
cvmap[k].append(float(v))
|
||||
|
||||
msg = idx
|
||||
|
||||
if show_stdv:
|
||||
fmt = '\tcv-{0}:{1}+{2}'
|
||||
else:
|
||||
fmt = '\tcv-{0}:{1}'
|
||||
|
||||
index = []
|
||||
results = []
|
||||
for k, v in sorted(cvmap.items(), key=lambda x: x[0]):
|
||||
for k, v in sorted(cvmap.items(), key=lambda x: (x[0].startswith('test'), x[0])):
|
||||
v = np.array(v)
|
||||
if not isinstance(msg, STRING_TYPES):
|
||||
msg = msg.decode()
|
||||
mean, std = np.mean(v), np.std(v)
|
||||
msg += fmt.format(k, mean, std)
|
||||
|
||||
index.extend([k + '-mean', k + '-std'])
|
||||
results.extend([mean, std])
|
||||
|
||||
if as_pandas:
|
||||
try:
|
||||
import pandas as pd
|
||||
results = pd.Series(results, index=index)
|
||||
except ImportError:
|
||||
if verbose_eval is None:
|
||||
verbose_eval = True
|
||||
else:
|
||||
# if verbose_eval is default (None),
|
||||
# result will be np.ndarray as it can't hold column name
|
||||
if verbose_eval is None:
|
||||
verbose_eval = True
|
||||
|
||||
if (isinstance(verbose_eval, int) and verbose_eval > 0 and trial % verbose_eval == 0) or \
|
||||
(isinstance(verbose_eval, bool) and verbose_eval):
|
||||
sys.stderr.write(msg + '\n')
|
||||
sys.stderr.flush()
|
||||
|
||||
results.extend([(k, mean, std)])
|
||||
return results
|
||||
|
||||
|
||||
def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None,
|
||||
metrics=(), obj=None, feval=None, maximize=False, early_stopping_rounds=None,
|
||||
fpreproc=None, as_pandas=True, verbose_eval=None, show_stdv=True, seed=0):
|
||||
fpreproc=None, as_pandas=True, verbose_eval=None, show_stdv=True, seed=0,
|
||||
callbacks=None):
|
||||
# pylint: disable = invalid-name
|
||||
"""Cross-validation with given paramaters.
|
||||
|
||||
@@ -404,6 +338,8 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None
|
||||
Results are not affected, and always contains std.
|
||||
seed : int
|
||||
Seed used to generate the folds (passed to numpy.random.seed).
|
||||
callbacks : list of callback functions
|
||||
List of callback functions that are applied at end of each iteration.
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -431,59 +367,63 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None
|
||||
|
||||
params.pop("eval_metric", None)
|
||||
|
||||
if early_stopping_rounds is not None:
|
||||
|
||||
if len(metrics) > 1:
|
||||
msg = ('Check your params. '
|
||||
'Early stopping works with single eval metric only.')
|
||||
raise ValueError(msg)
|
||||
if verbose_eval:
|
||||
msg = "Will train until cv error hasn't decreased in {} rounds.\n"
|
||||
sys.stderr.write(msg.format(early_stopping_rounds))
|
||||
|
||||
maximize_score = False
|
||||
if len(metrics) == 1:
|
||||
maximize_metrics = ('auc', 'map', 'ndcg')
|
||||
if any(metrics[0].startswith(x) for x in maximize_metrics):
|
||||
maximize_score = True
|
||||
if feval is not None:
|
||||
maximize_score = maximize
|
||||
|
||||
if maximize_score:
|
||||
best_score = 0.0
|
||||
else:
|
||||
best_score = float('inf')
|
||||
|
||||
best_score_i = 0
|
||||
results = []
|
||||
results = {}
|
||||
cvfolds = mknfold(dtrain, nfold, params, seed, metrics, fpreproc, stratified, folds)
|
||||
|
||||
# setup callbacks
|
||||
callbacks = [] if callbacks is None else callbacks
|
||||
if early_stopping_rounds is not None:
|
||||
callbacks.append(callback.early_stop(early_stopping_rounds,
|
||||
maximize=maximize,
|
||||
verbose=False))
|
||||
if isinstance(verbose_eval, bool) and verbose_eval:
|
||||
callbacks.append(callback.print_evaluation(show_stdv=show_stdv))
|
||||
else:
|
||||
if isinstance(verbose_eval, int):
|
||||
callbacks.append(callback.print_evaluation(verbose_eval, show_stdv=show_stdv))
|
||||
|
||||
callbacks_before_iter = [
|
||||
cb for cb in callbacks if cb.__dict__.get('before_iteration', False)]
|
||||
callbacks_after_iter = [
|
||||
cb for cb in callbacks if not cb.__dict__.get('before_iteration', False)]
|
||||
|
||||
for i in range(num_boost_round):
|
||||
for cb in callbacks_before_iter:
|
||||
cb(CallbackEnv(model=None,
|
||||
cvfolds=cvfolds,
|
||||
iteration=i,
|
||||
begin_iteration=0,
|
||||
end_iteration=num_boost_round,
|
||||
rank=0,
|
||||
evaluation_result_list=None))
|
||||
for fold in cvfolds:
|
||||
fold.update(i, obj)
|
||||
res = aggcv([f.eval(i, feval) for f in cvfolds],
|
||||
show_stdv=show_stdv, verbose_eval=verbose_eval,
|
||||
as_pandas=as_pandas, trial=i)
|
||||
results.append(res)
|
||||
res = aggcv([f.eval(i, feval) for f in cvfolds])
|
||||
|
||||
if early_stopping_rounds is not None:
|
||||
score = res[0]
|
||||
if (maximize_score and score > best_score) or \
|
||||
(not maximize_score and score < best_score):
|
||||
best_score = score
|
||||
best_score_i = i
|
||||
elif i - best_score_i >= early_stopping_rounds:
|
||||
results = results[:best_score_i + 1]
|
||||
if verbose_eval:
|
||||
msg = "Stopping. Best iteration:\n[{}] cv-mean:{}\tcv-std:{}\n"
|
||||
sys.stderr.write(msg.format(best_score_i, results[-1][0], results[-1][1]))
|
||||
break
|
||||
for key, mean, std in res:
|
||||
if key + '-mean' not in results:
|
||||
results[key + '-mean'] = []
|
||||
if key + '-std' not in results:
|
||||
results[key + '-std'] = []
|
||||
results[key + '-mean'].append(mean)
|
||||
results[key + '-std'].append(std)
|
||||
try:
|
||||
for cb in callbacks_after_iter:
|
||||
cb(CallbackEnv(model=None,
|
||||
cvfolds=cvfolds,
|
||||
iteration=i,
|
||||
begin_iteration=0,
|
||||
end_iteration=num_boost_round,
|
||||
rank=0,
|
||||
evaluation_result_list=res))
|
||||
except EarlyStopException as e:
|
||||
for k in results.keys():
|
||||
results[k] = results[k][:(e.best_iteration + 1)]
|
||||
break
|
||||
if as_pandas:
|
||||
try:
|
||||
import pandas as pd
|
||||
results = pd.DataFrame(results)
|
||||
results = pd.DataFrame.from_dict(results)
|
||||
except ImportError:
|
||||
results = np.array(results)
|
||||
else:
|
||||
results = np.array(results)
|
||||
|
||||
pass
|
||||
return results
|
||||
|
||||
Reference in New Issue
Block a user