From f958e326832c9acc20f4c9548132f036cf1785af Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Thu, 3 Aug 2023 14:14:07 +0200 Subject: [PATCH] Raise if expected workers are not alive in `xgboost.dask.train` (#9421) --- python-package/xgboost/dask.py | 13 ++++++- .../test_with_dask/test_with_dask.py | 38 ++++++++++++++++++- 2 files changed, 48 insertions(+), 3 deletions(-) diff --git a/python-package/xgboost/dask.py b/python-package/xgboost/dask.py index 271a5e458..219ad2698 100644 --- a/python-package/xgboost/dask.py +++ b/python-package/xgboost/dask.py @@ -850,8 +850,6 @@ async def _get_rabit_args( except Exception: # pylint: disable=broad-except sched_addr = None - # make sure all workers are online so that we can obtain reliable scheduler_info - await client.wait_for_workers(n_workers) # type: ignore env = await client.run_on_scheduler( _start_tracker, n_workers, sched_addr, user_addr ) @@ -907,6 +905,16 @@ def _filter_empty( raise ValueError("None of the workers can provide a valid result.") +async def _check_workers_are_alive( + workers: List[str], client: "distributed.Client" +) -> None: + info = await client.scheduler.identity() + current_workers = info["workers"].keys() + missing_workers = set(workers) - current_workers + if missing_workers: + raise RuntimeError(f"Missing required workers: {missing_workers}") + + async def _train_async( client: "distributed.Client", global_config: Dict[str, Any], @@ -924,6 +932,7 @@ async def _train_async( custom_metric: Optional[Metric], ) -> Optional[TrainReturnT]: workers = _get_workers_from_data(dtrain, evals) + await _check_workers_are_alive(workers, client) _rabit_args = await _get_rabit_args(len(workers), dconfig, client) _check_distributed_params(params) diff --git a/tests/test_distributed/test_with_dask/test_with_dask.py b/tests/test_distributed/test_with_dask/test_with_dask.py index 23bfc2d23..3add01192 100644 --- a/tests/test_distributed/test_with_dask/test_with_dask.py +++ b/tests/test_distributed/test_with_dask/test_with_dask.py @@ -36,7 +36,8 @@ pytestmark = [tm.timeout(60), pytest.mark.skipif(**tm.no_dask())] import dask import dask.array as da import dask.dataframe as dd -from distributed import Client, LocalCluster +from distributed import Client, LocalCluster, Nanny, Worker +from distributed.utils_test import async_poll_for, gen_cluster from toolz import sliding_window # dependency of dask from xgboost.dask import DaskDMatrix @@ -2226,3 +2227,38 @@ class TestDaskCallbacks: ) for i in range(1, 10): assert os.path.exists(os.path.join(tmpdir, "model_" + str(i) + ".json")) + + +@gen_cluster(client=True, clean_kwargs={"processes": False, "threads": False}, allow_unclosed=True) +async def test_worker_left(c, s, a, b): + async with Worker(s.address): + dx = da.random.random((1000, 10)).rechunk(chunks=(10, None)) + dy = da.random.random((1000,)).rechunk(chunks=(10,)) + d_train = await xgb.dask.DaskDMatrix( + c, dx, dy, + ) + await async_poll_for(lambda: len(s.workers) == 2, timeout=5) + with pytest.raises(RuntimeError, match="Missing"): + await xgb.dask.train( + c, + {}, + d_train, + evals=[(d_train, "train")], + ) + + +@gen_cluster(client=True, Worker=Nanny, clean_kwargs={"processes": False, "threads": False}, allow_unclosed=True) +async def test_worker_restarted(c, s, a, b): + dx = da.random.random((1000, 10)).rechunk(chunks=(10, None)) + dy = da.random.random((1000,)).rechunk(chunks=(10,)) + d_train = await xgb.dask.DaskDMatrix( + c, dx, dy, + ) + await c.restart_workers([a.worker_address]) + with pytest.raises(RuntimeError, match="Missing"): + await xgb.dask.train( + c, + {}, + d_train, + evals=[(d_train, "train")], + )