Use UBJ in Python checkpoint. (#9958)
This commit is contained in:
@@ -244,7 +244,7 @@ class TestCallbacks:
|
||||
assert booster.num_boosted_rounds() == booster.best_iteration + 1
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
path = os.path.join(tmpdir, 'model.json')
|
||||
path = os.path.join(tmpdir, "model.json")
|
||||
cls.save_model(path)
|
||||
cls = xgb.XGBClassifier()
|
||||
cls.load_model(path)
|
||||
@@ -378,7 +378,7 @@ class TestCallbacks:
|
||||
scheduler = xgb.callback.LearningRateScheduler
|
||||
|
||||
dtrain, dtest = tm.load_agaricus(__file__)
|
||||
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
|
||||
watchlist = [(dtest, "eval"), (dtrain, "train")]
|
||||
|
||||
param = {
|
||||
"max_depth": 2,
|
||||
@@ -429,7 +429,7 @@ class TestCallbacks:
|
||||
assert tree_3th_0["split_conditions"] != tree_3th_1["split_conditions"]
|
||||
|
||||
@pytest.mark.parametrize("tree_method", ["hist", "approx", "approx"])
|
||||
def test_eta_decay(self, tree_method):
|
||||
def test_eta_decay(self, tree_method: str) -> None:
|
||||
self.run_eta_decay(tree_method)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -446,7 +446,7 @@ class TestCallbacks:
|
||||
def test_eta_decay_leaf_output(self, tree_method: str, objective: str) -> None:
|
||||
self.run_eta_decay_leaf_output(tree_method, objective)
|
||||
|
||||
def test_check_point(self):
|
||||
def test_check_point(self) -> None:
|
||||
from sklearn.datasets import load_breast_cancer
|
||||
|
||||
X, y = load_breast_cancer(return_X_y=True)
|
||||
@@ -463,7 +463,12 @@ class TestCallbacks:
|
||||
callbacks=[check_point],
|
||||
)
|
||||
for i in range(1, 10):
|
||||
assert os.path.exists(os.path.join(tmpdir, "model_" + str(i) + ".json"))
|
||||
assert os.path.exists(
|
||||
os.path.join(
|
||||
tmpdir,
|
||||
f"model_{i}.{xgb.callback.TrainingCheckPoint.default_format}",
|
||||
)
|
||||
)
|
||||
|
||||
check_point = xgb.callback.TrainingCheckPoint(
|
||||
directory=tmpdir, interval=1, as_pickle=True, name="model"
|
||||
@@ -478,7 +483,7 @@ class TestCallbacks:
|
||||
for i in range(1, 10):
|
||||
assert os.path.exists(os.path.join(tmpdir, "model_" + str(i) + ".pkl"))
|
||||
|
||||
def test_callback_list(self):
|
||||
def test_callback_list(self) -> None:
|
||||
X, y = tm.data.get_california_housing()
|
||||
m = xgb.DMatrix(X, y)
|
||||
callbacks = [xgb.callback.EarlyStopping(rounds=10)]
|
||||
|
||||
Reference in New Issue
Block a user