[dask] Fix asyncio. (#7508)

This commit is contained in:
Jiaming Yuan 2021-12-13 01:48:25 +08:00 committed by GitHub
parent 01152f89ee
commit 05497a9141
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 4 additions and 4 deletions

View File

@ -1622,8 +1622,9 @@ class DaskScikitLearnBase(XGBModel):
should use `worker_client' instead of default client.
"""
asynchronous = getattr(self, "_asynchronous", False)
if self._client is None:
asynchronous = getattr(self, "_asynchronous", False)
try:
distributed.get_worker()
in_worker = True
@ -1636,7 +1637,7 @@ class DaskScikitLearnBase(XGBModel):
return ret
return ret
return self.client.sync(func, **kwargs, asynchronous=asynchronous)
return self.client.sync(func, **kwargs, asynchronous=self.client.asynchronous)
@xgboost_model_doc(

View File

@ -751,8 +751,7 @@ async def run_from_dask_array_asyncio(scheduler_address: str) -> xgb.dask.TrainR
async def run_dask_regressor_asyncio(scheduler_address: str) -> None:
async with Client(scheduler_address, asynchronous=True) as client:
X, y, _ = generate_array()
regressor = await xgb.dask.DaskXGBRegressor(verbosity=1,
n_estimators=2)
regressor = await xgb.dask.DaskXGBRegressor(verbosity=1, n_estimators=2)
regressor.set_params(tree_method='hist')
regressor.client = client
await regressor.fit(X, y, eval_set=[(X, y)])