Use UBJ in Python checkpoint. (#9958)

This commit is contained in:
Jiaming Yuan
2024-01-09 03:22:15 +08:00
committed by GitHub
parent fa5e2f6c45
commit b3eb5d0945
7 changed files with 104 additions and 46 deletions

View File

@@ -31,6 +31,8 @@ class LintersPaths:
"tests/python/test_with_pandas.py",
"tests/python-gpu/",
"tests/python-sycl/",
"tests/test_distributed/test_with_dask/",
"tests/test_distributed/test_gpu_with_dask/",
"tests/test_distributed/test_with_spark/",
"tests/test_distributed/test_gpu_with_spark/",
# demo
@@ -91,6 +93,7 @@ class LintersPaths:
# demo
"demo/json-model/json_parser.py",
"demo/guide-python/external_memory.py",
"demo/guide-python/callbacks.py",
"demo/guide-python/cat_in_the_dat.py",
"demo/guide-python/categorical.py",
"demo/guide-python/cat_pipeline.py",

View File

@@ -244,7 +244,7 @@ class TestCallbacks:
assert booster.num_boosted_rounds() == booster.best_iteration + 1
with tempfile.TemporaryDirectory() as tmpdir:
path = os.path.join(tmpdir, 'model.json')
path = os.path.join(tmpdir, "model.json")
cls.save_model(path)
cls = xgb.XGBClassifier()
cls.load_model(path)
@@ -378,7 +378,7 @@ class TestCallbacks:
scheduler = xgb.callback.LearningRateScheduler
dtrain, dtest = tm.load_agaricus(__file__)
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
watchlist = [(dtest, "eval"), (dtrain, "train")]
param = {
"max_depth": 2,
@@ -429,7 +429,7 @@ class TestCallbacks:
assert tree_3th_0["split_conditions"] != tree_3th_1["split_conditions"]
@pytest.mark.parametrize("tree_method", ["hist", "approx", "approx"])
def test_eta_decay(self, tree_method):
def test_eta_decay(self, tree_method: str) -> None:
self.run_eta_decay(tree_method)
@pytest.mark.parametrize(
@@ -446,7 +446,7 @@ class TestCallbacks:
def test_eta_decay_leaf_output(self, tree_method: str, objective: str) -> None:
self.run_eta_decay_leaf_output(tree_method, objective)
def test_check_point(self):
def test_check_point(self) -> None:
from sklearn.datasets import load_breast_cancer
X, y = load_breast_cancer(return_X_y=True)
@@ -463,7 +463,12 @@ class TestCallbacks:
callbacks=[check_point],
)
for i in range(1, 10):
assert os.path.exists(os.path.join(tmpdir, "model_" + str(i) + ".json"))
assert os.path.exists(
os.path.join(
tmpdir,
f"model_{i}.{xgb.callback.TrainingCheckPoint.default_format}",
)
)
check_point = xgb.callback.TrainingCheckPoint(
directory=tmpdir, interval=1, as_pickle=True, name="model"
@@ -478,7 +483,7 @@ class TestCallbacks:
for i in range(1, 10):
assert os.path.exists(os.path.join(tmpdir, "model_" + str(i) + ".pkl"))
def test_callback_list(self):
def test_callback_list(self) -> None:
X, y = tm.data.get_california_housing()
m = xgb.DMatrix(X, y)
callbacks = [xgb.callback.EarlyStopping(rounds=10)]

View File

@@ -1590,7 +1590,7 @@ class TestWithDask:
@given(
params=hist_parameter_strategy,
cache_param=hist_cache_strategy,
dataset=tm.make_dataset_strategy()
dataset=tm.make_dataset_strategy(),
)
@settings(
deadline=None, max_examples=10, suppress_health_check=suppress, print_blob=True
@@ -2250,16 +2250,27 @@ class TestDaskCallbacks:
],
)
for i in range(1, 10):
assert os.path.exists(os.path.join(tmpdir, "model_" + str(i) + ".json"))
assert os.path.exists(
os.path.join(
tmpdir,
f"model_{i}.{xgb.callback.TrainingCheckPoint.default_format}",
)
)
@gen_cluster(client=True, clean_kwargs={"processes": False, "threads": False}, allow_unclosed=True)
@gen_cluster(
client=True,
clean_kwargs={"processes": False, "threads": False},
allow_unclosed=True,
)
async def test_worker_left(c, s, a, b):
async with Worker(s.address):
dx = da.random.random((1000, 10)).rechunk(chunks=(10, None))
dy = da.random.random((1000,)).rechunk(chunks=(10,))
d_train = await xgb.dask.DaskDMatrix(
c, dx, dy,
c,
dx,
dy,
)
await async_poll_for(lambda: len(s.workers) == 2, timeout=5)
with pytest.raises(RuntimeError, match="Missing"):
@@ -2271,12 +2282,19 @@ async def test_worker_left(c, s, a, b):
)
@gen_cluster(client=True, Worker=Nanny, clean_kwargs={"processes": False, "threads": False}, allow_unclosed=True)
@gen_cluster(
client=True,
Worker=Nanny,
clean_kwargs={"processes": False, "threads": False},
allow_unclosed=True,
)
async def test_worker_restarted(c, s, a, b):
dx = da.random.random((1000, 10)).rechunk(chunks=(10, None))
dy = da.random.random((1000,)).rechunk(chunks=(10,))
d_train = await xgb.dask.DaskDMatrix(
c, dx, dy,
c,
dx,
dy,
)
await c.restart_workers([a.worker_address])
with pytest.raises(RuntimeError, match="Missing"):