[dask] Disable broadcast in the scatter call. (#10632)
This commit is contained in:
parent
411c8466bd
commit
fcae6301ec
@ -1223,12 +1223,14 @@ def _infer_predict_output(
|
|||||||
async def _get_model_future(
|
async def _get_model_future(
|
||||||
client: "distributed.Client", model: Union[Booster, Dict, "distributed.Future"]
|
client: "distributed.Client", model: Union[Booster, Dict, "distributed.Future"]
|
||||||
) -> "distributed.Future":
|
) -> "distributed.Future":
|
||||||
# See https://github.com/dask/dask/issues/11179#issuecomment-2168094529 for
|
# See https://github.com/dask/dask/issues/11179#issuecomment-2168094529 for the use
|
||||||
# the use of hash.
|
# of hash.
|
||||||
|
# https://github.com/dask/distributed/pull/8796 Don't use broadcast in the `scatter`
|
||||||
|
# call, otherwise, the predict function might hang.
|
||||||
if isinstance(model, Booster):
|
if isinstance(model, Booster):
|
||||||
booster = await client.scatter(model, broadcast=True, hash=False)
|
booster = await client.scatter(model, hash=False)
|
||||||
elif isinstance(model, dict):
|
elif isinstance(model, dict):
|
||||||
booster = await client.scatter(model["booster"], broadcast=True, hash=False)
|
booster = await client.scatter(model["booster"], hash=False)
|
||||||
elif isinstance(model, distributed.Future):
|
elif isinstance(model, distributed.Future):
|
||||||
booster = model
|
booster = model
|
||||||
t = booster.type
|
t = booster.type
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user