diff --git a/tests/python/test_with_dask.py b/tests/python/test_with_dask.py index a83079010..789a66e7f 100644 --- a/tests/python/test_with_dask.py +++ b/tests/python/test_with_dask.py @@ -184,6 +184,9 @@ def test_dask_predict_shape_infer(client: "Client") -> None: def run_boost_from_prediction( X: xgb.dask._DaskCollection, y: xgb.dask._DaskCollection, tree_method: str, client: "Client" ) -> None: + X = client.persist(X) + y = client.persist(y) + model_0 = xgb.dask.DaskXGBClassifier( learning_rate=0.3, random_state=0, n_estimators=4, tree_method=tree_method)