[python-package] Provide a learning_rates parameter to xgb.cv() (#1770)

* Allow using learning_rates parameter when doing CV

- Create a new `callback_cv` method working when called from `xgb.cv()`
- Rename existing `callback` into `callback_train` and make it the default callback
- Get the logic out of the callbacks and place it into a common helper

* Add a learning_rates parameter to cv()

* lint

* remove caller explicit reference

* callback is aware of its calling context

* remove caller argument

* remove learning_rates param

* restore learning_rates for training, but deprecated

* lint

* lint line too long

* quick example for predefined callbacks
This commit is contained in:
Jivan Roquet 2016-11-24 19:49:07 +02:00 committed by Tianqi Chen
parent 80e70c56b9
commit 0c19d4b029
2 changed files with 45 additions and 19 deletions

View File

@ -7,6 +7,15 @@ from . import rabit
from .core import EarlyStopException from .core import EarlyStopException
def _get_callback_context(env):
"""return whether the current callback context is cv or train"""
if env.model is not None and env.cvfolds is None:
context = 'train'
elif env.model is None and env.cvfolds is not None:
context = 'cv'
return context
def _fmt_metric(value, show_stdv=True): def _fmt_metric(value, show_stdv=True):
"""format metric string""" """format metric string"""
if len(value) == 2: if len(value) == 2:
@ -103,16 +112,29 @@ def reset_learning_rate(learning_rates):
callback : function callback : function
The requested callback function. The requested callback function.
""" """
def get_learning_rate(i, n, learning_rates):
"""helper providing the learning rate"""
if isinstance(learning_rates, list):
if len(learning_rates) != n:
raise ValueError("Length of list 'learning_rates' has to equal 'num_boost_round'.")
new_learning_rate = learning_rates[i]
else:
new_learning_rate = learning_rates(i, n)
return new_learning_rate
def callback(env): def callback(env):
"""internal function""" """internal function"""
bst = env.model context = _get_callback_context(env)
i = env.iteration
if isinstance(learning_rates, list): if context == 'train':
if len(learning_rates) != env.end_iteration: bst, i, n = env.model, env.iteration, env.end_iteration
raise ValueError("Length of list 'learning_rates' has to equal 'num_boost_round'.") bst.set_param('learning_rate', get_learning_rate(i, n, learning_rates))
bst.set_param('learning_rate', learning_rates[i]) elif context == 'cv':
else: i, n = env.iteration, env.end_iteration
bst.set_param('learning_rate', learning_rates(i, env.end_iteration)) for cvpack in env.cvfolds:
bst = cvpack.bst
bst.set_param('learning_rate', get_learning_rate(i, n, learning_rates))
callback.before_iteration = True callback.before_iteration = True
return callback return callback

View File

@ -4,7 +4,7 @@
"""Training Library containing training routines.""" """Training Library containing training routines."""
from __future__ import absolute_import from __future__ import absolute_import
import warnings
import numpy as np import numpy as np
from .core import Booster, STRING_TYPES, XGBoostError, CallbackEnv, EarlyStopException from .core import Booster, STRING_TYPES, XGBoostError, CallbackEnv, EarlyStopException
from .compat import (SKLEARN_INSTALLED, XGBStratifiedKFold) from .compat import (SKLEARN_INSTALLED, XGBStratifiedKFold)
@ -114,7 +114,7 @@ def _train_internal(params, dtrain,
def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None, def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
maximize=False, early_stopping_rounds=None, evals_result=None, maximize=False, early_stopping_rounds=None, evals_result=None,
verbose_eval=True, learning_rates=None, xgb_model=None, callbacks=None): verbose_eval=True, xgb_model=None, callbacks=None, learning_rates=None):
# pylint: disable=too-many-statements,too-many-branches, attribute-defined-outside-init # pylint: disable=too-many-statements,too-many-branches, attribute-defined-outside-init
"""Train a booster with given parameters. """Train a booster with given parameters.
@ -160,18 +160,17 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
/ the boosting stage found by using `early_stopping_rounds` is also printed. / the boosting stage found by using `early_stopping_rounds` is also printed.
Example: with verbose_eval=4 and at least one item in evals, an evaluation metric Example: with verbose_eval=4 and at least one item in evals, an evaluation metric
is printed every 4 boosting stages, instead of every boosting stage. is printed every 4 boosting stages, instead of every boosting stage.
learning_rates: list or function learning_rates: list or function (deprecated - use callback API instead)
List of learning rate for each boosting round List of learning rate for each boosting round
or a customized function that calculates eta in terms of or a customized function that calculates eta in terms of
current number of round and the total number of boosting round (e.g. yields current number of round and the total number of boosting round (e.g. yields
learning rate decay) learning rate decay)
- list l: eta = l[boosting round]
- function f: eta = f(boosting round, num_boost_round)
xgb_model : file name of stored xgb model or 'Booster' instance xgb_model : file name of stored xgb model or 'Booster' instance
Xgb model to be loaded before training (allows training continuation). Xgb model to be loaded before training (allows training continuation).
callbacks : list of callback functions callbacks : list of callback functions
List of callback functions that are applied at end of each iteration. List of callback functions that are applied at end of each iteration.
It is possible to use predefined callbacks by using xgb.callback module.
Example: [xgb.callback.reset_learning_rate(custom_rates)]
Returns Returns
------- -------
@ -190,12 +189,14 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
callbacks.append(callback.early_stop(early_stopping_rounds, callbacks.append(callback.early_stop(early_stopping_rounds,
maximize=maximize, maximize=maximize,
verbose=bool(verbose_eval))) verbose=bool(verbose_eval)))
if learning_rates is not None:
callbacks.append(callback.reset_learning_rate(learning_rates))
if evals_result is not None: if evals_result is not None:
callbacks.append(callback.record_evaluation(evals_result)) callbacks.append(callback.record_evaluation(evals_result))
if learning_rates is not None:
warnings.warn("learning_rates parameter is deprecated - use callback API instead",
DeprecationWarning)
callbacks.append(callback.reset_learning_rate(learning_rates))
return _train_internal(params, dtrain, return _train_internal(params, dtrain,
num_boost_round=num_boost_round, num_boost_round=num_boost_round,
evals=evals, evals=evals,
@ -287,8 +288,8 @@ def aggcv(rlist):
def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None, def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None,
metrics=(), obj=None, feval=None, maximize=False, early_stopping_rounds=None, metrics=(), obj=None, feval=None, maximize=False, early_stopping_rounds=None,
fpreproc=None, as_pandas=True, verbose_eval=None, show_stdv=True, seed=0, fpreproc=None, as_pandas=True, verbose_eval=None, show_stdv=True,
callbacks=None): seed=0, callbacks=None):
# pylint: disable = invalid-name # pylint: disable = invalid-name
"""Cross-validation with given paramaters. """Cross-validation with given paramaters.
@ -336,6 +337,8 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None
Seed used to generate the folds (passed to numpy.random.seed). Seed used to generate the folds (passed to numpy.random.seed).
callbacks : list of callback functions callbacks : list of callback functions
List of callback functions that are applied at end of each iteration. List of callback functions that are applied at end of each iteration.
It is possible to use predefined callbacks by using xgb.callback module.
Example: [xgb.callback.reset_learning_rate(custom_rates)]
Returns Returns
------- -------
@ -372,6 +375,7 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None
callbacks.append(callback.early_stop(early_stopping_rounds, callbacks.append(callback.early_stop(early_stopping_rounds,
maximize=maximize, maximize=maximize,
verbose=False)) verbose=False))
if isinstance(verbose_eval, bool) and verbose_eval: if isinstance(verbose_eval, bool) and verbose_eval:
callbacks.append(callback.print_evaluation(show_stdv=show_stdv)) callbacks.append(callback.print_evaluation(show_stdv=show_stdv))
else: else: