diff --git a/python-package/xgboost/dask.py b/python-package/xgboost/dask.py index 5fa37dedd..431a80d26 100644 --- a/python-package/xgboost/dask.py +++ b/python-package/xgboost/dask.py @@ -1606,8 +1606,9 @@ class DaskScikitLearnBase(XGBModel): should use `worker_client' instead of default client. """ - asynchronous = getattr(self, "_asynchronous", False) + if self._client is None: + asynchronous = getattr(self, "_asynchronous", False) try: distributed.get_worker() in_worker = True @@ -1620,7 +1621,7 @@ class DaskScikitLearnBase(XGBModel): return ret return ret - return self.client.sync(func, **kwargs, asynchronous=asynchronous) + return self.client.sync(func, **kwargs, asynchronous=self.client.asynchronous) @xgboost_model_doc( diff --git a/tests/python/test_with_dask.py b/tests/python/test_with_dask.py index 343eff97b..f26097a83 100644 --- a/tests/python/test_with_dask.py +++ b/tests/python/test_with_dask.py @@ -705,8 +705,7 @@ async def run_from_dask_array_asyncio(scheduler_address: str) -> xgb.dask.TrainR async def run_dask_regressor_asyncio(scheduler_address: str) -> None: async with Client(scheduler_address, asynchronous=True) as client: X, y, _ = generate_array() - regressor = await xgb.dask.DaskXGBRegressor(verbosity=1, - n_estimators=2) + regressor = await xgb.dask.DaskXGBRegressor(verbosity=1, n_estimators=2) regressor.set_params(tree_method='hist') regressor.client = client await regressor.fit(X, y, eval_set=[(X, y)])