best_ntree_limit attribute added

- best_ntree_limit as new booster atrribute added
- usage of bst.best_ntree_limit in python doc added
- fixed wrong 'best_iteration' after training continuation
This commit is contained in:
Far0n 2015-11-04 10:06:18 +01:00
parent f91ce704f3
commit ce5930c365
3 changed files with 51 additions and 24 deletions

View File

@ -121,7 +121,7 @@ Early stopping requires at least one set in `evals`. If there's more than one, i
The model will train until the validation score stops improving. Validation error needs to decrease at least every `early_stopping_rounds` to continue training. The model will train until the validation score stops improving. Validation error needs to decrease at least every `early_stopping_rounds` to continue training.
If early stopping occurs, the model will have two additional fields: `bst.best_score` and `bst.best_iteration`. Note that `train()` will return a model from the last iteration, not the best one. If early stopping occurs, the model will have three additional fields: `bst.best_score`, `bst.best_iteration` and `bst.best_ntree_limit`. Note that `train()` will return a model from the last iteration, not the best one.
This works with both metrics to minimize (RMSE, log loss, etc.) and to maximize (MAP, NDCG, AUC). Note that if you specify more than one evaluation metric the last one in `param['eval_metric']` is used for early stopping. This works with both metrics to minimize (RMSE, log loss, etc.) and to maximize (MAP, NDCG, AUC). Note that if you specify more than one evaluation metric the last one in `param['eval_metric']` is used for early stopping.
@ -135,9 +135,9 @@ dtest = xgb.DMatrix(data)
ypred = bst.predict(xgmat) ypred = bst.predict(xgmat)
``` ```
If early stopping is enabled during training, you can predict with the best iteration. If early stopping is enabled during training, you can get predicticions from the best iteration with `bst.best_ntree_limit`:
```python ```python
ypred = bst.predict(xgmat,ntree_limit=bst.best_iteration) ypred = bst.predict(xgmat,ntree_limit=bst.best_ntree_limit)
``` ```
Plotting Plotting

View File

