[dask] Return the first valid booster instead of all valid ones. (#8993)
* [dask] Return the first valid booster instead of all valid ones. - Reduce memory footprint of the returned model. * mypy error. * lint. * duplicated.
This commit is contained in:
parent
6676c28cbc
commit
a58055075b
@ -888,6 +888,29 @@ def _get_workers_from_data(
|
||||
return list(X_worker_map)
|
||||
|
||||
|
||||
def _filter_empty(
|
||||
booster: Booster, local_history: TrainingCallback.EvalsLog, is_valid: bool
|
||||
) -> Optional[TrainReturnT]:
|
||||
n_workers = collective.get_world_size()
|
||||
non_empty = numpy.zeros(shape=(n_workers,), dtype=numpy.int32)
|
||||
rank = collective.get_rank()
|
||||
non_empty[rank] = int(is_valid)
|
||||
non_empty = collective.allreduce(non_empty, collective.Op.SUM)
|
||||
non_empty = non_empty.astype(bool)
|
||||
ret: Optional[TrainReturnT] = {
|
||||
"booster": booster,
|
||||
"history": local_history,
|
||||
}
|
||||
for i in range(non_empty.size):
|
||||
# This is the first valid worker
|
||||
if non_empty[i] and i == rank:
|
||||
return ret
|
||||
if non_empty[i]:
|
||||
return None
|
||||
|
||||
raise ValueError("None of the workers can provide a valid result.")
|
||||
|
||||
|
||||
async def _train_async(
|
||||
client: "distributed.Client",
|
||||
global_config: Dict[str, Any],
|
||||
@ -973,14 +996,10 @@ async def _train_async(
|
||||
xgb_model=xgb_model,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
if Xy.num_row() != 0:
|
||||
ret: Optional[TrainReturnT] = {
|
||||
"booster": booster,
|
||||
"history": local_history,
|
||||
}
|
||||
else:
|
||||
ret = None
|
||||
return ret
|
||||
# Don't return the boosters from empty workers. It's quite difficult to
|
||||
# guarantee everything is in sync in the present of empty workers,
|
||||
# especially with complex objectives like quantile.
|
||||
return _filter_empty(booster, local_history, Xy.num_row() != 0)
|
||||
|
||||
async with distributed.MultiLock(workers, client):
|
||||
if evals is not None:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user