[dask] Return the first valid booster instead of all valid ones. (#8993)

* [dask] Return the first valid booster instead of all valid ones.

- Reduce memory footprint of the returned model.

* mypy error.

* lint.

* duplicated.
This commit is contained in:
Jiaming Yuan 2023-03-30 03:16:18 +08:00 committed by GitHub
parent 6676c28cbc
commit a58055075b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -888,6 +888,29 @@ def _get_workers_from_data(
return list(X_worker_map)
def _filter_empty(
booster: Booster, local_history: TrainingCallback.EvalsLog, is_valid: bool
) -> Optional[TrainReturnT]:
n_workers = collective.get_world_size()
non_empty = numpy.zeros(shape=(n_workers,), dtype=numpy.int32)
rank = collective.get_rank()
non_empty[rank] = int(is_valid)
non_empty = collective.allreduce(non_empty, collective.Op.SUM)
non_empty = non_empty.astype(bool)
ret: Optional[TrainReturnT] = {
"booster": booster,
"history": local_history,
}
for i in range(non_empty.size):
# This is the first valid worker
if non_empty[i] and i == rank:
return ret
if non_empty[i]:
return None
raise ValueError("None of the workers can provide a valid result.")
async def _train_async(
client: "distributed.Client",
global_config: Dict[str, Any],
@ -973,14 +996,10 @@ async def _train_async(
xgb_model=xgb_model,
callbacks=callbacks,
)
if Xy.num_row() != 0:
ret: Optional[TrainReturnT] = {
"booster": booster,
"history": local_history,
}
else:
ret = None
return ret
# Don't return the boosters from empty workers. It's quite difficult to
# guarantee everything is in sync in the present of empty workers,
# especially with complex objectives like quantile.
return _filter_empty(booster, local_history, Xy.num_row() != 0)
async with distributed.MultiLock(workers, client):
if evals is not None: