[dask] Fix asyncio. (#7508)
This commit is contained in:
parent
01152f89ee
commit
05497a9141
@ -1622,8 +1622,9 @@ class DaskScikitLearnBase(XGBModel):
|
||||
should use `worker_client' instead of default client.
|
||||
|
||||
"""
|
||||
asynchronous = getattr(self, "_asynchronous", False)
|
||||
|
||||
if self._client is None:
|
||||
asynchronous = getattr(self, "_asynchronous", False)
|
||||
try:
|
||||
distributed.get_worker()
|
||||
in_worker = True
|
||||
@ -1636,7 +1637,7 @@ class DaskScikitLearnBase(XGBModel):
|
||||
return ret
|
||||
return ret
|
||||
|
||||
return self.client.sync(func, **kwargs, asynchronous=asynchronous)
|
||||
return self.client.sync(func, **kwargs, asynchronous=self.client.asynchronous)
|
||||
|
||||
|
||||
@xgboost_model_doc(
|
||||
|
||||
@ -751,8 +751,7 @@ async def run_from_dask_array_asyncio(scheduler_address: str) -> xgb.dask.TrainR
|
||||
async def run_dask_regressor_asyncio(scheduler_address: str) -> None:
|
||||
async with Client(scheduler_address, asynchronous=True) as client:
|
||||
X, y, _ = generate_array()
|
||||
regressor = await xgb.dask.DaskXGBRegressor(verbosity=1,
|
||||
n_estimators=2)
|
||||
regressor = await xgb.dask.DaskXGBRegressor(verbosity=1, n_estimators=2)
|
||||
regressor.set_params(tree_method='hist')
|
||||
regressor.client = client
|
||||
await regressor.fit(X, y, eval_set=[(X, y)])
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user