[python-package] Provide a learning_rates parameter to xgb.cv() (#1770)
* Allow using learning_rates parameter when doing CV - Create a new `callback_cv` method working when called from `xgb.cv()` - Rename existing `callback` into `callback_train` and make it the default callback - Get the logic out of the callbacks and place it into a common helper * Add a learning_rates parameter to cv() * lint * remove caller explicit reference * callback is aware of its calling context * remove caller argument * remove learning_rates param * restore learning_rates for training, but deprecated * lint * lint line too long * quick example for predefined callbacks
This commit is contained in:
parent
80e70c56b9
commit
0c19d4b029
@ -7,6 +7,15 @@ from . import rabit
|
|||||||
from .core import EarlyStopException
|
from .core import EarlyStopException
|
||||||
|
|
||||||
|
|
||||||
|
def _get_callback_context(env):
|
||||||
|
"""return whether the current callback context is cv or train"""
|
||||||
|
if env.model is not None and env.cvfolds is None:
|
||||||
|
context = 'train'
|
||||||
|
elif env.model is None and env.cvfolds is not None:
|
||||||
|
context = 'cv'
|
||||||
|
return context
|
||||||
|
|
||||||
|
|
||||||
def _fmt_metric(value, show_stdv=True):
|
def _fmt_metric(value, show_stdv=True):
|
||||||
"""format metric string"""
|
"""format metric string"""
|
||||||
if len(value) == 2:
|
if len(value) == 2:
|
||||||
@ -103,16 +112,29 @@ def reset_learning_rate(learning_rates):
|
|||||||
callback : function
|
callback : function
|
||||||
The requested callback function.
|
The requested callback function.
|
||||||
"""
|
"""
|
||||||
|
def get_learning_rate(i, n, learning_rates):
|
||||||
|
"""helper providing the learning rate"""
|
||||||
|
if isinstance(learning_rates, list):
|
||||||
|
if len(learning_rates) != n:
|
||||||
|
raise ValueError("Length of list 'learning_rates' has to equal 'num_boost_round'.")
|
||||||
|
new_learning_rate = learning_rates[i]
|
||||||
|
else:
|
||||||
|
new_learning_rate = learning_rates(i, n)
|
||||||
|
return new_learning_rate
|
||||||
|
|
||||||
def callback(env):
|
def callback(env):
|
||||||
"""internal function"""
|
"""internal function"""
|
||||||
bst = env.model
|
context = _get_callback_context(env)
|
||||||
i = env.iteration
|
|
||||||
if isinstance(learning_rates, list):
|
if context == 'train':
|
||||||
if len(learning_rates) != env.end_iteration:
|
bst, i, n = env.model, env.iteration, env.end_iteration
|
||||||
raise ValueError("Length of list 'learning_rates' has to equal 'num_boost_round'.")
|
bst.set_param('learning_rate', get_learning_rate(i, n, learning_rates))
|
||||||
bst.set_param('learning_rate', learning_rates[i])
|
elif context == 'cv':
|
||||||
else:
|
i, n = env.iteration, env.end_iteration
|
||||||
bst.set_param('learning_rate', learning_rates(i, env.end_iteration))
|
for cvpack in env.cvfolds:
|
||||||
|
bst = cvpack.bst
|
||||||
|
bst.set_param('learning_rate', get_learning_rate(i, n, learning_rates))
|
||||||
|
|
||||||
callback.before_iteration = True
|
callback.before_iteration = True
|
||||||
return callback
|
return callback
|
||||||
|
|
||||||
|
|||||||
@ -4,7 +4,7 @@
|
|||||||
"""Training Library containing training routines."""
|
"""Training Library containing training routines."""
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
|
|
||||||
|
import warnings
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from .core import Booster, STRING_TYPES, XGBoostError, CallbackEnv, EarlyStopException
|
from .core import Booster, STRING_TYPES, XGBoostError, CallbackEnv, EarlyStopException
|
||||||
from .compat import (SKLEARN_INSTALLED, XGBStratifiedKFold)
|
from .compat import (SKLEARN_INSTALLED, XGBStratifiedKFold)
|
||||||
@ -114,7 +114,7 @@ def _train_internal(params, dtrain,
|
|||||||
|
|
||||||
def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
|
def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
|
||||||
maximize=False, early_stopping_rounds=None, evals_result=None,
|
maximize=False, early_stopping_rounds=None, evals_result=None,
|
||||||
verbose_eval=True, learning_rates=None, xgb_model=None, callbacks=None):
|
verbose_eval=True, xgb_model=None, callbacks=None, learning_rates=None):
|
||||||
# pylint: disable=too-many-statements,too-many-branches, attribute-defined-outside-init
|
# pylint: disable=too-many-statements,too-many-branches, attribute-defined-outside-init
|
||||||
"""Train a booster with given parameters.
|
"""Train a booster with given parameters.
|
||||||
|
|
||||||
@ -160,18 +160,17 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
|
|||||||
/ the boosting stage found by using `early_stopping_rounds` is also printed.
|
/ the boosting stage found by using `early_stopping_rounds` is also printed.
|
||||||
Example: with verbose_eval=4 and at least one item in evals, an evaluation metric
|
Example: with verbose_eval=4 and at least one item in evals, an evaluation metric
|
||||||
is printed every 4 boosting stages, instead of every boosting stage.
|
is printed every 4 boosting stages, instead of every boosting stage.
|
||||||
learning_rates: list or function
|
learning_rates: list or function (deprecated - use callback API instead)
|
||||||
List of learning rate for each boosting round
|
List of learning rate for each boosting round
|
||||||
or a customized function that calculates eta in terms of
|
or a customized function that calculates eta in terms of
|
||||||
current number of round and the total number of boosting round (e.g. yields
|
current number of round and the total number of boosting round (e.g. yields
|
||||||
learning rate decay)
|
learning rate decay)
|
||||||
- list l: eta = l[boosting round]
|
|
||||||
- function f: eta = f(boosting round, num_boost_round)
|
|
||||||
xgb_model : file name of stored xgb model or 'Booster' instance
|
xgb_model : file name of stored xgb model or 'Booster' instance
|
||||||
Xgb model to be loaded before training (allows training continuation).
|
Xgb model to be loaded before training (allows training continuation).
|
||||||
|
|
||||||
callbacks : list of callback functions
|
callbacks : list of callback functions
|
||||||
List of callback functions that are applied at end of each iteration.
|
List of callback functions that are applied at end of each iteration.
|
||||||
|
It is possible to use predefined callbacks by using xgb.callback module.
|
||||||
|
Example: [xgb.callback.reset_learning_rate(custom_rates)]
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
@ -190,12 +189,14 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
|
|||||||
callbacks.append(callback.early_stop(early_stopping_rounds,
|
callbacks.append(callback.early_stop(early_stopping_rounds,
|
||||||
maximize=maximize,
|
maximize=maximize,
|
||||||
verbose=bool(verbose_eval)))
|
verbose=bool(verbose_eval)))
|
||||||
if learning_rates is not None:
|
|
||||||
callbacks.append(callback.reset_learning_rate(learning_rates))
|
|
||||||
|
|
||||||
if evals_result is not None:
|
if evals_result is not None:
|
||||||
callbacks.append(callback.record_evaluation(evals_result))
|
callbacks.append(callback.record_evaluation(evals_result))
|
||||||
|
|
||||||
|
if learning_rates is not None:
|
||||||
|
warnings.warn("learning_rates parameter is deprecated - use callback API instead",
|
||||||
|
DeprecationWarning)
|
||||||
|
callbacks.append(callback.reset_learning_rate(learning_rates))
|
||||||
|
|
||||||
return _train_internal(params, dtrain,
|
return _train_internal(params, dtrain,
|
||||||
num_boost_round=num_boost_round,
|
num_boost_round=num_boost_round,
|
||||||
evals=evals,
|
evals=evals,
|
||||||
@ -287,8 +288,8 @@ def aggcv(rlist):
|
|||||||
|
|
||||||
def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None,
|
def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None,
|
||||||
metrics=(), obj=None, feval=None, maximize=False, early_stopping_rounds=None,
|
metrics=(), obj=None, feval=None, maximize=False, early_stopping_rounds=None,
|
||||||
fpreproc=None, as_pandas=True, verbose_eval=None, show_stdv=True, seed=0,
|
fpreproc=None, as_pandas=True, verbose_eval=None, show_stdv=True,
|
||||||
callbacks=None):
|
seed=0, callbacks=None):
|
||||||
# pylint: disable = invalid-name
|
# pylint: disable = invalid-name
|
||||||
"""Cross-validation with given paramaters.
|
"""Cross-validation with given paramaters.
|
||||||
|
|
||||||
@ -336,6 +337,8 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None
|
|||||||
Seed used to generate the folds (passed to numpy.random.seed).
|
Seed used to generate the folds (passed to numpy.random.seed).
|
||||||
callbacks : list of callback functions
|
callbacks : list of callback functions
|
||||||
List of callback functions that are applied at end of each iteration.
|
List of callback functions that are applied at end of each iteration.
|
||||||
|
It is possible to use predefined callbacks by using xgb.callback module.
|
||||||
|
Example: [xgb.callback.reset_learning_rate(custom_rates)]
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
@ -372,6 +375,7 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None
|
|||||||
callbacks.append(callback.early_stop(early_stopping_rounds,
|
callbacks.append(callback.early_stop(early_stopping_rounds,
|
||||||
maximize=maximize,
|
maximize=maximize,
|
||||||
verbose=False))
|
verbose=False))
|
||||||
|
|
||||||
if isinstance(verbose_eval, bool) and verbose_eval:
|
if isinstance(verbose_eval, bool) and verbose_eval:
|
||||||
callbacks.append(callback.print_evaluation(show_stdv=show_stdv))
|
callbacks.append(callback.print_evaluation(show_stdv=show_stdv))
|
||||||
else:
|
else:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user