Refactor tests for training continuation. (#9997)
This commit is contained in:
@@ -6,6 +6,7 @@ import pytest
|
||||
|
||||
import xgboost as xgb
|
||||
from xgboost import testing as tm
|
||||
from xgboost.testing.continuation import run_training_continuation_model_output
|
||||
|
||||
rng = np.random.RandomState(1337)
|
||||
|
||||
@@ -15,54 +16,51 @@ class TestTrainingContinuation:
|
||||
|
||||
def generate_parameters(self):
|
||||
xgb_params_01_binary = {
|
||||
'nthread': 1,
|
||||
"nthread": 1,
|
||||
}
|
||||
|
||||
xgb_params_02_binary = {
|
||||
'nthread': 1,
|
||||
'num_parallel_tree': self.num_parallel_tree
|
||||
"nthread": 1,
|
||||
"num_parallel_tree": self.num_parallel_tree,
|
||||
}
|
||||
|
||||
xgb_params_03_binary = {
|
||||
'nthread': 1,
|
||||
'num_class': 5,
|
||||
'num_parallel_tree': self.num_parallel_tree
|
||||
"nthread": 1,
|
||||
"num_class": 5,
|
||||
"num_parallel_tree": self.num_parallel_tree,
|
||||
}
|
||||
|
||||
return [
|
||||
xgb_params_01_binary, xgb_params_02_binary, xgb_params_03_binary
|
||||
]
|
||||
return [xgb_params_01_binary, xgb_params_02_binary, xgb_params_03_binary]
|
||||
|
||||
def run_training_continuation(self, xgb_params_01, xgb_params_02,
|
||||
xgb_params_03):
|
||||
def run_training_continuation(self, xgb_params_01, xgb_params_02, xgb_params_03):
|
||||
from sklearn.datasets import load_digits
|
||||
from sklearn.metrics import mean_squared_error
|
||||
|
||||
digits_2class = load_digits(n_class=2)
|
||||
digits_5class = load_digits(n_class=5)
|
||||
|
||||
X_2class = digits_2class['data']
|
||||
y_2class = digits_2class['target']
|
||||
X_2class = digits_2class["data"]
|
||||
y_2class = digits_2class["target"]
|
||||
|
||||
X_5class = digits_5class['data']
|
||||
y_5class = digits_5class['target']
|
||||
X_5class = digits_5class["data"]
|
||||
y_5class = digits_5class["target"]
|
||||
|
||||
dtrain_2class = xgb.DMatrix(X_2class, label=y_2class)
|
||||
dtrain_5class = xgb.DMatrix(X_5class, label=y_5class)
|
||||
|
||||
gbdt_01 = xgb.train(xgb_params_01, dtrain_2class,
|
||||
num_boost_round=10)
|
||||
gbdt_01 = xgb.train(xgb_params_01, dtrain_2class, num_boost_round=10)
|
||||
ntrees_01 = len(gbdt_01.get_dump())
|
||||
assert ntrees_01 == 10
|
||||
|
||||
gbdt_02 = xgb.train(xgb_params_01, dtrain_2class,
|
||||
num_boost_round=0)
|
||||
gbdt_02.save_model('xgb_tc.json')
|
||||
gbdt_02 = xgb.train(xgb_params_01, dtrain_2class, num_boost_round=0)
|
||||
gbdt_02.save_model("xgb_tc.json")
|
||||
|
||||
gbdt_02a = xgb.train(xgb_params_01, dtrain_2class,
|
||||
num_boost_round=10, xgb_model=gbdt_02)
|
||||
gbdt_02b = xgb.train(xgb_params_01, dtrain_2class,
|
||||
num_boost_round=10, xgb_model="xgb_tc.json")
|
||||
gbdt_02a = xgb.train(
|
||||
xgb_params_01, dtrain_2class, num_boost_round=10, xgb_model=gbdt_02
|
||||
)
|
||||
gbdt_02b = xgb.train(
|
||||
xgb_params_01, dtrain_2class, num_boost_round=10, xgb_model="xgb_tc.json"
|
||||
)
|
||||
ntrees_02a = len(gbdt_02a.get_dump())
|
||||
ntrees_02b = len(gbdt_02b.get_dump())
|
||||
assert ntrees_02a == 10
|
||||
@@ -76,20 +74,21 @@ class TestTrainingContinuation:
|
||||
res2 = mean_squared_error(y_2class, gbdt_02b.predict(dtrain_2class))
|
||||
assert res1 == res2
|
||||
|
||||
gbdt_03 = xgb.train(xgb_params_01, dtrain_2class,
|
||||
num_boost_round=3)
|
||||
gbdt_03.save_model('xgb_tc.json')
|
||||
gbdt_03 = xgb.train(xgb_params_01, dtrain_2class, num_boost_round=3)
|
||||
gbdt_03.save_model("xgb_tc.json")
|
||||
|
||||
gbdt_03a = xgb.train(xgb_params_01, dtrain_2class,
|
||||
num_boost_round=7, xgb_model=gbdt_03)
|
||||
gbdt_03b = xgb.train(xgb_params_01, dtrain_2class,
|
||||
num_boost_round=7, xgb_model="xgb_tc.json")
|
||||
gbdt_03a = xgb.train(
|
||||
xgb_params_01, dtrain_2class, num_boost_round=7, xgb_model=gbdt_03
|
||||
)
|
||||
gbdt_03b = xgb.train(
|
||||
xgb_params_01, dtrain_2class, num_boost_round=7, xgb_model="xgb_tc.json"
|
||||
)
|
||||
ntrees_03a = len(gbdt_03a.get_dump())
|
||||
ntrees_03b = len(gbdt_03b.get_dump())
|
||||
assert ntrees_03a == 10
|
||||
assert ntrees_03b == 10
|
||||
|
||||
os.remove('xgb_tc.json')
|
||||
os.remove("xgb_tc.json")
|
||||
|
||||
res1 = mean_squared_error(y_2class, gbdt_03a.predict(dtrain_2class))
|
||||
res2 = mean_squared_error(y_2class, gbdt_03b.predict(dtrain_2class))
|
||||
@@ -113,16 +112,14 @@ class TestTrainingContinuation:
|
||||
y_2class,
|
||||
gbdt_04.predict(
|
||||
dtrain_2class, iteration_range=(0, gbdt_04.num_boosted_rounds())
|
||||
)
|
||||
),
|
||||
)
|
||||
assert res1 == res2
|
||||
|
||||
gbdt_05 = xgb.train(xgb_params_03, dtrain_5class,
|
||||
num_boost_round=7)
|
||||
gbdt_05 = xgb.train(xgb_params_03,
|
||||
dtrain_5class,
|
||||
num_boost_round=3,
|
||||
xgb_model=gbdt_05)
|
||||
gbdt_05 = xgb.train(xgb_params_03, dtrain_5class, num_boost_round=7)
|
||||
gbdt_05 = xgb.train(
|
||||
xgb_params_03, dtrain_5class, num_boost_round=3, xgb_model=gbdt_05
|
||||
)
|
||||
|
||||
res1 = gbdt_05.predict(dtrain_5class)
|
||||
res2 = gbdt_05.predict(
|
||||
@@ -163,3 +160,7 @@ class TestTrainingContinuation:
|
||||
clf.set_params(eval_metric="error")
|
||||
clf.fit(X, y, eval_set=[(X, y)], xgb_model=loaded)
|
||||
assert tm.non_increasing(clf.evals_result()["validation_0"]["error"])
|
||||
|
||||
@pytest.mark.parametrize("tree_method", ["hist", "approx", "exact"])
|
||||
def test_model_output(self, tree_method: str) -> None:
|
||||
run_training_continuation_model_output("cpu", tree_method)
|
||||
|
||||
Reference in New Issue
Block a user