early_stopping_rounds for train() in Python wrapper 🔥
This commit is contained in:
parent
39093bc432
commit
d7f9499f88
@ -520,7 +520,7 @@ class Booster(object):
|
||||
return fmap
|
||||
|
||||
|
||||
def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None):
|
||||
def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None, early_stopping_rounds=None):
|
||||
"""
|
||||
Train a booster with given parameters.
|
||||
|
||||
@ -532,28 +532,31 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None):
|
||||
Data to be trained.
|
||||
num_boost_round: int
|
||||
Number of boosting iterations.
|
||||
If negative, train until validation error hasn't decreased in -num_boost_round rounds.
|
||||
Requires at least one item in evals. If there's more than one, will use the last.
|
||||
watchlist : list of pairs (DMatrix, string)
|
||||
List of items to be evaluated during training, this allows user to watch
|
||||
performance on the validation set.
|
||||
obj : function
|
||||
obj : function
|
||||
Customized objective function.
|
||||
feval : function
|
||||
Customized evaluation function.
|
||||
early_stopping_rounds: int
|
||||
Activates early stopping. Validation error needs to decrease at least
|
||||
every <early_stopping_rounds> round(s) to continue training.
|
||||
Requires at least one item in evals.
|
||||
If there's more than one, will use the last.
|
||||
Returns the model from the last iteration (not the best one).
|
||||
If early stopping occurs, the model will have two additional fields:
|
||||
bst.best_score and bst.best_iteration.
|
||||
|
||||
Returns
|
||||
-------
|
||||
booster : a trained booster model
|
||||
"""
|
||||
|
||||
if num_boost_round < 0 and len(evals) < 1:
|
||||
raise ValueError('For early stopping you need at least on set in evals.')
|
||||
|
||||
evals = list(evals)
|
||||
bst = Booster(params, [dtrain] + [d[0] for d in evals])
|
||||
|
||||
if num_boost_round >= 0:
|
||||
if not early_stopping_rounds:
|
||||
for i in range(num_boost_round):
|
||||
bst.update(dtrain, i, obj)
|
||||
if len(evals) != 0:
|
||||
@ -562,11 +565,15 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None):
|
||||
sys.stderr.write(bst_eval_set + '\n')
|
||||
else:
|
||||
sys.stderr.write(bst_eval_set.decode() + '\n')
|
||||
return bst
|
||||
|
||||
else:
|
||||
# early stopping
|
||||
|
||||
# TODO: return model from the best iteration
|
||||
sys.stderr.write("Will train until {} error hasn't decreased in {} rounds.\n".format(evals[-1][1], -num_boost_round))
|
||||
if len(evals) < 1:
|
||||
raise ValueError('For early stopping you need at least on set in evals.')
|
||||
|
||||
sys.stderr.write("Will train until {} error hasn't decreased in {} rounds.\n".format(evals[-1][1], early_stopping_rounds))
|
||||
|
||||
# is params a list of tuples? are we using multiple eval metrics?
|
||||
if type(params) == list:
|
||||
@ -588,9 +595,8 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None):
|
||||
|
||||
best_msg = ''
|
||||
best_score_i = 0
|
||||
i = 0
|
||||
|
||||
while True:
|
||||
for i in range(num_boost_round):
|
||||
bst.update(dtrain, i, obj)
|
||||
bst_eval_set = bst.eval_set(evals, i, feval)
|
||||
|
||||
@ -607,14 +613,15 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None):
|
||||
best_score = score
|
||||
best_score_i = i
|
||||
best_msg = msg
|
||||
elif i - best_score_i >= -num_boost_round:
|
||||
sys.stderr.write("Stopping. Best iteration:\n{}".format(best_msg))
|
||||
break
|
||||
|
||||
i += 1
|
||||
|
||||
return bst
|
||||
elif i - best_score_i >= early_stopping_rounds:
|
||||
sys.stderr.write("Stopping. Best iteration:\n{}\n\n".format(best_msg))
|
||||
bst.best_score = best_score
|
||||
bst.best_iteration = best_score_i
|
||||
return bst
|
||||
|
||||
return bst
|
||||
|
||||
|
||||
|
||||
class CVPack(object):
|
||||
def __init__(self, dtrain, dtest, param):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user