unittest for 'num_class > 2' added
This commit is contained in:
parent
ce5930c365
commit
7f2628acd7
@ -1,5 +1,6 @@
|
|||||||
import xgboost as xgb
|
import xgboost as xgb
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from sklearn.preprocessing import MultiLabelBinarizer
|
||||||
from sklearn.cross_validation import KFold, train_test_split
|
from sklearn.cross_validation import KFold, train_test_split
|
||||||
from sklearn.metrics import mean_squared_error
|
from sklearn.metrics import mean_squared_error
|
||||||
from sklearn.grid_search import GridSearchCV
|
from sklearn.grid_search import GridSearchCV
|
||||||
@ -23,46 +24,69 @@ class TestTrainingContinuation(unittest.TestCase):
|
|||||||
'num_parallel_tree': num_parallel_tree
|
'num_parallel_tree': num_parallel_tree
|
||||||
}
|
}
|
||||||
|
|
||||||
|
xgb_params_03 = {
|
||||||
|
'silent': 1,
|
||||||
|
'nthread': 1,
|
||||||
|
'num_class': 5,
|
||||||
|
'num_parallel_tree': num_parallel_tree
|
||||||
|
}
|
||||||
|
|
||||||
def test_training_continuation(self):
|
def test_training_continuation(self):
|
||||||
digits = load_digits(2)
|
digits_2class = load_digits(2)
|
||||||
X = digits['data']
|
digits_5class = load_digits(5)
|
||||||
y = digits['target']
|
|
||||||
|
|
||||||
dtrain = xgb.DMatrix(X, label=y)
|
X_2class = digits_2class['data']
|
||||||
|
y_2class = digits_2class['target']
|
||||||
|
|
||||||
gbdt_01 = xgb.train(self.xgb_params_01, dtrain, num_boost_round=10)
|
X_5class = digits_5class['data']
|
||||||
|
y_5class = digits_5class['target']
|
||||||
|
|
||||||
|
dtrain_2class = xgb.DMatrix(X_2class, label=y_2class)
|
||||||
|
dtrain_5class = xgb.DMatrix(X_5class, label=y_5class)
|
||||||
|
|
||||||
|
gbdt_01 = xgb.train(self.xgb_params_01, dtrain_2class, num_boost_round=10)
|
||||||
ntrees_01 = len(gbdt_01.get_dump())
|
ntrees_01 = len(gbdt_01.get_dump())
|
||||||
assert ntrees_01 == 10
|
assert ntrees_01 == 10
|
||||||
|
|
||||||
gbdt_02 = xgb.train(self.xgb_params_01, dtrain, num_boost_round=0)
|
gbdt_02 = xgb.train(self.xgb_params_01, dtrain_2class, num_boost_round=0)
|
||||||
gbdt_02.save_model('xgb_tc.model')
|
gbdt_02.save_model('xgb_tc.model')
|
||||||
|
|
||||||
gbdt_02a = xgb.train(self.xgb_params_01, dtrain, num_boost_round=10, xgb_model=gbdt_02)
|
gbdt_02a = xgb.train(self.xgb_params_01, dtrain_2class, num_boost_round=10, xgb_model=gbdt_02)
|
||||||
gbdt_02b = xgb.train(self.xgb_params_01, dtrain, num_boost_round=10, xgb_model="xgb_tc.model")
|
gbdt_02b = xgb.train(self.xgb_params_01, dtrain_2class, num_boost_round=10, xgb_model="xgb_tc.model")
|
||||||
ntrees_02a = len(gbdt_02a.get_dump())
|
ntrees_02a = len(gbdt_02a.get_dump())
|
||||||
ntrees_02b = len(gbdt_02b.get_dump())
|
ntrees_02b = len(gbdt_02b.get_dump())
|
||||||
assert ntrees_02a == 10
|
assert ntrees_02a == 10
|
||||||
assert ntrees_02b == 10
|
assert ntrees_02b == 10
|
||||||
assert mean_squared_error(y, gbdt_01.predict(dtrain)) == mean_squared_error(y, gbdt_02a.predict(dtrain))
|
assert mean_squared_error(y_2class, gbdt_01.predict(dtrain_2class)) == \
|
||||||
assert mean_squared_error(y, gbdt_01.predict(dtrain)) == mean_squared_error(y, gbdt_02b.predict(dtrain))
|
mean_squared_error(y_2class, gbdt_02a.predict(dtrain_2class))
|
||||||
|
assert mean_squared_error(y_2class, gbdt_01.predict(dtrain_2class)) == \
|
||||||
|
mean_squared_error(y_2class, gbdt_02b.predict(dtrain_2class))
|
||||||
|
|
||||||
gbdt_03 = xgb.train(self.xgb_params_01, dtrain, num_boost_round=3)
|
gbdt_03 = xgb.train(self.xgb_params_01, dtrain_2class, num_boost_round=3)
|
||||||
gbdt_03.save_model('xgb_tc.model')
|
gbdt_03.save_model('xgb_tc.model')
|
||||||
|
|
||||||
gbdt_03a = xgb.train(self.xgb_params_01, dtrain, num_boost_round=7, xgb_model=gbdt_03)
|
gbdt_03a = xgb.train(self.xgb_params_01, dtrain_2class, num_boost_round=7, xgb_model=gbdt_03)
|
||||||
gbdt_03b = xgb.train(self.xgb_params_01, dtrain, num_boost_round=7, xgb_model="xgb_tc.model")
|
gbdt_03b = xgb.train(self.xgb_params_01, dtrain_2class, num_boost_round=7, xgb_model="xgb_tc.model")
|
||||||
ntrees_03a = len(gbdt_03a.get_dump())
|
ntrees_03a = len(gbdt_03a.get_dump())
|
||||||
ntrees_03b = len(gbdt_03b.get_dump())
|
ntrees_03b = len(gbdt_03b.get_dump())
|
||||||
assert ntrees_03a == 10
|
assert ntrees_03a == 10
|
||||||
assert ntrees_03b == 10
|
assert ntrees_03b == 10
|
||||||
assert mean_squared_error(y, gbdt_03a.predict(dtrain)) == mean_squared_error(y, gbdt_03b.predict(dtrain))
|
assert mean_squared_error(y_2class, gbdt_03a.predict(dtrain_2class)) == \
|
||||||
|
mean_squared_error(y_2class, gbdt_03b.predict(dtrain_2class))
|
||||||
|
|
||||||
gbdt_04 = xgb.train(self.xgb_params_02, dtrain, num_boost_round=3)
|
gbdt_04 = xgb.train(self.xgb_params_02, dtrain_2class, num_boost_round=3)
|
||||||
assert gbdt_04.best_ntree_limit == (gbdt_04.best_iteration + 1) * self.num_parallel_tree
|
assert gbdt_04.best_ntree_limit == (gbdt_04.best_iteration + 1) * self.num_parallel_tree
|
||||||
assert mean_squared_error(y, gbdt_04.predict(dtrain)) == \
|
assert mean_squared_error(y_2class, gbdt_04.predict(dtrain_2class)) == \
|
||||||
mean_squared_error(y, gbdt_04.predict(dtrain, ntree_limit=gbdt_04.best_ntree_limit))
|
mean_squared_error(y_2class, gbdt_04.predict(dtrain_2class, ntree_limit=gbdt_04.best_ntree_limit))
|
||||||
|
|
||||||
gbdt_04 = xgb.train(self.xgb_params_02, dtrain, num_boost_round=7, xgb_model=gbdt_04)
|
gbdt_04 = xgb.train(self.xgb_params_02, dtrain_2class, num_boost_round=7, xgb_model=gbdt_04)
|
||||||
assert gbdt_04.best_ntree_limit == (gbdt_04.best_iteration + 1) * self.num_parallel_tree
|
assert gbdt_04.best_ntree_limit == (gbdt_04.best_iteration + 1) * self.num_parallel_tree
|
||||||
assert mean_squared_error(y, gbdt_04.predict(dtrain)) == \
|
assert mean_squared_error(y_2class, gbdt_04.predict(dtrain_2class)) == \
|
||||||
mean_squared_error(y, gbdt_04.predict(dtrain, ntree_limit=gbdt_04.best_ntree_limit))
|
mean_squared_error(y_2class, gbdt_04.predict(dtrain_2class, ntree_limit=gbdt_04.best_ntree_limit))
|
||||||
|
|
||||||
|
gbdt_05 = xgb.train(self.xgb_params_03, dtrain_5class, num_boost_round=7)
|
||||||
|
assert gbdt_05.best_ntree_limit == (gbdt_05.best_iteration + 1) * self.num_parallel_tree
|
||||||
|
gbdt_05 = xgb.train(self.xgb_params_03, dtrain_5class, num_boost_round=3, xgb_model=gbdt_05)
|
||||||
|
assert gbdt_05.best_ntree_limit == (gbdt_05.best_iteration + 1) * self.num_parallel_tree
|
||||||
|
assert np.any(gbdt_05.predict(dtrain_5class) !=
|
||||||
|
gbdt_05.predict(dtrain_5class, ntree_limit=gbdt_05.best_ntree_limit)) == False
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user