Refactor tests for training continuation. (#9997)
This commit is contained in:
parent
5062a3ab46
commit
d12cc1090a
58
python-package/xgboost/testing/continuation.py
Normal file
58
python-package/xgboost/testing/continuation.py
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
"""Tests for training continuation."""
|
||||||
|
import json
|
||||||
|
from typing import Any, Dict, TypeVar
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import xgboost as xgb
|
||||||
|
|
||||||
|
|
||||||
|
# pylint: disable=too-many-locals
|
||||||
|
def run_training_continuation_model_output(device: str, tree_method: str) -> None:
|
||||||
|
"""Run training continuation test."""
|
||||||
|
datasets = pytest.importorskip("sklearn.datasets")
|
||||||
|
n_samples = 64
|
||||||
|
n_features = 32
|
||||||
|
X, y = datasets.make_regression(n_samples, n_features, random_state=1)
|
||||||
|
|
||||||
|
dtrain = xgb.DMatrix(X, y)
|
||||||
|
params = {
|
||||||
|
"tree_method": tree_method,
|
||||||
|
"max_depth": "2",
|
||||||
|
"gamma": "0.1",
|
||||||
|
"alpha": "0.01",
|
||||||
|
"device": device,
|
||||||
|
}
|
||||||
|
bst_0 = xgb.train(params, dtrain, num_boost_round=64)
|
||||||
|
dump_0 = bst_0.get_dump(dump_format="json")
|
||||||
|
|
||||||
|
bst_1 = xgb.train(params, dtrain, num_boost_round=32)
|
||||||
|
bst_1 = xgb.train(params, dtrain, num_boost_round=32, xgb_model=bst_1)
|
||||||
|
dump_1 = bst_1.get_dump(dump_format="json")
|
||||||
|
|
||||||
|
T = TypeVar("T", Dict[str, Any], float, str, int, list)
|
||||||
|
|
||||||
|
def recursive_compare(obj_0: T, obj_1: T) -> None:
|
||||||
|
if isinstance(obj_0, float):
|
||||||
|
assert np.isclose(obj_0, obj_1, atol=1e-6)
|
||||||
|
elif isinstance(obj_0, str):
|
||||||
|
assert obj_0 == obj_1
|
||||||
|
elif isinstance(obj_0, int):
|
||||||
|
assert obj_0 == obj_1
|
||||||
|
elif isinstance(obj_0, dict):
|
||||||
|
for i in range(len(obj_0.items())):
|
||||||
|
assert list(obj_0.keys())[i] == list(obj_1.keys())[i]
|
||||||
|
if list(obj_0.keys())[i] != "missing":
|
||||||
|
recursive_compare(list(obj_0.values()), list(obj_1.values()))
|
||||||
|
else:
|
||||||
|
for i, lhs in enumerate(obj_0):
|
||||||
|
rhs = obj_1[i]
|
||||||
|
recursive_compare(lhs, rhs)
|
||||||
|
|
||||||
|
assert len(dump_0) == len(dump_1)
|
||||||
|
|
||||||
|
for i, lhs in enumerate(dump_0):
|
||||||
|
obj_0 = json.loads(lhs)
|
||||||
|
obj_1 = json.loads(dump_1[i])
|
||||||
|
recursive_compare(obj_0, obj_1)
|
||||||
@ -28,6 +28,7 @@ class LintersPaths:
|
|||||||
"tests/python/test_predict.py",
|
"tests/python/test_predict.py",
|
||||||
"tests/python/test_quantile_dmatrix.py",
|
"tests/python/test_quantile_dmatrix.py",
|
||||||
"tests/python/test_tree_regularization.py",
|
"tests/python/test_tree_regularization.py",
|
||||||
|
"tests/python/test_training_continuation.py",
|
||||||
"tests/python/test_shap.py",
|
"tests/python/test_shap.py",
|
||||||
"tests/python/test_model_io.py",
|
"tests/python/test_model_io.py",
|
||||||
"tests/python/test_with_pandas.py",
|
"tests/python/test_with_pandas.py",
|
||||||
@ -91,6 +92,7 @@ class LintersPaths:
|
|||||||
"tests/python/test_multi_target.py",
|
"tests/python/test_multi_target.py",
|
||||||
"tests/python-gpu/test_gpu_data_iterator.py",
|
"tests/python-gpu/test_gpu_data_iterator.py",
|
||||||
"tests/python-gpu/load_pickle.py",
|
"tests/python-gpu/load_pickle.py",
|
||||||
|
"tests/python-gpu/test_gpu_training_continuation.py",
|
||||||
"tests/python/test_model_io.py",
|
"tests/python/test_model_io.py",
|
||||||
"tests/test_distributed/test_with_spark/test_data.py",
|
"tests/test_distributed/test_with_spark/test_data.py",
|
||||||
"tests/test_distributed/test_gpu_with_spark/test_data.py",
|
"tests/test_distributed/test_gpu_with_spark/test_data.py",
|
||||||
|
|||||||
@ -1,54 +1,12 @@
|
|||||||
import json
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
import xgboost as xgb
|
from xgboost.testing.continuation import run_training_continuation_model_output
|
||||||
|
|
||||||
rng = np.random.RandomState(1994)
|
rng = np.random.RandomState(1994)
|
||||||
|
|
||||||
|
|
||||||
class TestGPUTrainingContinuation:
|
class TestGPUTrainingContinuation:
|
||||||
def test_training_continuation(self):
|
@pytest.mark.parametrize("tree_method", ["hist", "approx"])
|
||||||
kRows = 64
|
def test_model_output(self, tree_method: str) -> None:
|
||||||
kCols = 32
|
run_training_continuation_model_output("cuda", tree_method)
|
||||||
X = np.random.randn(kRows, kCols)
|
|
||||||
y = np.random.randn(kRows)
|
|
||||||
dtrain = xgb.DMatrix(X, y)
|
|
||||||
params = {
|
|
||||||
"tree_method": "gpu_hist",
|
|
||||||
"max_depth": "2",
|
|
||||||
"gamma": "0.1",
|
|
||||||
"alpha": "0.01",
|
|
||||||
}
|
|
||||||
bst_0 = xgb.train(params, dtrain, num_boost_round=64)
|
|
||||||
dump_0 = bst_0.get_dump(dump_format="json")
|
|
||||||
|
|
||||||
bst_1 = xgb.train(params, dtrain, num_boost_round=32)
|
|
||||||
bst_1 = xgb.train(params, dtrain, num_boost_round=32, xgb_model=bst_1)
|
|
||||||
dump_1 = bst_1.get_dump(dump_format="json")
|
|
||||||
|
|
||||||
def recursive_compare(obj_0, obj_1):
|
|
||||||
if isinstance(obj_0, float):
|
|
||||||
assert np.isclose(obj_0, obj_1, atol=1e-6)
|
|
||||||
elif isinstance(obj_0, str):
|
|
||||||
assert obj_0 == obj_1
|
|
||||||
elif isinstance(obj_0, int):
|
|
||||||
assert obj_0 == obj_1
|
|
||||||
elif isinstance(obj_0, dict):
|
|
||||||
keys_0 = list(obj_0.keys())
|
|
||||||
keys_1 = list(obj_1.keys())
|
|
||||||
values_0 = list(obj_0.values())
|
|
||||||
values_1 = list(obj_1.values())
|
|
||||||
for i in range(len(obj_0.items())):
|
|
||||||
assert keys_0[i] == keys_1[i]
|
|
||||||
if list(obj_0.keys())[i] != "missing":
|
|
||||||
recursive_compare(values_0[i], values_1[i])
|
|
||||||
else:
|
|
||||||
for i in range(len(obj_0)):
|
|
||||||
recursive_compare(obj_0[i], obj_1[i])
|
|
||||||
|
|
||||||
assert len(dump_0) == len(dump_1)
|
|
||||||
for i in range(len(dump_0)):
|
|
||||||
obj_0 = json.loads(dump_0[i])
|
|
||||||
obj_1 = json.loads(dump_1[i])
|
|
||||||
recursive_compare(obj_0, obj_1)
|
|
||||||
|
|||||||
@ -6,6 +6,7 @@ import pytest
|
|||||||
|
|
||||||
import xgboost as xgb
|
import xgboost as xgb
|
||||||
from xgboost import testing as tm
|
from xgboost import testing as tm
|
||||||
|
from xgboost.testing.continuation import run_training_continuation_model_output
|
||||||
|
|
||||||
rng = np.random.RandomState(1337)
|
rng = np.random.RandomState(1337)
|
||||||
|
|
||||||
@ -15,54 +16,51 @@ class TestTrainingContinuation:
|
|||||||
|
|
||||||
def generate_parameters(self):
|
def generate_parameters(self):
|
||||||
xgb_params_01_binary = {
|
xgb_params_01_binary = {
|
||||||
'nthread': 1,
|
"nthread": 1,
|
||||||
}
|
}
|
||||||
|
|
||||||
xgb_params_02_binary = {
|
xgb_params_02_binary = {
|
||||||
'nthread': 1,
|
"nthread": 1,
|
||||||
'num_parallel_tree': self.num_parallel_tree
|
"num_parallel_tree": self.num_parallel_tree,
|
||||||
}
|
}
|
||||||
|
|
||||||
xgb_params_03_binary = {
|
xgb_params_03_binary = {
|
||||||
'nthread': 1,
|
"nthread": 1,
|
||||||
'num_class': 5,
|
"num_class": 5,
|
||||||
'num_parallel_tree': self.num_parallel_tree
|
"num_parallel_tree": self.num_parallel_tree,
|
||||||
}
|
}
|
||||||
|
|
||||||
return [
|
return [xgb_params_01_binary, xgb_params_02_binary, xgb_params_03_binary]
|
||||||
xgb_params_01_binary, xgb_params_02_binary, xgb_params_03_binary
|
|
||||||
]
|
|
||||||
|
|
||||||
def run_training_continuation(self, xgb_params_01, xgb_params_02,
|
def run_training_continuation(self, xgb_params_01, xgb_params_02, xgb_params_03):
|
||||||
xgb_params_03):
|
|
||||||
from sklearn.datasets import load_digits
|
from sklearn.datasets import load_digits
|
||||||
from sklearn.metrics import mean_squared_error
|
from sklearn.metrics import mean_squared_error
|
||||||
|
|
||||||
digits_2class = load_digits(n_class=2)
|
digits_2class = load_digits(n_class=2)
|
||||||
digits_5class = load_digits(n_class=5)
|
digits_5class = load_digits(n_class=5)
|
||||||
|
|
||||||
X_2class = digits_2class['data']
|
X_2class = digits_2class["data"]
|
||||||
y_2class = digits_2class['target']
|
y_2class = digits_2class["target"]
|
||||||
|
|
||||||
X_5class = digits_5class['data']
|
X_5class = digits_5class["data"]
|
||||||
y_5class = digits_5class['target']
|
y_5class = digits_5class["target"]
|
||||||
|
|
||||||
dtrain_2class = xgb.DMatrix(X_2class, label=y_2class)
|
dtrain_2class = xgb.DMatrix(X_2class, label=y_2class)
|
||||||
dtrain_5class = xgb.DMatrix(X_5class, label=y_5class)
|
dtrain_5class = xgb.DMatrix(X_5class, label=y_5class)
|
||||||
|
|
||||||
gbdt_01 = xgb.train(xgb_params_01, dtrain_2class,
|
gbdt_01 = xgb.train(xgb_params_01, dtrain_2class, num_boost_round=10)
|
||||||
num_boost_round=10)
|
|
||||||
ntrees_01 = len(gbdt_01.get_dump())
|
ntrees_01 = len(gbdt_01.get_dump())
|
||||||
assert ntrees_01 == 10
|
assert ntrees_01 == 10
|
||||||
|
|
||||||
gbdt_02 = xgb.train(xgb_params_01, dtrain_2class,
|
gbdt_02 = xgb.train(xgb_params_01, dtrain_2class, num_boost_round=0)
|
||||||
num_boost_round=0)
|
gbdt_02.save_model("xgb_tc.json")
|
||||||
gbdt_02.save_model('xgb_tc.json')
|
|
||||||
|
|
||||||
gbdt_02a = xgb.train(xgb_params_01, dtrain_2class,
|
gbdt_02a = xgb.train(
|
||||||
num_boost_round=10, xgb_model=gbdt_02)
|
xgb_params_01, dtrain_2class, num_boost_round=10, xgb_model=gbdt_02
|
||||||
gbdt_02b = xgb.train(xgb_params_01, dtrain_2class,
|
)
|
||||||
num_boost_round=10, xgb_model="xgb_tc.json")
|
gbdt_02b = xgb.train(
|
||||||
|
xgb_params_01, dtrain_2class, num_boost_round=10, xgb_model="xgb_tc.json"
|
||||||
|
)
|
||||||
ntrees_02a = len(gbdt_02a.get_dump())
|
ntrees_02a = len(gbdt_02a.get_dump())
|
||||||
ntrees_02b = len(gbdt_02b.get_dump())
|
ntrees_02b = len(gbdt_02b.get_dump())
|
||||||
assert ntrees_02a == 10
|
assert ntrees_02a == 10
|
||||||
@ -76,20 +74,21 @@ class TestTrainingContinuation:
|
|||||||
res2 = mean_squared_error(y_2class, gbdt_02b.predict(dtrain_2class))
|
res2 = mean_squared_error(y_2class, gbdt_02b.predict(dtrain_2class))
|
||||||
assert res1 == res2
|
assert res1 == res2
|
||||||
|
|
||||||
gbdt_03 = xgb.train(xgb_params_01, dtrain_2class,
|
gbdt_03 = xgb.train(xgb_params_01, dtrain_2class, num_boost_round=3)
|
||||||
num_boost_round=3)
|
gbdt_03.save_model("xgb_tc.json")
|
||||||
gbdt_03.save_model('xgb_tc.json')
|
|
||||||
|
|
||||||
gbdt_03a = xgb.train(xgb_params_01, dtrain_2class,
|
gbdt_03a = xgb.train(
|
||||||
num_boost_round=7, xgb_model=gbdt_03)
|
xgb_params_01, dtrain_2class, num_boost_round=7, xgb_model=gbdt_03
|
||||||
gbdt_03b = xgb.train(xgb_params_01, dtrain_2class,
|
)
|
||||||
num_boost_round=7, xgb_model="xgb_tc.json")
|
gbdt_03b = xgb.train(
|
||||||
|
xgb_params_01, dtrain_2class, num_boost_round=7, xgb_model="xgb_tc.json"
|
||||||
|
)
|
||||||
ntrees_03a = len(gbdt_03a.get_dump())
|
ntrees_03a = len(gbdt_03a.get_dump())
|
||||||
ntrees_03b = len(gbdt_03b.get_dump())
|
ntrees_03b = len(gbdt_03b.get_dump())
|
||||||
assert ntrees_03a == 10
|
assert ntrees_03a == 10
|
||||||
assert ntrees_03b == 10
|
assert ntrees_03b == 10
|
||||||
|
|
||||||
os.remove('xgb_tc.json')
|
os.remove("xgb_tc.json")
|
||||||
|
|
||||||
res1 = mean_squared_error(y_2class, gbdt_03a.predict(dtrain_2class))
|
res1 = mean_squared_error(y_2class, gbdt_03a.predict(dtrain_2class))
|
||||||
res2 = mean_squared_error(y_2class, gbdt_03b.predict(dtrain_2class))
|
res2 = mean_squared_error(y_2class, gbdt_03b.predict(dtrain_2class))
|
||||||
@ -113,16 +112,14 @@ class TestTrainingContinuation:
|
|||||||
y_2class,
|
y_2class,
|
||||||
gbdt_04.predict(
|
gbdt_04.predict(
|
||||||
dtrain_2class, iteration_range=(0, gbdt_04.num_boosted_rounds())
|
dtrain_2class, iteration_range=(0, gbdt_04.num_boosted_rounds())
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
assert res1 == res2
|
assert res1 == res2
|
||||||
|
|
||||||
gbdt_05 = xgb.train(xgb_params_03, dtrain_5class,
|
gbdt_05 = xgb.train(xgb_params_03, dtrain_5class, num_boost_round=7)
|
||||||
num_boost_round=7)
|
gbdt_05 = xgb.train(
|
||||||
gbdt_05 = xgb.train(xgb_params_03,
|
xgb_params_03, dtrain_5class, num_boost_round=3, xgb_model=gbdt_05
|
||||||
dtrain_5class,
|
)
|
||||||
num_boost_round=3,
|
|
||||||
xgb_model=gbdt_05)
|
|
||||||
|
|
||||||
res1 = gbdt_05.predict(dtrain_5class)
|
res1 = gbdt_05.predict(dtrain_5class)
|
||||||
res2 = gbdt_05.predict(
|
res2 = gbdt_05.predict(
|
||||||
@ -163,3 +160,7 @@ class TestTrainingContinuation:
|
|||||||
clf.set_params(eval_metric="error")
|
clf.set_params(eval_metric="error")
|
||||||
clf.fit(X, y, eval_set=[(X, y)], xgb_model=loaded)
|
clf.fit(X, y, eval_set=[(X, y)], xgb_model=loaded)
|
||||||
assert tm.non_increasing(clf.evals_result()["validation_0"]["error"])
|
assert tm.non_increasing(clf.evals_result()["validation_0"]["error"])
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("tree_method", ["hist", "approx", "exact"])
|
||||||
|
def test_model_output(self, tree_method: str) -> None:
|
||||||
|
run_training_continuation_model_output("cpu", tree_method)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user