diff --git a/python-package/xgboost/training.py b/python-package/xgboost/training.py index cd2680e0e..03e24bdba 100644 --- a/python-package/xgboost/training.py +++ b/python-package/xgboost/training.py @@ -10,7 +10,8 @@ import numpy as np from .core import Booster, STRING_TYPES def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None, - early_stopping_rounds=None, evals_result=None, verbose_eval=True, learning_rates=None): + maximize=False, early_stopping_rounds=None, evals_result=None, + verbose_eval=True, learning_rates=None, xgb_model=None): # pylint: disable=too-many-statements,too-many-branches, attribute-defined-outside-init """Train a booster with given parameters. @@ -29,6 +30,8 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None, Customized objective function. feval : function Customized evaluation function. + maximize : bool + Whether to maximize feval. early_stopping_rounds: int Activates early stopping. Validation error needs to decrease at least every round(s) to continue training. @@ -50,13 +53,23 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None, Learning rate for each boosting round (yields learning rate decay). - list l: eta = l[boosting round] - function f: eta = f(boosting round, num_boost_round) + xgb_model : file name of stored xgb model or 'Booster' instance + Xgb model to be loaded before training (allows training continuation). Returns ------- booster : a trained booster model """ evals = list(evals) - bst = Booster(params, [dtrain] + [d[0] for d in evals]) + ntrees = 0 + if xgb_model is not None: + if not isinstance(xgb_model, STRING_TYPES): + xgb_model = xgb_model.save_raw() + bst = Booster(params, [dtrain] + [d[0] for d in evals], model_file=xgb_model) + ntrees = len(bst.get_dump()) + else: + bst = Booster(params, [dtrain] + [d[0] for d in evals]) + if evals_result is not None: if not isinstance(evals_result, dict): @@ -69,6 +82,7 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None, if not early_stopping_rounds: for i in range(num_boost_round): bst.update(dtrain, i, obj) + ntrees += 1 if len(evals) != 0: bst_eval_set = bst.eval_set(evals, i, feval) if isinstance(bst_eval_set, STRING_TYPES): @@ -91,6 +105,7 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None, evals_result[key][res_key].append(res_val) else: evals_result[key][res_key] = [res_val] + bst.best_iteration = (ntrees - 1) return bst else: @@ -115,6 +130,8 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None, maximize_metrics = ('auc', 'map', 'ndcg') if any(params['eval_metric'].startswith(x) for x in maximize_metrics): maximize_score = True + if feval is not None: + maximize_score = maximize if maximize_score: best_score = 0.0 @@ -122,7 +139,7 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None, best_score = float('inf') best_msg = '' - best_score_i = 0 + best_score_i = ntrees if isinstance(learning_rates, list) and len(learning_rates) != num_boost_round: raise ValueError("Length of list 'learning_rates' has to equal 'num_boost_round'.") @@ -134,6 +151,7 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None, else: bst.set_param({'eta': learning_rates(i, num_boost_round)}) bst.update(dtrain, i, obj) + ntrees += 1 bst_eval_set = bst.eval_set(evals, i, feval) if isinstance(bst_eval_set, STRING_TYPES): @@ -162,7 +180,7 @@ def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None, if (maximize_score and score > best_score) or \ (not maximize_score and score < best_score): best_score = score - best_score_i = i + best_score_i = (ntrees - 1) best_msg = msg elif i - best_score_i >= early_stopping_rounds: sys.stderr.write("Stopping. Best iteration:\n{}\n\n".format(best_msg)) diff --git a/tests/python/test_training_continuation.py b/tests/python/test_training_continuation.py new file mode 100644 index 000000000..fec7a6a62 --- /dev/null +++ b/tests/python/test_training_continuation.py @@ -0,0 +1,52 @@ +import xgboost as xgb +import numpy as np +from sklearn.cross_validation import KFold, train_test_split +from sklearn.metrics import mean_squared_error +from sklearn.grid_search import GridSearchCV +from sklearn.datasets import load_iris, load_digits, load_boston +import unittest + +rng = np.random.RandomState(1337) + +class TestTrainingContinuation(unittest.TestCase): + + xgb_params = { + 'colsample_bytree': 0.7, + 'silent': 1, + 'nthread': 1, + } + + def test_training_continuation(self): + digits = load_digits(2) + X = digits['data'] + y = digits['target'] + + dtrain = xgb.DMatrix(X,label=y) + + gbdt_01 = xgb.train(self.xgb_params, dtrain, num_boost_round=10) + ntrees_01 = len(gbdt_01.get_dump()) + assert ntrees_01 == 10 + + gbdt_02 = xgb.train(self.xgb_params, dtrain, num_boost_round=0) + gbdt_02.save_model('xgb_tc.model') + + gbdt_02a = xgb.train(self.xgb_params, dtrain, num_boost_round=10, xgb_model=gbdt_02) + gbdt_02b = xgb.train(self.xgb_params, dtrain, num_boost_round=10, xgb_model="xgb_tc.model") + ntrees_02a = len(gbdt_02a.get_dump()) + ntrees_02b = len(gbdt_02b.get_dump()) + assert ntrees_02a == 10 + assert ntrees_02b == 10 + assert mean_squared_error(y, gbdt_01.predict(dtrain)) == mean_squared_error(y, gbdt_02a.predict(dtrain)) + assert mean_squared_error(y, gbdt_01.predict(dtrain)) == mean_squared_error(y, gbdt_02b.predict(dtrain)) + + gbdt_03 = xgb.train(self.xgb_params, dtrain, num_boost_round=3) + gbdt_03.save_model('xgb_tc.model') + + gbdt_03a = xgb.train(self.xgb_params, dtrain, num_boost_round=7, xgb_model=gbdt_03) + gbdt_03b = xgb.train(self.xgb_params, dtrain, num_boost_round=7, xgb_model="xgb_tc.model") + ntrees_03a = len(gbdt_03a.get_dump()) + ntrees_03b = len(gbdt_03b.get_dump()) + assert ntrees_03a == 10 + assert ntrees_03b == 10 + assert mean_squared_error(y, gbdt_03a.predict(dtrain)) == mean_squared_error(y, gbdt_03b.predict(dtrain)) +