Raise if expected workers are not alive in xgboost.dask.train (#9421)

This commit is contained in:
Hendrik Makait
2023-08-03 14:14:07 +02:00
committed by GitHub
parent 7129988847
commit f958e32683
2 changed files with 48 additions and 3 deletions

View File

@@ -36,7 +36,8 @@ pytestmark = [tm.timeout(60), pytest.mark.skipif(**tm.no_dask())]
import dask
import dask.array as da
import dask.dataframe as dd
from distributed import Client, LocalCluster
from distributed import Client, LocalCluster, Nanny, Worker
from distributed.utils_test import async_poll_for, gen_cluster
from toolz import sliding_window # dependency of dask
from xgboost.dask import DaskDMatrix
@@ -2226,3 +2227,38 @@ class TestDaskCallbacks:
)
for i in range(1, 10):
assert os.path.exists(os.path.join(tmpdir, "model_" + str(i) + ".json"))
@gen_cluster(client=True, clean_kwargs={"processes": False, "threads": False}, allow_unclosed=True)
async def test_worker_left(c, s, a, b):
async with Worker(s.address):
dx = da.random.random((1000, 10)).rechunk(chunks=(10, None))
dy = da.random.random((1000,)).rechunk(chunks=(10,))
d_train = await xgb.dask.DaskDMatrix(
c, dx, dy,
)
await async_poll_for(lambda: len(s.workers) == 2, timeout=5)
with pytest.raises(RuntimeError, match="Missing"):
await xgb.dask.train(
c,
{},
d_train,
evals=[(d_train, "train")],
)
@gen_cluster(client=True, Worker=Nanny, clean_kwargs={"processes": False, "threads": False}, allow_unclosed=True)
async def test_worker_restarted(c, s, a, b):
dx = da.random.random((1000, 10)).rechunk(chunks=(10, None))
dy = da.random.random((1000,)).rechunk(chunks=(10,))
d_train = await xgb.dask.DaskDMatrix(
c, dx, dy,
)
await c.restart_workers([a.worker_address])
with pytest.raises(RuntimeError, match="Missing"):
await xgb.dask.train(
c,
{},
d_train,
evals=[(d_train, "train")],
)