From 7f2628acd706938cc737c824807db051d8fd3df5 Mon Sep 17 00:00:00 2001 From: Faron Date: Thu, 12 Nov 2015 08:21:19 +0100 Subject: [PATCH] unittest for 'num_class > 2' added --- tests/python/test_training_continuation.py | 64 +++++++++++++++------- 1 file changed, 44 insertions(+), 20 deletions(-) diff --git a/tests/python/test_training_continuation.py b/tests/python/test_training_continuation.py index e75ff9d43..ac6deca26 100644 --- a/tests/python/test_training_continuation.py +++ b/tests/python/test_training_continuation.py @@ -1,5 +1,6 @@ import xgboost as xgb import numpy as np +from sklearn.preprocessing import MultiLabelBinarizer from sklearn.cross_validation import KFold, train_test_split from sklearn.metrics import mean_squared_error from sklearn.grid_search import GridSearchCV @@ -23,46 +24,69 @@ class TestTrainingContinuation(unittest.TestCase): 'num_parallel_tree': num_parallel_tree } + xgb_params_03 = { + 'silent': 1, + 'nthread': 1, + 'num_class': 5, + 'num_parallel_tree': num_parallel_tree + } + def test_training_continuation(self): - digits = load_digits(2) - X = digits['data'] - y = digits['target'] + digits_2class = load_digits(2) + digits_5class = load_digits(5) - dtrain = xgb.DMatrix(X, label=y) + X_2class = digits_2class['data'] + y_2class = digits_2class['target'] - gbdt_01 = xgb.train(self.xgb_params_01, dtrain, num_boost_round=10) + X_5class = digits_5class['data'] + y_5class = digits_5class['target'] + + dtrain_2class = xgb.DMatrix(X_2class, label=y_2class) + dtrain_5class = xgb.DMatrix(X_5class, label=y_5class) + + gbdt_01 = xgb.train(self.xgb_params_01, dtrain_2class, num_boost_round=10) ntrees_01 = len(gbdt_01.get_dump()) assert ntrees_01 == 10 - gbdt_02 = xgb.train(self.xgb_params_01, dtrain, num_boost_round=0) + gbdt_02 = xgb.train(self.xgb_params_01, dtrain_2class, num_boost_round=0) gbdt_02.save_model('xgb_tc.model') - gbdt_02a = xgb.train(self.xgb_params_01, dtrain, num_boost_round=10, xgb_model=gbdt_02) - gbdt_02b = xgb.train(self.xgb_params_01, dtrain, num_boost_round=10, xgb_model="xgb_tc.model") + gbdt_02a = xgb.train(self.xgb_params_01, dtrain_2class, num_boost_round=10, xgb_model=gbdt_02) + gbdt_02b = xgb.train(self.xgb_params_01, dtrain_2class, num_boost_round=10, xgb_model="xgb_tc.model") ntrees_02a = len(gbdt_02a.get_dump()) ntrees_02b = len(gbdt_02b.get_dump()) assert ntrees_02a == 10 assert ntrees_02b == 10 - assert mean_squared_error(y, gbdt_01.predict(dtrain)) == mean_squared_error(y, gbdt_02a.predict(dtrain)) - assert mean_squared_error(y, gbdt_01.predict(dtrain)) == mean_squared_error(y, gbdt_02b.predict(dtrain)) + assert mean_squared_error(y_2class, gbdt_01.predict(dtrain_2class)) == \ + mean_squared_error(y_2class, gbdt_02a.predict(dtrain_2class)) + assert mean_squared_error(y_2class, gbdt_01.predict(dtrain_2class)) == \ + mean_squared_error(y_2class, gbdt_02b.predict(dtrain_2class)) - gbdt_03 = xgb.train(self.xgb_params_01, dtrain, num_boost_round=3) + gbdt_03 = xgb.train(self.xgb_params_01, dtrain_2class, num_boost_round=3) gbdt_03.save_model('xgb_tc.model') - gbdt_03a = xgb.train(self.xgb_params_01, dtrain, num_boost_round=7, xgb_model=gbdt_03) - gbdt_03b = xgb.train(self.xgb_params_01, dtrain, num_boost_round=7, xgb_model="xgb_tc.model") + gbdt_03a = xgb.train(self.xgb_params_01, dtrain_2class, num_boost_round=7, xgb_model=gbdt_03) + gbdt_03b = xgb.train(self.xgb_params_01, dtrain_2class, num_boost_round=7, xgb_model="xgb_tc.model") ntrees_03a = len(gbdt_03a.get_dump()) ntrees_03b = len(gbdt_03b.get_dump()) assert ntrees_03a == 10 assert ntrees_03b == 10 - assert mean_squared_error(y, gbdt_03a.predict(dtrain)) == mean_squared_error(y, gbdt_03b.predict(dtrain)) + assert mean_squared_error(y_2class, gbdt_03a.predict(dtrain_2class)) == \ + mean_squared_error(y_2class, gbdt_03b.predict(dtrain_2class)) - gbdt_04 = xgb.train(self.xgb_params_02, dtrain, num_boost_round=3) + gbdt_04 = xgb.train(self.xgb_params_02, dtrain_2class, num_boost_round=3) assert gbdt_04.best_ntree_limit == (gbdt_04.best_iteration + 1) * self.num_parallel_tree - assert mean_squared_error(y, gbdt_04.predict(dtrain)) == \ - mean_squared_error(y, gbdt_04.predict(dtrain, ntree_limit=gbdt_04.best_ntree_limit)) + assert mean_squared_error(y_2class, gbdt_04.predict(dtrain_2class)) == \ + mean_squared_error(y_2class, gbdt_04.predict(dtrain_2class, ntree_limit=gbdt_04.best_ntree_limit)) - gbdt_04 = xgb.train(self.xgb_params_02, dtrain, num_boost_round=7, xgb_model=gbdt_04) + gbdt_04 = xgb.train(self.xgb_params_02, dtrain_2class, num_boost_round=7, xgb_model=gbdt_04) assert gbdt_04.best_ntree_limit == (gbdt_04.best_iteration + 1) * self.num_parallel_tree - assert mean_squared_error(y, gbdt_04.predict(dtrain)) == \ - mean_squared_error(y, gbdt_04.predict(dtrain, ntree_limit=gbdt_04.best_ntree_limit)) + assert mean_squared_error(y_2class, gbdt_04.predict(dtrain_2class)) == \ + mean_squared_error(y_2class, gbdt_04.predict(dtrain_2class, ntree_limit=gbdt_04.best_ntree_limit)) + + gbdt_05 = xgb.train(self.xgb_params_03, dtrain_5class, num_boost_round=7) + assert gbdt_05.best_ntree_limit == (gbdt_05.best_iteration + 1) * self.num_parallel_tree + gbdt_05 = xgb.train(self.xgb_params_03, dtrain_5class, num_boost_round=3, xgb_model=gbdt_05) + assert gbdt_05.best_ntree_limit == (gbdt_05.best_iteration + 1) * self.num_parallel_tree + assert np.any(gbdt_05.predict(dtrain_5class) != + gbdt_05.predict(dtrain_5class, ntree_limit=gbdt_05.best_ntree_limit)) == False