diff --git a/python-package/xgboost/dask.py b/python-package/xgboost/dask.py index a17fbad70..8c679b75b 100644 --- a/python-package/xgboost/dask.py +++ b/python-package/xgboost/dask.py @@ -888,6 +888,29 @@ def _get_workers_from_data( return list(X_worker_map) +def _filter_empty( + booster: Booster, local_history: TrainingCallback.EvalsLog, is_valid: bool +) -> Optional[TrainReturnT]: + n_workers = collective.get_world_size() + non_empty = numpy.zeros(shape=(n_workers,), dtype=numpy.int32) + rank = collective.get_rank() + non_empty[rank] = int(is_valid) + non_empty = collective.allreduce(non_empty, collective.Op.SUM) + non_empty = non_empty.astype(bool) + ret: Optional[TrainReturnT] = { + "booster": booster, + "history": local_history, + } + for i in range(non_empty.size): + # This is the first valid worker + if non_empty[i] and i == rank: + return ret + if non_empty[i]: + return None + + raise ValueError("None of the workers can provide a valid result.") + + async def _train_async( client: "distributed.Client", global_config: Dict[str, Any], @@ -973,14 +996,10 @@ async def _train_async( xgb_model=xgb_model, callbacks=callbacks, ) - if Xy.num_row() != 0: - ret: Optional[TrainReturnT] = { - "booster": booster, - "history": local_history, - } - else: - ret = None - return ret + # Don't return the boosters from empty workers. It's quite difficult to + # guarantee everything is in sync in the present of empty workers, + # especially with complex objectives like quantile. + return _filter_empty(booster, local_history, Xy.num_row() != 0) async with distributed.MultiLock(workers, client): if evals is not None: