Merge pull request #598 from Far0n/py_train

best_ntree_limit attribute & training continuation bugfix
This commit is contained in:
Yuan (Terry) Tang
2015-11-12 06:16:19 -06:00
3 changed files with 81 additions and 30 deletions

View File

@@ -38,8 +38,8 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
Requires at least one item in evals.
If there's more than one, will use the last.
Returns the model from the last iteration (not the best one).
If early stopping occurs, the model will have two additional fields:
bst.best_score and bst.best_iteration.
If early stopping occurs, the model will have three additional fields:
bst.best_score, bst.best_iteration and bst.best_ntree_limit.
evals_result: dict
This dictionary stores the evaluation results of all the items in watchlist.
Example: with a watchlist containing [(dtest,'eval'), (dtrain,'train')] and
@@ -75,15 +75,24 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
params += [('eval_metric', eval_metric)]
bst = Booster(params, [dtrain] + [d[0] for d in evals])
ntrees = 0
nboost = 0
num_parallel_tree = 1
if xgb_model is not None:
if not isinstance(xgb_model, STRING_TYPES):
xgb_model = xgb_model.save_raw()
bst = Booster(params, [dtrain] + [d[0] for d in evals], model_file=xgb_model)
ntrees = len(bst.get_dump())
nboost = len(bst.get_dump())
else:
bst = Booster(params, [dtrain] + [d[0] for d in evals])
_params = dict(params) if isinstance(params, list) else params
if 'num_parallel_tree' in _params:
num_parallel_tree = _params['num_parallel_tree']
nboost //= num_parallel_tree
if 'num_class' in _params:
nboost //= _params['num_class']
if evals_result is not None:
if not isinstance(evals_result, dict):
raise TypeError('evals_result has to be a dictionary')
@@ -95,7 +104,7 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
if not early_stopping_rounds:
for i in range(num_boost_round):
bst.update(dtrain, i, obj)
ntrees += 1
nboost += 1
if len(evals) != 0:
bst_eval_set = bst.eval_set(evals, i, feval)
if isinstance(bst_eval_set, STRING_TYPES):
@@ -118,7 +127,8 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
evals_result[key][res_key].append(res_val)
else:
evals_result[key][res_key] = [res_val]
bst.best_iteration = (ntrees - 1)
bst.best_iteration = (nboost - 1)
bst.best_ntree_limit = nboost * num_parallel_tree
return bst
else:
@@ -154,7 +164,7 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
best_score = float('inf')
best_msg = ''
best_score_i = ntrees
best_score_i = (nboost - 1)
if isinstance(learning_rates, list) and len(learning_rates) != num_boost_round:
raise ValueError("Length of list 'learning_rates' has to equal 'num_boost_round'.")
@@ -166,7 +176,7 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
else:
bst.set_param({'eta': learning_rates(i, num_boost_round)})
bst.update(dtrain, i, obj)
ntrees += 1
nboost += 1
bst_eval_set = bst.eval_set(evals, i, feval)
if isinstance(bst_eval_set, STRING_TYPES):
@@ -195,7 +205,7 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
if (maximize_score and score > best_score) or \
(not maximize_score and score < best_score):
best_score = score
best_score_i = (ntrees - 1)
best_score_i = (nboost - 1)
best_msg = msg
elif i - best_score_i >= early_stopping_rounds:
sys.stderr.write("Stopping. Best iteration:\n{}\n\n".format(best_msg))
@@ -204,6 +214,7 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
break
bst.best_score = best_score
bst.best_iteration = best_score_i
bst.best_ntree_limit = (bst.best_iteration + 1) * num_parallel_tree
return bst