Raise if expected workers are not alive in xgboost.dask.train (#9421)
This commit is contained in:
@@ -850,8 +850,6 @@ async def _get_rabit_args(
|
||||
except Exception: # pylint: disable=broad-except
|
||||
sched_addr = None
|
||||
|
||||
# make sure all workers are online so that we can obtain reliable scheduler_info
|
||||
await client.wait_for_workers(n_workers) # type: ignore
|
||||
env = await client.run_on_scheduler(
|
||||
_start_tracker, n_workers, sched_addr, user_addr
|
||||
)
|
||||
@@ -907,6 +905,16 @@ def _filter_empty(
|
||||
raise ValueError("None of the workers can provide a valid result.")
|
||||
|
||||
|
||||
async def _check_workers_are_alive(
|
||||
workers: List[str], client: "distributed.Client"
|
||||
) -> None:
|
||||
info = await client.scheduler.identity()
|
||||
current_workers = info["workers"].keys()
|
||||
missing_workers = set(workers) - current_workers
|
||||
if missing_workers:
|
||||
raise RuntimeError(f"Missing required workers: {missing_workers}")
|
||||
|
||||
|
||||
async def _train_async(
|
||||
client: "distributed.Client",
|
||||
global_config: Dict[str, Any],
|
||||
@@ -924,6 +932,7 @@ async def _train_async(
|
||||
custom_metric: Optional[Metric],
|
||||
) -> Optional[TrainReturnT]:
|
||||
workers = _get_workers_from_data(dtrain, evals)
|
||||
await _check_workers_are_alive(workers, client)
|
||||
_rabit_args = await _get_rabit_args(len(workers), dconfig, client)
|
||||
_check_distributed_params(params)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user