parent
afb9dfd421
commit
3e2d7519a6
@ -1606,8 +1606,9 @@ class DaskScikitLearnBase(XGBModel):
|
|||||||
should use `worker_client' instead of default client.
|
should use `worker_client' instead of default client.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
asynchronous = getattr(self, "_asynchronous", False)
|
|
||||||
if self._client is None:
|
if self._client is None:
|
||||||
|
asynchronous = getattr(self, "_asynchronous", False)
|
||||||
try:
|
try:
|
||||||
distributed.get_worker()
|
distributed.get_worker()
|
||||||
in_worker = True
|
in_worker = True
|
||||||
@ -1620,7 +1621,7 @@ class DaskScikitLearnBase(XGBModel):
|
|||||||
return ret
|
return ret
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
return self.client.sync(func, **kwargs, asynchronous=asynchronous)
|
return self.client.sync(func, **kwargs, asynchronous=self.client.asynchronous)
|
||||||
|
|
||||||
|
|
||||||
@xgboost_model_doc(
|
@xgboost_model_doc(
|
||||||
|
|||||||
@ -705,8 +705,7 @@ async def run_from_dask_array_asyncio(scheduler_address: str) -> xgb.dask.TrainR
|
|||||||
async def run_dask_regressor_asyncio(scheduler_address: str) -> None:
|
async def run_dask_regressor_asyncio(scheduler_address: str) -> None:
|
||||||
async with Client(scheduler_address, asynchronous=True) as client:
|
async with Client(scheduler_address, asynchronous=True) as client:
|
||||||
X, y, _ = generate_array()
|
X, y, _ = generate_array()
|
||||||
regressor = await xgb.dask.DaskXGBRegressor(verbosity=1,
|
regressor = await xgb.dask.DaskXGBRegressor(verbosity=1, n_estimators=2)
|
||||||
n_estimators=2)
|
|
||||||
regressor.set_params(tree_method='hist')
|
regressor.set_params(tree_method='hist')
|
||||||
regressor.client = client
|
regressor.client = client
|
||||||
await regressor.fit(X, y, eval_set=[(X, y)])
|
await regressor.fit(X, y, eval_set=[(X, y)])
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user