@ -38,8 +38,8 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
Requires at least one item in evals. Requires at least one item in evals.
If there's more than one, will use the last. If there's more than one, will use the last.
Returns the model from the last iteration (not the best one). Returns the model from the last iteration (not the best one).
If early stopping occurs, the model will have two additional fields: If early stopping occurs, the model will have three additional fields:
bst.best_score and bst.best_iteration. bst.best_score, bst.best_iteration and bst.best_ntree_limit.
evals_result: dict evals_result: dict
This dictionary stores the evaluation results of all the items in watchlist. This dictionary stores the evaluation results of all the items in watchlist.
Example: with a watchlist containing [(dtest,'eval'), (dtrain,'train')] and Example: with a watchlist containing [(dtest,'eval'), (dtrain,'train')] and
@ -75,15 +75,24 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
params += [('eval_metric', eval_metric)] params += [('eval_metric', eval_metric)]
bst = Booster(params, [dtrain] + [d[0] for d in evals]) bst = Booster(params, [dtrain] + [d[0] for d in evals])
ntrees = 0 nboost = 0
num_parallel_tree = 1
if xgb_model is not None: if xgb_model is not None:
if not isinstance(xgb_model, STRING_TYPES): if not isinstance(xgb_model, STRING_TYPES):
xgb_model = xgb_model.save_raw() xgb_model = xgb_model.save_raw()
bst = Booster(params, [dtrain] + [d[0] for d in evals], model_file=xgb_model) bst = Booster(params, [dtrain] + [d[0] for d in evals], model_file=xgb_model)
ntrees = len(bst.get_dump()) nboost = len(bst.get_dump())
else: else:
bst = Booster(params, [dtrain] + [d[0] for d in evals]) bst = Booster(params, [dtrain] + [d[0] for d in evals])
_params = dict(params) if isinstance(params, list) else params
if 'num_parallel_tree' in _params:
num_parallel_tree = _params['num_parallel_tree']
nboost //= num_parallel_tree
if 'num_class' in _params:
nboost //= _params['num_class']
if evals_result is not None: if evals_result is not None:
if not isinstance(evals_result, dict): if not isinstance(evals_result, dict):
raise TypeError('evals_result has to be a dictionary') raise TypeError('evals_result has to be a dictionary')
@ -95,7 +104,7 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
if not early_stopping_rounds: if not early_stopping_rounds:
for i in range(num_boost_round): for i in range(num_boost_round):
bst.update(dtrain, i, obj) bst.update(dtrain, i, obj)
ntrees += 1 nboost += 1
if len(evals) != 0: if len(evals) != 0:
bst_eval_set = bst.eval_set(evals, i, feval) bst_eval_set = bst.eval_set(evals, i, feval)
if isinstance(bst_eval_set, STRING_TYPES): if isinstance(bst_eval_set, STRING_TYPES):
@ -118,7 +127,8 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
evals_result[key][res_key].append(res_val) evals_result[key][res_key].append(res_val)
else: else:
evals_result[key][res_key] = [res_val] evals_result[key][res_key] = [res_val]
bst.best_iteration = (ntrees - 1) bst.best_iteration = (nboost - 1)
bst.best_ntree_limit = nboost * num_parallel_tree
return bst return bst
else: else:
@ -154,7 +164,7 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
best_score = float('inf') best_score = float('inf')
best_msg = '' best_msg = ''
best_score_i = ntrees best_score_i = (nboost - 1)
if isinstance(learning_rates, list) and len(learning_rates) != num_boost_round: if isinstance(learning_rates, list) and len(learning_rates) != num_boost_round:
raise ValueError("Length of list 'learning_rates' has to equal 'num_boost_round'.") raise ValueError("Length of list 'learning_rates' has to equal 'num_boost_round'.")
@ -166,7 +176,7 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
else: else:
bst.set_param({'eta': learning_rates(i, num_boost_round)}) bst.set_param({'eta': learning_rates(i, num_boost_round)})
bst.update(dtrain, i, obj) bst.update(dtrain, i, obj)
ntrees += 1 nboost += 1
bst_eval_set = bst.eval_set(evals, i, feval) bst_eval_set = bst.eval_set(evals, i, feval)
if isinstance(bst_eval_set, STRING_TYPES): if isinstance(bst_eval_set, STRING_TYPES):
@ -195,7 +205,7 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
if (maximize_score and score > best_score) or \ if (maximize_score and score > best_score) or \
(not maximize_score and score < best_score): (not maximize_score and score < best_score):
best_score = score best_score = score
best_score_i = (ntrees - 1) best_score_i = (nboost - 1)
best_msg = msg best_msg = msg
elif i - best_score_i >= early_stopping_rounds: elif i - best_score_i >= early_stopping_rounds:
sys.stderr.write("Stopping. Best iteration:\n{}\n\n".format(best_msg)) sys.stderr.write("Stopping. Best iteration:\n{}\n\n".format(best_msg))
@ -204,6 +214,7 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
break break
bst.best_score = best_score bst.best_score = best_score
bst.best_iteration = best_score_i bst.best_iteration = best_score_i
bst.best_ntree_limit = (bst.best_iteration + 1) * num_parallel_tree
return bst return bst

View File

