[dask] Workaround the tokenizer by changing the scatter function. (#10419)

---------

Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
Jiaming Yuan 2024-06-15 19:10:00 +08:00 committed by GitHub
parent 601f2067c7
commit bbff74d2ff
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 4 additions and 9 deletions

View File

@ -7,7 +7,6 @@ import json
import os import os
import re import re
import sys import sys
import uuid
import warnings import warnings
import weakref import weakref
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
@ -3144,9 +3143,3 @@ class Booster:
UserWarning, UserWarning,
) )
return nph_stacked return nph_stacked
def __dask_tokenize__(self) -> uuid.UUID:
# TODO: Implement proper tokenization to avoid unnecessary re-computation in
# Dask. However, default tokenzation causes problems after
# https://github.com/dask/dask/pull/10883
return uuid.uuid4()

View File

@ -1237,10 +1237,12 @@ def _infer_predict_output(
async def _get_model_future( async def _get_model_future(
client: "distributed.Client", model: Union[Booster, Dict, "distributed.Future"] client: "distributed.Client", model: Union[Booster, Dict, "distributed.Future"]
) -> "distributed.Future": ) -> "distributed.Future":
# See https://github.com/dask/dask/issues/11179#issuecomment-2168094529 for
# the use of hash.
if isinstance(model, Booster): if isinstance(model, Booster):
booster = await client.scatter(model, broadcast=True) booster = await client.scatter(model, broadcast=True, hash=False)
elif isinstance(model, dict): elif isinstance(model, dict):
booster = await client.scatter(model["booster"], broadcast=True) booster = await client.scatter(model["booster"], broadcast=True, hash=False)
elif isinstance(model, distributed.Future): elif isinstance(model, distributed.Future):
booster = model booster = model
t = booster.type t = booster.type