Enable flake8
This commit is contained in:
@@ -1,10 +1,7 @@
|
||||
import xgboost as xgb
|
||||
import numpy as np
|
||||
from sklearn.preprocessing import MultiLabelBinarizer
|
||||
from sklearn.cross_validation import KFold, train_test_split
|
||||
from sklearn.metrics import mean_squared_error
|
||||
from sklearn.grid_search import GridSearchCV
|
||||
from sklearn.datasets import load_iris, load_digits, load_boston
|
||||
from sklearn.datasets import load_digits
|
||||
import unittest
|
||||
|
||||
rng = np.random.RandomState(1337)
|
||||
@@ -57,10 +54,14 @@ class TestTrainingContinuation(unittest.TestCase):
|
||||
ntrees_02b = len(gbdt_02b.get_dump())
|
||||
assert ntrees_02a == 10
|
||||
assert ntrees_02b == 10
|
||||
assert mean_squared_error(y_2class, gbdt_01.predict(dtrain_2class)) == \
|
||||
mean_squared_error(y_2class, gbdt_02a.predict(dtrain_2class))
|
||||
assert mean_squared_error(y_2class, gbdt_01.predict(dtrain_2class)) == \
|
||||
mean_squared_error(y_2class, gbdt_02b.predict(dtrain_2class))
|
||||
|
||||
res1 = mean_squared_error(y_2class, gbdt_01.predict(dtrain_2class))
|
||||
res2 = mean_squared_error(y_2class, gbdt_02a.predict(dtrain_2class))
|
||||
assert res1 == res2
|
||||
|
||||
res1 = mean_squared_error(y_2class, gbdt_01.predict(dtrain_2class))
|
||||
res2 = mean_squared_error(y_2class, gbdt_02b.predict(dtrain_2class))
|
||||
assert res1 == res2
|
||||
|
||||
gbdt_03 = xgb.train(self.xgb_params_01, dtrain_2class, num_boost_round=3)
|
||||
gbdt_03.save_model('xgb_tc.model')
|
||||
@@ -71,22 +72,30 @@ class TestTrainingContinuation(unittest.TestCase):
|
||||
ntrees_03b = len(gbdt_03b.get_dump())
|
||||
assert ntrees_03a == 10
|
||||
assert ntrees_03b == 10
|
||||
assert mean_squared_error(y_2class, gbdt_03a.predict(dtrain_2class)) == \
|
||||
mean_squared_error(y_2class, gbdt_03b.predict(dtrain_2class))
|
||||
|
||||
res1 = mean_squared_error(y_2class, gbdt_03a.predict(dtrain_2class))
|
||||
res2 = mean_squared_error(y_2class, gbdt_03b.predict(dtrain_2class))
|
||||
assert res1 == res2
|
||||
|
||||
gbdt_04 = xgb.train(self.xgb_params_02, dtrain_2class, num_boost_round=3)
|
||||
assert gbdt_04.best_ntree_limit == (gbdt_04.best_iteration + 1) * self.num_parallel_tree
|
||||
assert mean_squared_error(y_2class, gbdt_04.predict(dtrain_2class)) == \
|
||||
mean_squared_error(y_2class, gbdt_04.predict(dtrain_2class, ntree_limit=gbdt_04.best_ntree_limit))
|
||||
|
||||
res1 = mean_squared_error(y_2class, gbdt_04.predict(dtrain_2class))
|
||||
res2 = mean_squared_error(y_2class, gbdt_04.predict(dtrain_2class, ntree_limit=gbdt_04.best_ntree_limit))
|
||||
assert res1 == res2
|
||||
|
||||
gbdt_04 = xgb.train(self.xgb_params_02, dtrain_2class, num_boost_round=7, xgb_model=gbdt_04)
|
||||
assert gbdt_04.best_ntree_limit == (gbdt_04.best_iteration + 1) * self.num_parallel_tree
|
||||
assert mean_squared_error(y_2class, gbdt_04.predict(dtrain_2class)) == \
|
||||
mean_squared_error(y_2class, gbdt_04.predict(dtrain_2class, ntree_limit=gbdt_04.best_ntree_limit))
|
||||
|
||||
res1 = mean_squared_error(y_2class, gbdt_04.predict(dtrain_2class))
|
||||
res2 = mean_squared_error(y_2class, gbdt_04.predict(dtrain_2class, ntree_limit=gbdt_04.best_ntree_limit))
|
||||
assert res1 == res2
|
||||
|
||||
gbdt_05 = xgb.train(self.xgb_params_03, dtrain_5class, num_boost_round=7)
|
||||
assert gbdt_05.best_ntree_limit == (gbdt_05.best_iteration + 1) * self.num_parallel_tree
|
||||
gbdt_05 = xgb.train(self.xgb_params_03, dtrain_5class, num_boost_round=3, xgb_model=gbdt_05)
|
||||
assert gbdt_05.best_ntree_limit == (gbdt_05.best_iteration + 1) * self.num_parallel_tree
|
||||
assert np.any(gbdt_05.predict(dtrain_5class) !=
|
||||
gbdt_05.predict(dtrain_5class, ntree_limit=gbdt_05.best_ntree_limit)) == False
|
||||
|
||||
res1 = gbdt_05.predict(dtrain_5class)
|
||||
res2 = gbdt_05.predict(dtrain_5class, ntree_limit=gbdt_05.best_ntree_limit)
|
||||
np.testing.assert_almost_equal(res1, res2)
|
||||
|
||||
Reference in New Issue
Block a user