[Dask] Asyncio support. (#5862)
This commit is contained in:
@@ -2,6 +2,7 @@ import sys
|
||||
import os
|
||||
import pytest
|
||||
import numpy as np
|
||||
import asyncio
|
||||
import unittest
|
||||
import xgboost
|
||||
import subprocess
|
||||
@@ -219,7 +220,7 @@ class TestDistributedGPU(unittest.TestCase):
|
||||
with LocalCUDACluster() as cluster:
|
||||
with Client(cluster) as client:
|
||||
workers = list(dxgb._get_client_workers(client).keys())
|
||||
rabit_args = dxgb._get_rabit_args(workers, client)
|
||||
rabit_args = client.sync(dxgb._get_rabit_args, workers, client)
|
||||
futures = client.map(runit,
|
||||
workers,
|
||||
pure=False,
|
||||
@@ -242,3 +243,39 @@ class TestDistributedGPU(unittest.TestCase):
|
||||
@pytest.mark.gtest
|
||||
def test_quantile_same_on_all_workers(self):
|
||||
self.run_quantile('SameOnAllWorkers')
|
||||
|
||||
|
||||
async def run_from_dask_array_asyncio(scheduler_address):
|
||||
async with Client(scheduler_address, asynchronous=True) as client:
|
||||
import cupy as cp
|
||||
X, y = generate_array()
|
||||
X = X.map_blocks(cp.array)
|
||||
y = y.map_blocks(cp.array)
|
||||
|
||||
m = await xgboost.dask.DaskDeviceQuantileDMatrix(client, X, y)
|
||||
output = await xgboost.dask.train(client, {'tree_method': 'gpu_hist'},
|
||||
dtrain=m)
|
||||
|
||||
with_m = await xgboost.dask.predict(client, output, m)
|
||||
with_X = await xgboost.dask.predict(client, output, X)
|
||||
inplace = await xgboost.dask.inplace_predict(client, output, X)
|
||||
assert isinstance(with_m, da.Array)
|
||||
assert isinstance(with_X, da.Array)
|
||||
assert isinstance(inplace, da.Array)
|
||||
|
||||
cp.testing.assert_allclose(await client.compute(with_m),
|
||||
await client.compute(with_X))
|
||||
cp.testing.assert_allclose(await client.compute(with_m),
|
||||
await client.compute(inplace))
|
||||
|
||||
client.shutdown()
|
||||
return output
|
||||
|
||||
|
||||
def test_with_asyncio():
|
||||
with LocalCUDACluster() as cluster:
|
||||
with Client(cluster) as client:
|
||||
address = client.scheduler.address
|
||||
output = asyncio.run(run_from_dask_array_asyncio(address))
|
||||
assert isinstance(output['booster'], xgboost.Booster)
|
||||
assert isinstance(output['history'], dict)
|
||||
|
||||
Reference in New Issue
Block a user