From 89a49cf30ea7fbcd9ae31a7cbe53407e6fbd2bb0 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Sat, 29 May 2021 04:43:12 +0800 Subject: [PATCH] Fix dask predict on `DaskDMatrix` with `iteration_range`. (#7005) --- python-package/xgboost/dask.py | 2 ++ tests/python/test_with_dask.py | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 35 insertions(+) diff --git a/python-package/xgboost/dask.py b/python-package/xgboost/dask.py index b1d791051..d38ebbc50 100644 --- a/python-package/xgboost/dask.py +++ b/python-package/xgboost/dask.py @@ -1217,6 +1217,8 @@ async def _predict_async( approx_contribs=approx_contribs, pred_interactions=pred_interactions, validate_features=validate_features, + iteration_range=iteration_range, + strict_shape=strict_shape, ) return predt diff --git a/tests/python/test_with_dask.py b/tests/python/test_with_dask.py index c84a37d18..1e7b7587b 100644 --- a/tests/python/test_with_dask.py +++ b/tests/python/test_with_dask.py @@ -952,6 +952,39 @@ def test_dask_predict_leaf(booster: str, client: "Client") -> None: verify_leaf_output(leaf, num_parallel_tree) +def test_dask_iteration_range(client: "Client"): + X, y, _ = generate_array() + n_rounds = 10 + + Xy = xgb.DMatrix(X.compute(), y.compute()) + + dXy = xgb.dask.DaskDMatrix(client, X, y) + booster = xgb.dask.train( + client, {"tree_method": "hist"}, dXy, num_boost_round=n_rounds + )["booster"] + + for i in range(0, n_rounds): + iter_range = (0, i) + native_predt = booster.predict(Xy, iteration_range=iter_range) + + with_dask_dmatrix = xgb.dask.predict( + client, booster, dXy, iteration_range=iter_range + ) + with_dask_collection = xgb.dask.predict( + client, booster, X, iteration_range=iter_range + ) + with_inplace = xgb.dask.inplace_predict( + client, booster, X, iteration_range=iter_range + ) + np.testing.assert_allclose(native_predt, with_dask_dmatrix.compute()) + np.testing.assert_allclose(native_predt, with_dask_collection.compute()) + np.testing.assert_allclose(native_predt, with_inplace.compute()) + + full_predt = xgb.dask.predict(client, booster, X, iteration_range=(0, n_rounds)) + default = xgb.dask.predict(client, booster, X) + np.testing.assert_allclose(full_predt.compute(), default.compute()) + + class TestWithDask: @pytest.mark.parametrize('config_key,config_value', [('verbosity', 0), ('use_rmm', True)]) def test_global_config(