Fix dask predict on DaskDMatrix with iteration_range. (#7005)

This commit is contained in:
Jiaming Yuan 2021-05-29 04:43:12 +08:00 committed by GitHub
parent 4cf95a6041
commit 89a49cf30e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 35 additions and 0 deletions

View File

@ -1217,6 +1217,8 @@ async def _predict_async(
approx_contribs=approx_contribs, approx_contribs=approx_contribs,
pred_interactions=pred_interactions, pred_interactions=pred_interactions,
validate_features=validate_features, validate_features=validate_features,
iteration_range=iteration_range,
strict_shape=strict_shape,
) )
return predt return predt

View File

@ -952,6 +952,39 @@ def test_dask_predict_leaf(booster: str, client: "Client") -> None:
verify_leaf_output(leaf, num_parallel_tree) verify_leaf_output(leaf, num_parallel_tree)
def test_dask_iteration_range(client: "Client"):
X, y, _ = generate_array()
n_rounds = 10
Xy = xgb.DMatrix(X.compute(), y.compute())
dXy = xgb.dask.DaskDMatrix(client, X, y)
booster = xgb.dask.train(
client, {"tree_method": "hist"}, dXy, num_boost_round=n_rounds
)["booster"]
for i in range(0, n_rounds):
iter_range = (0, i)
native_predt = booster.predict(Xy, iteration_range=iter_range)
with_dask_dmatrix = xgb.dask.predict(
client, booster, dXy, iteration_range=iter_range
)
with_dask_collection = xgb.dask.predict(
client, booster, X, iteration_range=iter_range
)
with_inplace = xgb.dask.inplace_predict(
client, booster, X, iteration_range=iter_range
)
np.testing.assert_allclose(native_predt, with_dask_dmatrix.compute())
np.testing.assert_allclose(native_predt, with_dask_collection.compute())
np.testing.assert_allclose(native_predt, with_inplace.compute())
full_predt = xgb.dask.predict(client, booster, X, iteration_range=(0, n_rounds))
default = xgb.dask.predict(client, booster, X)
np.testing.assert_allclose(full_predt.compute(), default.compute())
class TestWithDask: class TestWithDask:
@pytest.mark.parametrize('config_key,config_value', [('verbosity', 0), ('use_rmm', True)]) @pytest.mark.parametrize('config_key,config_value', [('verbosity', 0), ('use_rmm', True)])
def test_global_config( def test_global_config(