218 lines
7.5 KiB
Python

# coding: utf-8
# pylint: disable= invalid-name
"""Training Library containing training routines."""
from __future__ import absolute_import
from . import rabit
from .core import EarlyStopException
def _fmt_metric(value, show_stdv=True):
"""format metric string"""
if len(value) == 2:
return '%s:%g' % (value[0], value[1])
elif len(value) == 3:
if show_stdv:
return '%s:%g+%g' % (value[0], value[1], value[2])
else:
return '%s:%g' % (value[0], value[1])
else:
raise ValueError("wrong metric value")
def print_evaluation(period=1, show_stdv=True):
"""Create a callback that print evaluation result.
Parameters
----------
period : int
The period to log the evaluation results
show_stdv : bool, optional
Whether show stdv if provided
Returns
-------
callback : function
A callback that print evaluation every period iterations.
"""
def callback(env):
"""internal function"""
if env.rank != 0 or len(env.evaluation_result_list) == 0:
return
i = env.iteration
if (i % period == 0 or i + 1 == env.begin_iteration):
msg = '\t'.join([_fmt_metric(x, show_stdv) for x in env.evaluation_result_list])
rabit.tracker_print('[%d]\t%s\n' % (i, msg))
return callback
def record_evaluation(eval_result):
"""Create a call back that records the evaluation history into eval_result.
Parameters
----------
eval_result : dict
A dictionary to store the evaluation results.
Returns
-------
callback : function
The requested callback function.
"""
if not isinstance(eval_result, dict):
raise TypeError('eval_result has to be a dictionary')
eval_result.clear()
def init(env):
"""internal function"""
for k, _ in env.evaluation_result_list:
key, metric = k.split('-')
if key not in eval_result:
eval_result[key] = {}
if metric not in eval_result[key]:
eval_result[key][metric] = []
def callback(env):
"""internal function"""
if len(eval_result) == 0:
init(env)
for k, v in env.evaluation_result_list:
key, metric = k.split('-')
eval_result[key][metric].append(v)
return callback
def reset_learning_rate(learning_rates):
"""Reset learning rate after iteration 1
NOTE: the initial learning rate will still take in-effect on first iteration.
Parameters
----------
learning_rates: list or function
List of learning rate for each boosting round
or a customized function that calculates eta in terms of
current number of round and the total number of boosting round (e.g. yields
learning rate decay)
- list l: eta = l[boosting round]
- function f: eta = f(boosting round, num_boost_round)
Returns
-------
callback : function
The requested callback function.
"""
def callback(env):
"""internal function"""
bst = env.model
i = env.iteration
if isinstance(learning_rates, list):
if len(learning_rates) != env.end_iteration:
raise ValueError("Length of list 'learning_rates' has to equal 'num_boost_round'.")
bst.set_param('learning_rate', learning_rates[i])
else:
bst.set_param('learning_rate', learning_rates(i, env.end_iteration))
callback.before_iteration = True
return callback
def early_stop(stopping_rounds, maximize=False, verbose=True):
"""Create a callback that activates early stoppping.
Validation error needs to decrease at least
every <stopping_rounds> round(s) to continue training.
Requires at least one item in evals.
If there's more than one, will use the last.
Returns the model from the last iteration (not the best one).
If early stopping occurs, the model will have three additional fields:
bst.best_score, bst.best_iteration and bst.best_ntree_limit.
(Use bst.best_ntree_limit to get the correct value if num_parallel_tree
and/or num_class appears in the parameters)
Parameters
----------
stopp_rounds : int
The stopping rounds before the trend occur.
maximize : bool
Whether to maximize evaluation metric.
verbose : optional, bool
Whether to print message about early stopping information.
Returns
-------
callback : function
The requested callback function.
"""
state = {}
def init(env):
"""internal function"""
bst = env.model
if len(env.evaluation_result_list) == 0:
raise ValueError('For early stopping you need at least one set in evals.')
if len(env.evaluation_result_list) > 1 and verbose:
msg = ("Multiple eval metrics have been passed: "
"'{0}' will be used for early stopping.\n\n")
rabit.tracker_print(msg.format(env.evaluation_result_list[-1][0]))
maximize_metrics = ('auc', 'map', 'ndcg')
maximize_score = maximize
metric = env.evaluation_result_list[-1][0]
if any(env.evaluation_result_list[-1][0].split('-')[1].startswith(x)
for x in maximize_metrics):
maximize_score = True
if verbose and env.rank == 0:
msg = "Will train until {} hasn't improved in {} rounds.\n"
rabit.tracker_print(msg.format(metric, stopping_rounds))
state['maximize_score'] = maximize_score
state['best_iteration'] = 0
if maximize_score:
state['best_score'] = float('-inf')
else:
state['best_score'] = float('inf')
if bst is not None:
if bst.attr('best_score') is not None:
state['best_score'] = float(bst.attr('best_score'))
state['best_iteration'] = int(bst.attr('best_iteration'))
state['best_msg'] = bst.attr('best_msg')
else:
bst.set_attr(best_iteration=str(state['best_iteration']))
bst.set_attr(best_score=str(state['best_score']))
else:
assert env.cvfolds is not None
def callback(env):
"""internal function"""
score = env.evaluation_result_list[-1][1]
if len(state) == 0:
init(env)
best_score = state['best_score']
best_iteration = state['best_iteration']
maximize_score = state['maximize_score']
if (maximize_score and score > best_score) or \
(not maximize_score and score < best_score):
msg = '[%d]\t%s' % (
env.iteration,
'\t'.join([_fmt_metric(x) for x in env.evaluation_result_list]))
state['best_msg'] = msg
state['best_score'] = score
state['best_iteration'] = env.iteration
# save the property to attributes, so they will occur in checkpoint.
if env.model is not None:
env.model.set_attr(best_score=str(state['best_score']),
best_iteration=str(state['best_iteration']),
best_msg=state['best_msg'])
elif env.iteration - best_iteration >= stopping_rounds:
best_msg = state['best_msg']
if verbose and env.rank == 0:
msg = "Stopping. Best iteration:\n{}\n\n"
rabit.tracker_print(msg.format(best_msg))
raise EarlyStopException(best_iteration)
return callback