[dask] Filter models on worker. (#9518)
This commit is contained in:
parent
972730cde0
commit
aa86bd5207
@ -47,6 +47,7 @@ from typing import (
|
||||
Callable,
|
||||
Dict,
|
||||
Generator,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
@ -97,10 +98,12 @@ if TYPE_CHECKING:
|
||||
import dask
|
||||
import distributed
|
||||
from dask import array as da
|
||||
from dask import bag as db
|
||||
from dask import dataframe as dd
|
||||
else:
|
||||
dd = LazyLoader("dd", globals(), "dask.dataframe")
|
||||
da = LazyLoader("da", globals(), "dask.array")
|
||||
db = LazyLoader("db", globals(), "dask.bag")
|
||||
dask = LazyLoader("dask", globals(), "dask")
|
||||
distributed = LazyLoader("distributed", globals(), "dask.distributed")
|
||||
|
||||
@ -509,12 +512,10 @@ async def map_worker_partitions(
|
||||
func: Callable[..., _MapRetT],
|
||||
*refs: Any,
|
||||
workers: Sequence[str],
|
||||
) -> List[_MapRetT]:
|
||||
) -> _MapRetT:
|
||||
"""Map a function onto partitions of each worker."""
|
||||
# Note for function purity:
|
||||
# XGBoost is deterministic in most of the cases, which means train function is
|
||||
# supposed to be idempotent. One known exception is gblinear with shotgun updater.
|
||||
# We haven't been able to do a full verification so here we keep pure to be False.
|
||||
# XGBoost is sensitive to data partition and uses random number generator.
|
||||
client = _xgb_get_client(client)
|
||||
futures = []
|
||||
for addr in workers:
|
||||
@ -526,11 +527,26 @@ async def map_worker_partitions(
|
||||
else:
|
||||
args.append(ref)
|
||||
fut = client.submit(
|
||||
func, *args, pure=False, workers=[addr], allow_other_workers=False
|
||||
# turn result into a list for bag construction
|
||||
lambda *args, **kwargs: [func(*args, **kwargs)],
|
||||
*args,
|
||||
pure=False,
|
||||
workers=[addr],
|
||||
allow_other_workers=False,
|
||||
)
|
||||
futures.append(fut)
|
||||
results = await client.gather(futures)
|
||||
return results
|
||||
|
||||
def first_valid(results: Iterable[Optional[_MapRetT]]) -> Optional[_MapRetT]:
|
||||
for v in results:
|
||||
if v is not None:
|
||||
return v
|
||||
return None
|
||||
|
||||
bag = db.from_delayed(futures)
|
||||
fut = await bag.reduction(first_valid, first_valid)
|
||||
result = await client.compute(fut).result()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
_DataParts = List[Dict[str, Any]]
|
||||
@ -882,29 +898,6 @@ def _get_workers_from_data(
|
||||
return list(X_worker_map)
|
||||
|
||||
|
||||
def _filter_empty(
|
||||
booster: Booster, local_history: TrainingCallback.EvalsLog, is_valid: bool
|
||||
) -> Optional[TrainReturnT]:
|
||||
n_workers = collective.get_world_size()
|
||||
non_empty = numpy.zeros(shape=(n_workers,), dtype=numpy.int32)
|
||||
rank = collective.get_rank()
|
||||
non_empty[rank] = int(is_valid)
|
||||
non_empty = collective.allreduce(non_empty, collective.Op.SUM)
|
||||
non_empty = non_empty.astype(bool)
|
||||
ret: Optional[TrainReturnT] = {
|
||||
"booster": booster,
|
||||
"history": local_history,
|
||||
}
|
||||
for i in range(non_empty.size):
|
||||
# This is the first valid worker
|
||||
if non_empty[i] and i == rank:
|
||||
return ret
|
||||
if non_empty[i]:
|
||||
return None
|
||||
|
||||
raise ValueError("None of the workers can provide a valid result.")
|
||||
|
||||
|
||||
async def _check_workers_are_alive(
|
||||
workers: List[str], client: "distributed.Client"
|
||||
) -> None:
|
||||
@ -998,9 +991,16 @@ async def _train_async(
|
||||
callbacks=callbacks,
|
||||
)
|
||||
# Don't return the boosters from empty workers. It's quite difficult to
|
||||
# guarantee everything is in sync in the present of empty workers,
|
||||
# especially with complex objectives like quantile.
|
||||
return _filter_empty(booster, local_history, Xy.num_row() != 0)
|
||||
# guarantee everything is in sync in the present of empty workers, especially
|
||||
# with complex objectives like quantile.
|
||||
if Xy.num_row() != 0:
|
||||
ret: Optional[TrainReturnT] = {
|
||||
"booster": booster,
|
||||
"history": local_history,
|
||||
}
|
||||
else:
|
||||
ret = None
|
||||
return ret
|
||||
|
||||
async with distributed.MultiLock(workers, client):
|
||||
if evals is not None:
|
||||
@ -1012,7 +1012,7 @@ async def _train_async(
|
||||
evals_name = []
|
||||
evals_id = []
|
||||
|
||||
results = await map_worker_partitions(
|
||||
result = await map_worker_partitions(
|
||||
client,
|
||||
dispatched_train,
|
||||
# extra function parameters
|
||||
@ -1025,7 +1025,7 @@ async def _train_async(
|
||||
# workers to be used for training
|
||||
workers=workers,
|
||||
)
|
||||
return list(filter(lambda ret: ret is not None, results))[0]
|
||||
return result
|
||||
|
||||
|
||||
@_deprecate_positional_args
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user