[dask] Use distributed.MultiLock (#6743)
* [dask] Use `distributed.MultiLock` This enables training multiple models in parallel. * Conditionally import `MultiLock`. * Use async train directly in scikit learn interface. * Use `worker_client` when available.
This commit is contained in:
@@ -17,6 +17,7 @@ https://github.com/dask/dask-xgboost
|
||||
"""
|
||||
import platform
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence
|
||||
from threading import Thread
|
||||
@@ -93,6 +94,34 @@ except ImportError:
|
||||
LOGGER = logging.getLogger('[xgboost.dask]')
|
||||
|
||||
|
||||
def _multi_lock() -> Any:
|
||||
"""MultiLock is only available on latest distributed. See:
|
||||
|
||||
https://github.com/dask/distributed/pull/4503
|
||||
|
||||
"""
|
||||
try:
|
||||
from distributed import MultiLock
|
||||
except ImportError:
|
||||
class MultiLock: # type:ignore
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
pass
|
||||
|
||||
def __enter__(self) -> "MultiLock":
|
||||
return self
|
||||
|
||||
def __exit__(self, *args: Any, **kwargs: Any) -> None:
|
||||
return
|
||||
|
||||
async def __aenter__(self) -> "MultiLock":
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *args: Any, **kwargs: Any) -> None:
|
||||
return
|
||||
|
||||
return MultiLock
|
||||
|
||||
|
||||
def _start_tracker(n_workers: int) -> Dict[str, Any]:
|
||||
"""Start Rabit tracker """
|
||||
env = {'DMLC_NUM_WORKER': n_workers}
|
||||
@@ -770,7 +799,7 @@ async def _get_rabit_args(n_workers: int, client: "distributed.Client") -> List[
|
||||
def _get_workers_from_data(
|
||||
dtrain: DaskDMatrix,
|
||||
evals: Optional[List[Tuple[DaskDMatrix, str]]]
|
||||
) -> Set[str]:
|
||||
) -> List[str]:
|
||||
X_worker_map: Set[str] = set(dtrain.worker_map.keys())
|
||||
if evals:
|
||||
for e in evals:
|
||||
@@ -780,7 +809,7 @@ def _get_workers_from_data(
|
||||
continue
|
||||
worker_map = set(e[0].worker_map.keys())
|
||||
X_worker_map = X_worker_map.union(worker_map)
|
||||
return X_worker_map
|
||||
return list(X_worker_map)
|
||||
|
||||
|
||||
async def _train_async(
|
||||
@@ -795,9 +824,9 @@ async def _train_async(
|
||||
early_stopping_rounds: Optional[int],
|
||||
verbose_eval: Union[int, bool],
|
||||
xgb_model: Optional[Booster],
|
||||
callbacks: Optional[List[TrainingCallback]]
|
||||
callbacks: Optional[List[TrainingCallback]],
|
||||
) -> Optional[TrainReturnT]:
|
||||
workers = list(_get_workers_from_data(dtrain, evals))
|
||||
workers = _get_workers_from_data(dtrain, evals)
|
||||
_rabit_args = await _get_rabit_args(len(workers), client)
|
||||
|
||||
if params.get("booster", None) == "gblinear":
|
||||
@@ -858,29 +887,32 @@ async def _train_async(
|
||||
# XGBoost is deterministic in most of the cases, which means train function is
|
||||
# supposed to be idempotent. One known exception is gblinear with shotgun updater.
|
||||
# We haven't been able to do a full verification so here we keep pure to be False.
|
||||
futures = []
|
||||
for i, worker_addr in enumerate(workers):
|
||||
if evals:
|
||||
# pylint: disable=protected-access
|
||||
evals_per_worker = [(e._create_fn_args(worker_addr), name, id(e))
|
||||
for e, name in evals]
|
||||
else:
|
||||
evals_per_worker = []
|
||||
f = client.submit(
|
||||
dispatched_train,
|
||||
worker_addr,
|
||||
_rabit_args,
|
||||
# pylint: disable=protected-access
|
||||
dtrain._create_fn_args(workers[i]),
|
||||
id(dtrain),
|
||||
evals_per_worker,
|
||||
pure=False,
|
||||
workers=[worker_addr]
|
||||
)
|
||||
futures.append(f)
|
||||
async with _multi_lock()(workers, client):
|
||||
futures = []
|
||||
for worker_addr in workers:
|
||||
if evals:
|
||||
# pylint: disable=protected-access
|
||||
evals_per_worker = [
|
||||
(e._create_fn_args(worker_addr), name, id(e)) for e, name in evals
|
||||
]
|
||||
else:
|
||||
evals_per_worker = []
|
||||
f = client.submit(
|
||||
dispatched_train,
|
||||
worker_addr,
|
||||
_rabit_args,
|
||||
# pylint: disable=protected-access
|
||||
dtrain._create_fn_args(worker_addr),
|
||||
id(dtrain),
|
||||
evals_per_worker,
|
||||
pure=False,
|
||||
workers=[worker_addr],
|
||||
)
|
||||
futures.append(f)
|
||||
|
||||
results = await client.gather(futures)
|
||||
return list(filter(lambda ret: ret is not None, results))[0]
|
||||
results = await client.gather(futures, asynchronous=True)
|
||||
|
||||
return list(filter(lambda ret: ret is not None, results))[0]
|
||||
|
||||
|
||||
def train( # pylint: disable=unused-argument
|
||||
@@ -927,9 +959,8 @@ def train( # pylint: disable=unused-argument
|
||||
"""
|
||||
_assert_dask_support()
|
||||
client = _xgb_get_client(client)
|
||||
# Get global configuration before transferring computation to another thread or
|
||||
# process.
|
||||
return client.sync(_train_async, global_config=config.get_config(), **locals())
|
||||
args = locals()
|
||||
return client.sync(_train_async, global_config=config.get_config(), **args)
|
||||
|
||||
|
||||
def _can_output_df(is_df: bool, output_shape: Tuple) -> bool:
|
||||
@@ -1366,6 +1397,9 @@ def inplace_predict( # pylint: disable=unused-argument
|
||||
"""
|
||||
_assert_dask_support()
|
||||
client = _xgb_get_client(client)
|
||||
# When used in asynchronous environment, the `client` object should have
|
||||
# `asynchronous` attribute as True. When invoked by the skl interface, it's
|
||||
# responsible for setting up the client.
|
||||
return client.sync(
|
||||
_inplace_predict_async, global_config=config.get_config(), **locals()
|
||||
)
|
||||
@@ -1393,6 +1427,18 @@ async def _async_wrap_evaluation_matrices(
|
||||
return train_dmatrix, awaited
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _set_worker_client(
|
||||
model: "DaskScikitLearnBase", client: "distributed.Client"
|
||||
) -> Generator:
|
||||
"""Temporarily set the client for sklearn model."""
|
||||
try:
|
||||
model.client = client
|
||||
yield model
|
||||
finally:
|
||||
model.client = None
|
||||
|
||||
|
||||
class DaskScikitLearnBase(XGBModel):
|
||||
"""Base class for implementing scikit-learn interface with Dask"""
|
||||
|
||||
@@ -1487,7 +1533,7 @@ class DaskScikitLearnBase(XGBModel):
|
||||
async def _() -> Awaitable[Any]:
|
||||
return self
|
||||
|
||||
return self.client.sync(_).__await__()
|
||||
return self._client_sync(_).__await__()
|
||||
|
||||
def __getstate__(self) -> Dict:
|
||||
this = self.__dict__.copy()
|
||||
@@ -1497,14 +1543,43 @@ class DaskScikitLearnBase(XGBModel):
|
||||
|
||||
@property
|
||||
def client(self) -> "distributed.Client":
|
||||
"""The dask client used in this model."""
|
||||
"""The dask client used in this model. The `Client` object can not be serialized for
|
||||
transmission, so if task is launched from a worker instead of directly from the
|
||||
client process, this attribute needs to be set at that worker.
|
||||
|
||||
"""
|
||||
|
||||
client = _xgb_get_client(self._client)
|
||||
return client
|
||||
|
||||
@client.setter
|
||||
def client(self, clt: "distributed.Client") -> None:
|
||||
# calling `worker_client' doesn't return the correct `asynchronous` attribute, so
|
||||
# we have to pass it ourselves.
|
||||
self._asynchronous = clt.asynchronous if clt is not None else False
|
||||
self._client = clt
|
||||
|
||||
def _client_sync(self, func: Callable, **kwargs: Any) -> Any:
|
||||
"""Get the correct client, when method is invoked inside a worker we
|
||||
should use `worker_client' instead of default client.
|
||||
|
||||
"""
|
||||
asynchronous = getattr(self, "_asynchronous", False)
|
||||
if self._client is None:
|
||||
try:
|
||||
distributed.get_worker()
|
||||
in_worker = True
|
||||
except ValueError:
|
||||
in_worker = False
|
||||
if in_worker:
|
||||
with distributed.worker_client() as client:
|
||||
with _set_worker_client(self, client) as this:
|
||||
ret = this.client.sync(func, **kwargs, asynchronous=asynchronous)
|
||||
return ret
|
||||
return ret
|
||||
|
||||
return self.client.sync(func, **kwargs, asynchronous=asynchronous)
|
||||
|
||||
|
||||
@xgboost_model_doc(
|
||||
"""Implementation of the Scikit-Learn API for XGBoost.""", ["estimators", "model"]
|
||||
@@ -1552,22 +1627,24 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
|
||||
model, metric, params = self._configure_fit(
|
||||
booster=xgb_model, eval_metric=eval_metric, params=params
|
||||
)
|
||||
results = await train(
|
||||
results = await self.client.sync(
|
||||
_train_async,
|
||||
asynchronous=True,
|
||||
client=self.client,
|
||||
global_config=config.get_config(),
|
||||
params=params,
|
||||
dtrain=dtrain,
|
||||
num_boost_round=self.get_num_boosting_rounds(),
|
||||
evals=evals,
|
||||
feval=metric,
|
||||
obj=obj,
|
||||
feval=metric,
|
||||
verbose_eval=verbose,
|
||||
early_stopping_rounds=early_stopping_rounds,
|
||||
callbacks=callbacks,
|
||||
xgb_model=model,
|
||||
)
|
||||
self._Booster = results["booster"]
|
||||
# pylint: disable=attribute-defined-outside-init
|
||||
self.evals_result_ = results["history"]
|
||||
self._set_evaluation_result(results["history"])
|
||||
return self
|
||||
|
||||
# pylint: disable=missing-docstring, disable=unused-argument
|
||||
@@ -1591,7 +1668,7 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
|
||||
) -> "DaskXGBRegressor":
|
||||
_assert_dask_support()
|
||||
args = {k: v for k, v in locals().items() if k != "self"}
|
||||
return self.client.sync(self._fit_async, **args)
|
||||
return self._client_sync(self._fit_async, **args)
|
||||
|
||||
|
||||
@xgboost_model_doc(
|
||||
@@ -1651,8 +1728,11 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
||||
model, metric, params = self._configure_fit(
|
||||
booster=xgb_model, eval_metric=eval_metric, params=params
|
||||
)
|
||||
results = await train(
|
||||
results = await self.client.sync(
|
||||
_train_async,
|
||||
asynchronous=True,
|
||||
client=self.client,
|
||||
global_config=config.get_config(),
|
||||
params=params,
|
||||
dtrain=dtrain,
|
||||
num_boost_round=self.get_num_boosting_rounds(),
|
||||
@@ -1665,16 +1745,12 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
||||
xgb_model=model,
|
||||
)
|
||||
self._Booster = results['booster']
|
||||
|
||||
if not callable(self.objective):
|
||||
self.objective = params["objective"]
|
||||
|
||||
# pylint: disable=attribute-defined-outside-init
|
||||
self.evals_result_ = results['history']
|
||||
self._set_evaluation_result(results["history"])
|
||||
return self
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
@_deprecate_positional_args
|
||||
def fit(
|
||||
self,
|
||||
X: _DaskCollection,
|
||||
@@ -1694,7 +1770,7 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
||||
) -> "DaskXGBClassifier":
|
||||
_assert_dask_support()
|
||||
args = {k: v for k, v in locals().items() if k != 'self'}
|
||||
return self.client.sync(self._fit_async, **args)
|
||||
return self._client_sync(self._fit_async, **args)
|
||||
|
||||
async def _predict_proba_async(
|
||||
self,
|
||||
@@ -1728,7 +1804,7 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
||||
_assert_dask_support()
|
||||
msg = "`ntree_limit` is not supported on dask, use `iteration_range` instead."
|
||||
assert ntree_limit is None, msg
|
||||
return self.client.sync(
|
||||
return self._client_sync(
|
||||
self._predict_proba_async,
|
||||
X=X,
|
||||
validate_features=validate_features,
|
||||
@@ -1838,12 +1914,16 @@ class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn):
|
||||
model, metric, params = self._configure_fit(
|
||||
booster=xgb_model, eval_metric=eval_metric, params=params
|
||||
)
|
||||
results = await train(
|
||||
results = await self.client.sync(
|
||||
_train_async,
|
||||
asynchronous=True,
|
||||
client=self.client,
|
||||
global_config=config.get_config(),
|
||||
params=params,
|
||||
dtrain=dtrain,
|
||||
num_boost_round=self.get_num_boosting_rounds(),
|
||||
evals=evals,
|
||||
obj=None,
|
||||
feval=metric,
|
||||
verbose_eval=verbose,
|
||||
early_stopping_rounds=early_stopping_rounds,
|
||||
@@ -1879,7 +1959,7 @@ class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn):
|
||||
) -> "DaskXGBRanker":
|
||||
_assert_dask_support()
|
||||
args = {k: v for k, v in locals().items() if k != "self"}
|
||||
return self.client.sync(self._fit_async, **args)
|
||||
return self._client_sync(self._fit_async, **args)
|
||||
|
||||
# FIXME(trivialfis): arguments differ due to additional parameters like group and qid.
|
||||
fit.__doc__ = XGBRanker.fit.__doc__
|
||||
|
||||
Reference in New Issue
Block a user