early_stopping_rounds for train() in Python wrapper 🔥

This commit is contained in:
Zygmunt Zając 2015-04-02 19:43:30 +02:00
parent 39093bc432
commit d7f9499f88

View File

@ -520,7 +520,7 @@ class Booster(object):
return fmap
def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None):
def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None, early_stopping_rounds=None):
"""
Train a booster with given parameters.
@ -532,28 +532,31 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None):
Data to be trained.
num_boost_round: int
Number of boosting iterations.
If negative, train until validation error hasn't decreased in -num_boost_round rounds.
Requires at least one item in evals. If there's more than one, will use the last.
watchlist : list of pairs (DMatrix, string)
List of items to be evaluated during training, this allows user to watch
performance on the validation set.
obj : function
obj : function
Customized objective function.
feval : function
Customized evaluation function.
early_stopping_rounds: int
Activates early stopping. Validation error needs to decrease at least
every <early_stopping_rounds> round(s) to continue training.
Requires at least one item in evals.
If there's more than one, will use the last.
Returns the model from the last iteration (not the best one).
If early stopping occurs, the model will have two additional fields:
bst.best_score and bst.best_iteration.
Returns
-------
booster : a trained booster model
"""
if num_boost_round < 0 and len(evals) < 1:
raise ValueError('For early stopping you need at least on set in evals.')
evals = list(evals)
bst = Booster(params, [dtrain] + [d[0] for d in evals])
if num_boost_round >= 0:
if not early_stopping_rounds:
for i in range(num_boost_round):
bst.update(dtrain, i, obj)
if len(evals) != 0:
@ -562,11 +565,15 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None):
sys.stderr.write(bst_eval_set + '\n')
else:
sys.stderr.write(bst_eval_set.decode() + '\n')
return bst
else:
# early stopping
# TODO: return model from the best iteration
sys.stderr.write("Will train until {} error hasn't decreased in {} rounds.\n".format(evals[-1][1], -num_boost_round))
if len(evals) < 1:
raise ValueError('For early stopping you need at least on set in evals.')
sys.stderr.write("Will train until {} error hasn't decreased in {} rounds.\n".format(evals[-1][1], early_stopping_rounds))
# is params a list of tuples? are we using multiple eval metrics?
if type(params) == list:
@ -588,9 +595,8 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None):
best_msg = ''
best_score_i = 0
i = 0
while True:
for i in range(num_boost_round):
bst.update(dtrain, i, obj)
bst_eval_set = bst.eval_set(evals, i, feval)
@ -607,14 +613,15 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None):
best_score = score
best_score_i = i
best_msg = msg
elif i - best_score_i >= -num_boost_round:
sys.stderr.write("Stopping. Best iteration:\n{}".format(best_msg))
break
i += 1
return bst
elif i - best_score_i >= early_stopping_rounds:
sys.stderr.write("Stopping. Best iteration:\n{}\n\n".format(best_msg))
bst.best_score = best_score
bst.best_iteration = best_score_i
return bst
return bst
class CVPack(object):
def __init__(self, dtrain, dtest, param):