[dask] Use distributed.MultiLock (#6743)

* [dask] Use `distributed.MultiLock`

This enables training multiple models in parallel.

* Conditionally import `MultiLock`.
* Use async train directly in scikit learn interface.
* Use `worker_client` when available.
This commit is contained in:
Jiaming Yuan
2021-03-16 14:19:41 +08:00
committed by GitHub
parent 19a2c54265
commit 325bc93e16
4 changed files with 212 additions and 55 deletions

View File

@@ -17,6 +17,7 @@ https://github.com/dask/dask-xgboost
"""
import platform
import logging
from contextlib import contextmanager
from collections import defaultdict
from collections.abc import Sequence
from threading import Thread
@@ -93,6 +94,34 @@ except ImportError:
LOGGER = logging.getLogger('[xgboost.dask]')
def _multi_lock() -> Any:
"""MultiLock is only available on latest distributed. See:
https://github.com/dask/distributed/pull/4503
"""
try:
from distributed import MultiLock
except ImportError:
class MultiLock: # type:ignore
def __init__(self, *args: Any, **kwargs: Any) -> None:
pass
def __enter__(self) -> "MultiLock":
return self
def __exit__(self, *args: Any, **kwargs: Any) -> None:
return
async def __aenter__(self) -> "MultiLock":
return self
async def __aexit__(self, *args: Any, **kwargs: Any) -> None:
return
return MultiLock
def _start_tracker(n_workers: int) -> Dict[str, Any]:
"""Start Rabit tracker """
env = {'DMLC_NUM_WORKER': n_workers}
@@ -770,7 +799,7 @@ async def _get_rabit_args(n_workers: int, client: "distributed.Client") -> List[
def _get_workers_from_data(
dtrain: DaskDMatrix,
evals: Optional[List[Tuple[DaskDMatrix, str]]]
) -> Set[str]:
) -> List[str]:
X_worker_map: Set[str] = set(dtrain.worker_map.keys())
if evals:
for e in evals:
@@ -780,7 +809,7 @@ def _get_workers_from_data(
continue
worker_map = set(e[0].worker_map.keys())
X_worker_map = X_worker_map.union(worker_map)
return X_worker_map
return list(X_worker_map)
async def _train_async(
@@ -795,9 +824,9 @@ async def _train_async(
early_stopping_rounds: Optional[int],
verbose_eval: Union[int, bool],
xgb_model: Optional[Booster],
callbacks: Optional[List[TrainingCallback]]
callbacks: Optional[List[TrainingCallback]],
) -> Optional[TrainReturnT]:
workers = list(_get_workers_from_data(dtrain, evals))
workers = _get_workers_from_data(dtrain, evals)
_rabit_args = await _get_rabit_args(len(workers), client)
if params.get("booster", None) == "gblinear":
@@ -858,29 +887,32 @@ async def _train_async(
# XGBoost is deterministic in most of the cases, which means train function is
# supposed to be idempotent. One known exception is gblinear with shotgun updater.
# We haven't been able to do a full verification so here we keep pure to be False.
futures = []
for i, worker_addr in enumerate(workers):
if evals:
# pylint: disable=protected-access
evals_per_worker = [(e._create_fn_args(worker_addr), name, id(e))
for e, name in evals]
else:
evals_per_worker = []
f = client.submit(
dispatched_train,
worker_addr,
_rabit_args,
# pylint: disable=protected-access
dtrain._create_fn_args(workers[i]),
id(dtrain),
evals_per_worker,
pure=False,
workers=[worker_addr]
)
futures.append(f)
async with _multi_lock()(workers, client):
futures = []
for worker_addr in workers:
if evals:
# pylint: disable=protected-access
evals_per_worker = [
(e._create_fn_args(worker_addr), name, id(e)) for e, name in evals
]
else:
evals_per_worker = []
f = client.submit(
dispatched_train,
worker_addr,
_rabit_args,
# pylint: disable=protected-access
dtrain._create_fn_args(worker_addr),
id(dtrain),
evals_per_worker,
pure=False,
workers=[worker_addr],
)
futures.append(f)
results = await client.gather(futures)
return list(filter(lambda ret: ret is not None, results))[0]
results = await client.gather(futures, asynchronous=True)
return list(filter(lambda ret: ret is not None, results))[0]
def train( # pylint: disable=unused-argument
@@ -927,9 +959,8 @@ def train( # pylint: disable=unused-argument
"""
_assert_dask_support()
client = _xgb_get_client(client)
# Get global configuration before transferring computation to another thread or
# process.
return client.sync(_train_async, global_config=config.get_config(), **locals())
args = locals()
return client.sync(_train_async, global_config=config.get_config(), **args)
def _can_output_df(is_df: bool, output_shape: Tuple) -> bool:
@@ -1366,6 +1397,9 @@ def inplace_predict( # pylint: disable=unused-argument
"""
_assert_dask_support()
client = _xgb_get_client(client)
# When used in asynchronous environment, the `client` object should have
# `asynchronous` attribute as True. When invoked by the skl interface, it's
# responsible for setting up the client.
return client.sync(
_inplace_predict_async, global_config=config.get_config(), **locals()
)
@@ -1393,6 +1427,18 @@ async def _async_wrap_evaluation_matrices(
return train_dmatrix, awaited
@contextmanager
def _set_worker_client(
model: "DaskScikitLearnBase", client: "distributed.Client"
) -> Generator:
"""Temporarily set the client for sklearn model."""
try:
model.client = client
yield model
finally:
model.client = None
class DaskScikitLearnBase(XGBModel):
"""Base class for implementing scikit-learn interface with Dask"""
@@ -1487,7 +1533,7 @@ class DaskScikitLearnBase(XGBModel):
async def _() -> Awaitable[Any]:
return self
return self.client.sync(_).__await__()
return self._client_sync(_).__await__()
def __getstate__(self) -> Dict:
this = self.__dict__.copy()
@@ -1497,14 +1543,43 @@ class DaskScikitLearnBase(XGBModel):
@property
def client(self) -> "distributed.Client":
"""The dask client used in this model."""
"""The dask client used in this model. The `Client` object can not be serialized for
transmission, so if task is launched from a worker instead of directly from the
client process, this attribute needs to be set at that worker.
"""
client = _xgb_get_client(self._client)
return client
@client.setter
def client(self, clt: "distributed.Client") -> None:
# calling `worker_client' doesn't return the correct `asynchronous` attribute, so
# we have to pass it ourselves.
self._asynchronous = clt.asynchronous if clt is not None else False
self._client = clt
def _client_sync(self, func: Callable, **kwargs: Any) -> Any:
"""Get the correct client, when method is invoked inside a worker we
should use `worker_client' instead of default client.
"""
asynchronous = getattr(self, "_asynchronous", False)
if self._client is None:
try:
distributed.get_worker()
in_worker = True
except ValueError:
in_worker = False
if in_worker:
with distributed.worker_client() as client:
with _set_worker_client(self, client) as this:
ret = this.client.sync(func, **kwargs, asynchronous=asynchronous)
return ret
return ret
return self.client.sync(func, **kwargs, asynchronous=asynchronous)
@xgboost_model_doc(
"""Implementation of the Scikit-Learn API for XGBoost.""", ["estimators", "model"]
@@ -1552,22 +1627,24 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
model, metric, params = self._configure_fit(
booster=xgb_model, eval_metric=eval_metric, params=params
)
results = await train(
results = await self.client.sync(
_train_async,
asynchronous=True,
client=self.client,
global_config=config.get_config(),
params=params,
dtrain=dtrain,
num_boost_round=self.get_num_boosting_rounds(),
evals=evals,
feval=metric,
obj=obj,
feval=metric,
verbose_eval=verbose,
early_stopping_rounds=early_stopping_rounds,
callbacks=callbacks,
xgb_model=model,
)
self._Booster = results["booster"]
# pylint: disable=attribute-defined-outside-init
self.evals_result_ = results["history"]
self._set_evaluation_result(results["history"])
return self
# pylint: disable=missing-docstring, disable=unused-argument
@@ -1591,7 +1668,7 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
) -> "DaskXGBRegressor":
_assert_dask_support()
args = {k: v for k, v in locals().items() if k != "self"}
return self.client.sync(self._fit_async, **args)
return self._client_sync(self._fit_async, **args)
@xgboost_model_doc(
@@ -1651,8 +1728,11 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
model, metric, params = self._configure_fit(
booster=xgb_model, eval_metric=eval_metric, params=params
)
results = await train(
results = await self.client.sync(
_train_async,
asynchronous=True,
client=self.client,
global_config=config.get_config(),
params=params,
dtrain=dtrain,
num_boost_round=self.get_num_boosting_rounds(),
@@ -1665,16 +1745,12 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
xgb_model=model,
)
self._Booster = results['booster']
if not callable(self.objective):
self.objective = params["objective"]
# pylint: disable=attribute-defined-outside-init
self.evals_result_ = results['history']
self._set_evaluation_result(results["history"])
return self
# pylint: disable=unused-argument
@_deprecate_positional_args
def fit(
self,
X: _DaskCollection,
@@ -1694,7 +1770,7 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
) -> "DaskXGBClassifier":
_assert_dask_support()
args = {k: v for k, v in locals().items() if k != 'self'}
return self.client.sync(self._fit_async, **args)
return self._client_sync(self._fit_async, **args)
async def _predict_proba_async(
self,
@@ -1728,7 +1804,7 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
_assert_dask_support()
msg = "`ntree_limit` is not supported on dask, use `iteration_range` instead."
assert ntree_limit is None, msg
return self.client.sync(
return self._client_sync(
self._predict_proba_async,
X=X,
validate_features=validate_features,
@@ -1838,12 +1914,16 @@ class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn):
model, metric, params = self._configure_fit(
booster=xgb_model, eval_metric=eval_metric, params=params
)
results = await train(
results = await self.client.sync(
_train_async,
asynchronous=True,
client=self.client,
global_config=config.get_config(),
params=params,
dtrain=dtrain,
num_boost_round=self.get_num_boosting_rounds(),
evals=evals,
obj=None,
feval=metric,
verbose_eval=verbose,
early_stopping_rounds=early_stopping_rounds,
@@ -1879,7 +1959,7 @@ class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn):
) -> "DaskXGBRanker":
_assert_dask_support()
args = {k: v for k, v in locals().items() if k != "self"}
return self.client.sync(self._fit_async, **args)
return self._client_sync(self._fit_async, **args)
# FIXME(trivialfis): arguments differ due to additional parameters like group and qid.
fit.__doc__ = XGBRanker.fit.__doc__