[dask] Filter models on worker. (#9518)

This commit is contained in:
Jiaming Yuan 2023-08-25 20:23:47 +08:00 committed by GitHub
parent 972730cde0
commit aa86bd5207
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -47,6 +47,7 @@ from typing import (
Callable, Callable,
Dict, Dict,
Generator, Generator,
Iterable,
List, List,
Optional, Optional,
Sequence, Sequence,
@ -97,10 +98,12 @@ if TYPE_CHECKING:
import dask import dask
import distributed import distributed
from dask import array as da from dask import array as da
from dask import bag as db
from dask import dataframe as dd from dask import dataframe as dd
else: else:
dd = LazyLoader("dd", globals(), "dask.dataframe") dd = LazyLoader("dd", globals(), "dask.dataframe")
da = LazyLoader("da", globals(), "dask.array") da = LazyLoader("da", globals(), "dask.array")
db = LazyLoader("db", globals(), "dask.bag")
dask = LazyLoader("dask", globals(), "dask") dask = LazyLoader("dask", globals(), "dask")
distributed = LazyLoader("distributed", globals(), "dask.distributed") distributed = LazyLoader("distributed", globals(), "dask.distributed")
@ -509,12 +512,10 @@ async def map_worker_partitions(
func: Callable[..., _MapRetT], func: Callable[..., _MapRetT],
*refs: Any, *refs: Any,
workers: Sequence[str], workers: Sequence[str],
) -> List[_MapRetT]: ) -> _MapRetT:
"""Map a function onto partitions of each worker.""" """Map a function onto partitions of each worker."""
# Note for function purity: # Note for function purity:
# XGBoost is deterministic in most of the cases, which means train function is # XGBoost is sensitive to data partition and uses random number generator.
# supposed to be idempotent. One known exception is gblinear with shotgun updater.
# We haven't been able to do a full verification so here we keep pure to be False.
client = _xgb_get_client(client) client = _xgb_get_client(client)
futures = [] futures = []
for addr in workers: for addr in workers:
@ -526,11 +527,26 @@ async def map_worker_partitions(
else: else:
args.append(ref) args.append(ref)
fut = client.submit( fut = client.submit(
func, *args, pure=False, workers=[addr], allow_other_workers=False # turn result into a list for bag construction
lambda *args, **kwargs: [func(*args, **kwargs)],
*args,
pure=False,
workers=[addr],
allow_other_workers=False,
) )
futures.append(fut) futures.append(fut)
results = await client.gather(futures)
return results def first_valid(results: Iterable[Optional[_MapRetT]]) -> Optional[_MapRetT]:
for v in results:
if v is not None:
return v
return None
bag = db.from_delayed(futures)
fut = await bag.reduction(first_valid, first_valid)
result = await client.compute(fut).result()
return result
_DataParts = List[Dict[str, Any]] _DataParts = List[Dict[str, Any]]
@ -882,29 +898,6 @@ def _get_workers_from_data(
return list(X_worker_map) return list(X_worker_map)
def _filter_empty(
booster: Booster, local_history: TrainingCallback.EvalsLog, is_valid: bool
) -> Optional[TrainReturnT]:
n_workers = collective.get_world_size()
non_empty = numpy.zeros(shape=(n_workers,), dtype=numpy.int32)
rank = collective.get_rank()
non_empty[rank] = int(is_valid)
non_empty = collective.allreduce(non_empty, collective.Op.SUM)
non_empty = non_empty.astype(bool)
ret: Optional[TrainReturnT] = {
"booster": booster,
"history": local_history,
}
for i in range(non_empty.size):
# This is the first valid worker
if non_empty[i] and i == rank:
return ret
if non_empty[i]:
return None
raise ValueError("None of the workers can provide a valid result.")
async def _check_workers_are_alive( async def _check_workers_are_alive(
workers: List[str], client: "distributed.Client" workers: List[str], client: "distributed.Client"
) -> None: ) -> None:
@ -997,10 +990,17 @@ async def _train_async(
xgb_model=xgb_model, xgb_model=xgb_model,
callbacks=callbacks, callbacks=callbacks,
) )
# Don't return the boosters from empty workers. It's quite difficult to # Don't return the boosters from empty workers. It's quite difficult to
# guarantee everything is in sync in the present of empty workers, # guarantee everything is in sync in the present of empty workers, especially
# especially with complex objectives like quantile. # with complex objectives like quantile.
return _filter_empty(booster, local_history, Xy.num_row() != 0) if Xy.num_row() != 0:
ret: Optional[TrainReturnT] = {
"booster": booster,
"history": local_history,
}
else:
ret = None
return ret
async with distributed.MultiLock(workers, client): async with distributed.MultiLock(workers, client):
if evals is not None: if evals is not None:
@ -1012,7 +1012,7 @@ async def _train_async(
evals_name = [] evals_name = []
evals_id = [] evals_id = []
results = await map_worker_partitions( result = await map_worker_partitions(
client, client,
dispatched_train, dispatched_train,
# extra function parameters # extra function parameters
@ -1025,7 +1025,7 @@ async def _train_async(
# workers to be used for training # workers to be used for training
workers=workers, workers=workers,
) )
return list(filter(lambda ret: ret is not None, results))[0] return result
@_deprecate_positional_args @_deprecate_positional_args