early_stopping_rounds for train() in Python wrapper 🔥

This commit is contained in:
Zygmunt Zając 2015-04-02 19:43:30 +02:00
parent 39093bc432
commit d7f9499f88

View File

@ -520,7 +520,7 @@ class Booster(object):
return fmap return fmap
def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None): def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None, early_stopping_rounds=None):
""" """
Train a booster with given parameters. Train a booster with given parameters.
@ -532,28 +532,31 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None):
Data to be trained. Data to be trained.
num_boost_round: int num_boost_round: int
Number of boosting iterations. Number of boosting iterations.
If negative, train until validation error hasn't decreased in -num_boost_round rounds.
Requires at least one item in evals. If there's more than one, will use the last.
watchlist : list of pairs (DMatrix, string) watchlist : list of pairs (DMatrix, string)
List of items to be evaluated during training, this allows user to watch List of items to be evaluated during training, this allows user to watch
performance on the validation set. performance on the validation set.
obj : function obj : function
Customized objective function. Customized objective function.
feval : function feval : function
Customized evaluation function. Customized evaluation function.
early_stopping_rounds: int
Activates early stopping. Validation error needs to decrease at least
every <early_stopping_rounds> round(s) to continue training.
Requires at least one item in evals.
If there's more than one, will use the last.
Returns the model from the last iteration (not the best one).
If early stopping occurs, the model will have two additional fields:
bst.best_score and bst.best_iteration.
Returns Returns
------- -------
booster : a trained booster model booster : a trained booster model
""" """
if num_boost_round < 0 and len(evals) < 1:
raise ValueError('For early stopping you need at least on set in evals.')
evals = list(evals) evals = list(evals)
bst = Booster(params, [dtrain] + [d[0] for d in evals]) bst = Booster(params, [dtrain] + [d[0] for d in evals])
if num_boost_round >= 0: if not early_stopping_rounds:
for i in range(num_boost_round): for i in range(num_boost_round):
bst.update(dtrain, i, obj) bst.update(dtrain, i, obj)
if len(evals) != 0: if len(evals) != 0:
@ -562,11 +565,15 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None):
sys.stderr.write(bst_eval_set + '\n') sys.stderr.write(bst_eval_set + '\n')
else: else:
sys.stderr.write(bst_eval_set.decode() + '\n') sys.stderr.write(bst_eval_set.decode() + '\n')
return bst
else: else:
# early stopping # early stopping
# TODO: return model from the best iteration if len(evals) < 1:
sys.stderr.write("Will train until {} error hasn't decreased in {} rounds.\n".format(evals[-1][1], -num_boost_round)) raise ValueError('For early stopping you need at least on set in evals.')
sys.stderr.write("Will train until {} error hasn't decreased in {} rounds.\n".format(evals[-1][1], early_stopping_rounds))
# is params a list of tuples? are we using multiple eval metrics? # is params a list of tuples? are we using multiple eval metrics?
if type(params) == list: if type(params) == list:
@ -588,9 +595,8 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None):
best_msg = '' best_msg = ''
best_score_i = 0 best_score_i = 0
i = 0
while True: for i in range(num_boost_round):
bst.update(dtrain, i, obj) bst.update(dtrain, i, obj)
bst_eval_set = bst.eval_set(evals, i, feval) bst_eval_set = bst.eval_set(evals, i, feval)
@ -607,14 +613,15 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None):
best_score = score best_score = score
best_score_i = i best_score_i = i
best_msg = msg best_msg = msg
elif i - best_score_i >= -num_boost_round: elif i - best_score_i >= early_stopping_rounds:
sys.stderr.write("Stopping. Best iteration:\n{}".format(best_msg)) sys.stderr.write("Stopping. Best iteration:\n{}\n\n".format(best_msg))
break bst.best_score = best_score
bst.best_iteration = best_score_i
i += 1 return bst
return bst return bst
class CVPack(object): class CVPack(object):
def __init__(self, dtrain, dtest, param): def __init__(self, dtrain, dtest, param):