[dask] Use distributed.MultiLock (#6743)
* [dask] Use `distributed.MultiLock` This enables training multiple models in parallel. * Conditionally import `MultiLock`. * Use async train directly in scikit learn interface. * Use `worker_client` when available.
This commit is contained in:
parent
19a2c54265
commit
325bc93e16
@ -17,6 +17,7 @@ https://github.com/dask/dask-xgboost
|
|||||||
"""
|
"""
|
||||||
import platform
|
import platform
|
||||||
import logging
|
import logging
|
||||||
|
from contextlib import contextmanager
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
@ -93,6 +94,34 @@ except ImportError:
|
|||||||
LOGGER = logging.getLogger('[xgboost.dask]')
|
LOGGER = logging.getLogger('[xgboost.dask]')
|
||||||
|
|
||||||
|
|
||||||
|
def _multi_lock() -> Any:
|
||||||
|
"""MultiLock is only available on latest distributed. See:
|
||||||
|
|
||||||
|
https://github.com/dask/distributed/pull/4503
|
||||||
|
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from distributed import MultiLock
|
||||||
|
except ImportError:
|
||||||
|
class MultiLock: # type:ignore
|
||||||
|
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __enter__(self) -> "MultiLock":
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, *args: Any, **kwargs: Any) -> None:
|
||||||
|
return
|
||||||
|
|
||||||
|
async def __aenter__(self) -> "MultiLock":
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, *args: Any, **kwargs: Any) -> None:
|
||||||
|
return
|
||||||
|
|
||||||
|
return MultiLock
|
||||||
|
|
||||||
|
|
||||||
def _start_tracker(n_workers: int) -> Dict[str, Any]:
|
def _start_tracker(n_workers: int) -> Dict[str, Any]:
|
||||||
"""Start Rabit tracker """
|
"""Start Rabit tracker """
|
||||||
env = {'DMLC_NUM_WORKER': n_workers}
|
env = {'DMLC_NUM_WORKER': n_workers}
|
||||||
@ -770,7 +799,7 @@ async def _get_rabit_args(n_workers: int, client: "distributed.Client") -> List[
|
|||||||
def _get_workers_from_data(
|
def _get_workers_from_data(
|
||||||
dtrain: DaskDMatrix,
|
dtrain: DaskDMatrix,
|
||||||
evals: Optional[List[Tuple[DaskDMatrix, str]]]
|
evals: Optional[List[Tuple[DaskDMatrix, str]]]
|
||||||
) -> Set[str]:
|
) -> List[str]:
|
||||||
X_worker_map: Set[str] = set(dtrain.worker_map.keys())
|
X_worker_map: Set[str] = set(dtrain.worker_map.keys())
|
||||||
if evals:
|
if evals:
|
||||||
for e in evals:
|
for e in evals:
|
||||||
@ -780,7 +809,7 @@ def _get_workers_from_data(
|
|||||||
continue
|
continue
|
||||||
worker_map = set(e[0].worker_map.keys())
|
worker_map = set(e[0].worker_map.keys())
|
||||||
X_worker_map = X_worker_map.union(worker_map)
|
X_worker_map = X_worker_map.union(worker_map)
|
||||||
return X_worker_map
|
return list(X_worker_map)
|
||||||
|
|
||||||
|
|
||||||
async def _train_async(
|
async def _train_async(
|
||||||
@ -795,9 +824,9 @@ async def _train_async(
|
|||||||
early_stopping_rounds: Optional[int],
|
early_stopping_rounds: Optional[int],
|
||||||
verbose_eval: Union[int, bool],
|
verbose_eval: Union[int, bool],
|
||||||
xgb_model: Optional[Booster],
|
xgb_model: Optional[Booster],
|
||||||
callbacks: Optional[List[TrainingCallback]]
|
callbacks: Optional[List[TrainingCallback]],
|
||||||
) -> Optional[TrainReturnT]:
|
) -> Optional[TrainReturnT]:
|
||||||
workers = list(_get_workers_from_data(dtrain, evals))
|
workers = _get_workers_from_data(dtrain, evals)
|
||||||
_rabit_args = await _get_rabit_args(len(workers), client)
|
_rabit_args = await _get_rabit_args(len(workers), client)
|
||||||
|
|
||||||
if params.get("booster", None) == "gblinear":
|
if params.get("booster", None) == "gblinear":
|
||||||
@ -858,29 +887,32 @@ async def _train_async(
|
|||||||
# XGBoost is deterministic in most of the cases, which means train function is
|
# XGBoost is deterministic in most of the cases, which means train function is
|
||||||
# supposed to be idempotent. One known exception is gblinear with shotgun updater.
|
# supposed to be idempotent. One known exception is gblinear with shotgun updater.
|
||||||
# We haven't been able to do a full verification so here we keep pure to be False.
|
# We haven't been able to do a full verification so here we keep pure to be False.
|
||||||
futures = []
|
async with _multi_lock()(workers, client):
|
||||||
for i, worker_addr in enumerate(workers):
|
futures = []
|
||||||
if evals:
|
for worker_addr in workers:
|
||||||
# pylint: disable=protected-access
|
if evals:
|
||||||
evals_per_worker = [(e._create_fn_args(worker_addr), name, id(e))
|
# pylint: disable=protected-access
|
||||||
for e, name in evals]
|
evals_per_worker = [
|
||||||
else:
|
(e._create_fn_args(worker_addr), name, id(e)) for e, name in evals
|
||||||
evals_per_worker = []
|
]
|
||||||
f = client.submit(
|
else:
|
||||||
dispatched_train,
|
evals_per_worker = []
|
||||||
worker_addr,
|
f = client.submit(
|
||||||
_rabit_args,
|
dispatched_train,
|
||||||
# pylint: disable=protected-access
|
worker_addr,
|
||||||
dtrain._create_fn_args(workers[i]),
|
_rabit_args,
|
||||||
id(dtrain),
|
# pylint: disable=protected-access
|
||||||
evals_per_worker,
|
dtrain._create_fn_args(worker_addr),
|
||||||
pure=False,
|
id(dtrain),
|
||||||
workers=[worker_addr]
|
evals_per_worker,
|
||||||
)
|
pure=False,
|
||||||
futures.append(f)
|
workers=[worker_addr],
|
||||||
|
)
|
||||||
|
futures.append(f)
|
||||||
|
|
||||||
results = await client.gather(futures)
|
results = await client.gather(futures, asynchronous=True)
|
||||||
return list(filter(lambda ret: ret is not None, results))[0]
|
|
||||||
|
return list(filter(lambda ret: ret is not None, results))[0]
|
||||||
|
|
||||||
|
|
||||||
def train( # pylint: disable=unused-argument
|
def train( # pylint: disable=unused-argument
|
||||||
@ -927,9 +959,8 @@ def train( # pylint: disable=unused-argument
|
|||||||
"""
|
"""
|
||||||
_assert_dask_support()
|
_assert_dask_support()
|
||||||
client = _xgb_get_client(client)
|
client = _xgb_get_client(client)
|
||||||
# Get global configuration before transferring computation to another thread or
|
args = locals()
|
||||||
# process.
|
return client.sync(_train_async, global_config=config.get_config(), **args)
|
||||||
return client.sync(_train_async, global_config=config.get_config(), **locals())
|
|
||||||
|
|
||||||
|
|
||||||
def _can_output_df(is_df: bool, output_shape: Tuple) -> bool:
|
def _can_output_df(is_df: bool, output_shape: Tuple) -> bool:
|
||||||
@ -1366,6 +1397,9 @@ def inplace_predict( # pylint: disable=unused-argument
|
|||||||
"""
|
"""
|
||||||
_assert_dask_support()
|
_assert_dask_support()
|
||||||
client = _xgb_get_client(client)
|
client = _xgb_get_client(client)
|
||||||
|
# When used in asynchronous environment, the `client` object should have
|
||||||
|
# `asynchronous` attribute as True. When invoked by the skl interface, it's
|
||||||
|
# responsible for setting up the client.
|
||||||
return client.sync(
|
return client.sync(
|
||||||
_inplace_predict_async, global_config=config.get_config(), **locals()
|
_inplace_predict_async, global_config=config.get_config(), **locals()
|
||||||
)
|
)
|
||||||
@ -1393,6 +1427,18 @@ async def _async_wrap_evaluation_matrices(
|
|||||||
return train_dmatrix, awaited
|
return train_dmatrix, awaited
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def _set_worker_client(
|
||||||
|
model: "DaskScikitLearnBase", client: "distributed.Client"
|
||||||
|
) -> Generator:
|
||||||
|
"""Temporarily set the client for sklearn model."""
|
||||||
|
try:
|
||||||
|
model.client = client
|
||||||
|
yield model
|
||||||
|
finally:
|
||||||
|
model.client = None
|
||||||
|
|
||||||
|
|
||||||
class DaskScikitLearnBase(XGBModel):
|
class DaskScikitLearnBase(XGBModel):
|
||||||
"""Base class for implementing scikit-learn interface with Dask"""
|
"""Base class for implementing scikit-learn interface with Dask"""
|
||||||
|
|
||||||
@ -1487,7 +1533,7 @@ class DaskScikitLearnBase(XGBModel):
|
|||||||
async def _() -> Awaitable[Any]:
|
async def _() -> Awaitable[Any]:
|
||||||
return self
|
return self
|
||||||
|
|
||||||
return self.client.sync(_).__await__()
|
return self._client_sync(_).__await__()
|
||||||
|
|
||||||
def __getstate__(self) -> Dict:
|
def __getstate__(self) -> Dict:
|
||||||
this = self.__dict__.copy()
|
this = self.__dict__.copy()
|
||||||
@ -1497,14 +1543,43 @@ class DaskScikitLearnBase(XGBModel):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def client(self) -> "distributed.Client":
|
def client(self) -> "distributed.Client":
|
||||||
"""The dask client used in this model."""
|
"""The dask client used in this model. The `Client` object can not be serialized for
|
||||||
|
transmission, so if task is launched from a worker instead of directly from the
|
||||||
|
client process, this attribute needs to be set at that worker.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
client = _xgb_get_client(self._client)
|
client = _xgb_get_client(self._client)
|
||||||
return client
|
return client
|
||||||
|
|
||||||
@client.setter
|
@client.setter
|
||||||
def client(self, clt: "distributed.Client") -> None:
|
def client(self, clt: "distributed.Client") -> None:
|
||||||
|
# calling `worker_client' doesn't return the correct `asynchronous` attribute, so
|
||||||
|
# we have to pass it ourselves.
|
||||||
|
self._asynchronous = clt.asynchronous if clt is not None else False
|
||||||
self._client = clt
|
self._client = clt
|
||||||
|
|
||||||
|
def _client_sync(self, func: Callable, **kwargs: Any) -> Any:
|
||||||
|
"""Get the correct client, when method is invoked inside a worker we
|
||||||
|
should use `worker_client' instead of default client.
|
||||||
|
|
||||||
|
"""
|
||||||
|
asynchronous = getattr(self, "_asynchronous", False)
|
||||||
|
if self._client is None:
|
||||||
|
try:
|
||||||
|
distributed.get_worker()
|
||||||
|
in_worker = True
|
||||||
|
except ValueError:
|
||||||
|
in_worker = False
|
||||||
|
if in_worker:
|
||||||
|
with distributed.worker_client() as client:
|
||||||
|
with _set_worker_client(self, client) as this:
|
||||||
|
ret = this.client.sync(func, **kwargs, asynchronous=asynchronous)
|
||||||
|
return ret
|
||||||
|
return ret
|
||||||
|
|
||||||
|
return self.client.sync(func, **kwargs, asynchronous=asynchronous)
|
||||||
|
|
||||||
|
|
||||||
@xgboost_model_doc(
|
@xgboost_model_doc(
|
||||||
"""Implementation of the Scikit-Learn API for XGBoost.""", ["estimators", "model"]
|
"""Implementation of the Scikit-Learn API for XGBoost.""", ["estimators", "model"]
|
||||||
@ -1552,22 +1627,24 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
|
|||||||
model, metric, params = self._configure_fit(
|
model, metric, params = self._configure_fit(
|
||||||
booster=xgb_model, eval_metric=eval_metric, params=params
|
booster=xgb_model, eval_metric=eval_metric, params=params
|
||||||
)
|
)
|
||||||
results = await train(
|
results = await self.client.sync(
|
||||||
|
_train_async,
|
||||||
|
asynchronous=True,
|
||||||
client=self.client,
|
client=self.client,
|
||||||
|
global_config=config.get_config(),
|
||||||
params=params,
|
params=params,
|
||||||
dtrain=dtrain,
|
dtrain=dtrain,
|
||||||
num_boost_round=self.get_num_boosting_rounds(),
|
num_boost_round=self.get_num_boosting_rounds(),
|
||||||
evals=evals,
|
evals=evals,
|
||||||
feval=metric,
|
|
||||||
obj=obj,
|
obj=obj,
|
||||||
|
feval=metric,
|
||||||
verbose_eval=verbose,
|
verbose_eval=verbose,
|
||||||
early_stopping_rounds=early_stopping_rounds,
|
early_stopping_rounds=early_stopping_rounds,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
xgb_model=model,
|
xgb_model=model,
|
||||||
)
|
)
|
||||||
self._Booster = results["booster"]
|
self._Booster = results["booster"]
|
||||||
# pylint: disable=attribute-defined-outside-init
|
self._set_evaluation_result(results["history"])
|
||||||
self.evals_result_ = results["history"]
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
# pylint: disable=missing-docstring, disable=unused-argument
|
# pylint: disable=missing-docstring, disable=unused-argument
|
||||||
@ -1591,7 +1668,7 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
|
|||||||
) -> "DaskXGBRegressor":
|
) -> "DaskXGBRegressor":
|
||||||
_assert_dask_support()
|
_assert_dask_support()
|
||||||
args = {k: v for k, v in locals().items() if k != "self"}
|
args = {k: v for k, v in locals().items() if k != "self"}
|
||||||
return self.client.sync(self._fit_async, **args)
|
return self._client_sync(self._fit_async, **args)
|
||||||
|
|
||||||
|
|
||||||
@xgboost_model_doc(
|
@xgboost_model_doc(
|
||||||
@ -1651,8 +1728,11 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
|||||||
model, metric, params = self._configure_fit(
|
model, metric, params = self._configure_fit(
|
||||||
booster=xgb_model, eval_metric=eval_metric, params=params
|
booster=xgb_model, eval_metric=eval_metric, params=params
|
||||||
)
|
)
|
||||||
results = await train(
|
results = await self.client.sync(
|
||||||
|
_train_async,
|
||||||
|
asynchronous=True,
|
||||||
client=self.client,
|
client=self.client,
|
||||||
|
global_config=config.get_config(),
|
||||||
params=params,
|
params=params,
|
||||||
dtrain=dtrain,
|
dtrain=dtrain,
|
||||||
num_boost_round=self.get_num_boosting_rounds(),
|
num_boost_round=self.get_num_boosting_rounds(),
|
||||||
@ -1665,16 +1745,12 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
|||||||
xgb_model=model,
|
xgb_model=model,
|
||||||
)
|
)
|
||||||
self._Booster = results['booster']
|
self._Booster = results['booster']
|
||||||
|
|
||||||
if not callable(self.objective):
|
if not callable(self.objective):
|
||||||
self.objective = params["objective"]
|
self.objective = params["objective"]
|
||||||
|
self._set_evaluation_result(results["history"])
|
||||||
# pylint: disable=attribute-defined-outside-init
|
|
||||||
self.evals_result_ = results['history']
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
# pylint: disable=unused-argument
|
# pylint: disable=unused-argument
|
||||||
@_deprecate_positional_args
|
|
||||||
def fit(
|
def fit(
|
||||||
self,
|
self,
|
||||||
X: _DaskCollection,
|
X: _DaskCollection,
|
||||||
@ -1694,7 +1770,7 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
|||||||
) -> "DaskXGBClassifier":
|
) -> "DaskXGBClassifier":
|
||||||
_assert_dask_support()
|
_assert_dask_support()
|
||||||
args = {k: v for k, v in locals().items() if k != 'self'}
|
args = {k: v for k, v in locals().items() if k != 'self'}
|
||||||
return self.client.sync(self._fit_async, **args)
|
return self._client_sync(self._fit_async, **args)
|
||||||
|
|
||||||
async def _predict_proba_async(
|
async def _predict_proba_async(
|
||||||
self,
|
self,
|
||||||
@ -1728,7 +1804,7 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
|||||||
_assert_dask_support()
|
_assert_dask_support()
|
||||||
msg = "`ntree_limit` is not supported on dask, use `iteration_range` instead."
|
msg = "`ntree_limit` is not supported on dask, use `iteration_range` instead."
|
||||||
assert ntree_limit is None, msg
|
assert ntree_limit is None, msg
|
||||||
return self.client.sync(
|
return self._client_sync(
|
||||||
self._predict_proba_async,
|
self._predict_proba_async,
|
||||||
X=X,
|
X=X,
|
||||||
validate_features=validate_features,
|
validate_features=validate_features,
|
||||||
@ -1838,12 +1914,16 @@ class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn):
|
|||||||
model, metric, params = self._configure_fit(
|
model, metric, params = self._configure_fit(
|
||||||
booster=xgb_model, eval_metric=eval_metric, params=params
|
booster=xgb_model, eval_metric=eval_metric, params=params
|
||||||
)
|
)
|
||||||
results = await train(
|
results = await self.client.sync(
|
||||||
|
_train_async,
|
||||||
|
asynchronous=True,
|
||||||
client=self.client,
|
client=self.client,
|
||||||
|
global_config=config.get_config(),
|
||||||
params=params,
|
params=params,
|
||||||
dtrain=dtrain,
|
dtrain=dtrain,
|
||||||
num_boost_round=self.get_num_boosting_rounds(),
|
num_boost_round=self.get_num_boosting_rounds(),
|
||||||
evals=evals,
|
evals=evals,
|
||||||
|
obj=None,
|
||||||
feval=metric,
|
feval=metric,
|
||||||
verbose_eval=verbose,
|
verbose_eval=verbose,
|
||||||
early_stopping_rounds=early_stopping_rounds,
|
early_stopping_rounds=early_stopping_rounds,
|
||||||
@ -1879,7 +1959,7 @@ class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn):
|
|||||||
) -> "DaskXGBRanker":
|
) -> "DaskXGBRanker":
|
||||||
_assert_dask_support()
|
_assert_dask_support()
|
||||||
args = {k: v for k, v in locals().items() if k != "self"}
|
args = {k: v for k, v in locals().items() if k != "self"}
|
||||||
return self.client.sync(self._fit_async, **args)
|
return self._client_sync(self._fit_async, **args)
|
||||||
|
|
||||||
# FIXME(trivialfis): arguments differ due to additional parameters like group and qid.
|
# FIXME(trivialfis): arguments differ due to additional parameters like group and qid.
|
||||||
fit.__doc__ = XGBRanker.fit.__doc__
|
fit.__doc__ = XGBRanker.fit.__doc__
|
||||||
|
|||||||
@ -293,7 +293,7 @@ class TestDistributedGPU:
|
|||||||
fw = fw - fw.min()
|
fw = fw - fw.min()
|
||||||
m = dxgb.DaskDMatrix(client, X, y, feature_weights=fw)
|
m = dxgb.DaskDMatrix(client, X, y, feature_weights=fw)
|
||||||
|
|
||||||
workers = list(_get_client_workers(client).keys())
|
workers = _get_client_workers(client)
|
||||||
rabit_args = client.sync(dxgb._get_rabit_args, len(workers), client)
|
rabit_args = client.sync(dxgb._get_rabit_args, len(workers), client)
|
||||||
|
|
||||||
def worker_fn(worker_addr: str, data_ref: Dict) -> None:
|
def worker_fn(worker_addr: str, data_ref: Dict) -> None:
|
||||||
@ -384,7 +384,7 @@ class TestDistributedGPU:
|
|||||||
return subprocess.run([str(exe), test], env=env, stdout=subprocess.PIPE)
|
return subprocess.run([str(exe), test], env=env, stdout=subprocess.PIPE)
|
||||||
|
|
||||||
with Client(local_cuda_cluster) as client:
|
with Client(local_cuda_cluster) as client:
|
||||||
workers = list(_get_client_workers(client).keys())
|
workers = _get_client_workers(client)
|
||||||
rabit_args = client.sync(dxgb._get_rabit_args, workers, client)
|
rabit_args = client.sync(dxgb._get_rabit_args, workers, client)
|
||||||
futures = client.map(runit,
|
futures = client.map(runit,
|
||||||
workers,
|
workers,
|
||||||
|
|||||||
@ -27,7 +27,7 @@ def run_rabit_ops(client, n_workers):
|
|||||||
from xgboost.dask import RabitContext, _get_rabit_args
|
from xgboost.dask import RabitContext, _get_rabit_args
|
||||||
from xgboost import rabit
|
from xgboost import rabit
|
||||||
|
|
||||||
workers = list(_get_client_workers(client).keys())
|
workers = _get_client_workers(client)
|
||||||
rabit_args = client.sync(_get_rabit_args, len(workers), client)
|
rabit_args = client.sync(_get_rabit_args, len(workers), client)
|
||||||
assert not rabit.is_distributed()
|
assert not rabit.is_distributed()
|
||||||
n_workers_from_dask = len(workers)
|
n_workers_from_dask = len(workers)
|
||||||
|
|||||||
@ -9,6 +9,7 @@ import scipy
|
|||||||
import json
|
import json
|
||||||
from typing import List, Tuple, Dict, Optional, Type, Any
|
from typing import List, Tuple, Dict, Optional, Type, Any
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
import tempfile
|
import tempfile
|
||||||
from sklearn.datasets import make_classification
|
from sklearn.datasets import make_classification
|
||||||
import sklearn
|
import sklearn
|
||||||
@ -43,9 +44,9 @@ kCols = 10
|
|||||||
kWorkers = 5
|
kWorkers = 5
|
||||||
|
|
||||||
|
|
||||||
def _get_client_workers(client: "Client") -> Dict[str, Dict]:
|
def _get_client_workers(client: "Client") -> List[str]:
|
||||||
workers = client.scheduler_info()['workers']
|
workers = client.scheduler_info()['workers']
|
||||||
return workers
|
return list(workers.keys())
|
||||||
|
|
||||||
|
|
||||||
def generate_array(
|
def generate_array(
|
||||||
@ -646,7 +647,7 @@ def test_with_asyncio() -> None:
|
|||||||
|
|
||||||
|
|
||||||
async def generate_concurrent_trainings() -> None:
|
async def generate_concurrent_trainings() -> None:
|
||||||
async def train():
|
async def train() -> None:
|
||||||
async with LocalCluster(n_workers=2,
|
async with LocalCluster(n_workers=2,
|
||||||
threads_per_worker=1,
|
threads_per_worker=1,
|
||||||
asynchronous=True,
|
asynchronous=True,
|
||||||
@ -967,7 +968,7 @@ class TestWithDask:
|
|||||||
|
|
||||||
with LocalCluster(n_workers=4) as cluster:
|
with LocalCluster(n_workers=4) as cluster:
|
||||||
with Client(cluster) as client:
|
with Client(cluster) as client:
|
||||||
workers = list(_get_client_workers(client).keys())
|
workers = _get_client_workers(client)
|
||||||
rabit_args = client.sync(
|
rabit_args = client.sync(
|
||||||
xgb.dask._get_rabit_args, len(workers), client)
|
xgb.dask._get_rabit_args, len(workers), client)
|
||||||
futures = client.map(runit,
|
futures = client.map(runit,
|
||||||
@ -1000,7 +1001,7 @@ class TestWithDask:
|
|||||||
def test_n_workers(self) -> None:
|
def test_n_workers(self) -> None:
|
||||||
with LocalCluster(n_workers=2) as cluster:
|
with LocalCluster(n_workers=2) as cluster:
|
||||||
with Client(cluster) as client:
|
with Client(cluster) as client:
|
||||||
workers = list(_get_client_workers(client).keys())
|
workers = _get_client_workers(client)
|
||||||
from sklearn.datasets import load_breast_cancer
|
from sklearn.datasets import load_breast_cancer
|
||||||
X, y = load_breast_cancer(return_X_y=True)
|
X, y = load_breast_cancer(return_X_y=True)
|
||||||
dX = client.submit(da.from_array, X, workers=[workers[0]]).result()
|
dX = client.submit(da.from_array, X, workers=[workers[0]]).result()
|
||||||
@ -1090,7 +1091,7 @@ class TestWithDask:
|
|||||||
X, y, _ = generate_array()
|
X, y, _ = generate_array()
|
||||||
n_partitions = X.npartitions
|
n_partitions = X.npartitions
|
||||||
m = xgb.dask.DaskDMatrix(client, X, y)
|
m = xgb.dask.DaskDMatrix(client, X, y)
|
||||||
workers = list(_get_client_workers(client).keys())
|
workers = _get_client_workers(client)
|
||||||
rabit_args = client.sync(xgb.dask._get_rabit_args, len(workers), client)
|
rabit_args = client.sync(xgb.dask._get_rabit_args, len(workers), client)
|
||||||
n_workers = len(workers)
|
n_workers = len(workers)
|
||||||
|
|
||||||
@ -1285,6 +1286,82 @@ def test_dask_unsupported_features(client: "Client") -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_parallel_submits(client: "Client") -> None:
|
||||||
|
"""Test for running multiple train simultaneously from single clients."""
|
||||||
|
try:
|
||||||
|
from distributed import MultiLock # NOQA
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("`distributed.MultiLock' is not available")
|
||||||
|
|
||||||
|
from sklearn.datasets import load_digits
|
||||||
|
|
||||||
|
futures = []
|
||||||
|
workers = _get_client_workers(client)
|
||||||
|
n_submits = len(workers)
|
||||||
|
for i in range(n_submits):
|
||||||
|
X_, y_ = load_digits(return_X_y=True)
|
||||||
|
X = dd.from_array(X_, chunksize=32)
|
||||||
|
y = dd.from_array(y_, chunksize=32)
|
||||||
|
cls = xgb.dask.DaskXGBClassifier(
|
||||||
|
verbosity=1,
|
||||||
|
n_estimators=i + 1,
|
||||||
|
eval_metric="merror",
|
||||||
|
use_label_encoder=False,
|
||||||
|
)
|
||||||
|
f = client.submit(cls.fit, X, y, pure=False)
|
||||||
|
futures.append(f)
|
||||||
|
|
||||||
|
classifiers = client.gather(futures)
|
||||||
|
assert len(classifiers) == n_submits
|
||||||
|
for i, cls in enumerate(classifiers):
|
||||||
|
assert cls.get_booster().num_boosted_rounds() == i + 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_parallel_submit_multi_clients() -> None:
|
||||||
|
"""Test for running multiple train simultaneously from multiple clients."""
|
||||||
|
try:
|
||||||
|
from distributed import MultiLock # NOQA
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("`distributed.MultiLock' is not available")
|
||||||
|
|
||||||
|
from sklearn.datasets import load_digits
|
||||||
|
|
||||||
|
with LocalCluster(n_workers=4) as cluster:
|
||||||
|
with Client(cluster) as client:
|
||||||
|
workers = _get_client_workers(client)
|
||||||
|
|
||||||
|
n_submits = len(workers)
|
||||||
|
assert n_submits == 4
|
||||||
|
futures = []
|
||||||
|
|
||||||
|
for i in range(n_submits):
|
||||||
|
client = Client(cluster)
|
||||||
|
X_, y_ = load_digits(return_X_y=True)
|
||||||
|
X_ += 1.0
|
||||||
|
X = dd.from_array(X_, chunksize=32)
|
||||||
|
y = dd.from_array(y_, chunksize=32)
|
||||||
|
cls = xgb.dask.DaskXGBClassifier(
|
||||||
|
verbosity=1,
|
||||||
|
n_estimators=i + 1,
|
||||||
|
eval_metric="merror",
|
||||||
|
use_label_encoder=False,
|
||||||
|
)
|
||||||
|
f = client.submit(cls.fit, X, y, pure=False)
|
||||||
|
futures.append((client, f))
|
||||||
|
|
||||||
|
t_futures = []
|
||||||
|
with ThreadPoolExecutor(max_workers=16) as e:
|
||||||
|
for i in range(n_submits):
|
||||||
|
def _() -> xgb.dask.DaskXGBClassifier:
|
||||||
|
return futures[i][0].compute(futures[i][1]).result()
|
||||||
|
|
||||||
|
f = e.submit(_)
|
||||||
|
t_futures.append(f)
|
||||||
|
|
||||||
|
for i, f in enumerate(t_futures):
|
||||||
|
assert f.result().get_booster().num_boosted_rounds() == i + 1
|
||||||
|
|
||||||
|
|
||||||
class TestDaskCallbacks:
|
class TestDaskCallbacks:
|
||||||
@pytest.mark.skipif(**tm.no_sklearn())
|
@pytest.mark.skipif(**tm.no_sklearn())
|
||||||
def test_early_stopping(self, client: "Client") -> None:
|
def test_early_stopping(self, client: "Client") -> None:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user