From ce5930c3656cb5b8e8e0b958528ce90536e57bdc Mon Sep 17 00:00:00 2001 From: Far0n Date: Wed, 4 Nov 2015 10:06:18 +0100 Subject: [PATCH 1/2] best_ntree_limit attribute added - best_ntree_limit as new booster atrribute added - usage of bst.best_ntree_limit in python doc added - fixed wrong 'best_iteration' after training continuation --- doc/python/python_intro.md | 6 ++-- python-package/xgboost/training.py | 29 +++++++++++----- tests/python/test_training_continuation.py | 40 +++++++++++++++------- 3 files changed, 51 insertions(+), 24 deletions(-) diff --git a/doc/python/python_intro.md b/doc/python/python_intro.md index c0a269a83..9e07d3c73 100644 --- a/doc/python/python_intro.md +++ b/doc/python/python_intro.md @@ -121,7 +121,7 @@ Early stopping requires at least one set in `evals`. If there's more than one, i The model will train until the validation score stops improving. Validation error needs to decrease at least every `early_stopping_rounds` to continue training. -If early stopping occurs, the model will have two additional fields: `bst.best_score` and `bst.best_iteration`. Note that `train()` will return a model from the last iteration, not the best one. +If early stopping occurs, the model will have three additional fields: `bst.best_score`, `bst.best_iteration` and `bst.best_ntree_limit`. Note that `train()` will return a model from the last iteration, not the best one. This works with both metrics to minimize (RMSE, log loss, etc.) and to maximize (MAP, NDCG, AUC). Note that if you specify more than one evaluation metric the last one in `param['eval_metric']` is used for early stopping. @@ -135,9 +135,9 @@ dtest = xgb.DMatrix(data) ypred = bst.predict(xgmat) ``` -If early stopping is enabled during training, you can predict with the best iteration. +If early stopping is enabled during training, you can get predicticions from the best iteration with `bst.best_ntree_limit`: ```python -ypred = bst.predict(xgmat,ntree_limit=bst.best_iteration) +ypred = bst.predict(xgmat,ntree_limit=bst.best_ntree_limit) ``` Plotting diff --git a/python-package/xgboost/training.py b/python-package/xgboost/training.py index 5110295ad..f3aceaf48 100644 --- a/python-package/xgboost/training.py +++ b/python-package/xgboost/training.py @@ -38,8 +38,8 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None, Requires at least one item in evals. If there's more than one, will use the last. Returns the model from the last iteration (not the best one). - If early stopping occurs, the model will have two additional fields: - bst.best_score and bst.best_iteration. + If early stopping occurs, the model will have three additional fields: + bst.best_score, bst.best_iteration and bst.best_ntree_limit. evals_result: dict This dictionary stores the evaluation results of all the items in watchlist. Example: with a watchlist containing [(dtest,'eval'), (dtrain,'train')] and @@ -75,15 +75,24 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None, params += [('eval_metric', eval_metric)] bst = Booster(params, [dtrain] + [d[0] for d in evals]) - ntrees = 0 + nboost = 0 + num_parallel_tree = 1 + if xgb_model is not None: if not isinstance(xgb_model, STRING_TYPES): xgb_model = xgb_model.save_raw() bst = Booster(params, [dtrain] + [d[0] for d in evals], model_file=xgb_model) - ntrees = len(bst.get_dump()) + nboost = len(bst.get_dump()) else: bst = Booster(params, [dtrain] + [d[0] for d in evals]) + _params = dict(params) if isinstance(params, list) else params + if 'num_parallel_tree' in _params: + num_parallel_tree = _params['num_parallel_tree'] + nboost //= num_parallel_tree + if 'num_class' in _params: + nboost //= _params['num_class'] + if evals_result is not None: if not isinstance(evals_result, dict): raise TypeError('evals_result has to be a dictionary') @@ -95,7 +104,7 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None, if not early_stopping_rounds: for i in range(num_boost_round): bst.update(dtrain, i, obj) - ntrees += 1 + nboost += 1 if len(evals) != 0: bst_eval_set = bst.eval_set(evals, i, feval) if isinstance(bst_eval_set, STRING_TYPES): @@ -118,7 +127,8 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None, evals_result[key][res_key].append(res_val) else: evals_result[key][res_key] = [res_val] - bst.best_iteration = (ntrees - 1) + bst.best_iteration = (nboost - 1) + bst.best_ntree_limit = nboost * num_parallel_tree return bst else: @@ -154,7 +164,7 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None, best_score = float('inf') best_msg = '' - best_score_i = ntrees + best_score_i = (nboost - 1) if isinstance(learning_rates, list) and len(learning_rates) != num_boost_round: raise ValueError("Length of list 'learning_rates' has to equal 'num_boost_round'.") @@ -166,7 +176,7 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None, else: bst.set_param({'eta': learning_rates(i, num_boost_round)}) bst.update(dtrain, i, obj) - ntrees += 1 + nboost += 1 bst_eval_set = bst.eval_set(evals, i, feval) if isinstance(bst_eval_set, STRING_TYPES): @@ -195,7 +205,7 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None, if (maximize_score and score > best_score) or \ (not maximize_score and score < best_score): best_score = score - best_score_i = (ntrees - 1) + best_score_i = (nboost - 1) best_msg = msg elif i - best_score_i >= early_stopping_rounds: sys.stderr.write("Stopping. Best iteration:\n{}\n\n".format(best_msg)) @@ -204,6 +214,7 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None, break bst.best_score = best_score bst.best_iteration = best_score_i + bst.best_ntree_limit = (bst.best_iteration + 1) * num_parallel_tree return bst diff --git a/tests/python/test_training_continuation.py b/tests/python/test_training_continuation.py index fec7a6a62..e75ff9d43 100644 --- a/tests/python/test_training_continuation.py +++ b/tests/python/test_training_continuation.py @@ -8,30 +8,37 @@ import unittest rng = np.random.RandomState(1337) -class TestTrainingContinuation(unittest.TestCase): - xgb_params = { - 'colsample_bytree': 0.7, +class TestTrainingContinuation(unittest.TestCase): + num_parallel_tree = 3 + + xgb_params_01 = { 'silent': 1, 'nthread': 1, } + xgb_params_02 = { + 'silent': 1, + 'nthread': 1, + 'num_parallel_tree': num_parallel_tree + } + def test_training_continuation(self): digits = load_digits(2) X = digits['data'] y = digits['target'] - dtrain = xgb.DMatrix(X,label=y) + dtrain = xgb.DMatrix(X, label=y) - gbdt_01 = xgb.train(self.xgb_params, dtrain, num_boost_round=10) + gbdt_01 = xgb.train(self.xgb_params_01, dtrain, num_boost_round=10) ntrees_01 = len(gbdt_01.get_dump()) assert ntrees_01 == 10 - gbdt_02 = xgb.train(self.xgb_params, dtrain, num_boost_round=0) + gbdt_02 = xgb.train(self.xgb_params_01, dtrain, num_boost_round=0) gbdt_02.save_model('xgb_tc.model') - gbdt_02a = xgb.train(self.xgb_params, dtrain, num_boost_round=10, xgb_model=gbdt_02) - gbdt_02b = xgb.train(self.xgb_params, dtrain, num_boost_round=10, xgb_model="xgb_tc.model") + gbdt_02a = xgb.train(self.xgb_params_01, dtrain, num_boost_round=10, xgb_model=gbdt_02) + gbdt_02b = xgb.train(self.xgb_params_01, dtrain, num_boost_round=10, xgb_model="xgb_tc.model") ntrees_02a = len(gbdt_02a.get_dump()) ntrees_02b = len(gbdt_02b.get_dump()) assert ntrees_02a == 10 @@ -39,14 +46,23 @@ class TestTrainingContinuation(unittest.TestCase): assert mean_squared_error(y, gbdt_01.predict(dtrain)) == mean_squared_error(y, gbdt_02a.predict(dtrain)) assert mean_squared_error(y, gbdt_01.predict(dtrain)) == mean_squared_error(y, gbdt_02b.predict(dtrain)) - gbdt_03 = xgb.train(self.xgb_params, dtrain, num_boost_round=3) + gbdt_03 = xgb.train(self.xgb_params_01, dtrain, num_boost_round=3) gbdt_03.save_model('xgb_tc.model') - gbdt_03a = xgb.train(self.xgb_params, dtrain, num_boost_round=7, xgb_model=gbdt_03) - gbdt_03b = xgb.train(self.xgb_params, dtrain, num_boost_round=7, xgb_model="xgb_tc.model") + gbdt_03a = xgb.train(self.xgb_params_01, dtrain, num_boost_round=7, xgb_model=gbdt_03) + gbdt_03b = xgb.train(self.xgb_params_01, dtrain, num_boost_round=7, xgb_model="xgb_tc.model") ntrees_03a = len(gbdt_03a.get_dump()) ntrees_03b = len(gbdt_03b.get_dump()) assert ntrees_03a == 10 assert ntrees_03b == 10 assert mean_squared_error(y, gbdt_03a.predict(dtrain)) == mean_squared_error(y, gbdt_03b.predict(dtrain)) - + + gbdt_04 = xgb.train(self.xgb_params_02, dtrain, num_boost_round=3) + assert gbdt_04.best_ntree_limit == (gbdt_04.best_iteration + 1) * self.num_parallel_tree + assert mean_squared_error(y, gbdt_04.predict(dtrain)) == \ + mean_squared_error(y, gbdt_04.predict(dtrain, ntree_limit=gbdt_04.best_ntree_limit)) + + gbdt_04 = xgb.train(self.xgb_params_02, dtrain, num_boost_round=7, xgb_model=gbdt_04) + assert gbdt_04.best_ntree_limit == (gbdt_04.best_iteration + 1) * self.num_parallel_tree + assert mean_squared_error(y, gbdt_04.predict(dtrain)) == \ + mean_squared_error(y, gbdt_04.predict(dtrain, ntree_limit=gbdt_04.best_ntree_limit)) From 7f2628acd706938cc737c824807db051d8fd3df5 Mon Sep 17 00:00:00 2001 From: Faron Date: Thu, 12 Nov 2015 08:21:19 +0100 Subject: [PATCH 2/2] unittest for 'num_class > 2' added --- tests/python/test_training_continuation.py | 64 +++++++++++++++------- 1 file changed, 44 insertions(+), 20 deletions(-) diff --git a/tests/python/test_training_continuation.py b/tests/python/test_training_continuation.py index e75ff9d43..ac6deca26 100644 --- a/tests/python/test_training_continuation.py +++ b/tests/python/test_training_continuation.py @@ -1,5 +1,6 @@ import xgboost as xgb import numpy as np +from sklearn.preprocessing import MultiLabelBinarizer from sklearn.cross_validation import KFold, train_test_split from sklearn.metrics import mean_squared_error from sklearn.grid_search import GridSearchCV @@ -23,46 +24,69 @@ class TestTrainingContinuation(unittest.TestCase): 'num_parallel_tree': num_parallel_tree } + xgb_params_03 = { + 'silent': 1, + 'nthread': 1, + 'num_class': 5, + 'num_parallel_tree': num_parallel_tree + } + def test_training_continuation(self): - digits = load_digits(2) - X = digits['data'] - y = digits['target'] + digits_2class = load_digits(2) + digits_5class = load_digits(5) - dtrain = xgb.DMatrix(X, label=y) + X_2class = digits_2class['data'] + y_2class = digits_2class['target'] - gbdt_01 = xgb.train(self.xgb_params_01, dtrain, num_boost_round=10) + X_5class = digits_5class['data'] + y_5class = digits_5class['target'] + + dtrain_2class = xgb.DMatrix(X_2class, label=y_2class) + dtrain_5class = xgb.DMatrix(X_5class, label=y_5class) + + gbdt_01 = xgb.train(self.xgb_params_01, dtrain_2class, num_boost_round=10) ntrees_01 = len(gbdt_01.get_dump()) assert ntrees_01 == 10 - gbdt_02 = xgb.train(self.xgb_params_01, dtrain, num_boost_round=0) + gbdt_02 = xgb.train(self.xgb_params_01, dtrain_2class, num_boost_round=0) gbdt_02.save_model('xgb_tc.model') - gbdt_02a = xgb.train(self.xgb_params_01, dtrain, num_boost_round=10, xgb_model=gbdt_02) - gbdt_02b = xgb.train(self.xgb_params_01, dtrain, num_boost_round=10, xgb_model="xgb_tc.model") + gbdt_02a = xgb.train(self.xgb_params_01, dtrain_2class, num_boost_round=10, xgb_model=gbdt_02) + gbdt_02b = xgb.train(self.xgb_params_01, dtrain_2class, num_boost_round=10, xgb_model="xgb_tc.model") ntrees_02a = len(gbdt_02a.get_dump()) ntrees_02b = len(gbdt_02b.get_dump()) assert ntrees_02a == 10 assert ntrees_02b == 10 - assert mean_squared_error(y, gbdt_01.predict(dtrain)) == mean_squared_error(y, gbdt_02a.predict(dtrain)) - assert mean_squared_error(y, gbdt_01.predict(dtrain)) == mean_squared_error(y, gbdt_02b.predict(dtrain)) + assert mean_squared_error(y_2class, gbdt_01.predict(dtrain_2class)) == \ + mean_squared_error(y_2class, gbdt_02a.predict(dtrain_2class)) + assert mean_squared_error(y_2class, gbdt_01.predict(dtrain_2class)) == \ + mean_squared_error(y_2class, gbdt_02b.predict(dtrain_2class)) - gbdt_03 = xgb.train(self.xgb_params_01, dtrain, num_boost_round=3) + gbdt_03 = xgb.train(self.xgb_params_01, dtrain_2class, num_boost_round=3) gbdt_03.save_model('xgb_tc.model') - gbdt_03a = xgb.train(self.xgb_params_01, dtrain, num_boost_round=7, xgb_model=gbdt_03) - gbdt_03b = xgb.train(self.xgb_params_01, dtrain, num_boost_round=7, xgb_model="xgb_tc.model") + gbdt_03a = xgb.train(self.xgb_params_01, dtrain_2class, num_boost_round=7, xgb_model=gbdt_03) + gbdt_03b = xgb.train(self.xgb_params_01, dtrain_2class, num_boost_round=7, xgb_model="xgb_tc.model") ntrees_03a = len(gbdt_03a.get_dump()) ntrees_03b = len(gbdt_03b.get_dump()) assert ntrees_03a == 10 assert ntrees_03b == 10 - assert mean_squared_error(y, gbdt_03a.predict(dtrain)) == mean_squared_error(y, gbdt_03b.predict(dtrain)) + assert mean_squared_error(y_2class, gbdt_03a.predict(dtrain_2class)) == \ + mean_squared_error(y_2class, gbdt_03b.predict(dtrain_2class)) - gbdt_04 = xgb.train(self.xgb_params_02, dtrain, num_boost_round=3) + gbdt_04 = xgb.train(self.xgb_params_02, dtrain_2class, num_boost_round=3) assert gbdt_04.best_ntree_limit == (gbdt_04.best_iteration + 1) * self.num_parallel_tree - assert mean_squared_error(y, gbdt_04.predict(dtrain)) == \ - mean_squared_error(y, gbdt_04.predict(dtrain, ntree_limit=gbdt_04.best_ntree_limit)) + assert mean_squared_error(y_2class, gbdt_04.predict(dtrain_2class)) == \ + mean_squared_error(y_2class, gbdt_04.predict(dtrain_2class, ntree_limit=gbdt_04.best_ntree_limit)) - gbdt_04 = xgb.train(self.xgb_params_02, dtrain, num_boost_round=7, xgb_model=gbdt_04) + gbdt_04 = xgb.train(self.xgb_params_02, dtrain_2class, num_boost_round=7, xgb_model=gbdt_04) assert gbdt_04.best_ntree_limit == (gbdt_04.best_iteration + 1) * self.num_parallel_tree - assert mean_squared_error(y, gbdt_04.predict(dtrain)) == \ - mean_squared_error(y, gbdt_04.predict(dtrain, ntree_limit=gbdt_04.best_ntree_limit)) + assert mean_squared_error(y_2class, gbdt_04.predict(dtrain_2class)) == \ + mean_squared_error(y_2class, gbdt_04.predict(dtrain_2class, ntree_limit=gbdt_04.best_ntree_limit)) + + gbdt_05 = xgb.train(self.xgb_params_03, dtrain_5class, num_boost_round=7) + assert gbdt_05.best_ntree_limit == (gbdt_05.best_iteration + 1) * self.num_parallel_tree + gbdt_05 = xgb.train(self.xgb_params_03, dtrain_5class, num_boost_round=3, xgb_model=gbdt_05) + assert gbdt_05.best_ntree_limit == (gbdt_05.best_iteration + 1) * self.num_parallel_tree + assert np.any(gbdt_05.predict(dtrain_5class) != + gbdt_05.predict(dtrain_5class, ntree_limit=gbdt_05.best_ntree_limit)) == False