[dask] Return the first valid booster instead of all valid ones. (#8993)
* [dask] Return the first valid booster instead of all valid ones. - Reduce memory footprint of the returned model. * mypy error. * lint. * duplicated.
This commit is contained in:
parent
6676c28cbc
commit
a58055075b
@ -888,6 +888,29 @@ def _get_workers_from_data(
|
|||||||
return list(X_worker_map)
|
return list(X_worker_map)
|
||||||
|
|
||||||
|
|
||||||
|
def _filter_empty(
|
||||||
|
booster: Booster, local_history: TrainingCallback.EvalsLog, is_valid: bool
|
||||||
|
) -> Optional[TrainReturnT]:
|
||||||
|
n_workers = collective.get_world_size()
|
||||||
|
non_empty = numpy.zeros(shape=(n_workers,), dtype=numpy.int32)
|
||||||
|
rank = collective.get_rank()
|
||||||
|
non_empty[rank] = int(is_valid)
|
||||||
|
non_empty = collective.allreduce(non_empty, collective.Op.SUM)
|
||||||
|
non_empty = non_empty.astype(bool)
|
||||||
|
ret: Optional[TrainReturnT] = {
|
||||||
|
"booster": booster,
|
||||||
|
"history": local_history,
|
||||||
|
}
|
||||||
|
for i in range(non_empty.size):
|
||||||
|
# This is the first valid worker
|
||||||
|
if non_empty[i] and i == rank:
|
||||||
|
return ret
|
||||||
|
if non_empty[i]:
|
||||||
|
return None
|
||||||
|
|
||||||
|
raise ValueError("None of the workers can provide a valid result.")
|
||||||
|
|
||||||
|
|
||||||
async def _train_async(
|
async def _train_async(
|
||||||
client: "distributed.Client",
|
client: "distributed.Client",
|
||||||
global_config: Dict[str, Any],
|
global_config: Dict[str, Any],
|
||||||
@ -973,14 +996,10 @@ async def _train_async(
|
|||||||
xgb_model=xgb_model,
|
xgb_model=xgb_model,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
)
|
)
|
||||||
if Xy.num_row() != 0:
|
# Don't return the boosters from empty workers. It's quite difficult to
|
||||||
ret: Optional[TrainReturnT] = {
|
# guarantee everything is in sync in the present of empty workers,
|
||||||
"booster": booster,
|
# especially with complex objectives like quantile.
|
||||||
"history": local_history,
|
return _filter_empty(booster, local_history, Xy.num_row() != 0)
|
||||||
}
|
|
||||||
else:
|
|
||||||
ret = None
|
|
||||||
return ret
|
|
||||||
|
|
||||||
async with distributed.MultiLock(workers, client):
|
async with distributed.MultiLock(workers, client):
|
||||||
if evals is not None:
|
if evals is not None:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user