From 8e1adddc2bce874b736ecde72fd540261dbe0e9f Mon Sep 17 00:00:00 2001 From: Far0n Date: Tue, 3 Nov 2015 14:44:17 +0100 Subject: [PATCH] added unittest for training continuation --- tests/python/test_training_continuation.py | 52 ++++++++++++++++++++++ 1 file changed, 52 insertions(+) create mode 100644 tests/python/test_training_continuation.py diff --git a/tests/python/test_training_continuation.py b/tests/python/test_training_continuation.py new file mode 100644 index 000000000..fec7a6a62 --- /dev/null +++ b/tests/python/test_training_continuation.py @@ -0,0 +1,52 @@ +import xgboost as xgb +import numpy as np +from sklearn.cross_validation import KFold, train_test_split +from sklearn.metrics import mean_squared_error +from sklearn.grid_search import GridSearchCV +from sklearn.datasets import load_iris, load_digits, load_boston +import unittest + +rng = np.random.RandomState(1337) + +class TestTrainingContinuation(unittest.TestCase): + + xgb_params = { + 'colsample_bytree': 0.7, + 'silent': 1, + 'nthread': 1, + } + + def test_training_continuation(self): + digits = load_digits(2) + X = digits['data'] + y = digits['target'] + + dtrain = xgb.DMatrix(X,label=y) + + gbdt_01 = xgb.train(self.xgb_params, dtrain, num_boost_round=10) + ntrees_01 = len(gbdt_01.get_dump()) + assert ntrees_01 == 10 + + gbdt_02 = xgb.train(self.xgb_params, dtrain, num_boost_round=0) + gbdt_02.save_model('xgb_tc.model') + + gbdt_02a = xgb.train(self.xgb_params, dtrain, num_boost_round=10, xgb_model=gbdt_02) + gbdt_02b = xgb.train(self.xgb_params, dtrain, num_boost_round=10, xgb_model="xgb_tc.model") + ntrees_02a = len(gbdt_02a.get_dump()) + ntrees_02b = len(gbdt_02b.get_dump()) + assert ntrees_02a == 10 + assert ntrees_02b == 10 + assert mean_squared_error(y, gbdt_01.predict(dtrain)) == mean_squared_error(y, gbdt_02a.predict(dtrain)) + assert mean_squared_error(y, gbdt_01.predict(dtrain)) == mean_squared_error(y, gbdt_02b.predict(dtrain)) + + gbdt_03 = xgb.train(self.xgb_params, dtrain, num_boost_round=3) + gbdt_03.save_model('xgb_tc.model') + + gbdt_03a = xgb.train(self.xgb_params, dtrain, num_boost_round=7, xgb_model=gbdt_03) + gbdt_03b = xgb.train(self.xgb_params, dtrain, num_boost_round=7, xgb_model="xgb_tc.model") + ntrees_03a = len(gbdt_03a.get_dump()) + ntrees_03b = len(gbdt_03b.get_dump()) + assert ntrees_03a == 10 + assert ntrees_03b == 10 + assert mean_squared_error(y, gbdt_03a.predict(dtrain)) == mean_squared_error(y, gbdt_03b.predict(dtrain)) +