Use UBJ in Python checkpoint. (#9958)

This commit is contained in:
Jiaming Yuan
2024-01-09 03:22:15 +08:00
committed by GitHub
parent fa5e2f6c45
commit b3eb5d0945
7 changed files with 104 additions and 46 deletions

View File

@@ -244,7 +244,7 @@ class TestCallbacks:
assert booster.num_boosted_rounds() == booster.best_iteration + 1
with tempfile.TemporaryDirectory() as tmpdir:
path = os.path.join(tmpdir, 'model.json')
path = os.path.join(tmpdir, "model.json")
cls.save_model(path)
cls = xgb.XGBClassifier()
cls.load_model(path)
@@ -378,7 +378,7 @@ class TestCallbacks:
scheduler = xgb.callback.LearningRateScheduler
dtrain, dtest = tm.load_agaricus(__file__)
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
watchlist = [(dtest, "eval"), (dtrain, "train")]
param = {
"max_depth": 2,
@@ -429,7 +429,7 @@ class TestCallbacks:
assert tree_3th_0["split_conditions"] != tree_3th_1["split_conditions"]
@pytest.mark.parametrize("tree_method", ["hist", "approx", "approx"])
def test_eta_decay(self, tree_method):
def test_eta_decay(self, tree_method: str) -> None:
self.run_eta_decay(tree_method)
@pytest.mark.parametrize(
@@ -446,7 +446,7 @@ class TestCallbacks:
def test_eta_decay_leaf_output(self, tree_method: str, objective: str) -> None:
self.run_eta_decay_leaf_output(tree_method, objective)
def test_check_point(self):
def test_check_point(self) -> None:
from sklearn.datasets import load_breast_cancer
X, y = load_breast_cancer(return_X_y=True)
@@ -463,7 +463,12 @@ class TestCallbacks:
callbacks=[check_point],
)
for i in range(1, 10):
assert os.path.exists(os.path.join(tmpdir, "model_" + str(i) + ".json"))
assert os.path.exists(
os.path.join(
tmpdir,
f"model_{i}.{xgb.callback.TrainingCheckPoint.default_format}",
)
)
check_point = xgb.callback.TrainingCheckPoint(
directory=tmpdir, interval=1, as_pickle=True, name="model"
@@ -478,7 +483,7 @@ class TestCallbacks:
for i in range(1, 10):
assert os.path.exists(os.path.join(tmpdir, "model_" + str(i) + ".pkl"))
def test_callback_list(self):
def test_callback_list(self) -> None:
X, y = tm.data.get_california_housing()
m = xgb.DMatrix(X, y)
callbacks = [xgb.callback.EarlyStopping(rounds=10)]