Raise if expected workers are not alive in xgboost.dask.train (#9421)

This commit is contained in:
Hendrik Makait 2023-08-03 14:14:07 +02:00 committed by GitHub
parent 7129988847
commit f958e32683
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 48 additions and 3 deletions

View File

@ -850,8 +850,6 @@ async def _get_rabit_args(
except Exception: # pylint: disable=broad-except except Exception: # pylint: disable=broad-except
sched_addr = None sched_addr = None
# make sure all workers are online so that we can obtain reliable scheduler_info
await client.wait_for_workers(n_workers) # type: ignore
env = await client.run_on_scheduler( env = await client.run_on_scheduler(
_start_tracker, n_workers, sched_addr, user_addr _start_tracker, n_workers, sched_addr, user_addr
) )
@ -907,6 +905,16 @@ def _filter_empty(
raise ValueError("None of the workers can provide a valid result.") raise ValueError("None of the workers can provide a valid result.")
async def _check_workers_are_alive(
workers: List[str], client: "distributed.Client"
) -> None:
info = await client.scheduler.identity()
current_workers = info["workers"].keys()
missing_workers = set(workers) - current_workers
if missing_workers:
raise RuntimeError(f"Missing required workers: {missing_workers}")
async def _train_async( async def _train_async(
client: "distributed.Client", client: "distributed.Client",
global_config: Dict[str, Any], global_config: Dict[str, Any],
@ -924,6 +932,7 @@ async def _train_async(
custom_metric: Optional[Metric], custom_metric: Optional[Metric],
) -> Optional[TrainReturnT]: ) -> Optional[TrainReturnT]:
workers = _get_workers_from_data(dtrain, evals) workers = _get_workers_from_data(dtrain, evals)
await _check_workers_are_alive(workers, client)
_rabit_args = await _get_rabit_args(len(workers), dconfig, client) _rabit_args = await _get_rabit_args(len(workers), dconfig, client)
_check_distributed_params(params) _check_distributed_params(params)

View File

@ -36,7 +36,8 @@ pytestmark = [tm.timeout(60), pytest.mark.skipif(**tm.no_dask())]
import dask import dask
import dask.array as da import dask.array as da
import dask.dataframe as dd import dask.dataframe as dd
from distributed import Client, LocalCluster from distributed import Client, LocalCluster, Nanny, Worker
from distributed.utils_test import async_poll_for, gen_cluster
from toolz import sliding_window # dependency of dask from toolz import sliding_window # dependency of dask
from xgboost.dask import DaskDMatrix from xgboost.dask import DaskDMatrix
@ -2226,3 +2227,38 @@ class TestDaskCallbacks:
) )
for i in range(1, 10): for i in range(1, 10):
assert os.path.exists(os.path.join(tmpdir, "model_" + str(i) + ".json")) assert os.path.exists(os.path.join(tmpdir, "model_" + str(i) + ".json"))
@gen_cluster(client=True, clean_kwargs={"processes": False, "threads": False}, allow_unclosed=True)
async def test_worker_left(c, s, a, b):
async with Worker(s.address):
dx = da.random.random((1000, 10)).rechunk(chunks=(10, None))
dy = da.random.random((1000,)).rechunk(chunks=(10,))
d_train = await xgb.dask.DaskDMatrix(
c, dx, dy,
)
await async_poll_for(lambda: len(s.workers) == 2, timeout=5)
with pytest.raises(RuntimeError, match="Missing"):
await xgb.dask.train(
c,
{},
d_train,
evals=[(d_train, "train")],
)
@gen_cluster(client=True, Worker=Nanny, clean_kwargs={"processes": False, "threads": False}, allow_unclosed=True)
async def test_worker_restarted(c, s, a, b):
dx = da.random.random((1000, 10)).rechunk(chunks=(10, None))
dy = da.random.random((1000,)).rechunk(chunks=(10,))
d_train = await xgb.dask.DaskDMatrix(
c, dx, dy,
)
await c.restart_workers([a.worker_address])
with pytest.raises(RuntimeError, match="Missing"):
await xgb.dask.train(
c,
{},
d_train,
evals=[(d_train, "train")],
)