[dask] Accept Future of model for prediction. (#6650)
This PR changes predict and inplace_predict to accept a Future of model, to avoid sending models to workers repeatably. * Document is updated to reflect functionality additions in recent changes.
This commit is contained in:
@@ -608,28 +608,30 @@ def test_with_asyncio() -> None:
|
||||
asyncio.run(run_dask_classifier_asyncio(address))
|
||||
|
||||
|
||||
def test_predict() -> None:
|
||||
with LocalCluster(n_workers=kWorkers) as cluster:
|
||||
with Client(cluster) as client:
|
||||
X, y, _ = generate_array()
|
||||
dtrain = DaskDMatrix(client, X, y)
|
||||
booster = xgb.dask.train(
|
||||
client, {}, dtrain, num_boost_round=2)['booster']
|
||||
def test_predict(client: "Client") -> None:
|
||||
X, y, _ = generate_array()
|
||||
dtrain = DaskDMatrix(client, X, y)
|
||||
booster = xgb.dask.train(client, {}, dtrain, num_boost_round=2)["booster"]
|
||||
|
||||
pred = xgb.dask.predict(client, model=booster, data=dtrain)
|
||||
assert pred.ndim == 1
|
||||
assert pred.shape[0] == kRows
|
||||
predt_0 = xgb.dask.predict(client, model=booster, data=dtrain)
|
||||
assert predt_0.ndim == 1
|
||||
assert predt_0.shape[0] == kRows
|
||||
|
||||
margin = xgb.dask.predict(client, model=booster, data=dtrain,
|
||||
output_margin=True)
|
||||
assert margin.ndim == 1
|
||||
assert margin.shape[0] == kRows
|
||||
margin = xgb.dask.predict(client, model=booster, data=dtrain, output_margin=True)
|
||||
assert margin.ndim == 1
|
||||
assert margin.shape[0] == kRows
|
||||
|
||||
shap = xgb.dask.predict(client, model=booster, data=dtrain,
|
||||
pred_contribs=True)
|
||||
assert shap.ndim == 2
|
||||
assert shap.shape[0] == kRows
|
||||
assert shap.shape[1] == kCols + 1
|
||||
shap = xgb.dask.predict(client, model=booster, data=dtrain, pred_contribs=True)
|
||||
assert shap.ndim == 2
|
||||
assert shap.shape[0] == kRows
|
||||
assert shap.shape[1] == kCols + 1
|
||||
|
||||
booster_f = client.scatter(booster, broadcast=True)
|
||||
|
||||
predt_1 = xgb.dask.predict(client, booster_f, X).compute()
|
||||
predt_2 = xgb.dask.inplace_predict(client, booster_f, X).compute()
|
||||
np.testing.assert_allclose(predt_0, predt_1)
|
||||
np.testing.assert_allclose(predt_0, predt_2)
|
||||
|
||||
|
||||
def test_predict_with_meta(client: "Client") -> None:
|
||||
@@ -1034,7 +1036,7 @@ class TestWithDask:
|
||||
rows = X.shape[0]
|
||||
cols = X.shape[1]
|
||||
|
||||
def assert_shape(shape):
|
||||
def assert_shape(shape: Tuple[int, ...]) -> None:
|
||||
assert shape[0] == rows
|
||||
if "num_class" in params.keys():
|
||||
assert shape[1] == params["num_class"]
|
||||
|
||||
Reference in New Issue
Block a user