Use UBJ in Python checkpoint. (#9958)

This commit is contained in:
Jiaming Yuan
2024-01-09 03:22:15 +08:00
committed by GitHub
parent fa5e2f6c45
commit b3eb5d0945
7 changed files with 104 additions and 46 deletions

View File

@@ -1590,7 +1590,7 @@ class TestWithDask:
@given(
params=hist_parameter_strategy,
cache_param=hist_cache_strategy,
dataset=tm.make_dataset_strategy()
dataset=tm.make_dataset_strategy(),
)
@settings(
deadline=None, max_examples=10, suppress_health_check=suppress, print_blob=True
@@ -2250,16 +2250,27 @@ class TestDaskCallbacks:
],
)
for i in range(1, 10):
assert os.path.exists(os.path.join(tmpdir, "model_" + str(i) + ".json"))
assert os.path.exists(
os.path.join(
tmpdir,
f"model_{i}.{xgb.callback.TrainingCheckPoint.default_format}",
)
)
@gen_cluster(client=True, clean_kwargs={"processes": False, "threads": False}, allow_unclosed=True)
@gen_cluster(
client=True,
clean_kwargs={"processes": False, "threads": False},
allow_unclosed=True,
)
async def test_worker_left(c, s, a, b):
async with Worker(s.address):
dx = da.random.random((1000, 10)).rechunk(chunks=(10, None))
dy = da.random.random((1000,)).rechunk(chunks=(10,))
d_train = await xgb.dask.DaskDMatrix(
c, dx, dy,
c,
dx,
dy,
)
await async_poll_for(lambda: len(s.workers) == 2, timeout=5)
with pytest.raises(RuntimeError, match="Missing"):
@@ -2271,12 +2282,19 @@ async def test_worker_left(c, s, a, b):
)
@gen_cluster(client=True, Worker=Nanny, clean_kwargs={"processes": False, "threads": False}, allow_unclosed=True)
@gen_cluster(
client=True,
Worker=Nanny,
clean_kwargs={"processes": False, "threads": False},
allow_unclosed=True,
)
async def test_worker_restarted(c, s, a, b):
dx = da.random.random((1000, 10)).rechunk(chunks=(10, None))
dy = da.random.random((1000,)).rechunk(chunks=(10,))
d_train = await xgb.dask.DaskDMatrix(
c, dx, dy,
c,
dx,
dy,
)
await c.restart_workers([a.worker_address])
with pytest.raises(RuntimeError, match="Missing"):