[dask] Use client to persist collections (#6722)
Co-authored-by: fis <jm.yuan@outlook.com>
This commit is contained in:
parent
9b530e5697
commit
b6167cd2ff
@ -303,11 +303,11 @@ class DaskDMatrix:
|
|||||||
' of columns for your dask Array explicitly. e.g.' \
|
' of columns for your dask Array explicitly. e.g.' \
|
||||||
' chunks=(partition_size, X.shape[1])'
|
' chunks=(partition_size, X.shape[1])'
|
||||||
|
|
||||||
data = data.persist()
|
data = client.persist(data)
|
||||||
for meta in [label, weights, base_margin, label_lower_bound,
|
for meta in [label, weights, base_margin, label_lower_bound,
|
||||||
label_upper_bound]:
|
label_upper_bound]:
|
||||||
if meta is not None:
|
if meta is not None:
|
||||||
meta = meta.persist()
|
meta = client.persist(meta)
|
||||||
# Breaking data into partitions, a trick borrowed from dask_xgboost.
|
# Breaking data into partitions, a trick borrowed from dask_xgboost.
|
||||||
|
|
||||||
# `to_delayed` downgrades high-level objects into numpy or pandas
|
# `to_delayed` downgrades high-level objects into numpy or pandas
|
||||||
|
|||||||
@ -563,8 +563,6 @@ async def run_from_dask_array_asyncio(scheduler_address: str) -> xgb.dask.TrainR
|
|||||||
await client.compute(with_X))
|
await client.compute(with_X))
|
||||||
np.testing.assert_allclose(await client.compute(with_m),
|
np.testing.assert_allclose(await client.compute(with_m),
|
||||||
await client.compute(inplace))
|
await client.compute(inplace))
|
||||||
|
|
||||||
client.shutdown()
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
@ -647,6 +645,25 @@ def test_with_asyncio() -> None:
|
|||||||
asyncio.run(run_dask_classifier_asyncio(address))
|
asyncio.run(run_dask_classifier_asyncio(address))
|
||||||
|
|
||||||
|
|
||||||
|
async def generate_concurrent_trainings() -> None:
|
||||||
|
async def train():
|
||||||
|
async with LocalCluster(n_workers=2,
|
||||||
|
threads_per_worker=1,
|
||||||
|
asynchronous=True,
|
||||||
|
dashboard_address=0) as cluster:
|
||||||
|
async with Client(cluster, asynchronous=True) as client:
|
||||||
|
X, y, w = generate_array(with_weights=True)
|
||||||
|
dtrain = await DaskDMatrix(client, X, y, weight=w)
|
||||||
|
dvalid = await DaskDMatrix(client, X, y, weight=w)
|
||||||
|
output = await xgb.dask.train(client, {}, dtrain=dtrain)
|
||||||
|
await xgb.dask.predict(client, output, data=dvalid)
|
||||||
|
await asyncio.gather(train(), train())
|
||||||
|
|
||||||
|
|
||||||
|
def test_concurrent_trainings() -> None:
|
||||||
|
asyncio.run(generate_concurrent_trainings())
|
||||||
|
|
||||||
|
|
||||||
def test_predict(client: "Client") -> None:
|
def test_predict(client: "Client") -> None:
|
||||||
X, y, _ = generate_array()
|
X, y, _ = generate_array()
|
||||||
dtrain = DaskDMatrix(client, X, y)
|
dtrain = DaskDMatrix(client, X, y)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user