@ -8,30 +8,37 @@ import unittest
rng = np.random.RandomState(1337) rng = np.random.RandomState(1337)
class TestTrainingContinuation(unittest.TestCase):
xgb_params = { class TestTrainingContinuation(unittest.TestCase):
'colsample_bytree': 0.7, num_parallel_tree = 3
xgb_params_01 = {
'silent': 1, 'silent': 1,
'nthread': 1, 'nthread': 1,
} }
xgb_params_02 = {
'silent': 1,
'nthread': 1,
'num_parallel_tree': num_parallel_tree
}
def test_training_continuation(self): def test_training_continuation(self):
digits = load_digits(2) digits = load_digits(2)
X = digits['data'] X = digits['data']
y = digits['target'] y = digits['target']
dtrain = xgb.DMatrix(X,label=y) dtrain = xgb.DMatrix(X, label=y)
gbdt_01 = xgb.train(self.xgb_params, dtrain, num_boost_round=10) gbdt_01 = xgb.train(self.xgb_params_01, dtrain, num_boost_round=10)
ntrees_01 = len(gbdt_01.get_dump()) ntrees_01 = len(gbdt_01.get_dump())
assert ntrees_01 == 10 assert ntrees_01 == 10
gbdt_02 = xgb.train(self.xgb_params, dtrain, num_boost_round=0) gbdt_02 = xgb.train(self.xgb_params_01, dtrain, num_boost_round=0)
gbdt_02.save_model('xgb_tc.model') gbdt_02.save_model('xgb_tc.model')
gbdt_02a = xgb.train(self.xgb_params, dtrain, num_boost_round=10, xgb_model=gbdt_02) gbdt_02a = xgb.train(self.xgb_params_01, dtrain, num_boost_round=10, xgb_model=gbdt_02)
gbdt_02b = xgb.train(self.xgb_params, dtrain, num_boost_round=10, xgb_model="xgb_tc.model") gbdt_02b = xgb.train(self.xgb_params_01, dtrain, num_boost_round=10, xgb_model="xgb_tc.model")
ntrees_02a = len(gbdt_02a.get_dump()) ntrees_02a = len(gbdt_02a.get_dump())
ntrees_02b = len(gbdt_02b.get_dump()) ntrees_02b = len(gbdt_02b.get_dump())
assert ntrees_02a == 10 assert ntrees_02a == 10
@ -39,14 +46,23 @@ class TestTrainingContinuation(unittest.TestCase):
assert mean_squared_error(y, gbdt_01.predict(dtrain)) == mean_squared_error(y, gbdt_02a.predict(dtrain)) assert mean_squared_error(y, gbdt_01.predict(dtrain)) == mean_squared_error(y, gbdt_02a.predict(dtrain))
assert mean_squared_error(y, gbdt_01.predict(dtrain)) == mean_squared_error(y, gbdt_02b.predict(dtrain)) assert mean_squared_error(y, gbdt_01.predict(dtrain)) == mean_squared_error(y, gbdt_02b.predict(dtrain))
gbdt_03 = xgb.train(self.xgb_params, dtrain, num_boost_round=3) gbdt_03 = xgb.train(self.xgb_params_01, dtrain, num_boost_round=3)
gbdt_03.save_model('xgb_tc.model') gbdt_03.save_model('xgb_tc.model')
gbdt_03a = xgb.train(self.xgb_params, dtrain, num_boost_round=7, xgb_model=gbdt_03) gbdt_03a = xgb.train(self.xgb_params_01, dtrain, num_boost_round=7, xgb_model=gbdt_03)
gbdt_03b = xgb.train(self.xgb_params, dtrain, num_boost_round=7, xgb_model="xgb_tc.model") gbdt_03b = xgb.train(self.xgb_params_01, dtrain, num_boost_round=7, xgb_model="xgb_tc.model")
ntrees_03a = len(gbdt_03a.get_dump()) ntrees_03a = len(gbdt_03a.get_dump())
ntrees_03b = len(gbdt_03b.get_dump()) ntrees_03b = len(gbdt_03b.get_dump())
assert ntrees_03a == 10 assert ntrees_03a == 10
assert ntrees_03b == 10 assert ntrees_03b == 10
assert mean_squared_error(y, gbdt_03a.predict(dtrain)) == mean_squared_error(y, gbdt_03b.predict(dtrain)) assert mean_squared_error(y, gbdt_03a.predict(dtrain)) == mean_squared_error(y, gbdt_03b.predict(dtrain))
gbdt_04 = xgb.train(self.xgb_params_02, dtrain, num_boost_round=3)
assert gbdt_04.best_ntree_limit == (gbdt_04.best_iteration + 1) * self.num_parallel_tree
assert mean_squared_error(y, gbdt_04.predict(dtrain)) == \
mean_squared_error(y, gbdt_04.predict(dtrain, ntree_limit=gbdt_04.best_ntree_limit))
gbdt_04 = xgb.train(self.xgb_params_02, dtrain, num_boost_round=7, xgb_model=gbdt_04)
assert gbdt_04.best_ntree_limit == (gbdt_04.best_iteration + 1) * self.num_parallel_tree
assert mean_squared_error(y, gbdt_04.predict(dtrain)) == \
mean_squared_error(y, gbdt_04.predict(dtrain, ntree_limit=gbdt_04.best_ntree_limit))