diff --git a/python-package/xgboost/dask.py b/python-package/xgboost/dask.py index 219ad2698..f62a3e5af 100644 --- a/python-package/xgboost/dask.py +++ b/python-package/xgboost/dask.py @@ -47,6 +47,7 @@ from typing import ( Callable, Dict, Generator, + Iterable, List, Optional, Sequence, @@ -97,10 +98,12 @@ if TYPE_CHECKING: import dask import distributed from dask import array as da + from dask import bag as db from dask import dataframe as dd else: dd = LazyLoader("dd", globals(), "dask.dataframe") da = LazyLoader("da", globals(), "dask.array") + db = LazyLoader("db", globals(), "dask.bag") dask = LazyLoader("dask", globals(), "dask") distributed = LazyLoader("distributed", globals(), "dask.distributed") @@ -509,12 +512,10 @@ async def map_worker_partitions( func: Callable[..., _MapRetT], *refs: Any, workers: Sequence[str], -) -> List[_MapRetT]: +) -> _MapRetT: """Map a function onto partitions of each worker.""" # Note for function purity: - # XGBoost is deterministic in most of the cases, which means train function is - # supposed to be idempotent. One known exception is gblinear with shotgun updater. - # We haven't been able to do a full verification so here we keep pure to be False. + # XGBoost is sensitive to data partition and uses random number generator. client = _xgb_get_client(client) futures = [] for addr in workers: @@ -526,11 +527,26 @@ async def map_worker_partitions( else: args.append(ref) fut = client.submit( - func, *args, pure=False, workers=[addr], allow_other_workers=False + # turn result into a list for bag construction + lambda *args, **kwargs: [func(*args, **kwargs)], + *args, + pure=False, + workers=[addr], + allow_other_workers=False, ) futures.append(fut) - results = await client.gather(futures) - return results + + def first_valid(results: Iterable[Optional[_MapRetT]]) -> Optional[_MapRetT]: + for v in results: + if v is not None: + return v + return None + + bag = db.from_delayed(futures) + fut = await bag.reduction(first_valid, first_valid) + result = await client.compute(fut).result() + + return result _DataParts = List[Dict[str, Any]] @@ -882,29 +898,6 @@ def _get_workers_from_data( return list(X_worker_map) -def _filter_empty( - booster: Booster, local_history: TrainingCallback.EvalsLog, is_valid: bool -) -> Optional[TrainReturnT]: - n_workers = collective.get_world_size() - non_empty = numpy.zeros(shape=(n_workers,), dtype=numpy.int32) - rank = collective.get_rank() - non_empty[rank] = int(is_valid) - non_empty = collective.allreduce(non_empty, collective.Op.SUM) - non_empty = non_empty.astype(bool) - ret: Optional[TrainReturnT] = { - "booster": booster, - "history": local_history, - } - for i in range(non_empty.size): - # This is the first valid worker - if non_empty[i] and i == rank: - return ret - if non_empty[i]: - return None - - raise ValueError("None of the workers can provide a valid result.") - - async def _check_workers_are_alive( workers: List[str], client: "distributed.Client" ) -> None: @@ -997,10 +990,17 @@ async def _train_async( xgb_model=xgb_model, callbacks=callbacks, ) - # Don't return the boosters from empty workers. It's quite difficult to - # guarantee everything is in sync in the present of empty workers, - # especially with complex objectives like quantile. - return _filter_empty(booster, local_history, Xy.num_row() != 0) + # Don't return the boosters from empty workers. It's quite difficult to + # guarantee everything is in sync in the present of empty workers, especially + # with complex objectives like quantile. + if Xy.num_row() != 0: + ret: Optional[TrainReturnT] = { + "booster": booster, + "history": local_history, + } + else: + ret = None + return ret async with distributed.MultiLock(workers, client): if evals is not None: @@ -1012,7 +1012,7 @@ async def _train_async( evals_name = [] evals_id = [] - results = await map_worker_partitions( + result = await map_worker_partitions( client, dispatched_train, # extra function parameters @@ -1025,7 +1025,7 @@ async def _train_async( # workers to be used for training workers=workers, ) - return list(filter(lambda ret: ret is not None, results))[0] + return result @_deprecate_positional_args