[Dask] Asyncio support. (#5862)
This commit is contained in:
@@ -27,8 +27,10 @@ def run_rabit_ops(client, n_workers):
|
||||
from xgboost import rabit
|
||||
|
||||
workers = list(_get_client_workers(client).keys())
|
||||
rabit_args = _get_rabit_args(workers, client)
|
||||
rabit_args = client.sync(_get_rabit_args, workers, client)
|
||||
assert not rabit.is_distributed()
|
||||
n_workers_from_dask = len(workers)
|
||||
assert n_workers == n_workers_from_dask
|
||||
|
||||
def local_test(worker_id):
|
||||
with RabitContext(rabit_args):
|
||||
|
||||
@@ -4,6 +4,7 @@ import xgboost as xgb
|
||||
import sys
|
||||
import numpy as np
|
||||
import json
|
||||
import asyncio
|
||||
|
||||
if sys.platform.startswith("win"):
|
||||
pytest.skip("Skipping dask tests on Windows", allow_module_level=True)
|
||||
@@ -327,3 +328,96 @@ def test_empty_dmatrix_approx():
|
||||
parameters = {'tree_method': 'approx'}
|
||||
run_empty_dmatrix_reg(client, parameters)
|
||||
run_empty_dmatrix_cls(client, parameters)
|
||||
|
||||
|
||||
async def run_from_dask_array_asyncio(scheduler_address):
|
||||
async with Client(scheduler_address, asynchronous=True) as client:
|
||||
X, y = generate_array()
|
||||
m = await DaskDMatrix(client, X, y)
|
||||
output = await xgb.dask.train(client, {}, dtrain=m)
|
||||
|
||||
with_m = await xgb.dask.predict(client, output, m)
|
||||
with_X = await xgb.dask.predict(client, output, X)
|
||||
inplace = await xgb.dask.inplace_predict(client, output, X)
|
||||
assert isinstance(with_m, da.Array)
|
||||
assert isinstance(with_X, da.Array)
|
||||
assert isinstance(inplace, da.Array)
|
||||
|
||||
np.testing.assert_allclose(await client.compute(with_m),
|
||||
await client.compute(with_X))
|
||||
np.testing.assert_allclose(await client.compute(with_m),
|
||||
await client.compute(inplace))
|
||||
|
||||
client.shutdown()
|
||||
return output
|
||||
|
||||
|
||||
async def run_dask_regressor_asyncio(scheduler_address):
|
||||
async with Client(scheduler_address, asynchronous=True) as client:
|
||||
X, y = generate_array()
|
||||
regressor = await xgb.dask.DaskXGBRegressor(verbosity=1,
|
||||
n_estimators=2)
|
||||
regressor.set_params(tree_method='hist')
|
||||
regressor.client = client
|
||||
await regressor.fit(X, y, eval_set=[(X, y)])
|
||||
prediction = await regressor.predict(X)
|
||||
|
||||
assert prediction.ndim == 1
|
||||
assert prediction.shape[0] == kRows
|
||||
|
||||
history = regressor.evals_result()
|
||||
|
||||
assert isinstance(prediction, da.Array)
|
||||
assert isinstance(history, dict)
|
||||
|
||||
assert list(history['validation_0'].keys())[0] == 'rmse'
|
||||
assert len(history['validation_0']['rmse']) == 2
|
||||
|
||||
|
||||
async def run_dask_classifier_asyncio(scheduler_address):
|
||||
async with Client(scheduler_address, asynchronous=True) as client:
|
||||
X, y = generate_array()
|
||||
y = (y * 10).astype(np.int32)
|
||||
classifier = await xgb.dask.DaskXGBClassifier(
|
||||
verbosity=1, n_estimators=2)
|
||||
classifier.client = client
|
||||
await classifier.fit(X, y, eval_set=[(X, y)])
|
||||
prediction = await classifier.predict(X)
|
||||
|
||||
assert prediction.ndim == 1
|
||||
assert prediction.shape[0] == kRows
|
||||
|
||||
history = classifier.evals_result()
|
||||
|
||||
assert isinstance(prediction, da.Array)
|
||||
assert isinstance(history, dict)
|
||||
|
||||
assert list(history.keys())[0] == 'validation_0'
|
||||
assert list(history['validation_0'].keys())[0] == 'merror'
|
||||
assert len(list(history['validation_0'])) == 1
|
||||
assert len(history['validation_0']['merror']) == 2
|
||||
|
||||
assert classifier.n_classes_ == 10
|
||||
|
||||
# Test with dataframe.
|
||||
X_d = dd.from_dask_array(X)
|
||||
y_d = dd.from_dask_array(y)
|
||||
await classifier.fit(X_d, y_d)
|
||||
|
||||
assert classifier.n_classes_ == 10
|
||||
prediction = await classifier.predict(X_d)
|
||||
|
||||
assert prediction.ndim == 1
|
||||
assert prediction.shape[0] == kRows
|
||||
|
||||
|
||||
def test_with_asyncio():
|
||||
with LocalCluster() as cluster:
|
||||
with Client(cluster) as client:
|
||||
address = client.scheduler.address
|
||||
output = asyncio.run(run_from_dask_array_asyncio(address))
|
||||
assert isinstance(output['booster'], xgb.Booster)
|
||||
assert isinstance(output['history'], dict)
|
||||
|
||||
asyncio.run(run_dask_regressor_asyncio(address))
|
||||
asyncio.run(run_dask_classifier_asyncio(address))
|
||||
|
||||
Reference in New Issue
Block a user