Refactor tests for training continuation. (#9997)
This commit is contained in:
58
python-package/xgboost/testing/continuation.py
Normal file
58
python-package/xgboost/testing/continuation.py
Normal file
@@ -0,0 +1,58 @@
|
||||
"""Tests for training continuation."""
|
||||
import json
|
||||
from typing import Any, Dict, TypeVar
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import xgboost as xgb
|
||||
|
||||
|
||||
# pylint: disable=too-many-locals
|
||||
def run_training_continuation_model_output(device: str, tree_method: str) -> None:
|
||||
"""Run training continuation test."""
|
||||
datasets = pytest.importorskip("sklearn.datasets")
|
||||
n_samples = 64
|
||||
n_features = 32
|
||||
X, y = datasets.make_regression(n_samples, n_features, random_state=1)
|
||||
|
||||
dtrain = xgb.DMatrix(X, y)
|
||||
params = {
|
||||
"tree_method": tree_method,
|
||||
"max_depth": "2",
|
||||
"gamma": "0.1",
|
||||
"alpha": "0.01",
|
||||
"device": device,
|
||||
}
|
||||
bst_0 = xgb.train(params, dtrain, num_boost_round=64)
|
||||
dump_0 = bst_0.get_dump(dump_format="json")
|
||||
|
||||
bst_1 = xgb.train(params, dtrain, num_boost_round=32)
|
||||
bst_1 = xgb.train(params, dtrain, num_boost_round=32, xgb_model=bst_1)
|
||||
dump_1 = bst_1.get_dump(dump_format="json")
|
||||
|
||||
T = TypeVar("T", Dict[str, Any], float, str, int, list)
|
||||
|
||||
def recursive_compare(obj_0: T, obj_1: T) -> None:
|
||||
if isinstance(obj_0, float):
|
||||
assert np.isclose(obj_0, obj_1, atol=1e-6)
|
||||
elif isinstance(obj_0, str):
|
||||
assert obj_0 == obj_1
|
||||
elif isinstance(obj_0, int):
|
||||
assert obj_0 == obj_1
|
||||
elif isinstance(obj_0, dict):
|
||||
for i in range(len(obj_0.items())):
|
||||
assert list(obj_0.keys())[i] == list(obj_1.keys())[i]
|
||||
if list(obj_0.keys())[i] != "missing":
|
||||
recursive_compare(list(obj_0.values()), list(obj_1.values()))
|
||||
else:
|
||||
for i, lhs in enumerate(obj_0):
|
||||
rhs = obj_1[i]
|
||||
recursive_compare(lhs, rhs)
|
||||
|
||||
assert len(dump_0) == len(dump_1)
|
||||
|
||||
for i, lhs in enumerate(dump_0):
|
||||
obj_0 = json.loads(lhs)
|
||||
obj_1 = json.loads(dump_1[i])
|
||||
recursive_compare(obj_0, obj_1)
|
||||
Reference in New Issue
Block a user