From fcae6301eced44e77b72a8f8920906366f839c78 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Thu, 25 Jul 2024 04:16:34 +0800 Subject: [PATCH] [dask] Disable `broadcast` in the `scatter` call. (#10632) --- python-package/xgboost/dask/__init__.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/python-package/xgboost/dask/__init__.py b/python-package/xgboost/dask/__init__.py index bd23df21b..099d28122 100644 --- a/python-package/xgboost/dask/__init__.py +++ b/python-package/xgboost/dask/__init__.py @@ -1223,12 +1223,14 @@ def _infer_predict_output( async def _get_model_future( client: "distributed.Client", model: Union[Booster, Dict, "distributed.Future"] ) -> "distributed.Future": - # See https://github.com/dask/dask/issues/11179#issuecomment-2168094529 for - # the use of hash. + # See https://github.com/dask/dask/issues/11179#issuecomment-2168094529 for the use + # of hash. + # https://github.com/dask/distributed/pull/8796 Don't use broadcast in the `scatter` + # call, otherwise, the predict function might hang. if isinstance(model, Booster): - booster = await client.scatter(model, broadcast=True, hash=False) + booster = await client.scatter(model, hash=False) elif isinstance(model, dict): - booster = await client.scatter(model["booster"], broadcast=True, hash=False) + booster = await client.scatter(model["booster"], hash=False) elif isinstance(model, distributed.Future): booster = model t = booster.type