Merge pull request #212 from zygmuntz/master

Early stopping for Python wrapper
This commit is contained in:
Tianqi Chen 2015-04-02 17:31:44 -07:00
commit 9b0dee986f

View File

@ -1,7 +1,10 @@
# coding: utf-8
"""
xgboost: eXtreme Gradient Boosting library
Authors: Tianqi Chen, Bing Xu
Early stopping by Zygmunt Zając
"""
from __future__ import absolute_import
@ -527,7 +530,7 @@ class Booster(object):
return fmap
def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None):
def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None, early_stopping_rounds=None):
"""
Train a booster with given parameters.
@ -542,27 +545,93 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None):
watchlist : list of pairs (DMatrix, string)
List of items to be evaluated during training, this allows user to watch
performance on the validation set.
obj : function
obj : function
Customized objective function.
feval : function
Customized evaluation function.
early_stopping_rounds: int
Activates early stopping. Validation error needs to decrease at least
every <early_stopping_rounds> round(s) to continue training.
Requires at least one item in evals.
If there's more than one, will use the last.
Returns the model from the last iteration (not the best one).
If early stopping occurs, the model will have two additional fields:
bst.best_score and bst.best_iteration.
Returns
-------
booster : a trained booster model
"""
evals = list(evals)
bst = Booster(params, [dtrain] + [d[0] for d in evals])
for i in range(num_boost_round):
bst.update(dtrain, i, obj)
if len(evals) != 0:
bst_eval_set = bst.eval_set(evals, i, feval)
if isinstance(bst_eval_set, string_types):
sys.stderr.write(bst_eval_set + '\n')
else:
sys.stderr.write(bst_eval_set.decode() + '\n')
return bst
if not early_stopping_rounds:
for i in range(num_boost_round):
bst.update(dtrain, i, obj)
if len(evals) != 0:
bst_eval_set = bst.eval_set(evals, i, feval)
if isinstance(bst_eval_set, string_types):
sys.stderr.write(bst_eval_set + '\n')
else:
sys.stderr.write(bst_eval_set.decode() + '\n')
return bst
else:
# early stopping
if len(evals) < 1:
raise ValueError('For early stopping you need at least on set in evals.')
sys.stderr.write("Will train until {} error hasn't decreased in {} rounds.\n".format(evals[-1][1], early_stopping_rounds))
# is params a list of tuples? are we using multiple eval metrics?
if type(params) == list:
if len(params) != len(dict(params).items()):
raise ValueError('Check your params. Early stopping works with single eval metric only.')
params = dict(params)
# either minimize loss or maximize AUC/MAP/NDCG
maximize_score = False
if 'eval_metric' in params:
maximize_metrics = ('auc', 'map', 'ndcg')
if filter(lambda x: params['eval_metric'].startswith(x), maximize_metrics):
maximize_score = True
if maximize_score:
best_score = 0.0
else:
best_score = float('inf')
best_msg = ''
best_score_i = 0
for i in range(num_boost_round):
bst.update(dtrain, i, obj)
bst_eval_set = bst.eval_set(evals, i, feval)
if isinstance(bst_eval_set, string_types):
msg = bst_eval_set
else:
msg = bst_eval_set.decode()
sys.stderr.write(msg + '\n')
score = float(msg.rsplit(':', 1)[1])
if (maximize_score and score > best_score) or \
(not maximize_score and score < best_score):
best_score = score
best_score_i = i
best_msg = msg
elif i - best_score_i >= early_stopping_rounds:
sys.stderr.write("Stopping. Best iteration:\n{}\n\n".format(best_msg))
bst.best_score = best_score
bst.best_iteration = best_score_i
return bst
return bst
class CVPack(object):
def __init__(self, dtrain, dtest, param):