[dask] Use distributed.MultiLock (#6743)

* [dask] Use `distributed.MultiLock`

This enables training multiple models in parallel.

* Conditionally import `MultiLock`.
* Use async train directly in scikit learn interface.
* Use `worker_client` when available.
This commit is contained in:
Jiaming Yuan 2021-03-16 14:19:41 +08:00 committed by GitHub
parent 19a2c54265
commit 325bc93e16
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 212 additions and 55 deletions

View File

@ -17,6 +17,7 @@ https://github.com/dask/dask-xgboost
""" """
import platform import platform
import logging import logging
from contextlib import contextmanager
from collections import defaultdict from collections import defaultdict
from collections.abc import Sequence from collections.abc import Sequence
from threading import Thread from threading import Thread
@ -93,6 +94,34 @@ except ImportError:
LOGGER = logging.getLogger('[xgboost.dask]') LOGGER = logging.getLogger('[xgboost.dask]')
def _multi_lock() -> Any:
"""MultiLock is only available on latest distributed. See:
https://github.com/dask/distributed/pull/4503
"""
try:
from distributed import MultiLock
except ImportError:
class MultiLock: # type:ignore
def __init__(self, *args: Any, **kwargs: Any) -> None:
pass
def __enter__(self) -> "MultiLock":
return self
def __exit__(self, *args: Any, **kwargs: Any) -> None:
return
async def __aenter__(self) -> "MultiLock":
return self
async def __aexit__(self, *args: Any, **kwargs: Any) -> None:
return
return MultiLock
def _start_tracker(n_workers: int) -> Dict[str, Any]: def _start_tracker(n_workers: int) -> Dict[str, Any]:
"""Start Rabit tracker """ """Start Rabit tracker """
env = {'DMLC_NUM_WORKER': n_workers} env = {'DMLC_NUM_WORKER': n_workers}
@ -770,7 +799,7 @@ async def _get_rabit_args(n_workers: int, client: "distributed.Client") -> List[
def _get_workers_from_data( def _get_workers_from_data(
dtrain: DaskDMatrix, dtrain: DaskDMatrix,
evals: Optional[List[Tuple[DaskDMatrix, str]]] evals: Optional[List[Tuple[DaskDMatrix, str]]]
) -> Set[str]: ) -> List[str]:
X_worker_map: Set[str] = set(dtrain.worker_map.keys()) X_worker_map: Set[str] = set(dtrain.worker_map.keys())
if evals: if evals:
for e in evals: for e in evals:
@ -780,7 +809,7 @@ def _get_workers_from_data(
continue continue
worker_map = set(e[0].worker_map.keys()) worker_map = set(e[0].worker_map.keys())
X_worker_map = X_worker_map.union(worker_map) X_worker_map = X_worker_map.union(worker_map)
return X_worker_map return list(X_worker_map)
async def _train_async( async def _train_async(
@ -795,9 +824,9 @@ async def _train_async(
early_stopping_rounds: Optional[int], early_stopping_rounds: Optional[int],
verbose_eval: Union[int, bool], verbose_eval: Union[int, bool],
xgb_model: Optional[Booster], xgb_model: Optional[Booster],
callbacks: Optional[List[TrainingCallback]] callbacks: Optional[List[TrainingCallback]],
) -> Optional[TrainReturnT]: ) -> Optional[TrainReturnT]:
workers = list(_get_workers_from_data(dtrain, evals)) workers = _get_workers_from_data(dtrain, evals)
_rabit_args = await _get_rabit_args(len(workers), client) _rabit_args = await _get_rabit_args(len(workers), client)
if params.get("booster", None) == "gblinear": if params.get("booster", None) == "gblinear":
@ -858,29 +887,32 @@ async def _train_async(
# XGBoost is deterministic in most of the cases, which means train function is # XGBoost is deterministic in most of the cases, which means train function is
# supposed to be idempotent. One known exception is gblinear with shotgun updater. # supposed to be idempotent. One known exception is gblinear with shotgun updater.
# We haven't been able to do a full verification so here we keep pure to be False. # We haven't been able to do a full verification so here we keep pure to be False.
futures = [] async with _multi_lock()(workers, client):
for i, worker_addr in enumerate(workers): futures = []
if evals: for worker_addr in workers:
# pylint: disable=protected-access if evals:
evals_per_worker = [(e._create_fn_args(worker_addr), name, id(e)) # pylint: disable=protected-access
for e, name in evals] evals_per_worker = [
else: (e._create_fn_args(worker_addr), name, id(e)) for e, name in evals
evals_per_worker = [] ]
f = client.submit( else:
dispatched_train, evals_per_worker = []
worker_addr, f = client.submit(
_rabit_args, dispatched_train,
# pylint: disable=protected-access worker_addr,
dtrain._create_fn_args(workers[i]), _rabit_args,
id(dtrain), # pylint: disable=protected-access
evals_per_worker, dtrain._create_fn_args(worker_addr),
pure=False, id(dtrain),
workers=[worker_addr] evals_per_worker,
) pure=False,
futures.append(f) workers=[worker_addr],
)
futures.append(f)
results = await client.gather(futures) results = await client.gather(futures, asynchronous=True)
return list(filter(lambda ret: ret is not None, results))[0]
return list(filter(lambda ret: ret is not None, results))[0]
def train( # pylint: disable=unused-argument def train( # pylint: disable=unused-argument
@ -927,9 +959,8 @@ def train( # pylint: disable=unused-argument
""" """
_assert_dask_support() _assert_dask_support()
client = _xgb_get_client(client) client = _xgb_get_client(client)
# Get global configuration before transferring computation to another thread or args = locals()
# process. return client.sync(_train_async, global_config=config.get_config(), **args)
return client.sync(_train_async, global_config=config.get_config(), **locals())
def _can_output_df(is_df: bool, output_shape: Tuple) -> bool: def _can_output_df(is_df: bool, output_shape: Tuple) -> bool:
@ -1366,6 +1397,9 @@ def inplace_predict( # pylint: disable=unused-argument
""" """
_assert_dask_support() _assert_dask_support()
client = _xgb_get_client(client) client = _xgb_get_client(client)
# When used in asynchronous environment, the `client` object should have
# `asynchronous` attribute as True. When invoked by the skl interface, it's
# responsible for setting up the client.
return client.sync( return client.sync(
_inplace_predict_async, global_config=config.get_config(), **locals() _inplace_predict_async, global_config=config.get_config(), **locals()
) )
@ -1393,6 +1427,18 @@ async def _async_wrap_evaluation_matrices(
return train_dmatrix, awaited return train_dmatrix, awaited
@contextmanager
def _set_worker_client(
model: "DaskScikitLearnBase", client: "distributed.Client"
) -> Generator:
"""Temporarily set the client for sklearn model."""
try:
model.client = client
yield model
finally:
model.client = None
class DaskScikitLearnBase(XGBModel): class DaskScikitLearnBase(XGBModel):
"""Base class for implementing scikit-learn interface with Dask""" """Base class for implementing scikit-learn interface with Dask"""
@ -1487,7 +1533,7 @@ class DaskScikitLearnBase(XGBModel):
async def _() -> Awaitable[Any]: async def _() -> Awaitable[Any]:
return self return self
return self.client.sync(_).__await__() return self._client_sync(_).__await__()
def __getstate__(self) -> Dict: def __getstate__(self) -> Dict:
this = self.__dict__.copy() this = self.__dict__.copy()
@ -1497,14 +1543,43 @@ class DaskScikitLearnBase(XGBModel):
@property @property
def client(self) -> "distributed.Client": def client(self) -> "distributed.Client":
"""The dask client used in this model.""" """The dask client used in this model. The `Client` object can not be serialized for
transmission, so if task is launched from a worker instead of directly from the
client process, this attribute needs to be set at that worker.
"""
client = _xgb_get_client(self._client) client = _xgb_get_client(self._client)
return client return client
@client.setter @client.setter
def client(self, clt: "distributed.Client") -> None: def client(self, clt: "distributed.Client") -> None:
# calling `worker_client' doesn't return the correct `asynchronous` attribute, so
# we have to pass it ourselves.
self._asynchronous = clt.asynchronous if clt is not None else False
self._client = clt self._client = clt
def _client_sync(self, func: Callable, **kwargs: Any) -> Any:
"""Get the correct client, when method is invoked inside a worker we
should use `worker_client' instead of default client.
"""
asynchronous = getattr(self, "_asynchronous", False)
if self._client is None:
try:
distributed.get_worker()
in_worker = True
except ValueError:
in_worker = False
if in_worker:
with distributed.worker_client() as client:
with _set_worker_client(self, client) as this:
ret = this.client.sync(func, **kwargs, asynchronous=asynchronous)
return ret
return ret
return self.client.sync(func, **kwargs, asynchronous=asynchronous)
@xgboost_model_doc( @xgboost_model_doc(
"""Implementation of the Scikit-Learn API for XGBoost.""", ["estimators", "model"] """Implementation of the Scikit-Learn API for XGBoost.""", ["estimators", "model"]
@ -1552,22 +1627,24 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
model, metric, params = self._configure_fit( model, metric, params = self._configure_fit(
booster=xgb_model, eval_metric=eval_metric, params=params booster=xgb_model, eval_metric=eval_metric, params=params
) )
results = await train( results = await self.client.sync(
_train_async,
asynchronous=True,
client=self.client, client=self.client,
global_config=config.get_config(),
params=params, params=params,
dtrain=dtrain, dtrain=dtrain,
num_boost_round=self.get_num_boosting_rounds(), num_boost_round=self.get_num_boosting_rounds(),
evals=evals, evals=evals,
feval=metric,
obj=obj, obj=obj,
feval=metric,
verbose_eval=verbose, verbose_eval=verbose,
early_stopping_rounds=early_stopping_rounds, early_stopping_rounds=early_stopping_rounds,
callbacks=callbacks, callbacks=callbacks,
xgb_model=model, xgb_model=model,
) )
self._Booster = results["booster"] self._Booster = results["booster"]
# pylint: disable=attribute-defined-outside-init self._set_evaluation_result(results["history"])
self.evals_result_ = results["history"]
return self return self
# pylint: disable=missing-docstring, disable=unused-argument # pylint: disable=missing-docstring, disable=unused-argument
@ -1591,7 +1668,7 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
) -> "DaskXGBRegressor": ) -> "DaskXGBRegressor":
_assert_dask_support() _assert_dask_support()
args = {k: v for k, v in locals().items() if k != "self"} args = {k: v for k, v in locals().items() if k != "self"}
return self.client.sync(self._fit_async, **args) return self._client_sync(self._fit_async, **args)
@xgboost_model_doc( @xgboost_model_doc(
@ -1651,8 +1728,11 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
model, metric, params = self._configure_fit( model, metric, params = self._configure_fit(
booster=xgb_model, eval_metric=eval_metric, params=params booster=xgb_model, eval_metric=eval_metric, params=params
) )
results = await train( results = await self.client.sync(
_train_async,
asynchronous=True,
client=self.client, client=self.client,
global_config=config.get_config(),
params=params, params=params,
dtrain=dtrain, dtrain=dtrain,
num_boost_round=self.get_num_boosting_rounds(), num_boost_round=self.get_num_boosting_rounds(),
@ -1665,16 +1745,12 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
xgb_model=model, xgb_model=model,
) )
self._Booster = results['booster'] self._Booster = results['booster']
if not callable(self.objective): if not callable(self.objective):
self.objective = params["objective"] self.objective = params["objective"]
self._set_evaluation_result(results["history"])
# pylint: disable=attribute-defined-outside-init
self.evals_result_ = results['history']
return self return self
# pylint: disable=unused-argument # pylint: disable=unused-argument
@_deprecate_positional_args
def fit( def fit(
self, self,
X: _DaskCollection, X: _DaskCollection,
@ -1694,7 +1770,7 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
) -> "DaskXGBClassifier": ) -> "DaskXGBClassifier":
_assert_dask_support() _assert_dask_support()
args = {k: v for k, v in locals().items() if k != 'self'} args = {k: v for k, v in locals().items() if k != 'self'}
return self.client.sync(self._fit_async, **args) return self._client_sync(self._fit_async, **args)
async def _predict_proba_async( async def _predict_proba_async(
self, self,
@ -1728,7 +1804,7 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
_assert_dask_support() _assert_dask_support()
msg = "`ntree_limit` is not supported on dask, use `iteration_range` instead." msg = "`ntree_limit` is not supported on dask, use `iteration_range` instead."
assert ntree_limit is None, msg assert ntree_limit is None, msg
return self.client.sync( return self._client_sync(
self._predict_proba_async, self._predict_proba_async,
X=X, X=X,
validate_features=validate_features, validate_features=validate_features,
@ -1838,12 +1914,16 @@ class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn):
model, metric, params = self._configure_fit( model, metric, params = self._configure_fit(
booster=xgb_model, eval_metric=eval_metric, params=params booster=xgb_model, eval_metric=eval_metric, params=params
) )
results = await train( results = await self.client.sync(
_train_async,
asynchronous=True,
client=self.client, client=self.client,
global_config=config.get_config(),
params=params, params=params,
dtrain=dtrain, dtrain=dtrain,
num_boost_round=self.get_num_boosting_rounds(), num_boost_round=self.get_num_boosting_rounds(),
evals=evals, evals=evals,
obj=None,
feval=metric, feval=metric,
verbose_eval=verbose, verbose_eval=verbose,
early_stopping_rounds=early_stopping_rounds, early_stopping_rounds=early_stopping_rounds,
@ -1879,7 +1959,7 @@ class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn):
) -> "DaskXGBRanker": ) -> "DaskXGBRanker":
_assert_dask_support() _assert_dask_support()
args = {k: v for k, v in locals().items() if k != "self"} args = {k: v for k, v in locals().items() if k != "self"}
return self.client.sync(self._fit_async, **args) return self._client_sync(self._fit_async, **args)
# FIXME(trivialfis): arguments differ due to additional parameters like group and qid. # FIXME(trivialfis): arguments differ due to additional parameters like group and qid.
fit.__doc__ = XGBRanker.fit.__doc__ fit.__doc__ = XGBRanker.fit.__doc__

View File

@ -293,7 +293,7 @@ class TestDistributedGPU:
fw = fw - fw.min() fw = fw - fw.min()
m = dxgb.DaskDMatrix(client, X, y, feature_weights=fw) m = dxgb.DaskDMatrix(client, X, y, feature_weights=fw)
workers = list(_get_client_workers(client).keys()) workers = _get_client_workers(client)
rabit_args = client.sync(dxgb._get_rabit_args, len(workers), client) rabit_args = client.sync(dxgb._get_rabit_args, len(workers), client)
def worker_fn(worker_addr: str, data_ref: Dict) -> None: def worker_fn(worker_addr: str, data_ref: Dict) -> None:
@ -384,7 +384,7 @@ class TestDistributedGPU:
return subprocess.run([str(exe), test], env=env, stdout=subprocess.PIPE) return subprocess.run([str(exe), test], env=env, stdout=subprocess.PIPE)
with Client(local_cuda_cluster) as client: with Client(local_cuda_cluster) as client:
workers = list(_get_client_workers(client).keys()) workers = _get_client_workers(client)
rabit_args = client.sync(dxgb._get_rabit_args, workers, client) rabit_args = client.sync(dxgb._get_rabit_args, workers, client)
futures = client.map(runit, futures = client.map(runit,
workers, workers,

View File

@ -27,7 +27,7 @@ def run_rabit_ops(client, n_workers):
from xgboost.dask import RabitContext, _get_rabit_args from xgboost.dask import RabitContext, _get_rabit_args
from xgboost import rabit from xgboost import rabit
workers = list(_get_client_workers(client).keys()) workers = _get_client_workers(client)
rabit_args = client.sync(_get_rabit_args, len(workers), client) rabit_args = client.sync(_get_rabit_args, len(workers), client)
assert not rabit.is_distributed() assert not rabit.is_distributed()
n_workers_from_dask = len(workers) n_workers_from_dask = len(workers)

View File

@ -9,6 +9,7 @@ import scipy
import json import json
from typing import List, Tuple, Dict, Optional, Type, Any from typing import List, Tuple, Dict, Optional, Type, Any
import asyncio import asyncio
from concurrent.futures import ThreadPoolExecutor
import tempfile import tempfile
from sklearn.datasets import make_classification from sklearn.datasets import make_classification
import sklearn import sklearn
@ -43,9 +44,9 @@ kCols = 10
kWorkers = 5 kWorkers = 5
def _get_client_workers(client: "Client") -> Dict[str, Dict]: def _get_client_workers(client: "Client") -> List[str]:
workers = client.scheduler_info()['workers'] workers = client.scheduler_info()['workers']
return workers return list(workers.keys())
def generate_array( def generate_array(
@ -646,7 +647,7 @@ def test_with_asyncio() -> None:
async def generate_concurrent_trainings() -> None: async def generate_concurrent_trainings() -> None:
async def train(): async def train() -> None:
async with LocalCluster(n_workers=2, async with LocalCluster(n_workers=2,
threads_per_worker=1, threads_per_worker=1,
asynchronous=True, asynchronous=True,
@ -967,7 +968,7 @@ class TestWithDask:
with LocalCluster(n_workers=4) as cluster: with LocalCluster(n_workers=4) as cluster:
with Client(cluster) as client: with Client(cluster) as client:
workers = list(_get_client_workers(client).keys()) workers = _get_client_workers(client)
rabit_args = client.sync( rabit_args = client.sync(
xgb.dask._get_rabit_args, len(workers), client) xgb.dask._get_rabit_args, len(workers), client)
futures = client.map(runit, futures = client.map(runit,
@ -1000,7 +1001,7 @@ class TestWithDask:
def test_n_workers(self) -> None: def test_n_workers(self) -> None:
with LocalCluster(n_workers=2) as cluster: with LocalCluster(n_workers=2) as cluster:
with Client(cluster) as client: with Client(cluster) as client:
workers = list(_get_client_workers(client).keys()) workers = _get_client_workers(client)
from sklearn.datasets import load_breast_cancer from sklearn.datasets import load_breast_cancer
X, y = load_breast_cancer(return_X_y=True) X, y = load_breast_cancer(return_X_y=True)
dX = client.submit(da.from_array, X, workers=[workers[0]]).result() dX = client.submit(da.from_array, X, workers=[workers[0]]).result()
@ -1090,7 +1091,7 @@ class TestWithDask:
X, y, _ = generate_array() X, y, _ = generate_array()
n_partitions = X.npartitions n_partitions = X.npartitions
m = xgb.dask.DaskDMatrix(client, X, y) m = xgb.dask.DaskDMatrix(client, X, y)
workers = list(_get_client_workers(client).keys()) workers = _get_client_workers(client)
rabit_args = client.sync(xgb.dask._get_rabit_args, len(workers), client) rabit_args = client.sync(xgb.dask._get_rabit_args, len(workers), client)
n_workers = len(workers) n_workers = len(workers)
@ -1285,6 +1286,82 @@ def test_dask_unsupported_features(client: "Client") -> None:
) )
def test_parallel_submits(client: "Client") -> None:
"""Test for running multiple train simultaneously from single clients."""
try:
from distributed import MultiLock # NOQA
except ImportError:
pytest.skip("`distributed.MultiLock' is not available")
from sklearn.datasets import load_digits
futures = []
workers = _get_client_workers(client)
n_submits = len(workers)
for i in range(n_submits):
X_, y_ = load_digits(return_X_y=True)
X = dd.from_array(X_, chunksize=32)
y = dd.from_array(y_, chunksize=32)
cls = xgb.dask.DaskXGBClassifier(
verbosity=1,
n_estimators=i + 1,
eval_metric="merror",
use_label_encoder=False,
)
f = client.submit(cls.fit, X, y, pure=False)
futures.append(f)
classifiers = client.gather(futures)
assert len(classifiers) == n_submits
for i, cls in enumerate(classifiers):
assert cls.get_booster().num_boosted_rounds() == i + 1
def test_parallel_submit_multi_clients() -> None:
"""Test for running multiple train simultaneously from multiple clients."""
try:
from distributed import MultiLock # NOQA
except ImportError:
pytest.skip("`distributed.MultiLock' is not available")
from sklearn.datasets import load_digits
with LocalCluster(n_workers=4) as cluster:
with Client(cluster) as client:
workers = _get_client_workers(client)
n_submits = len(workers)
assert n_submits == 4
futures = []
for i in range(n_submits):
client = Client(cluster)
X_, y_ = load_digits(return_X_y=True)
X_ += 1.0
X = dd.from_array(X_, chunksize=32)
y = dd.from_array(y_, chunksize=32)
cls = xgb.dask.DaskXGBClassifier(
verbosity=1,
n_estimators=i + 1,
eval_metric="merror",
use_label_encoder=False,
)
f = client.submit(cls.fit, X, y, pure=False)
futures.append((client, f))
t_futures = []
with ThreadPoolExecutor(max_workers=16) as e:
for i in range(n_submits):
def _() -> xgb.dask.DaskXGBClassifier:
return futures[i][0].compute(futures[i][1]).result()
f = e.submit(_)
t_futures.append(f)
for i, f in enumerate(t_futures):
assert f.result().get_booster().num_boosted_rounds() == i + 1
class TestDaskCallbacks: class TestDaskCallbacks:
@pytest.mark.skipif(**tm.no_sklearn()) @pytest.mark.skipif(**tm.no_sklearn())
def test_early_stopping(self, client: "Client") -> None: def test_early_stopping(self, client: "Client") -> None: