Raise if expected workers are not alive in xgboost.dask.train (#9421)
This commit is contained in:
parent
7129988847
commit
f958e32683
@ -850,8 +850,6 @@ async def _get_rabit_args(
|
|||||||
except Exception: # pylint: disable=broad-except
|
except Exception: # pylint: disable=broad-except
|
||||||
sched_addr = None
|
sched_addr = None
|
||||||
|
|
||||||
# make sure all workers are online so that we can obtain reliable scheduler_info
|
|
||||||
await client.wait_for_workers(n_workers) # type: ignore
|
|
||||||
env = await client.run_on_scheduler(
|
env = await client.run_on_scheduler(
|
||||||
_start_tracker, n_workers, sched_addr, user_addr
|
_start_tracker, n_workers, sched_addr, user_addr
|
||||||
)
|
)
|
||||||
@ -907,6 +905,16 @@ def _filter_empty(
|
|||||||
raise ValueError("None of the workers can provide a valid result.")
|
raise ValueError("None of the workers can provide a valid result.")
|
||||||
|
|
||||||
|
|
||||||
|
async def _check_workers_are_alive(
|
||||||
|
workers: List[str], client: "distributed.Client"
|
||||||
|
) -> None:
|
||||||
|
info = await client.scheduler.identity()
|
||||||
|
current_workers = info["workers"].keys()
|
||||||
|
missing_workers = set(workers) - current_workers
|
||||||
|
if missing_workers:
|
||||||
|
raise RuntimeError(f"Missing required workers: {missing_workers}")
|
||||||
|
|
||||||
|
|
||||||
async def _train_async(
|
async def _train_async(
|
||||||
client: "distributed.Client",
|
client: "distributed.Client",
|
||||||
global_config: Dict[str, Any],
|
global_config: Dict[str, Any],
|
||||||
@ -924,6 +932,7 @@ async def _train_async(
|
|||||||
custom_metric: Optional[Metric],
|
custom_metric: Optional[Metric],
|
||||||
) -> Optional[TrainReturnT]:
|
) -> Optional[TrainReturnT]:
|
||||||
workers = _get_workers_from_data(dtrain, evals)
|
workers = _get_workers_from_data(dtrain, evals)
|
||||||
|
await _check_workers_are_alive(workers, client)
|
||||||
_rabit_args = await _get_rabit_args(len(workers), dconfig, client)
|
_rabit_args = await _get_rabit_args(len(workers), dconfig, client)
|
||||||
_check_distributed_params(params)
|
_check_distributed_params(params)
|
||||||
|
|
||||||
|
|||||||
@ -36,7 +36,8 @@ pytestmark = [tm.timeout(60), pytest.mark.skipif(**tm.no_dask())]
|
|||||||
import dask
|
import dask
|
||||||
import dask.array as da
|
import dask.array as da
|
||||||
import dask.dataframe as dd
|
import dask.dataframe as dd
|
||||||
from distributed import Client, LocalCluster
|
from distributed import Client, LocalCluster, Nanny, Worker
|
||||||
|
from distributed.utils_test import async_poll_for, gen_cluster
|
||||||
from toolz import sliding_window # dependency of dask
|
from toolz import sliding_window # dependency of dask
|
||||||
|
|
||||||
from xgboost.dask import DaskDMatrix
|
from xgboost.dask import DaskDMatrix
|
||||||
@ -2226,3 +2227,38 @@ class TestDaskCallbacks:
|
|||||||
)
|
)
|
||||||
for i in range(1, 10):
|
for i in range(1, 10):
|
||||||
assert os.path.exists(os.path.join(tmpdir, "model_" + str(i) + ".json"))
|
assert os.path.exists(os.path.join(tmpdir, "model_" + str(i) + ".json"))
|
||||||
|
|
||||||
|
|
||||||
|
@gen_cluster(client=True, clean_kwargs={"processes": False, "threads": False}, allow_unclosed=True)
|
||||||
|
async def test_worker_left(c, s, a, b):
|
||||||
|
async with Worker(s.address):
|
||||||
|
dx = da.random.random((1000, 10)).rechunk(chunks=(10, None))
|
||||||
|
dy = da.random.random((1000,)).rechunk(chunks=(10,))
|
||||||
|
d_train = await xgb.dask.DaskDMatrix(
|
||||||
|
c, dx, dy,
|
||||||
|
)
|
||||||
|
await async_poll_for(lambda: len(s.workers) == 2, timeout=5)
|
||||||
|
with pytest.raises(RuntimeError, match="Missing"):
|
||||||
|
await xgb.dask.train(
|
||||||
|
c,
|
||||||
|
{},
|
||||||
|
d_train,
|
||||||
|
evals=[(d_train, "train")],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@gen_cluster(client=True, Worker=Nanny, clean_kwargs={"processes": False, "threads": False}, allow_unclosed=True)
|
||||||
|
async def test_worker_restarted(c, s, a, b):
|
||||||
|
dx = da.random.random((1000, 10)).rechunk(chunks=(10, None))
|
||||||
|
dy = da.random.random((1000,)).rechunk(chunks=(10,))
|
||||||
|
d_train = await xgb.dask.DaskDMatrix(
|
||||||
|
c, dx, dy,
|
||||||
|
)
|
||||||
|
await c.restart_workers([a.worker_address])
|
||||||
|
with pytest.raises(RuntimeError, match="Missing"):
|
||||||
|
await xgb.dask.train(
|
||||||
|
c,
|
||||||
|
{},
|
||||||
|
d_train,
|
||||||
|
evals=[(d_train, "train")],
|
||||||
|
)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user