[dask] Filter models on worker. (#9518)

This commit is contained in:
Jiaming Yuan 2023-08-25 20:23:47 +08:00 committed by GitHub
parent 972730cde0
commit aa86bd5207
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -47,6 +47,7 @@ from typing import (
Callable,
Dict,
Generator,
Iterable,
List,
Optional,
Sequence,
@ -97,10 +98,12 @@ if TYPE_CHECKING:
import dask
import distributed
from dask import array as da
from dask import bag as db
from dask import dataframe as dd
else:
dd = LazyLoader("dd", globals(), "dask.dataframe")
da = LazyLoader("da", globals(), "dask.array")
db = LazyLoader("db", globals(), "dask.bag")
dask = LazyLoader("dask", globals(), "dask")
distributed = LazyLoader("distributed", globals(), "dask.distributed")
@ -509,12 +512,10 @@ async def map_worker_partitions(
func: Callable[..., _MapRetT],
*refs: Any,
workers: Sequence[str],
) -> List[_MapRetT]:
) -> _MapRetT:
"""Map a function onto partitions of each worker."""
# Note for function purity:
# XGBoost is deterministic in most of the cases, which means train function is
# supposed to be idempotent. One known exception is gblinear with shotgun updater.
# We haven't been able to do a full verification so here we keep pure to be False.
# XGBoost is sensitive to data partition and uses random number generator.
client = _xgb_get_client(client)
futures = []
for addr in workers:
@ -526,11 +527,26 @@ async def map_worker_partitions(
else:
args.append(ref)
fut = client.submit(
func, *args, pure=False, workers=[addr], allow_other_workers=False
# turn result into a list for bag construction
lambda *args, **kwargs: [func(*args, **kwargs)],
*args,
pure=False,
workers=[addr],
allow_other_workers=False,
)
futures.append(fut)
results = await client.gather(futures)
return results
def first_valid(results: Iterable[Optional[_MapRetT]]) -> Optional[_MapRetT]:
for v in results:
if v is not None:
return v
return None
bag = db.from_delayed(futures)
fut = await bag.reduction(first_valid, first_valid)
result = await client.compute(fut).result()
return result
_DataParts = List[Dict[str, Any]]
@ -882,29 +898,6 @@ def _get_workers_from_data(
return list(X_worker_map)
def _filter_empty(
booster: Booster, local_history: TrainingCallback.EvalsLog, is_valid: bool
) -> Optional[TrainReturnT]:
n_workers = collective.get_world_size()
non_empty = numpy.zeros(shape=(n_workers,), dtype=numpy.int32)
rank = collective.get_rank()
non_empty[rank] = int(is_valid)
non_empty = collective.allreduce(non_empty, collective.Op.SUM)
non_empty = non_empty.astype(bool)
ret: Optional[TrainReturnT] = {
"booster": booster,
"history": local_history,
}
for i in range(non_empty.size):
# This is the first valid worker
if non_empty[i] and i == rank:
return ret
if non_empty[i]:
return None
raise ValueError("None of the workers can provide a valid result.")
async def _check_workers_are_alive(
workers: List[str], client: "distributed.Client"
) -> None:
@ -998,9 +991,16 @@ async def _train_async(
callbacks=callbacks,
)
# Don't return the boosters from empty workers. It's quite difficult to
# guarantee everything is in sync in the present of empty workers,
# especially with complex objectives like quantile.
return _filter_empty(booster, local_history, Xy.num_row() != 0)
# guarantee everything is in sync in the present of empty workers, especially
# with complex objectives like quantile.
if Xy.num_row() != 0:
ret: Optional[TrainReturnT] = {
"booster": booster,
"history": local_history,
}
else:
ret = None
return ret
async with distributed.MultiLock(workers, client):
if evals is not None:
@ -1012,7 +1012,7 @@ async def _train_async(
evals_name = []
evals_id = []
results = await map_worker_partitions(
result = await map_worker_partitions(
client,
dispatched_train,
# extra function parameters
@ -1025,7 +1025,7 @@ async def _train_async(
# workers to be used for training
workers=workers,
)
return list(filter(lambda ret: ret is not None, results))[0]
return result
@_deprecate_positional_args