Merge pull request #212 from zygmuntz/master
Early stopping for Python wrapper
This commit is contained in:
commit
9b0dee986f
@ -1,7 +1,10 @@
|
|||||||
|
# coding: utf-8
|
||||||
|
|
||||||
"""
|
"""
|
||||||
xgboost: eXtreme Gradient Boosting library
|
xgboost: eXtreme Gradient Boosting library
|
||||||
|
|
||||||
Authors: Tianqi Chen, Bing Xu
|
Authors: Tianqi Chen, Bing Xu
|
||||||
|
Early stopping by Zygmunt Zając
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
@ -527,7 +530,7 @@ class Booster(object):
|
|||||||
return fmap
|
return fmap
|
||||||
|
|
||||||
|
|
||||||
def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None):
|
def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None, early_stopping_rounds=None):
|
||||||
"""
|
"""
|
||||||
Train a booster with given parameters.
|
Train a booster with given parameters.
|
||||||
|
|
||||||
@ -542,27 +545,93 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None):
|
|||||||
watchlist : list of pairs (DMatrix, string)
|
watchlist : list of pairs (DMatrix, string)
|
||||||
List of items to be evaluated during training, this allows user to watch
|
List of items to be evaluated during training, this allows user to watch
|
||||||
performance on the validation set.
|
performance on the validation set.
|
||||||
obj : function
|
obj : function
|
||||||
Customized objective function.
|
Customized objective function.
|
||||||
feval : function
|
feval : function
|
||||||
Customized evaluation function.
|
Customized evaluation function.
|
||||||
|
early_stopping_rounds: int
|
||||||
|
Activates early stopping. Validation error needs to decrease at least
|
||||||
|
every <early_stopping_rounds> round(s) to continue training.
|
||||||
|
Requires at least one item in evals.
|
||||||
|
If there's more than one, will use the last.
|
||||||
|
Returns the model from the last iteration (not the best one).
|
||||||
|
If early stopping occurs, the model will have two additional fields:
|
||||||
|
bst.best_score and bst.best_iteration.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
booster : a trained booster model
|
booster : a trained booster model
|
||||||
"""
|
"""
|
||||||
|
|
||||||
evals = list(evals)
|
evals = list(evals)
|
||||||
bst = Booster(params, [dtrain] + [d[0] for d in evals])
|
bst = Booster(params, [dtrain] + [d[0] for d in evals])
|
||||||
for i in range(num_boost_round):
|
|
||||||
bst.update(dtrain, i, obj)
|
if not early_stopping_rounds:
|
||||||
if len(evals) != 0:
|
for i in range(num_boost_round):
|
||||||
bst_eval_set = bst.eval_set(evals, i, feval)
|
bst.update(dtrain, i, obj)
|
||||||
if isinstance(bst_eval_set, string_types):
|
if len(evals) != 0:
|
||||||
sys.stderr.write(bst_eval_set + '\n')
|
bst_eval_set = bst.eval_set(evals, i, feval)
|
||||||
else:
|
if isinstance(bst_eval_set, string_types):
|
||||||
sys.stderr.write(bst_eval_set.decode() + '\n')
|
sys.stderr.write(bst_eval_set + '\n')
|
||||||
return bst
|
else:
|
||||||
|
sys.stderr.write(bst_eval_set.decode() + '\n')
|
||||||
|
return bst
|
||||||
|
|
||||||
|
else:
|
||||||
|
# early stopping
|
||||||
|
|
||||||
|
if len(evals) < 1:
|
||||||
|
raise ValueError('For early stopping you need at least on set in evals.')
|
||||||
|
|
||||||
|
sys.stderr.write("Will train until {} error hasn't decreased in {} rounds.\n".format(evals[-1][1], early_stopping_rounds))
|
||||||
|
|
||||||
|
# is params a list of tuples? are we using multiple eval metrics?
|
||||||
|
if type(params) == list:
|
||||||
|
if len(params) != len(dict(params).items()):
|
||||||
|
raise ValueError('Check your params. Early stopping works with single eval metric only.')
|
||||||
|
params = dict(params)
|
||||||
|
|
||||||
|
# either minimize loss or maximize AUC/MAP/NDCG
|
||||||
|
maximize_score = False
|
||||||
|
if 'eval_metric' in params:
|
||||||
|
maximize_metrics = ('auc', 'map', 'ndcg')
|
||||||
|
if filter(lambda x: params['eval_metric'].startswith(x), maximize_metrics):
|
||||||
|
maximize_score = True
|
||||||
|
|
||||||
|
if maximize_score:
|
||||||
|
best_score = 0.0
|
||||||
|
else:
|
||||||
|
best_score = float('inf')
|
||||||
|
|
||||||
|
best_msg = ''
|
||||||
|
best_score_i = 0
|
||||||
|
|
||||||
|
for i in range(num_boost_round):
|
||||||
|
bst.update(dtrain, i, obj)
|
||||||
|
bst_eval_set = bst.eval_set(evals, i, feval)
|
||||||
|
|
||||||
|
if isinstance(bst_eval_set, string_types):
|
||||||
|
msg = bst_eval_set
|
||||||
|
else:
|
||||||
|
msg = bst_eval_set.decode()
|
||||||
|
|
||||||
|
sys.stderr.write(msg + '\n')
|
||||||
|
|
||||||
|
score = float(msg.rsplit(':', 1)[1])
|
||||||
|
if (maximize_score and score > best_score) or \
|
||||||
|
(not maximize_score and score < best_score):
|
||||||
|
best_score = score
|
||||||
|
best_score_i = i
|
||||||
|
best_msg = msg
|
||||||
|
elif i - best_score_i >= early_stopping_rounds:
|
||||||
|
sys.stderr.write("Stopping. Best iteration:\n{}\n\n".format(best_msg))
|
||||||
|
bst.best_score = best_score
|
||||||
|
bst.best_iteration = best_score_i
|
||||||
|
return bst
|
||||||
|
|
||||||
|
return bst
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class CVPack(object):
|
class CVPack(object):
|
||||||
def __init__(self, dtrain, dtest, param):
|
def __init__(self, dtrain, dtest, param):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user