[dask] Filter models on worker. (#9518)
This commit is contained in:
parent
972730cde0
commit
aa86bd5207
@ -47,6 +47,7 @@ from typing import (
|
|||||||
Callable,
|
Callable,
|
||||||
Dict,
|
Dict,
|
||||||
Generator,
|
Generator,
|
||||||
|
Iterable,
|
||||||
List,
|
List,
|
||||||
Optional,
|
Optional,
|
||||||
Sequence,
|
Sequence,
|
||||||
@ -97,10 +98,12 @@ if TYPE_CHECKING:
|
|||||||
import dask
|
import dask
|
||||||
import distributed
|
import distributed
|
||||||
from dask import array as da
|
from dask import array as da
|
||||||
|
from dask import bag as db
|
||||||
from dask import dataframe as dd
|
from dask import dataframe as dd
|
||||||
else:
|
else:
|
||||||
dd = LazyLoader("dd", globals(), "dask.dataframe")
|
dd = LazyLoader("dd", globals(), "dask.dataframe")
|
||||||
da = LazyLoader("da", globals(), "dask.array")
|
da = LazyLoader("da", globals(), "dask.array")
|
||||||
|
db = LazyLoader("db", globals(), "dask.bag")
|
||||||
dask = LazyLoader("dask", globals(), "dask")
|
dask = LazyLoader("dask", globals(), "dask")
|
||||||
distributed = LazyLoader("distributed", globals(), "dask.distributed")
|
distributed = LazyLoader("distributed", globals(), "dask.distributed")
|
||||||
|
|
||||||
@ -509,12 +512,10 @@ async def map_worker_partitions(
|
|||||||
func: Callable[..., _MapRetT],
|
func: Callable[..., _MapRetT],
|
||||||
*refs: Any,
|
*refs: Any,
|
||||||
workers: Sequence[str],
|
workers: Sequence[str],
|
||||||
) -> List[_MapRetT]:
|
) -> _MapRetT:
|
||||||
"""Map a function onto partitions of each worker."""
|
"""Map a function onto partitions of each worker."""
|
||||||
# Note for function purity:
|
# Note for function purity:
|
||||||
# XGBoost is deterministic in most of the cases, which means train function is
|
# XGBoost is sensitive to data partition and uses random number generator.
|
||||||
# supposed to be idempotent. One known exception is gblinear with shotgun updater.
|
|
||||||
# We haven't been able to do a full verification so here we keep pure to be False.
|
|
||||||
client = _xgb_get_client(client)
|
client = _xgb_get_client(client)
|
||||||
futures = []
|
futures = []
|
||||||
for addr in workers:
|
for addr in workers:
|
||||||
@ -526,11 +527,26 @@ async def map_worker_partitions(
|
|||||||
else:
|
else:
|
||||||
args.append(ref)
|
args.append(ref)
|
||||||
fut = client.submit(
|
fut = client.submit(
|
||||||
func, *args, pure=False, workers=[addr], allow_other_workers=False
|
# turn result into a list for bag construction
|
||||||
|
lambda *args, **kwargs: [func(*args, **kwargs)],
|
||||||
|
*args,
|
||||||
|
pure=False,
|
||||||
|
workers=[addr],
|
||||||
|
allow_other_workers=False,
|
||||||
)
|
)
|
||||||
futures.append(fut)
|
futures.append(fut)
|
||||||
results = await client.gather(futures)
|
|
||||||
return results
|
def first_valid(results: Iterable[Optional[_MapRetT]]) -> Optional[_MapRetT]:
|
||||||
|
for v in results:
|
||||||
|
if v is not None:
|
||||||
|
return v
|
||||||
|
return None
|
||||||
|
|
||||||
|
bag = db.from_delayed(futures)
|
||||||
|
fut = await bag.reduction(first_valid, first_valid)
|
||||||
|
result = await client.compute(fut).result()
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
_DataParts = List[Dict[str, Any]]
|
_DataParts = List[Dict[str, Any]]
|
||||||
@ -882,29 +898,6 @@ def _get_workers_from_data(
|
|||||||
return list(X_worker_map)
|
return list(X_worker_map)
|
||||||
|
|
||||||
|
|
||||||
def _filter_empty(
|
|
||||||
booster: Booster, local_history: TrainingCallback.EvalsLog, is_valid: bool
|
|
||||||
) -> Optional[TrainReturnT]:
|
|
||||||
n_workers = collective.get_world_size()
|
|
||||||
non_empty = numpy.zeros(shape=(n_workers,), dtype=numpy.int32)
|
|
||||||
rank = collective.get_rank()
|
|
||||||
non_empty[rank] = int(is_valid)
|
|
||||||
non_empty = collective.allreduce(non_empty, collective.Op.SUM)
|
|
||||||
non_empty = non_empty.astype(bool)
|
|
||||||
ret: Optional[TrainReturnT] = {
|
|
||||||
"booster": booster,
|
|
||||||
"history": local_history,
|
|
||||||
}
|
|
||||||
for i in range(non_empty.size):
|
|
||||||
# This is the first valid worker
|
|
||||||
if non_empty[i] and i == rank:
|
|
||||||
return ret
|
|
||||||
if non_empty[i]:
|
|
||||||
return None
|
|
||||||
|
|
||||||
raise ValueError("None of the workers can provide a valid result.")
|
|
||||||
|
|
||||||
|
|
||||||
async def _check_workers_are_alive(
|
async def _check_workers_are_alive(
|
||||||
workers: List[str], client: "distributed.Client"
|
workers: List[str], client: "distributed.Client"
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -997,10 +990,17 @@ async def _train_async(
|
|||||||
xgb_model=xgb_model,
|
xgb_model=xgb_model,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
)
|
)
|
||||||
# Don't return the boosters from empty workers. It's quite difficult to
|
# Don't return the boosters from empty workers. It's quite difficult to
|
||||||
# guarantee everything is in sync in the present of empty workers,
|
# guarantee everything is in sync in the present of empty workers, especially
|
||||||
# especially with complex objectives like quantile.
|
# with complex objectives like quantile.
|
||||||
return _filter_empty(booster, local_history, Xy.num_row() != 0)
|
if Xy.num_row() != 0:
|
||||||
|
ret: Optional[TrainReturnT] = {
|
||||||
|
"booster": booster,
|
||||||
|
"history": local_history,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
ret = None
|
||||||
|
return ret
|
||||||
|
|
||||||
async with distributed.MultiLock(workers, client):
|
async with distributed.MultiLock(workers, client):
|
||||||
if evals is not None:
|
if evals is not None:
|
||||||
@ -1012,7 +1012,7 @@ async def _train_async(
|
|||||||
evals_name = []
|
evals_name = []
|
||||||
evals_id = []
|
evals_id = []
|
||||||
|
|
||||||
results = await map_worker_partitions(
|
result = await map_worker_partitions(
|
||||||
client,
|
client,
|
||||||
dispatched_train,
|
dispatched_train,
|
||||||
# extra function parameters
|
# extra function parameters
|
||||||
@ -1025,7 +1025,7 @@ async def _train_async(
|
|||||||
# workers to be used for training
|
# workers to be used for training
|
||||||
workers=workers,
|
workers=workers,
|
||||||
)
|
)
|
||||||
return list(filter(lambda ret: ret is not None, results))[0]
|
return result
|
||||||
|
|
||||||
|
|
||||||
@_deprecate_positional_args
|
@_deprecate_positional_args
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user