[dask] Use distributed.MultiLock (#6743)

* [dask] Use `distributed.MultiLock`

This enables training multiple models in parallel.

* Conditionally import `MultiLock`.
* Use async train directly in scikit learn interface.
* Use `worker_client` when available.
This commit is contained in:
Jiaming Yuan 2021-03-16 14:19:41 +08:00 committed by GitHub
parent 19a2c54265
commit 325bc93e16
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 212 additions and 55 deletions

View File

@ -17,6 +17,7 @@ https://github.com/dask/dask-xgboost
"""
import platform
import logging
from contextlib import contextmanager
from collections import defaultdict
from collections.abc import Sequence
from threading import Thread
@ -93,6 +94,34 @@ except ImportError:
LOGGER = logging.getLogger('[xgboost.dask]')
def _multi_lock() -> Any:
"""MultiLock is only available on latest distributed. See:
https://github.com/dask/distributed/pull/4503
"""
try:
from distributed import MultiLock
except ImportError:
class MultiLock: # type:ignore
def __init__(self, *args: Any, **kwargs: Any) -> None:
pass
def __enter__(self) -> "MultiLock":
return self
def __exit__(self, *args: Any, **kwargs: Any) -> None:
return
async def __aenter__(self) -> "MultiLock":
return self
async def __aexit__(self, *args: Any, **kwargs: Any) -> None:
return
return MultiLock
def _start_tracker(n_workers: int) -> Dict[str, Any]:
"""Start Rabit tracker """
env = {'DMLC_NUM_WORKER': n_workers}
@ -770,7 +799,7 @@ async def _get_rabit_args(n_workers: int, client: "distributed.Client") -> List[
def _get_workers_from_data(
dtrain: DaskDMatrix,
evals: Optional[List[Tuple[DaskDMatrix, str]]]
) -> Set[str]:
) -> List[str]:
X_worker_map: Set[str] = set(dtrain.worker_map.keys())
if evals:
for e in evals:
@ -780,7 +809,7 @@ def _get_workers_from_data(
continue
worker_map = set(e[0].worker_map.keys())
X_worker_map = X_worker_map.union(worker_map)
return X_worker_map
return list(X_worker_map)
async def _train_async(
@ -795,9 +824,9 @@ async def _train_async(
early_stopping_rounds: Optional[int],
verbose_eval: Union[int, bool],
xgb_model: Optional[Booster],
callbacks: Optional[List[TrainingCallback]]
callbacks: Optional[List[TrainingCallback]],
) -> Optional[TrainReturnT]:
workers = list(_get_workers_from_data(dtrain, evals))
workers = _get_workers_from_data(dtrain, evals)
_rabit_args = await _get_rabit_args(len(workers), client)
if params.get("booster", None) == "gblinear":
@ -858,29 +887,32 @@ async def _train_async(
# XGBoost is deterministic in most of the cases, which means train function is
# supposed to be idempotent. One known exception is gblinear with shotgun updater.
# We haven't been able to do a full verification so here we keep pure to be False.
futures = []
for i, worker_addr in enumerate(workers):
if evals:
# pylint: disable=protected-access
evals_per_worker = [(e._create_fn_args(worker_addr), name, id(e))
for e, name in evals]
else:
evals_per_worker = []
f = client.submit(
dispatched_train,
worker_addr,
_rabit_args,
# pylint: disable=protected-access
dtrain._create_fn_args(workers[i]),
id(dtrain),
evals_per_worker,
pure=False,
workers=[worker_addr]
)
futures.append(f)
async with _multi_lock()(workers, client):
futures = []
for worker_addr in workers:
if evals:
# pylint: disable=protected-access
evals_per_worker = [
(e._create_fn_args(worker_addr), name, id(e)) for e, name in evals
]
else:
evals_per_worker = []
f = client.submit(
dispatched_train,
worker_addr,
_rabit_args,
# pylint: disable=protected-access
dtrain._create_fn_args(worker_addr),
id(dtrain),
evals_per_worker,
pure=False,
workers=[worker_addr],
)
futures.append(f)
results = await client.gather(futures)
return list(filter(lambda ret: ret is not None, results))[0]
results = await client.gather(futures, asynchronous=True)
return list(filter(lambda ret: ret is not None, results))[0]
def train( # pylint: disable=unused-argument
@ -927,9 +959,8 @@ def train( # pylint: disable=unused-argument
"""
_assert_dask_support()
client = _xgb_get_client(client)
# Get global configuration before transferring computation to another thread or
# process.
return client.sync(_train_async, global_config=config.get_config(), **locals())
args = locals()
return client.sync(_train_async, global_config=config.get_config(), **args)
def _can_output_df(is_df: bool, output_shape: Tuple) -> bool:
@ -1366,6 +1397,9 @@ def inplace_predict( # pylint: disable=unused-argument
"""
_assert_dask_support()
client = _xgb_get_client(client)
# When used in asynchronous environment, the `client` object should have
# `asynchronous` attribute as True. When invoked by the skl interface, it's
# responsible for setting up the client.
return client.sync(
_inplace_predict_async, global_config=config.get_config(), **locals()
)
@ -1393,6 +1427,18 @@ async def _async_wrap_evaluation_matrices(
return train_dmatrix, awaited
@contextmanager
def _set_worker_client(
model: "DaskScikitLearnBase", client: "distributed.Client"
) -> Generator:
"""Temporarily set the client for sklearn model."""
try:
model.client = client
yield model
finally:
model.client = None
class DaskScikitLearnBase(XGBModel):
"""Base class for implementing scikit-learn interface with Dask"""
@ -1487,7 +1533,7 @@ class DaskScikitLearnBase(XGBModel):
async def _() -> Awaitable[Any]:
return self
return self.client.sync(_).__await__()
return self._client_sync(_).__await__()
def __getstate__(self) -> Dict:
this = self.__dict__.copy()
@ -1497,14 +1543,43 @@ class DaskScikitLearnBase(XGBModel):
@property
def client(self) -> "distributed.Client":
"""The dask client used in this model."""
"""The dask client used in this model. The `Client` object can not be serialized for
transmission, so if task is launched from a worker instead of directly from the
client process, this attribute needs to be set at that worker.
"""
client = _xgb_get_client(self._client)
return client
@client.setter
def client(self, clt: "distributed.Client") -> None:
# calling `worker_client' doesn't return the correct `asynchronous` attribute, so
# we have to pass it ourselves.
self._asynchronous = clt.asynchronous if clt is not None else False
self._client = clt
def _client_sync(self, func: Callable, **kwargs: Any) -> Any:
"""Get the correct client, when method is invoked inside a worker we
should use `worker_client' instead of default client.
"""
asynchronous = getattr(self, "_asynchronous", False)
if self._client is None:
try:
distributed.get_worker()
in_worker = True
except ValueError:
in_worker = False
if in_worker:
with distributed.worker_client() as client:
with _set_worker_client(self, client) as this:
ret = this.client.sync(func, **kwargs, asynchronous=asynchronous)
return ret
return ret
return self.client.sync(func, **kwargs, asynchronous=asynchronous)
@xgboost_model_doc(
"""Implementation of the Scikit-Learn API for XGBoost.""", ["estimators", "model"]
@ -1552,22 +1627,24 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
model, metric, params = self._configure_fit(
booster=xgb_model, eval_metric=eval_metric, params=params
)
results = await train(
results = await self.client.sync(
_train_async,
asynchronous=True,
client=self.client,
global_config=config.get_config(),
params=params,
dtrain=dtrain,
num_boost_round=self.get_num_boosting_rounds(),
evals=evals,
feval=metric,
obj=obj,
feval=metric,
verbose_eval=verbose,
early_stopping_rounds=early_stopping_rounds,
callbacks=callbacks,
xgb_model=model,
)
self._Booster = results["booster"]
# pylint: disable=attribute-defined-outside-init
self.evals_result_ = results["history"]
self._set_evaluation_result(results["history"])
return self
# pylint: disable=missing-docstring, disable=unused-argument
@ -1591,7 +1668,7 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
) -> "DaskXGBRegressor":
_assert_dask_support()
args = {k: v for k, v in locals().items() if k != "self"}
return self.client.sync(self._fit_async, **args)
return self._client_sync(self._fit_async, **args)
@xgboost_model_doc(
@ -1651,8 +1728,11 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
model, metric, params = self._configure_fit(
booster=xgb_model, eval_metric=eval_metric, params=params
)
results = await train(
results = await self.client.sync(
_train_async,
asynchronous=True,
client=self.client,
global_config=config.get_config(),
params=params,
dtrain=dtrain,
num_boost_round=self.get_num_boosting_rounds(),
@ -1665,16 +1745,12 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
xgb_model=model,
)
self._Booster = results['booster']
if not callable(self.objective):
self.objective = params["objective"]
# pylint: disable=attribute-defined-outside-init
self.evals_result_ = results['history']
self._set_evaluation_result(results["history"])
return self
# pylint: disable=unused-argument
@_deprecate_positional_args
def fit(
self,
X: _DaskCollection,
@ -1694,7 +1770,7 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
) -> "DaskXGBClassifier":
_assert_dask_support()
args = {k: v for k, v in locals().items() if k != 'self'}
return self.client.sync(self._fit_async, **args)
return self._client_sync(self._fit_async, **args)
async def _predict_proba_async(
self,
@ -1728,7 +1804,7 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
_assert_dask_support()
msg = "`ntree_limit` is not supported on dask, use `iteration_range` instead."
assert ntree_limit is None, msg
return self.client.sync(
return self._client_sync(
self._predict_proba_async,
X=X,
validate_features=validate_features,
@ -1838,12 +1914,16 @@ class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn):
model, metric, params = self._configure_fit(
booster=xgb_model, eval_metric=eval_metric, params=params
)
results = await train(
results = await self.client.sync(
_train_async,
asynchronous=True,
client=self.client,
global_config=config.get_config(),
params=params,
dtrain=dtrain,
num_boost_round=self.get_num_boosting_rounds(),
evals=evals,
obj=None,
feval=metric,
verbose_eval=verbose,
early_stopping_rounds=early_stopping_rounds,
@ -1879,7 +1959,7 @@ class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn):
) -> "DaskXGBRanker":
_assert_dask_support()
args = {k: v for k, v in locals().items() if k != "self"}
return self.client.sync(self._fit_async, **args)
return self._client_sync(self._fit_async, **args)
# FIXME(trivialfis): arguments differ due to additional parameters like group and qid.
fit.__doc__ = XGBRanker.fit.__doc__

View File

@ -293,7 +293,7 @@ class TestDistributedGPU:
fw = fw - fw.min()
m = dxgb.DaskDMatrix(client, X, y, feature_weights=fw)
workers = list(_get_client_workers(client).keys())
workers = _get_client_workers(client)
rabit_args = client.sync(dxgb._get_rabit_args, len(workers), client)
def worker_fn(worker_addr: str, data_ref: Dict) -> None:
@ -384,7 +384,7 @@ class TestDistributedGPU:
return subprocess.run([str(exe), test], env=env, stdout=subprocess.PIPE)
with Client(local_cuda_cluster) as client:
workers = list(_get_client_workers(client).keys())
workers = _get_client_workers(client)
rabit_args = client.sync(dxgb._get_rabit_args, workers, client)
futures = client.map(runit,
workers,

View File

@ -27,7 +27,7 @@ def run_rabit_ops(client, n_workers):
from xgboost.dask import RabitContext, _get_rabit_args
from xgboost import rabit
workers = list(_get_client_workers(client).keys())
workers = _get_client_workers(client)
rabit_args = client.sync(_get_rabit_args, len(workers), client)
assert not rabit.is_distributed()
n_workers_from_dask = len(workers)

View File

@ -9,6 +9,7 @@ import scipy
import json
from typing import List, Tuple, Dict, Optional, Type, Any
import asyncio
from concurrent.futures import ThreadPoolExecutor
import tempfile
from sklearn.datasets import make_classification
import sklearn
@ -43,9 +44,9 @@ kCols = 10
kWorkers = 5
def _get_client_workers(client: "Client") -> Dict[str, Dict]:
def _get_client_workers(client: "Client") -> List[str]:
workers = client.scheduler_info()['workers']
return workers
return list(workers.keys())
def generate_array(
@ -646,7 +647,7 @@ def test_with_asyncio() -> None:
async def generate_concurrent_trainings() -> None:
async def train():
async def train() -> None:
async with LocalCluster(n_workers=2,
threads_per_worker=1,
asynchronous=True,
@ -967,7 +968,7 @@ class TestWithDask:
with LocalCluster(n_workers=4) as cluster:
with Client(cluster) as client:
workers = list(_get_client_workers(client).keys())
workers = _get_client_workers(client)
rabit_args = client.sync(
xgb.dask._get_rabit_args, len(workers), client)
futures = client.map(runit,
@ -1000,7 +1001,7 @@ class TestWithDask:
def test_n_workers(self) -> None:
with LocalCluster(n_workers=2) as cluster:
with Client(cluster) as client:
workers = list(_get_client_workers(client).keys())
workers = _get_client_workers(client)
from sklearn.datasets import load_breast_cancer
X, y = load_breast_cancer(return_X_y=True)
dX = client.submit(da.from_array, X, workers=[workers[0]]).result()
@ -1090,7 +1091,7 @@ class TestWithDask:
X, y, _ = generate_array()
n_partitions = X.npartitions
m = xgb.dask.DaskDMatrix(client, X, y)
workers = list(_get_client_workers(client).keys())
workers = _get_client_workers(client)
rabit_args = client.sync(xgb.dask._get_rabit_args, len(workers), client)
n_workers = len(workers)
@ -1285,6 +1286,82 @@ def test_dask_unsupported_features(client: "Client") -> None:
)
def test_parallel_submits(client: "Client") -> None:
"""Test for running multiple train simultaneously from single clients."""
try:
from distributed import MultiLock # NOQA
except ImportError:
pytest.skip("`distributed.MultiLock' is not available")
from sklearn.datasets import load_digits
futures = []
workers = _get_client_workers(client)
n_submits = len(workers)
for i in range(n_submits):
X_, y_ = load_digits(return_X_y=True)
X = dd.from_array(X_, chunksize=32)
y = dd.from_array(y_, chunksize=32)
cls = xgb.dask.DaskXGBClassifier(
verbosity=1,
n_estimators=i + 1,
eval_metric="merror",
use_label_encoder=False,
)
f = client.submit(cls.fit, X, y, pure=False)
futures.append(f)
classifiers = client.gather(futures)
assert len(classifiers) == n_submits
for i, cls in enumerate(classifiers):
assert cls.get_booster().num_boosted_rounds() == i + 1
def test_parallel_submit_multi_clients() -> None:
"""Test for running multiple train simultaneously from multiple clients."""
try:
from distributed import MultiLock # NOQA
except ImportError:
pytest.skip("`distributed.MultiLock' is not available")
from sklearn.datasets import load_digits
with LocalCluster(n_workers=4) as cluster:
with Client(cluster) as client:
workers = _get_client_workers(client)
n_submits = len(workers)
assert n_submits == 4
futures = []
for i in range(n_submits):
client = Client(cluster)
X_, y_ = load_digits(return_X_y=True)
X_ += 1.0
X = dd.from_array(X_, chunksize=32)
y = dd.from_array(y_, chunksize=32)
cls = xgb.dask.DaskXGBClassifier(
verbosity=1,
n_estimators=i + 1,
eval_metric="merror",
use_label_encoder=False,
)
f = client.submit(cls.fit, X, y, pure=False)
futures.append((client, f))
t_futures = []
with ThreadPoolExecutor(max_workers=16) as e:
for i in range(n_submits):
def _() -> xgb.dask.DaskXGBClassifier:
return futures[i][0].compute(futures[i][1]).result()
f = e.submit(_)
t_futures.append(f)
for i, f in enumerate(t_futures):
assert f.result().get_booster().num_boosted_rounds() == i + 1
class TestDaskCallbacks:
@pytest.mark.skipif(**tm.no_sklearn())
def test_early_stopping(self, client: "Client") -> None: