diff --git a/python-package/xgboost/dask.py b/python-package/xgboost/dask.py index 60c7ae290..4bdeb49e5 100644 --- a/python-package/xgboost/dask.py +++ b/python-package/xgboost/dask.py @@ -17,6 +17,7 @@ https://github.com/dask/dask-xgboost """ import platform import logging +from contextlib import contextmanager from collections import defaultdict from collections.abc import Sequence from threading import Thread @@ -93,6 +94,34 @@ except ImportError: LOGGER = logging.getLogger('[xgboost.dask]') +def _multi_lock() -> Any: + """MultiLock is only available on latest distributed. See: + + https://github.com/dask/distributed/pull/4503 + +""" + try: + from distributed import MultiLock + except ImportError: + class MultiLock: # type:ignore + def __init__(self, *args: Any, **kwargs: Any) -> None: + pass + + def __enter__(self) -> "MultiLock": + return self + + def __exit__(self, *args: Any, **kwargs: Any) -> None: + return + + async def __aenter__(self) -> "MultiLock": + return self + + async def __aexit__(self, *args: Any, **kwargs: Any) -> None: + return + + return MultiLock + + def _start_tracker(n_workers: int) -> Dict[str, Any]: """Start Rabit tracker """ env = {'DMLC_NUM_WORKER': n_workers} @@ -770,7 +799,7 @@ async def _get_rabit_args(n_workers: int, client: "distributed.Client") -> List[ def _get_workers_from_data( dtrain: DaskDMatrix, evals: Optional[List[Tuple[DaskDMatrix, str]]] -) -> Set[str]: +) -> List[str]: X_worker_map: Set[str] = set(dtrain.worker_map.keys()) if evals: for e in evals: @@ -780,7 +809,7 @@ def _get_workers_from_data( continue worker_map = set(e[0].worker_map.keys()) X_worker_map = X_worker_map.union(worker_map) - return X_worker_map + return list(X_worker_map) async def _train_async( @@ -795,9 +824,9 @@ async def _train_async( early_stopping_rounds: Optional[int], verbose_eval: Union[int, bool], xgb_model: Optional[Booster], - callbacks: Optional[List[TrainingCallback]] + callbacks: Optional[List[TrainingCallback]], ) -> Optional[TrainReturnT]: - workers = list(_get_workers_from_data(dtrain, evals)) + workers = _get_workers_from_data(dtrain, evals) _rabit_args = await _get_rabit_args(len(workers), client) if params.get("booster", None) == "gblinear": @@ -858,29 +887,32 @@ async def _train_async( # XGBoost is deterministic in most of the cases, which means train function is # supposed to be idempotent. One known exception is gblinear with shotgun updater. # We haven't been able to do a full verification so here we keep pure to be False. - futures = [] - for i, worker_addr in enumerate(workers): - if evals: - # pylint: disable=protected-access - evals_per_worker = [(e._create_fn_args(worker_addr), name, id(e)) - for e, name in evals] - else: - evals_per_worker = [] - f = client.submit( - dispatched_train, - worker_addr, - _rabit_args, - # pylint: disable=protected-access - dtrain._create_fn_args(workers[i]), - id(dtrain), - evals_per_worker, - pure=False, - workers=[worker_addr] - ) - futures.append(f) + async with _multi_lock()(workers, client): + futures = [] + for worker_addr in workers: + if evals: + # pylint: disable=protected-access + evals_per_worker = [ + (e._create_fn_args(worker_addr), name, id(e)) for e, name in evals + ] + else: + evals_per_worker = [] + f = client.submit( + dispatched_train, + worker_addr, + _rabit_args, + # pylint: disable=protected-access + dtrain._create_fn_args(worker_addr), + id(dtrain), + evals_per_worker, + pure=False, + workers=[worker_addr], + ) + futures.append(f) - results = await client.gather(futures) - return list(filter(lambda ret: ret is not None, results))[0] + results = await client.gather(futures, asynchronous=True) + + return list(filter(lambda ret: ret is not None, results))[0] def train( # pylint: disable=unused-argument @@ -927,9 +959,8 @@ def train( # pylint: disable=unused-argument """ _assert_dask_support() client = _xgb_get_client(client) - # Get global configuration before transferring computation to another thread or - # process. - return client.sync(_train_async, global_config=config.get_config(), **locals()) + args = locals() + return client.sync(_train_async, global_config=config.get_config(), **args) def _can_output_df(is_df: bool, output_shape: Tuple) -> bool: @@ -1366,6 +1397,9 @@ def inplace_predict( # pylint: disable=unused-argument """ _assert_dask_support() client = _xgb_get_client(client) + # When used in asynchronous environment, the `client` object should have + # `asynchronous` attribute as True. When invoked by the skl interface, it's + # responsible for setting up the client. return client.sync( _inplace_predict_async, global_config=config.get_config(), **locals() ) @@ -1393,6 +1427,18 @@ async def _async_wrap_evaluation_matrices( return train_dmatrix, awaited +@contextmanager +def _set_worker_client( + model: "DaskScikitLearnBase", client: "distributed.Client" +) -> Generator: + """Temporarily set the client for sklearn model.""" + try: + model.client = client + yield model + finally: + model.client = None + + class DaskScikitLearnBase(XGBModel): """Base class for implementing scikit-learn interface with Dask""" @@ -1487,7 +1533,7 @@ class DaskScikitLearnBase(XGBModel): async def _() -> Awaitable[Any]: return self - return self.client.sync(_).__await__() + return self._client_sync(_).__await__() def __getstate__(self) -> Dict: this = self.__dict__.copy() @@ -1497,14 +1543,43 @@ class DaskScikitLearnBase(XGBModel): @property def client(self) -> "distributed.Client": - """The dask client used in this model.""" + """The dask client used in this model. The `Client` object can not be serialized for + transmission, so if task is launched from a worker instead of directly from the + client process, this attribute needs to be set at that worker. + + """ + client = _xgb_get_client(self._client) return client @client.setter def client(self, clt: "distributed.Client") -> None: + # calling `worker_client' doesn't return the correct `asynchronous` attribute, so + # we have to pass it ourselves. + self._asynchronous = clt.asynchronous if clt is not None else False self._client = clt + def _client_sync(self, func: Callable, **kwargs: Any) -> Any: + """Get the correct client, when method is invoked inside a worker we + should use `worker_client' instead of default client. + + """ + asynchronous = getattr(self, "_asynchronous", False) + if self._client is None: + try: + distributed.get_worker() + in_worker = True + except ValueError: + in_worker = False + if in_worker: + with distributed.worker_client() as client: + with _set_worker_client(self, client) as this: + ret = this.client.sync(func, **kwargs, asynchronous=asynchronous) + return ret + return ret + + return self.client.sync(func, **kwargs, asynchronous=asynchronous) + @xgboost_model_doc( """Implementation of the Scikit-Learn API for XGBoost.""", ["estimators", "model"] @@ -1552,22 +1627,24 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase): model, metric, params = self._configure_fit( booster=xgb_model, eval_metric=eval_metric, params=params ) - results = await train( + results = await self.client.sync( + _train_async, + asynchronous=True, client=self.client, + global_config=config.get_config(), params=params, dtrain=dtrain, num_boost_round=self.get_num_boosting_rounds(), evals=evals, - feval=metric, obj=obj, + feval=metric, verbose_eval=verbose, early_stopping_rounds=early_stopping_rounds, callbacks=callbacks, xgb_model=model, ) self._Booster = results["booster"] - # pylint: disable=attribute-defined-outside-init - self.evals_result_ = results["history"] + self._set_evaluation_result(results["history"]) return self # pylint: disable=missing-docstring, disable=unused-argument @@ -1591,7 +1668,7 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase): ) -> "DaskXGBRegressor": _assert_dask_support() args = {k: v for k, v in locals().items() if k != "self"} - return self.client.sync(self._fit_async, **args) + return self._client_sync(self._fit_async, **args) @xgboost_model_doc( @@ -1651,8 +1728,11 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase): model, metric, params = self._configure_fit( booster=xgb_model, eval_metric=eval_metric, params=params ) - results = await train( + results = await self.client.sync( + _train_async, + asynchronous=True, client=self.client, + global_config=config.get_config(), params=params, dtrain=dtrain, num_boost_round=self.get_num_boosting_rounds(), @@ -1665,16 +1745,12 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase): xgb_model=model, ) self._Booster = results['booster'] - if not callable(self.objective): self.objective = params["objective"] - - # pylint: disable=attribute-defined-outside-init - self.evals_result_ = results['history'] + self._set_evaluation_result(results["history"]) return self # pylint: disable=unused-argument - @_deprecate_positional_args def fit( self, X: _DaskCollection, @@ -1694,7 +1770,7 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase): ) -> "DaskXGBClassifier": _assert_dask_support() args = {k: v for k, v in locals().items() if k != 'self'} - return self.client.sync(self._fit_async, **args) + return self._client_sync(self._fit_async, **args) async def _predict_proba_async( self, @@ -1728,7 +1804,7 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase): _assert_dask_support() msg = "`ntree_limit` is not supported on dask, use `iteration_range` instead." assert ntree_limit is None, msg - return self.client.sync( + return self._client_sync( self._predict_proba_async, X=X, validate_features=validate_features, @@ -1838,12 +1914,16 @@ class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn): model, metric, params = self._configure_fit( booster=xgb_model, eval_metric=eval_metric, params=params ) - results = await train( + results = await self.client.sync( + _train_async, + asynchronous=True, client=self.client, + global_config=config.get_config(), params=params, dtrain=dtrain, num_boost_round=self.get_num_boosting_rounds(), evals=evals, + obj=None, feval=metric, verbose_eval=verbose, early_stopping_rounds=early_stopping_rounds, @@ -1879,7 +1959,7 @@ class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn): ) -> "DaskXGBRanker": _assert_dask_support() args = {k: v for k, v in locals().items() if k != "self"} - return self.client.sync(self._fit_async, **args) + return self._client_sync(self._fit_async, **args) # FIXME(trivialfis): arguments differ due to additional parameters like group and qid. fit.__doc__ = XGBRanker.fit.__doc__ diff --git a/tests/python-gpu/test_gpu_with_dask.py b/tests/python-gpu/test_gpu_with_dask.py index 4851dc512..3633efd19 100644 --- a/tests/python-gpu/test_gpu_with_dask.py +++ b/tests/python-gpu/test_gpu_with_dask.py @@ -293,7 +293,7 @@ class TestDistributedGPU: fw = fw - fw.min() m = dxgb.DaskDMatrix(client, X, y, feature_weights=fw) - workers = list(_get_client_workers(client).keys()) + workers = _get_client_workers(client) rabit_args = client.sync(dxgb._get_rabit_args, len(workers), client) def worker_fn(worker_addr: str, data_ref: Dict) -> None: @@ -384,7 +384,7 @@ class TestDistributedGPU: return subprocess.run([str(exe), test], env=env, stdout=subprocess.PIPE) with Client(local_cuda_cluster) as client: - workers = list(_get_client_workers(client).keys()) + workers = _get_client_workers(client) rabit_args = client.sync(dxgb._get_rabit_args, workers, client) futures = client.map(runit, workers, diff --git a/tests/python/test_tracker.py b/tests/python/test_tracker.py index 93a62ea56..a6490f50c 100644 --- a/tests/python/test_tracker.py +++ b/tests/python/test_tracker.py @@ -27,7 +27,7 @@ def run_rabit_ops(client, n_workers): from xgboost.dask import RabitContext, _get_rabit_args from xgboost import rabit - workers = list(_get_client_workers(client).keys()) + workers = _get_client_workers(client) rabit_args = client.sync(_get_rabit_args, len(workers), client) assert not rabit.is_distributed() n_workers_from_dask = len(workers) diff --git a/tests/python/test_with_dask.py b/tests/python/test_with_dask.py index 6bd7c5dcf..5c6418a4c 100644 --- a/tests/python/test_with_dask.py +++ b/tests/python/test_with_dask.py @@ -9,6 +9,7 @@ import scipy import json from typing import List, Tuple, Dict, Optional, Type, Any import asyncio +from concurrent.futures import ThreadPoolExecutor import tempfile from sklearn.datasets import make_classification import sklearn @@ -43,9 +44,9 @@ kCols = 10 kWorkers = 5 -def _get_client_workers(client: "Client") -> Dict[str, Dict]: +def _get_client_workers(client: "Client") -> List[str]: workers = client.scheduler_info()['workers'] - return workers + return list(workers.keys()) def generate_array( @@ -646,7 +647,7 @@ def test_with_asyncio() -> None: async def generate_concurrent_trainings() -> None: - async def train(): + async def train() -> None: async with LocalCluster(n_workers=2, threads_per_worker=1, asynchronous=True, @@ -967,7 +968,7 @@ class TestWithDask: with LocalCluster(n_workers=4) as cluster: with Client(cluster) as client: - workers = list(_get_client_workers(client).keys()) + workers = _get_client_workers(client) rabit_args = client.sync( xgb.dask._get_rabit_args, len(workers), client) futures = client.map(runit, @@ -1000,7 +1001,7 @@ class TestWithDask: def test_n_workers(self) -> None: with LocalCluster(n_workers=2) as cluster: with Client(cluster) as client: - workers = list(_get_client_workers(client).keys()) + workers = _get_client_workers(client) from sklearn.datasets import load_breast_cancer X, y = load_breast_cancer(return_X_y=True) dX = client.submit(da.from_array, X, workers=[workers[0]]).result() @@ -1090,7 +1091,7 @@ class TestWithDask: X, y, _ = generate_array() n_partitions = X.npartitions m = xgb.dask.DaskDMatrix(client, X, y) - workers = list(_get_client_workers(client).keys()) + workers = _get_client_workers(client) rabit_args = client.sync(xgb.dask._get_rabit_args, len(workers), client) n_workers = len(workers) @@ -1285,6 +1286,82 @@ def test_dask_unsupported_features(client: "Client") -> None: ) +def test_parallel_submits(client: "Client") -> None: + """Test for running multiple train simultaneously from single clients.""" + try: + from distributed import MultiLock # NOQA + except ImportError: + pytest.skip("`distributed.MultiLock' is not available") + + from sklearn.datasets import load_digits + + futures = [] + workers = _get_client_workers(client) + n_submits = len(workers) + for i in range(n_submits): + X_, y_ = load_digits(return_X_y=True) + X = dd.from_array(X_, chunksize=32) + y = dd.from_array(y_, chunksize=32) + cls = xgb.dask.DaskXGBClassifier( + verbosity=1, + n_estimators=i + 1, + eval_metric="merror", + use_label_encoder=False, + ) + f = client.submit(cls.fit, X, y, pure=False) + futures.append(f) + + classifiers = client.gather(futures) + assert len(classifiers) == n_submits + for i, cls in enumerate(classifiers): + assert cls.get_booster().num_boosted_rounds() == i + 1 + + +def test_parallel_submit_multi_clients() -> None: + """Test for running multiple train simultaneously from multiple clients.""" + try: + from distributed import MultiLock # NOQA + except ImportError: + pytest.skip("`distributed.MultiLock' is not available") + + from sklearn.datasets import load_digits + + with LocalCluster(n_workers=4) as cluster: + with Client(cluster) as client: + workers = _get_client_workers(client) + + n_submits = len(workers) + assert n_submits == 4 + futures = [] + + for i in range(n_submits): + client = Client(cluster) + X_, y_ = load_digits(return_X_y=True) + X_ += 1.0 + X = dd.from_array(X_, chunksize=32) + y = dd.from_array(y_, chunksize=32) + cls = xgb.dask.DaskXGBClassifier( + verbosity=1, + n_estimators=i + 1, + eval_metric="merror", + use_label_encoder=False, + ) + f = client.submit(cls.fit, X, y, pure=False) + futures.append((client, f)) + + t_futures = [] + with ThreadPoolExecutor(max_workers=16) as e: + for i in range(n_submits): + def _() -> xgb.dask.DaskXGBClassifier: + return futures[i][0].compute(futures[i][1]).result() + + f = e.submit(_) + t_futures.append(f) + + for i, f in enumerate(t_futures): + assert f.result().get_booster().num_boosted_rounds() == i + 1 + + class TestDaskCallbacks: @pytest.mark.skipif(**tm.no_sklearn()) def test_early_stopping(self, client: "Client") -> None: