Fix dask predict on DaskDMatrix with iteration_range. (#7005)
This commit is contained in:
parent
4cf95a6041
commit
89a49cf30e
@ -1217,6 +1217,8 @@ async def _predict_async(
|
|||||||
approx_contribs=approx_contribs,
|
approx_contribs=approx_contribs,
|
||||||
pred_interactions=pred_interactions,
|
pred_interactions=pred_interactions,
|
||||||
validate_features=validate_features,
|
validate_features=validate_features,
|
||||||
|
iteration_range=iteration_range,
|
||||||
|
strict_shape=strict_shape,
|
||||||
)
|
)
|
||||||
return predt
|
return predt
|
||||||
|
|
||||||
|
|||||||
@ -952,6 +952,39 @@ def test_dask_predict_leaf(booster: str, client: "Client") -> None:
|
|||||||
verify_leaf_output(leaf, num_parallel_tree)
|
verify_leaf_output(leaf, num_parallel_tree)
|
||||||
|
|
||||||
|
|
||||||
|
def test_dask_iteration_range(client: "Client"):
|
||||||
|
X, y, _ = generate_array()
|
||||||
|
n_rounds = 10
|
||||||
|
|
||||||
|
Xy = xgb.DMatrix(X.compute(), y.compute())
|
||||||
|
|
||||||
|
dXy = xgb.dask.DaskDMatrix(client, X, y)
|
||||||
|
booster = xgb.dask.train(
|
||||||
|
client, {"tree_method": "hist"}, dXy, num_boost_round=n_rounds
|
||||||
|
)["booster"]
|
||||||
|
|
||||||
|
for i in range(0, n_rounds):
|
||||||
|
iter_range = (0, i)
|
||||||
|
native_predt = booster.predict(Xy, iteration_range=iter_range)
|
||||||
|
|
||||||
|
with_dask_dmatrix = xgb.dask.predict(
|
||||||
|
client, booster, dXy, iteration_range=iter_range
|
||||||
|
)
|
||||||
|
with_dask_collection = xgb.dask.predict(
|
||||||
|
client, booster, X, iteration_range=iter_range
|
||||||
|
)
|
||||||
|
with_inplace = xgb.dask.inplace_predict(
|
||||||
|
client, booster, X, iteration_range=iter_range
|
||||||
|
)
|
||||||
|
np.testing.assert_allclose(native_predt, with_dask_dmatrix.compute())
|
||||||
|
np.testing.assert_allclose(native_predt, with_dask_collection.compute())
|
||||||
|
np.testing.assert_allclose(native_predt, with_inplace.compute())
|
||||||
|
|
||||||
|
full_predt = xgb.dask.predict(client, booster, X, iteration_range=(0, n_rounds))
|
||||||
|
default = xgb.dask.predict(client, booster, X)
|
||||||
|
np.testing.assert_allclose(full_predt.compute(), default.compute())
|
||||||
|
|
||||||
|
|
||||||
class TestWithDask:
|
class TestWithDask:
|
||||||
@pytest.mark.parametrize('config_key,config_value', [('verbosity', 0), ('use_rmm', True)])
|
@pytest.mark.parametrize('config_key,config_value', [('verbosity', 0), ('use_rmm', True)])
|
||||||
def test_global_config(
|
def test_global_config(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user