early stopping for Python wrapper

This commit is contained in:
Zygmunt Zając 2015-03-30 19:53:47 +02:00
parent 431277d5ca
commit f9e157011f

View File

@ -1,7 +1,10 @@
# coding: utf-8
"""
xgboost: eXtreme Gradient Boosting library
Authors: Tianqi Chen, Bing Xu
Early stopping by Zygmunt Zając
"""
from __future__ import absolute_import
@ -529,6 +532,8 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None):
Data to be trained.
num_boost_round: int
Number of boosting iterations.
If negative, train until validation error hasn't decreased in -num_boost_round rounds.
Requires at least one item in evals. If there's more than one, will use the last.
watchlist : list of pairs (DMatrix, string)
List of items to be evaluated during training, this allows user to watch
performance on the validation set.
@ -541,16 +546,73 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None):
-------
booster : a trained booster model
"""
if num_boost_round < 0 and len(evals) < 1:
raise ValueError('For early stopping you need at least on set in evals.')
evals = list(evals)
bst = Booster(params, [dtrain] + [d[0] for d in evals])
for i in range(num_boost_round):
bst.update(dtrain, i, obj)
if len(evals) != 0:
if num_boost_round >= 0:
for i in range(num_boost_round):
bst.update(dtrain, i, obj)
if len(evals) != 0:
bst_eval_set = bst.eval_set(evals, i, feval)
if isinstance(bst_eval_set, string_types):
sys.stderr.write(bst_eval_set + '\n')
else:
sys.stderr.write(bst_eval_set.decode() + '\n')
else:
# early stopping
# TODO: return model from the best iteration
sys.stderr.write("Will train until {} error hasn't decreased in {} rounds.\n".format(evals[-1][1], -num_boost_round))
# is params a list of tuples? are we using multiple eval metrics?
if type(params) == list:
if len(params) != len(dict(params).items()):
raise ValueError('Check your params. Early stopping works with single eval metric only.')
params = dict(params)
# either minimize loss or maximize AUC/MAP/NDCG
maximize_score = False
if 'eval_metric' in params:
maximize_metrics = ('auc', 'map', 'ndcg')
if filter( lambda x: params['eval_metric'].startswith(x), maximize_metrics ):
maximize_score = True
if maximize_score:
best_score = 0.0
else:
best_score = float('inf')
best_msg = ''
best_score_i = 0
i = 0
while True:
bst.update(dtrain, i, obj)
bst_eval_set = bst.eval_set(evals, i, feval)
if isinstance(bst_eval_set, string_types):
sys.stderr.write(bst_eval_set + '\n')
msg = bst_eval_set
else:
sys.stderr.write(bst_eval_set.decode() + '\n')
msg = bst_eval_set.decode()
sys.stderr.write(msg + '\n')
score = float(msg.rsplit( ':', 1 )[1])
if (maximize_score and score > best_score) or \
(not maximize_score and score < best_score):
best_score = score
best_score_i = i
best_msg = msg
elif i - best_score_i >= -num_boost_round:
sys.stderr.write("Stopping. Best iteration:\n{}".format(best_msg))
break
i += 1
return bst