[dask] Accept Future of model for prediction. (#6650)

This PR changes predict and inplace_predict to accept a Future of model, to avoid sending models to workers repeatably.

* Document is updated to reflect functionality additions in recent changes.
This commit is contained in:
Jiaming Yuan
2021-02-02 08:45:52 +08:00
committed by GitHub
parent a9ec0ea6da
commit 87ab1ad607
4 changed files with 150 additions and 78 deletions

View File

@@ -608,28 +608,30 @@ def test_with_asyncio() -> None:
asyncio.run(run_dask_classifier_asyncio(address))
def test_predict() -> None:
with LocalCluster(n_workers=kWorkers) as cluster:
with Client(cluster) as client:
X, y, _ = generate_array()
dtrain = DaskDMatrix(client, X, y)
booster = xgb.dask.train(
client, {}, dtrain, num_boost_round=2)['booster']
def test_predict(client: "Client") -> None:
X, y, _ = generate_array()
dtrain = DaskDMatrix(client, X, y)
booster = xgb.dask.train(client, {}, dtrain, num_boost_round=2)["booster"]
pred = xgb.dask.predict(client, model=booster, data=dtrain)
assert pred.ndim == 1
assert pred.shape[0] == kRows
predt_0 = xgb.dask.predict(client, model=booster, data=dtrain)
assert predt_0.ndim == 1
assert predt_0.shape[0] == kRows
margin = xgb.dask.predict(client, model=booster, data=dtrain,
output_margin=True)
assert margin.ndim == 1
assert margin.shape[0] == kRows
margin = xgb.dask.predict(client, model=booster, data=dtrain, output_margin=True)
assert margin.ndim == 1
assert margin.shape[0] == kRows
shap = xgb.dask.predict(client, model=booster, data=dtrain,
pred_contribs=True)
assert shap.ndim == 2
assert shap.shape[0] == kRows
assert shap.shape[1] == kCols + 1
shap = xgb.dask.predict(client, model=booster, data=dtrain, pred_contribs=True)
assert shap.ndim == 2
assert shap.shape[0] == kRows
assert shap.shape[1] == kCols + 1
booster_f = client.scatter(booster, broadcast=True)
predt_1 = xgb.dask.predict(client, booster_f, X).compute()
predt_2 = xgb.dask.inplace_predict(client, booster_f, X).compute()
np.testing.assert_allclose(predt_0, predt_1)
np.testing.assert_allclose(predt_0, predt_2)
def test_predict_with_meta(client: "Client") -> None:
@@ -1034,7 +1036,7 @@ class TestWithDask:
rows = X.shape[0]
cols = X.shape[1]
def assert_shape(shape):
def assert_shape(shape: Tuple[int, ...]) -> None:
assert shape[0] == rows
if "num_class" in params.keys():
assert shape[1] == params["num_class"]