early stopping for Python wrapper

This commit is contained in:
Zygmunt Zając 2015-03-30 19:53:47 +02:00
parent 431277d5ca
commit f9e157011f

View File

@ -1,7 +1,10 @@
# coding: utf-8
""" """
xgboost: eXtreme Gradient Boosting library xgboost: eXtreme Gradient Boosting library
Authors: Tianqi Chen, Bing Xu Authors: Tianqi Chen, Bing Xu
Early stopping by Zygmunt Zając
""" """
from __future__ import absolute_import from __future__ import absolute_import
@ -529,6 +532,8 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None):
Data to be trained. Data to be trained.
num_boost_round: int num_boost_round: int
Number of boosting iterations. Number of boosting iterations.
If negative, train until validation error hasn't decreased in -num_boost_round rounds.
Requires at least one item in evals. If there's more than one, will use the last.
watchlist : list of pairs (DMatrix, string) watchlist : list of pairs (DMatrix, string)
List of items to be evaluated during training, this allows user to watch List of items to be evaluated during training, this allows user to watch
performance on the validation set. performance on the validation set.
@ -541,16 +546,73 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None):
------- -------
booster : a trained booster model booster : a trained booster model
""" """
if num_boost_round < 0 and len(evals) < 1:
raise ValueError('For early stopping you need at least on set in evals.')
evals = list(evals) evals = list(evals)
bst = Booster(params, [dtrain] + [d[0] for d in evals]) bst = Booster(params, [dtrain] + [d[0] for d in evals])
for i in range(num_boost_round):
bst.update(dtrain, i, obj) if num_boost_round >= 0:
if len(evals) != 0: for i in range(num_boost_round):
bst.update(dtrain, i, obj)
if len(evals) != 0:
bst_eval_set = bst.eval_set(evals, i, feval)
if isinstance(bst_eval_set, string_types):
sys.stderr.write(bst_eval_set + '\n')
else:
sys.stderr.write(bst_eval_set.decode() + '\n')
else:
# early stopping
# TODO: return model from the best iteration
sys.stderr.write("Will train until {} error hasn't decreased in {} rounds.\n".format(evals[-1][1], -num_boost_round))
# is params a list of tuples? are we using multiple eval metrics?
if type(params) == list:
if len(params) != len(dict(params).items()):
raise ValueError('Check your params. Early stopping works with single eval metric only.')
params = dict(params)
# either minimize loss or maximize AUC/MAP/NDCG
maximize_score = False
if 'eval_metric' in params:
maximize_metrics = ('auc', 'map', 'ndcg')
if filter( lambda x: params['eval_metric'].startswith(x), maximize_metrics ):
maximize_score = True
if maximize_score:
best_score = 0.0
else:
best_score = float('inf')
best_msg = ''
best_score_i = 0
i = 0
while True:
bst.update(dtrain, i, obj)
bst_eval_set = bst.eval_set(evals, i, feval) bst_eval_set = bst.eval_set(evals, i, feval)
if isinstance(bst_eval_set, string_types): if isinstance(bst_eval_set, string_types):
sys.stderr.write(bst_eval_set + '\n') msg = bst_eval_set
else: else:
sys.stderr.write(bst_eval_set.decode() + '\n') msg = bst_eval_set.decode()
sys.stderr.write(msg + '\n')
score = float(msg.rsplit( ':', 1 )[1])
if (maximize_score and score > best_score) or \
(not maximize_score and score < best_score):
best_score = score
best_score_i = i
best_msg = msg
elif i - best_score_i >= -num_boost_round:
sys.stderr.write("Stopping. Best iteration:\n{}".format(best_msg))
break
i += 1
return bst return bst