[dask] Use distributed.MultiLock (#6743)
* [dask] Use `distributed.MultiLock` This enables training multiple models in parallel. * Conditionally import `MultiLock`. * Use async train directly in scikit learn interface. * Use `worker_client` when available.
This commit is contained in:
parent
19a2c54265
commit
325bc93e16
@ -17,6 +17,7 @@ https://github.com/dask/dask-xgboost
|
||||
"""
|
||||
import platform
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence
|
||||
from threading import Thread
|
||||
@ -93,6 +94,34 @@ except ImportError:
|
||||
LOGGER = logging.getLogger('[xgboost.dask]')
|
||||
|
||||
|
||||
def _multi_lock() -> Any:
|
||||
"""MultiLock is only available on latest distributed. See:
|
||||
|
||||
https://github.com/dask/distributed/pull/4503
|
||||
|
||||
"""
|
||||
try:
|
||||
from distributed import MultiLock
|
||||
except ImportError:
|
||||
class MultiLock: # type:ignore
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
pass
|
||||
|
||||
def __enter__(self) -> "MultiLock":
|
||||
return self
|
||||
|
||||
def __exit__(self, *args: Any, **kwargs: Any) -> None:
|
||||
return
|
||||
|
||||
async def __aenter__(self) -> "MultiLock":
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *args: Any, **kwargs: Any) -> None:
|
||||
return
|
||||
|
||||
return MultiLock
|
||||
|
||||
|
||||
def _start_tracker(n_workers: int) -> Dict[str, Any]:
|
||||
"""Start Rabit tracker """
|
||||
env = {'DMLC_NUM_WORKER': n_workers}
|
||||
@ -770,7 +799,7 @@ async def _get_rabit_args(n_workers: int, client: "distributed.Client") -> List[
|
||||
def _get_workers_from_data(
|
||||
dtrain: DaskDMatrix,
|
||||
evals: Optional[List[Tuple[DaskDMatrix, str]]]
|
||||
) -> Set[str]:
|
||||
) -> List[str]:
|
||||
X_worker_map: Set[str] = set(dtrain.worker_map.keys())
|
||||
if evals:
|
||||
for e in evals:
|
||||
@ -780,7 +809,7 @@ def _get_workers_from_data(
|
||||
continue
|
||||
worker_map = set(e[0].worker_map.keys())
|
||||
X_worker_map = X_worker_map.union(worker_map)
|
||||
return X_worker_map
|
||||
return list(X_worker_map)
|
||||
|
||||
|
||||
async def _train_async(
|
||||
@ -795,9 +824,9 @@ async def _train_async(
|
||||
early_stopping_rounds: Optional[int],
|
||||
verbose_eval: Union[int, bool],
|
||||
xgb_model: Optional[Booster],
|
||||
callbacks: Optional[List[TrainingCallback]]
|
||||
callbacks: Optional[List[TrainingCallback]],
|
||||
) -> Optional[TrainReturnT]:
|
||||
workers = list(_get_workers_from_data(dtrain, evals))
|
||||
workers = _get_workers_from_data(dtrain, evals)
|
||||
_rabit_args = await _get_rabit_args(len(workers), client)
|
||||
|
||||
if params.get("booster", None) == "gblinear":
|
||||
@ -858,29 +887,32 @@ async def _train_async(
|
||||
# XGBoost is deterministic in most of the cases, which means train function is
|
||||
# supposed to be idempotent. One known exception is gblinear with shotgun updater.
|
||||
# We haven't been able to do a full verification so here we keep pure to be False.
|
||||
futures = []
|
||||
for i, worker_addr in enumerate(workers):
|
||||
if evals:
|
||||
# pylint: disable=protected-access
|
||||
evals_per_worker = [(e._create_fn_args(worker_addr), name, id(e))
|
||||
for e, name in evals]
|
||||
else:
|
||||
evals_per_worker = []
|
||||
f = client.submit(
|
||||
dispatched_train,
|
||||
worker_addr,
|
||||
_rabit_args,
|
||||
# pylint: disable=protected-access
|
||||
dtrain._create_fn_args(workers[i]),
|
||||
id(dtrain),
|
||||
evals_per_worker,
|
||||
pure=False,
|
||||
workers=[worker_addr]
|
||||
)
|
||||
futures.append(f)
|
||||
async with _multi_lock()(workers, client):
|
||||
futures = []
|
||||
for worker_addr in workers:
|
||||
if evals:
|
||||
# pylint: disable=protected-access
|
||||
evals_per_worker = [
|
||||
(e._create_fn_args(worker_addr), name, id(e)) for e, name in evals
|
||||
]
|
||||
else:
|
||||
evals_per_worker = []
|
||||
f = client.submit(
|
||||
dispatched_train,
|
||||
worker_addr,
|
||||
_rabit_args,
|
||||
# pylint: disable=protected-access
|
||||
dtrain._create_fn_args(worker_addr),
|
||||
id(dtrain),
|
||||
evals_per_worker,
|
||||
pure=False,
|
||||
workers=[worker_addr],
|
||||
)
|
||||
futures.append(f)
|
||||
|
||||
results = await client.gather(futures)
|
||||
return list(filter(lambda ret: ret is not None, results))[0]
|
||||
results = await client.gather(futures, asynchronous=True)
|
||||
|
||||
return list(filter(lambda ret: ret is not None, results))[0]
|
||||
|
||||
|
||||
def train( # pylint: disable=unused-argument
|
||||
@ -927,9 +959,8 @@ def train( # pylint: disable=unused-argument
|
||||
"""
|
||||
_assert_dask_support()
|
||||
client = _xgb_get_client(client)
|
||||
# Get global configuration before transferring computation to another thread or
|
||||
# process.
|
||||
return client.sync(_train_async, global_config=config.get_config(), **locals())
|
||||
args = locals()
|
||||
return client.sync(_train_async, global_config=config.get_config(), **args)
|
||||
|
||||
|
||||
def _can_output_df(is_df: bool, output_shape: Tuple) -> bool:
|
||||
@ -1366,6 +1397,9 @@ def inplace_predict( # pylint: disable=unused-argument
|
||||
"""
|
||||
_assert_dask_support()
|
||||
client = _xgb_get_client(client)
|
||||
# When used in asynchronous environment, the `client` object should have
|
||||
# `asynchronous` attribute as True. When invoked by the skl interface, it's
|
||||
# responsible for setting up the client.
|
||||
return client.sync(
|
||||
_inplace_predict_async, global_config=config.get_config(), **locals()
|
||||
)
|
||||
@ -1393,6 +1427,18 @@ async def _async_wrap_evaluation_matrices(
|
||||
return train_dmatrix, awaited
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _set_worker_client(
|
||||
model: "DaskScikitLearnBase", client: "distributed.Client"
|
||||
) -> Generator:
|
||||
"""Temporarily set the client for sklearn model."""
|
||||
try:
|
||||
model.client = client
|
||||
yield model
|
||||
finally:
|
||||
model.client = None
|
||||
|
||||
|
||||
class DaskScikitLearnBase(XGBModel):
|
||||
"""Base class for implementing scikit-learn interface with Dask"""
|
||||
|
||||
@ -1487,7 +1533,7 @@ class DaskScikitLearnBase(XGBModel):
|
||||
async def _() -> Awaitable[Any]:
|
||||
return self
|
||||
|
||||
return self.client.sync(_).__await__()
|
||||
return self._client_sync(_).__await__()
|
||||
|
||||
def __getstate__(self) -> Dict:
|
||||
this = self.__dict__.copy()
|
||||
@ -1497,14 +1543,43 @@ class DaskScikitLearnBase(XGBModel):
|
||||
|
||||
@property
|
||||
def client(self) -> "distributed.Client":
|
||||
"""The dask client used in this model."""
|
||||
"""The dask client used in this model. The `Client` object can not be serialized for
|
||||
transmission, so if task is launched from a worker instead of directly from the
|
||||
client process, this attribute needs to be set at that worker.
|
||||
|
||||
"""
|
||||
|
||||
client = _xgb_get_client(self._client)
|
||||
return client
|
||||
|
||||
@client.setter
|
||||
def client(self, clt: "distributed.Client") -> None:
|
||||
# calling `worker_client' doesn't return the correct `asynchronous` attribute, so
|
||||
# we have to pass it ourselves.
|
||||
self._asynchronous = clt.asynchronous if clt is not None else False
|
||||
self._client = clt
|
||||
|
||||
def _client_sync(self, func: Callable, **kwargs: Any) -> Any:
|
||||
"""Get the correct client, when method is invoked inside a worker we
|
||||
should use `worker_client' instead of default client.
|
||||
|
||||
"""
|
||||
asynchronous = getattr(self, "_asynchronous", False)
|
||||
if self._client is None:
|
||||
try:
|
||||
distributed.get_worker()
|
||||
in_worker = True
|
||||
except ValueError:
|
||||
in_worker = False
|
||||
if in_worker:
|
||||
with distributed.worker_client() as client:
|
||||
with _set_worker_client(self, client) as this:
|
||||
ret = this.client.sync(func, **kwargs, asynchronous=asynchronous)
|
||||
return ret
|
||||
return ret
|
||||
|
||||
return self.client.sync(func, **kwargs, asynchronous=asynchronous)
|
||||
|
||||
|
||||
@xgboost_model_doc(
|
||||
"""Implementation of the Scikit-Learn API for XGBoost.""", ["estimators", "model"]
|
||||
@ -1552,22 +1627,24 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
|
||||
model, metric, params = self._configure_fit(
|
||||
booster=xgb_model, eval_metric=eval_metric, params=params
|
||||
)
|
||||
results = await train(
|
||||
results = await self.client.sync(
|
||||
_train_async,
|
||||
asynchronous=True,
|
||||
client=self.client,
|
||||
global_config=config.get_config(),
|
||||
params=params,
|
||||
dtrain=dtrain,
|
||||
num_boost_round=self.get_num_boosting_rounds(),
|
||||
evals=evals,
|
||||
feval=metric,
|
||||
obj=obj,
|
||||
feval=metric,
|
||||
verbose_eval=verbose,
|
||||
early_stopping_rounds=early_stopping_rounds,
|
||||
callbacks=callbacks,
|
||||
xgb_model=model,
|
||||
)
|
||||
self._Booster = results["booster"]
|
||||
# pylint: disable=attribute-defined-outside-init
|
||||
self.evals_result_ = results["history"]
|
||||
self._set_evaluation_result(results["history"])
|
||||
return self
|
||||
|
||||
# pylint: disable=missing-docstring, disable=unused-argument
|
||||
@ -1591,7 +1668,7 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
|
||||
) -> "DaskXGBRegressor":
|
||||
_assert_dask_support()
|
||||
args = {k: v for k, v in locals().items() if k != "self"}
|
||||
return self.client.sync(self._fit_async, **args)
|
||||
return self._client_sync(self._fit_async, **args)
|
||||
|
||||
|
||||
@xgboost_model_doc(
|
||||
@ -1651,8 +1728,11 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
||||
model, metric, params = self._configure_fit(
|
||||
booster=xgb_model, eval_metric=eval_metric, params=params
|
||||
)
|
||||
results = await train(
|
||||
results = await self.client.sync(
|
||||
_train_async,
|
||||
asynchronous=True,
|
||||
client=self.client,
|
||||
global_config=config.get_config(),
|
||||
params=params,
|
||||
dtrain=dtrain,
|
||||
num_boost_round=self.get_num_boosting_rounds(),
|
||||
@ -1665,16 +1745,12 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
||||
xgb_model=model,
|
||||
)
|
||||
self._Booster = results['booster']
|
||||
|
||||
if not callable(self.objective):
|
||||
self.objective = params["objective"]
|
||||
|
||||
# pylint: disable=attribute-defined-outside-init
|
||||
self.evals_result_ = results['history']
|
||||
self._set_evaluation_result(results["history"])
|
||||
return self
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
@_deprecate_positional_args
|
||||
def fit(
|
||||
self,
|
||||
X: _DaskCollection,
|
||||
@ -1694,7 +1770,7 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
||||
) -> "DaskXGBClassifier":
|
||||
_assert_dask_support()
|
||||
args = {k: v for k, v in locals().items() if k != 'self'}
|
||||
return self.client.sync(self._fit_async, **args)
|
||||
return self._client_sync(self._fit_async, **args)
|
||||
|
||||
async def _predict_proba_async(
|
||||
self,
|
||||
@ -1728,7 +1804,7 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
||||
_assert_dask_support()
|
||||
msg = "`ntree_limit` is not supported on dask, use `iteration_range` instead."
|
||||
assert ntree_limit is None, msg
|
||||
return self.client.sync(
|
||||
return self._client_sync(
|
||||
self._predict_proba_async,
|
||||
X=X,
|
||||
validate_features=validate_features,
|
||||
@ -1838,12 +1914,16 @@ class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn):
|
||||
model, metric, params = self._configure_fit(
|
||||
booster=xgb_model, eval_metric=eval_metric, params=params
|
||||
)
|
||||
results = await train(
|
||||
results = await self.client.sync(
|
||||
_train_async,
|
||||
asynchronous=True,
|
||||
client=self.client,
|
||||
global_config=config.get_config(),
|
||||
params=params,
|
||||
dtrain=dtrain,
|
||||
num_boost_round=self.get_num_boosting_rounds(),
|
||||
evals=evals,
|
||||
obj=None,
|
||||
feval=metric,
|
||||
verbose_eval=verbose,
|
||||
early_stopping_rounds=early_stopping_rounds,
|
||||
@ -1879,7 +1959,7 @@ class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn):
|
||||
) -> "DaskXGBRanker":
|
||||
_assert_dask_support()
|
||||
args = {k: v for k, v in locals().items() if k != "self"}
|
||||
return self.client.sync(self._fit_async, **args)
|
||||
return self._client_sync(self._fit_async, **args)
|
||||
|
||||
# FIXME(trivialfis): arguments differ due to additional parameters like group and qid.
|
||||
fit.__doc__ = XGBRanker.fit.__doc__
|
||||
|
||||
@ -293,7 +293,7 @@ class TestDistributedGPU:
|
||||
fw = fw - fw.min()
|
||||
m = dxgb.DaskDMatrix(client, X, y, feature_weights=fw)
|
||||
|
||||
workers = list(_get_client_workers(client).keys())
|
||||
workers = _get_client_workers(client)
|
||||
rabit_args = client.sync(dxgb._get_rabit_args, len(workers), client)
|
||||
|
||||
def worker_fn(worker_addr: str, data_ref: Dict) -> None:
|
||||
@ -384,7 +384,7 @@ class TestDistributedGPU:
|
||||
return subprocess.run([str(exe), test], env=env, stdout=subprocess.PIPE)
|
||||
|
||||
with Client(local_cuda_cluster) as client:
|
||||
workers = list(_get_client_workers(client).keys())
|
||||
workers = _get_client_workers(client)
|
||||
rabit_args = client.sync(dxgb._get_rabit_args, workers, client)
|
||||
futures = client.map(runit,
|
||||
workers,
|
||||
|
||||
@ -27,7 +27,7 @@ def run_rabit_ops(client, n_workers):
|
||||
from xgboost.dask import RabitContext, _get_rabit_args
|
||||
from xgboost import rabit
|
||||
|
||||
workers = list(_get_client_workers(client).keys())
|
||||
workers = _get_client_workers(client)
|
||||
rabit_args = client.sync(_get_rabit_args, len(workers), client)
|
||||
assert not rabit.is_distributed()
|
||||
n_workers_from_dask = len(workers)
|
||||
|
||||
@ -9,6 +9,7 @@ import scipy
|
||||
import json
|
||||
from typing import List, Tuple, Dict, Optional, Type, Any
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import tempfile
|
||||
from sklearn.datasets import make_classification
|
||||
import sklearn
|
||||
@ -43,9 +44,9 @@ kCols = 10
|
||||
kWorkers = 5
|
||||
|
||||
|
||||
def _get_client_workers(client: "Client") -> Dict[str, Dict]:
|
||||
def _get_client_workers(client: "Client") -> List[str]:
|
||||
workers = client.scheduler_info()['workers']
|
||||
return workers
|
||||
return list(workers.keys())
|
||||
|
||||
|
||||
def generate_array(
|
||||
@ -646,7 +647,7 @@ def test_with_asyncio() -> None:
|
||||
|
||||
|
||||
async def generate_concurrent_trainings() -> None:
|
||||
async def train():
|
||||
async def train() -> None:
|
||||
async with LocalCluster(n_workers=2,
|
||||
threads_per_worker=1,
|
||||
asynchronous=True,
|
||||
@ -967,7 +968,7 @@ class TestWithDask:
|
||||
|
||||
with LocalCluster(n_workers=4) as cluster:
|
||||
with Client(cluster) as client:
|
||||
workers = list(_get_client_workers(client).keys())
|
||||
workers = _get_client_workers(client)
|
||||
rabit_args = client.sync(
|
||||
xgb.dask._get_rabit_args, len(workers), client)
|
||||
futures = client.map(runit,
|
||||
@ -1000,7 +1001,7 @@ class TestWithDask:
|
||||
def test_n_workers(self) -> None:
|
||||
with LocalCluster(n_workers=2) as cluster:
|
||||
with Client(cluster) as client:
|
||||
workers = list(_get_client_workers(client).keys())
|
||||
workers = _get_client_workers(client)
|
||||
from sklearn.datasets import load_breast_cancer
|
||||
X, y = load_breast_cancer(return_X_y=True)
|
||||
dX = client.submit(da.from_array, X, workers=[workers[0]]).result()
|
||||
@ -1090,7 +1091,7 @@ class TestWithDask:
|
||||
X, y, _ = generate_array()
|
||||
n_partitions = X.npartitions
|
||||
m = xgb.dask.DaskDMatrix(client, X, y)
|
||||
workers = list(_get_client_workers(client).keys())
|
||||
workers = _get_client_workers(client)
|
||||
rabit_args = client.sync(xgb.dask._get_rabit_args, len(workers), client)
|
||||
n_workers = len(workers)
|
||||
|
||||
@ -1285,6 +1286,82 @@ def test_dask_unsupported_features(client: "Client") -> None:
|
||||
)
|
||||
|
||||
|
||||
def test_parallel_submits(client: "Client") -> None:
|
||||
"""Test for running multiple train simultaneously from single clients."""
|
||||
try:
|
||||
from distributed import MultiLock # NOQA
|
||||
except ImportError:
|
||||
pytest.skip("`distributed.MultiLock' is not available")
|
||||
|
||||
from sklearn.datasets import load_digits
|
||||
|
||||
futures = []
|
||||
workers = _get_client_workers(client)
|
||||
n_submits = len(workers)
|
||||
for i in range(n_submits):
|
||||
X_, y_ = load_digits(return_X_y=True)
|
||||
X = dd.from_array(X_, chunksize=32)
|
||||
y = dd.from_array(y_, chunksize=32)
|
||||
cls = xgb.dask.DaskXGBClassifier(
|
||||
verbosity=1,
|
||||
n_estimators=i + 1,
|
||||
eval_metric="merror",
|
||||
use_label_encoder=False,
|
||||
)
|
||||
f = client.submit(cls.fit, X, y, pure=False)
|
||||
futures.append(f)
|
||||
|
||||
classifiers = client.gather(futures)
|
||||
assert len(classifiers) == n_submits
|
||||
for i, cls in enumerate(classifiers):
|
||||
assert cls.get_booster().num_boosted_rounds() == i + 1
|
||||
|
||||
|
||||
def test_parallel_submit_multi_clients() -> None:
|
||||
"""Test for running multiple train simultaneously from multiple clients."""
|
||||
try:
|
||||
from distributed import MultiLock # NOQA
|
||||
except ImportError:
|
||||
pytest.skip("`distributed.MultiLock' is not available")
|
||||
|
||||
from sklearn.datasets import load_digits
|
||||
|
||||
with LocalCluster(n_workers=4) as cluster:
|
||||
with Client(cluster) as client:
|
||||
workers = _get_client_workers(client)
|
||||
|
||||
n_submits = len(workers)
|
||||
assert n_submits == 4
|
||||
futures = []
|
||||
|
||||
for i in range(n_submits):
|
||||
client = Client(cluster)
|
||||
X_, y_ = load_digits(return_X_y=True)
|
||||
X_ += 1.0
|
||||
X = dd.from_array(X_, chunksize=32)
|
||||
y = dd.from_array(y_, chunksize=32)
|
||||
cls = xgb.dask.DaskXGBClassifier(
|
||||
verbosity=1,
|
||||
n_estimators=i + 1,
|
||||
eval_metric="merror",
|
||||
use_label_encoder=False,
|
||||
)
|
||||
f = client.submit(cls.fit, X, y, pure=False)
|
||||
futures.append((client, f))
|
||||
|
||||
t_futures = []
|
||||
with ThreadPoolExecutor(max_workers=16) as e:
|
||||
for i in range(n_submits):
|
||||
def _() -> xgb.dask.DaskXGBClassifier:
|
||||
return futures[i][0].compute(futures[i][1]).result()
|
||||
|
||||
f = e.submit(_)
|
||||
t_futures.append(f)
|
||||
|
||||
for i, f in enumerate(t_futures):
|
||||
assert f.result().get_booster().num_boosted_rounds() == i + 1
|
||||
|
||||
|
||||
class TestDaskCallbacks:
|
||||
@pytest.mark.skipif(**tm.no_sklearn())
|
||||
def test_early_stopping(self, client: "Client") -> None:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user