Fix mypy error with the latest dask. (#8052)

* Fix mypy error with latest dask.

Dask is adding type hints to its codebase and as the result, checks in XGBoost can be
performed more rigorously.

- Remove compatibility with old dask version where multi lock was missing.
- Restrict input of `X` to be non-series.
- Adopt latest definition of `Delayed`.
- Avoid passing optional `host_ip`.
- Avoid deprecated `worker.nthreads`.
This commit is contained in:
Jiaming Yuan 2022-07-09 08:02:42 +08:00 committed by GitHub
parent 210eb471e9
commit a5bc8e2c6a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -68,20 +68,20 @@ from .sklearn import xgboost_model_doc
from .sklearn import _cls_predict_proba from .sklearn import _cls_predict_proba
from .sklearn import XGBRanker from .sklearn import XGBRanker
if TYPE_CHECKING: if TYPE_CHECKING:
from dask import dataframe as dd from dask import dataframe as dd
from dask import array as da from dask import array as da
from dask import delayed as ddelayed
import dask import dask
import distributed import distributed
else: else:
dd = LazyLoader("dd", globals(), "dask.dataframe") dd = LazyLoader("dd", globals(), "dask.dataframe")
da = LazyLoader("da", globals(), "dask.array") da = LazyLoader("da", globals(), "dask.array")
ddelayed = LazyLoader("Delayed", globals(), "dask.delayed")
dask = LazyLoader("dask", globals(), "dask") dask = LazyLoader("dask", globals(), "dask")
distributed = LazyLoader("distributed", globals(), "dask.distributed") distributed = LazyLoader("distributed", globals(), "dask.distributed")
_DaskCollection = Union["da.Array", "dd.DataFrame", "dd.Series"] _DaskCollection = Union["da.Array", "dd.DataFrame", "dd.Series"]
_DataT = Union["da.Array", "dd.DataFrame"] # do not use series as predictor
try: try:
from mypy_extensions import TypedDict from mypy_extensions import TypedDict
@ -123,8 +123,8 @@ __all__ = [
# - Write everything with async, then use distributed Client sync function to do the # - Write everything with async, then use distributed Client sync function to do the
# switch. # switch.
# - Use Any for type hint when the return value can be union of Awaitable and plain # - Use Any for type hint when the return value can be union of Awaitable and plain
# value. This is caused by Client.sync can return both types depending on context. # value. This is caused by Client.sync can return both types depending on
# Right now there's no good way to silent: # context. Right now there's no good way to silent:
# #
# await train(...) # await train(...)
# #
@ -134,34 +134,6 @@ __all__ = [
LOGGER = logging.getLogger("[xgboost.dask]") LOGGER = logging.getLogger("[xgboost.dask]")
def _multi_lock() -> Any:
"""MultiLock is only available on latest distributed. See:
https://github.com/dask/distributed/pull/4503
"""
try:
from distributed import MultiLock
except ImportError:
class MultiLock: # type:ignore
def __init__(self, *args: Any, **kwargs: Any) -> None:
pass
def __enter__(self) -> "MultiLock":
return self
def __exit__(self, *args: Any, **kwargs: Any) -> None:
return
async def __aenter__(self) -> "MultiLock":
return self
async def __aexit__(self, *args: Any, **kwargs: Any) -> None:
return
return MultiLock
def _try_start_tracker( def _try_start_tracker(
n_workers: int, n_workers: int,
addrs: List[Union[Optional[str], Optional[Tuple[str, int]]]], addrs: List[Union[Optional[str], Optional[Tuple[str, int]]]],
@ -286,8 +258,8 @@ def _xgb_get_client(client: Optional["distributed.Client"]) -> "distributed.Clie
class DaskDMatrix: class DaskDMatrix:
# pylint: disable=missing-docstring, too-many-instance-attributes # pylint: disable=missing-docstring, too-many-instance-attributes
"""DMatrix holding on references to Dask DataFrame or Dask Array. Constructing a """DMatrix holding on references to Dask DataFrame or Dask Array. Constructing a
`DaskDMatrix` forces all lazy computation to be carried out. Wait for the input data `DaskDMatrix` forces all lazy computation to be carried out. Wait for the input
explicitly if you want to see actual computation of constructing `DaskDMatrix`. data explicitly if you want to see actual computation of constructing `DaskDMatrix`.
See doc for :py:obj:`xgboost.DMatrix` constructor for other parameters. DaskDMatrix See doc for :py:obj:`xgboost.DMatrix` constructor for other parameters. DaskDMatrix
accepts only dask collection. accepts only dask collection.
@ -302,8 +274,8 @@ class DaskDMatrix:
Parameters Parameters
---------- ----------
client : client :
Specify the dask client used for training. Use default client returned from dask Specify the dask client used for training. Use default client returned from
if it's set to None. dask if it's set to None.
""" """
@ -311,7 +283,7 @@ class DaskDMatrix:
def __init__( def __init__(
self, self,
client: "distributed.Client", client: "distributed.Client",
data: _DaskCollection, data: _DataT,
label: Optional[_DaskCollection] = None, label: Optional[_DaskCollection] = None,
*, *,
weight: Optional[_DaskCollection] = None, weight: Optional[_DaskCollection] = None,
@ -352,7 +324,7 @@ class DaskDMatrix:
self._n_cols = data.shape[1] self._n_cols = data.shape[1]
assert isinstance(self._n_cols, int) assert isinstance(self._n_cols, int)
self.worker_map: Dict[str, "distributed.Future"] = defaultdict(list) self.worker_map: Dict[str, List[distributed.Future]] = defaultdict(list)
self.is_quantile: bool = False self.is_quantile: bool = False
self._init = client.sync( self._init = client.sync(
@ -374,7 +346,7 @@ class DaskDMatrix:
async def _map_local_data( async def _map_local_data(
self, self,
client: "distributed.Client", client: "distributed.Client",
data: _DaskCollection, data: _DataT,
label: Optional[_DaskCollection] = None, label: Optional[_DaskCollection] = None,
weights: Optional[_DaskCollection] = None, weights: Optional[_DaskCollection] = None,
base_margin: Optional[_DaskCollection] = None, base_margin: Optional[_DaskCollection] = None,
@ -384,6 +356,7 @@ class DaskDMatrix:
label_upper_bound: Optional[_DaskCollection] = None, label_upper_bound: Optional[_DaskCollection] = None,
) -> "DaskDMatrix": ) -> "DaskDMatrix":
"""Obtain references to local data.""" """Obtain references to local data."""
from dask.delayed import Delayed
def inconsistent( def inconsistent(
left: List[Any], left_name: str, right: List[Any], right_name: str left: List[Any], left_name: str, right: List[Any], right_name: str
@ -404,7 +377,7 @@ class DaskDMatrix:
" chunks=(partition_size, X.shape[1])" " chunks=(partition_size, X.shape[1])"
) )
def to_delayed(d: _DaskCollection) -> List[ddelayed.Delayed]: def to_delayed(d: _DaskCollection) -> List[Delayed]:
"""Breaking data into partitions, a trick borrowed from dask_xgboost. `to_delayed` """Breaking data into partitions, a trick borrowed from dask_xgboost. `to_delayed`
downgrades high-level objects into numpy or pandas equivalents . downgrades high-level objects into numpy or pandas equivalents .
@ -414,17 +387,15 @@ class DaskDMatrix:
if isinstance(delayed_obj, numpy.ndarray): if isinstance(delayed_obj, numpy.ndarray):
# da.Array returns an array to delayed objects # da.Array returns an array to delayed objects
check_columns(delayed_obj) check_columns(delayed_obj)
delayed_list: List[ddelayed.Delayed] = delayed_obj.flatten().tolist() delayed_list: List[Delayed] = delayed_obj.flatten().tolist()
else: else:
# dd.DataFrame # dd.DataFrame
delayed_list = delayed_obj delayed_list = delayed_obj
return delayed_list return delayed_list
OpDelayed = TypeVar("OpDelayed", _DaskCollection, None) def flatten_meta(meta: Optional[_DaskCollection]) -> Optional[List[Delayed]]:
def flatten_meta(meta: OpDelayed) -> OpDelayed:
if meta is not None: if meta is not None:
meta_parts: List[ddelayed.Delayed] = to_delayed(meta) meta_parts: List[Delayed] = to_delayed(meta)
return meta_parts return meta_parts
return None return None
@ -436,9 +407,9 @@ class DaskDMatrix:
ll_parts = flatten_meta(label_lower_bound) ll_parts = flatten_meta(label_lower_bound)
lu_parts = flatten_meta(label_upper_bound) lu_parts = flatten_meta(label_upper_bound)
parts: Dict[str, List[ddelayed.Delayed]] = {"data": X_parts} parts: Dict[str, List[Delayed]] = {"data": X_parts}
def append_meta(m_parts: Optional[List[ddelayed.Delayed]], name: str) -> None: def append_meta(m_parts: Optional[List[Delayed]], name: str) -> None:
if m_parts is not None: if m_parts is not None:
assert len(X_parts) == len(m_parts), inconsistent( assert len(X_parts) == len(m_parts), inconsistent(
X_parts, "X", m_parts, name X_parts, "X", m_parts, name
@ -455,16 +426,16 @@ class DaskDMatrix:
# [(x0, x1, ..), (y0, y1, ..), ..] in delayed form # [(x0, x1, ..), (y0, y1, ..), ..] in delayed form
# turn into list of dictionaries. # turn into list of dictionaries.
packed_parts: List[Dict[str, ddelayed.Delayed]] = [] packed_parts: List[Dict[str, Delayed]] = []
for i in range(len(X_parts)): for i in range(len(X_parts)):
part_dict: Dict[str, ddelayed.Delayed] = {} part_dict: Dict[str, Delayed] = {}
for key, value in parts.items(): for key, value in parts.items():
part_dict[key] = value[i] part_dict[key] = value[i]
packed_parts.append(part_dict) packed_parts.append(part_dict)
# delay the zipped result # delay the zipped result
# pylint: disable=no-member # pylint: disable=no-member
delayed_parts: List[ddelayed.Delayed] = list(map(dask.delayed, packed_parts)) delayed_parts: List[Delayed] = list(map(dask.delayed, packed_parts))
# At this point, the mental model should look like: # At this point, the mental model should look like:
# [(x0, y0, ..), (x1, y1, ..), ..] in delayed form # [(x0, y0, ..), (x1, y1, ..), ..] in delayed form
@ -662,12 +633,12 @@ class DaskPartitionIter(DataIter): # pylint: disable=R0902
class DaskDeviceQuantileDMatrix(DaskDMatrix): class DaskDeviceQuantileDMatrix(DaskDMatrix):
"""Specialized data type for `gpu_hist` tree method. This class is used to reduce the """Specialized data type for `gpu_hist` tree method. This class is used to reduce
memory usage by eliminating data copies. Internally the all partitions/chunks of data the memory usage by eliminating data copies. Internally the all partitions/chunks
are merged by weighted GK sketching. So the number of partitions from dask may affect of data are merged by weighted GK sketching. So the number of partitions from dask
training accuracy as GK generates bounded error for each merge. See doc string for may affect training accuracy as GK generates bounded error for each merge. See doc
:py:obj:`xgboost.DeviceQuantileDMatrix` and :py:obj:`xgboost.DMatrix` for other string for :py:obj:`xgboost.DeviceQuantileDMatrix` and :py:obj:`xgboost.DMatrix` for
parameters. other parameters.
.. versionadded:: 1.2.0 .. versionadded:: 1.2.0
@ -681,7 +652,7 @@ class DaskDeviceQuantileDMatrix(DaskDMatrix):
def __init__( def __init__(
self, self,
client: "distributed.Client", client: "distributed.Client",
data: _DaskCollection, data: _DataT,
label: Optional[_DaskCollection] = None, label: Optional[_DaskCollection] = None,
*, *,
weight: Optional[_DaskCollection] = None, weight: Optional[_DaskCollection] = None,
@ -845,10 +816,11 @@ async def _get_rabit_args(
if k not in valid_config: if k not in valid_config:
raise ValueError(f"Unknown configuration: {k}") raise ValueError(f"Unknown configuration: {k}")
host_ip = dconfig.get("scheduler_address", None) host_ip = dconfig.get("scheduler_address", None)
try: if host_ip is not None:
host_ip, port = distributed.comm.get_address_host_port(host_ip) try:
except ValueError: host_ip, port = distributed.comm.get_address_host_port(host_ip)
pass except ValueError:
pass
if host_ip is not None: if host_ip is not None:
user_addr = (host_ip, port) user_addr = (host_ip, port)
else: else:
@ -932,16 +904,18 @@ async def _train_async(
worker = distributed.get_worker() worker = distributed.get_worker()
local_param = parameters.copy() local_param = parameters.copy()
n_threads = 0 n_threads = 0
# dask worker nthreads, "state" is available in 2022.6.1
dwnt = worker.state.nthreads if hasattr(worker, "state") else worker.nthreads
for p in ["nthread", "n_jobs"]: for p in ["nthread", "n_jobs"]:
if ( if (
local_param.get(p, None) is not None local_param.get(p, None) is not None
and local_param.get(p, worker.nthreads) != worker.nthreads and local_param.get(p, dwnt) != dwnt
): ):
LOGGER.info("Overriding `nthreads` defined in dask worker.") LOGGER.info("Overriding `nthreads` defined in dask worker.")
n_threads = local_param[p] n_threads = local_param[p]
break break
if n_threads == 0 or n_threads is None: if n_threads == 0 or n_threads is None:
n_threads = worker.nthreads n_threads = dwnt
local_param.update({"nthread": n_threads, "n_jobs": n_threads}) local_param.update({"nthread": n_threads, "n_jobs": n_threads})
local_history: TrainingCallback.EvalsLog = {} local_history: TrainingCallback.EvalsLog = {}
with RabitContext(rabit_args), config.config_context(**global_config): with RabitContext(rabit_args), config.config_context(**global_config):
@ -977,7 +951,7 @@ async def _train_async(
ret = None ret = None
return ret return ret
async with _multi_lock()(workers, client): async with distributed.MultiLock(workers, client):
if evals is not None: if evals is not None:
evals_data = [d for d, n in evals] evals_data = [d for d, n in evals]
evals_name = [n for d, n in evals] evals_name = [n for d, n in evals]
@ -1030,8 +1004,8 @@ def train( # pylint: disable=unused-argument
Parameters Parameters
---------- ----------
client : client :
Specify the dask client used for training. Use default client returned from dask Specify the dask client used for training. Use default client returned from
if it's set to None. dask if it's set to None.
Returns Returns
------- -------
@ -1068,8 +1042,8 @@ def _maybe_dataframe(
if _can_output_df(is_df, prediction.shape): if _can_output_df(is_df, prediction.shape):
# Need to preserve the index for dataframe. # Need to preserve the index for dataframe.
# See issue: https://github.com/dmlc/xgboost/issues/6939 # See issue: https://github.com/dmlc/xgboost/issues/6939
# In older versions of dask, the partition is actually a numpy array when input is # In older versions of dask, the partition is actually a numpy array when input
# dataframe. # is dataframe.
index = getattr(data, "index", None) index = getattr(data, "index", None)
if lazy_isinstance(data, "cudf.core.dataframe", "DataFrame"): if lazy_isinstance(data, "cudf.core.dataframe", "DataFrame"):
import cudf import cudf
@ -1093,7 +1067,7 @@ def _maybe_dataframe(
async def _direct_predict_impl( # pylint: disable=too-many-branches async def _direct_predict_impl( # pylint: disable=too-many-branches
mapped_predict: Callable, mapped_predict: Callable,
booster: "distributed.Future", booster: "distributed.Future",
data: _DaskCollection, data: _DataT,
base_margin: Optional[_DaskCollection], base_margin: Optional[_DaskCollection],
output_shape: Tuple[int, ...], output_shape: Tuple[int, ...],
meta: Dict[int, str], meta: Dict[int, str],
@ -1111,7 +1085,9 @@ async def _direct_predict_impl( # pylint: disable=too-many-branches
if _can_output_df(isinstance(data, dd.DataFrame), output_shape): if _can_output_df(isinstance(data, dd.DataFrame), output_shape):
if base_margin is not None and isinstance(base_margin, da.Array): if base_margin is not None and isinstance(base_margin, da.Array):
# Easier for map_partitions # Easier for map_partitions
base_margin_df: Optional[dd.DataFrame] = base_margin.to_dask_dataframe() base_margin_df: Optional[
Union[dd.DataFrame, dd.Series]
] = base_margin.to_dask_dataframe()
else: else:
base_margin_df = base_margin base_margin_df = base_margin
predictions = dd.map_partitions( predictions = dd.map_partitions(
@ -1149,6 +1125,9 @@ async def _direct_predict_impl( # pylint: disable=too-many-branches
# Somehow dask fail to infer output shape change for 2-dim prediction, and # Somehow dask fail to infer output shape change for 2-dim prediction, and
# `chunks = (None, output_shape[1])` doesn't work due to None is not # `chunks = (None, output_shape[1])` doesn't work due to None is not
# supported in map_blocks. # supported in map_blocks.
# data must be an array here as dataframe + 2-dim output predict will return
# a dataframe instead.
chunks: Optional[List[Tuple]] = list(data.chunks) chunks: Optional[List[Tuple]] = list(data.chunks)
assert isinstance(chunks, list) assert isinstance(chunks, list)
chunks[1] = (output_shape[1],) chunks[1] = (output_shape[1],)
@ -1200,9 +1179,10 @@ async def _get_model_future(
booster = await client.scatter(model["booster"], broadcast=True) booster = await client.scatter(model["booster"], broadcast=True)
elif isinstance(model, distributed.Future): elif isinstance(model, distributed.Future):
booster = model booster = model
if booster.type is not Booster: t = booster.type
if t is not Booster:
raise TypeError( raise TypeError(
f"Underlying type of model future should be `Booster`, got {booster.type}" f"Underlying type of model future should be `Booster`, got {t}"
) )
else: else:
raise TypeError(_expect([Booster, dict, distributed.Future], type(model))) raise TypeError(_expect([Booster, dict, distributed.Future], type(model)))
@ -1214,7 +1194,7 @@ async def _predict_async(
client: "distributed.Client", client: "distributed.Client",
global_config: Dict[str, Any], global_config: Dict[str, Any],
model: Union[Booster, Dict, "distributed.Future"], model: Union[Booster, Dict, "distributed.Future"],
data: _DaskCollection, data: _DataT,
output_margin: bool, output_margin: bool,
missing: float, missing: float,
pred_leaf: bool, pred_leaf: bool,
@ -1236,7 +1216,7 @@ async def _predict_async(
m = DMatrix( m = DMatrix(
data=partition, data=partition,
missing=missing, missing=missing,
enable_categorical=_has_categorical(booster, partition) enable_categorical=_has_categorical(booster, partition),
) )
predt = booster.predict( predt = booster.predict(
data=m, data=m,
@ -1358,9 +1338,9 @@ async def _predict_async(
def predict( # pylint: disable=unused-argument def predict( # pylint: disable=unused-argument
client: "distributed.Client", client: Optional["distributed.Client"],
model: Union[TrainReturnT, Booster, "distributed.Future"], model: Union[TrainReturnT, Booster, "distributed.Future"],
data: Union[DaskDMatrix, _DaskCollection], data: Union[DaskDMatrix, _DataT],
output_margin: bool = False, output_margin: bool = False,
missing: float = numpy.nan, missing: float = numpy.nan,
pred_leaf: bool = False, pred_leaf: bool = False,
@ -1375,10 +1355,10 @@ def predict( # pylint: disable=unused-argument
.. note:: .. note::
Using ``inplace_predict`` might be faster when some features are not needed. See Using ``inplace_predict`` might be faster when some features are not needed.
:py:meth:`xgboost.Booster.predict` for details on various parameters. When output See :py:meth:`xgboost.Booster.predict` for details on various parameters. When
has more than 2 dimensions (shap value, leaf with strict_shape), input should be output has more than 2 dimensions (shap value, leaf with strict_shape), input
``da.Array`` or ``DaskDMatrix``. should be ``da.Array`` or ``DaskDMatrix``.
.. versionadded:: 1.0.0 .. versionadded:: 1.0.0
@ -1400,8 +1380,8 @@ def predict( # pylint: disable=unused-argument
Returns Returns
------- -------
prediction: dask.array.Array/dask.dataframe.Series prediction: dask.array.Array/dask.dataframe.Series
When input data is ``dask.array.Array`` or ``DaskDMatrix``, the return value is an When input data is ``dask.array.Array`` or ``DaskDMatrix``, the return value is
array, when input data is ``dask.dataframe.DataFrame``, return value can be an array, when input data is ``dask.dataframe.DataFrame``, return value can be
``dask.dataframe.Series``, ``dask.dataframe.DataFrame``, depending on the output ``dask.dataframe.Series``, ``dask.dataframe.DataFrame``, depending on the output
shape. shape.
@ -1415,7 +1395,7 @@ async def _inplace_predict_async( # pylint: disable=too-many-branches
client: "distributed.Client", client: "distributed.Client",
global_config: Dict[str, Any], global_config: Dict[str, Any],
model: Union[Booster, Dict, "distributed.Future"], model: Union[Booster, Dict, "distributed.Future"],
data: _DaskCollection, data: _DataT,
iteration_range: Tuple[int, int], iteration_range: Tuple[int, int],
predict_type: str, predict_type: str,
missing: float, missing: float,
@ -1471,9 +1451,9 @@ async def _inplace_predict_async( # pylint: disable=too-many-branches
def inplace_predict( # pylint: disable=unused-argument def inplace_predict( # pylint: disable=unused-argument
client: "distributed.Client", client: Optional["distributed.Client"],
model: Union[TrainReturnT, Booster, "distributed.Future"], model: Union[TrainReturnT, Booster, "distributed.Future"],
data: _DaskCollection, data: _DataT,
iteration_range: Tuple[int, int] = (0, 0), iteration_range: Tuple[int, int] = (0, 0),
predict_type: str = "value", predict_type: str = "value",
missing: float = numpy.nan, missing: float = numpy.nan,
@ -1481,7 +1461,8 @@ def inplace_predict( # pylint: disable=unused-argument
base_margin: Optional[_DaskCollection] = None, base_margin: Optional[_DaskCollection] = None,
strict_shape: bool = False, strict_shape: bool = False,
) -> Any: ) -> Any:
"""Inplace prediction. See doc in :py:meth:`xgboost.Booster.inplace_predict` for details. """Inplace prediction. See doc in :py:meth:`xgboost.Booster.inplace_predict` for
details.
.. versionadded:: 1.1.0 .. versionadded:: 1.1.0
@ -1514,8 +1495,8 @@ def inplace_predict( # pylint: disable=unused-argument
Returns Returns
------- -------
prediction : prediction :
When input data is ``dask.array.Array``, the return value is an array, when input When input data is ``dask.array.Array``, the return value is an array, when
data is ``dask.dataframe.DataFrame``, return value can be input data is ``dask.dataframe.DataFrame``, return value can be
``dask.dataframe.Series``, ``dask.dataframe.DataFrame``, depending on the output ``dask.dataframe.Series``, ``dask.dataframe.DataFrame``, depending on the output
shape. shape.
@ -1531,7 +1512,7 @@ def inplace_predict( # pylint: disable=unused-argument
async def _async_wrap_evaluation_matrices( async def _async_wrap_evaluation_matrices(
client: "distributed.Client", **kwargs: Any client: Optional["distributed.Client"], **kwargs: Any
) -> Tuple[DaskDMatrix, Optional[List[Tuple[DaskDMatrix, str]]]]: ) -> Tuple[DaskDMatrix, Optional[List[Tuple[DaskDMatrix, str]]]]:
"""A switch function for async environment.""" """A switch function for async environment."""
@ -1561,7 +1542,7 @@ def _set_worker_client(
model.client = client model.client = client
yield model yield model
finally: finally:
model.client = None model.client = None # type:ignore
class DaskScikitLearnBase(XGBModel): class DaskScikitLearnBase(XGBModel):
@ -1571,7 +1552,7 @@ class DaskScikitLearnBase(XGBModel):
async def _predict_async( async def _predict_async(
self, self,
data: _DaskCollection, data: _DataT,
output_margin: bool, output_margin: bool,
validate_features: bool, validate_features: bool,
base_margin: Optional[_DaskCollection], base_margin: Optional[_DaskCollection],
@ -1597,7 +1578,7 @@ class DaskScikitLearnBase(XGBModel):
data=data, data=data,
base_margin=base_margin, base_margin=base_margin,
missing=self.missing, missing=self.missing,
feature_types=self.feature_types feature_types=self.feature_types,
) )
predts = await predict( predts = await predict(
self.client, self.client,
@ -1611,7 +1592,7 @@ class DaskScikitLearnBase(XGBModel):
def predict( def predict(
self, self,
X: _DaskCollection, X: _DataT,
output_margin: bool = False, output_margin: bool = False,
ntree_limit: Optional[int] = None, ntree_limit: Optional[int] = None,
validate_features: bool = True, validate_features: bool = True,
@ -1632,12 +1613,15 @@ class DaskScikitLearnBase(XGBModel):
async def _apply_async( async def _apply_async(
self, self,
X: _DaskCollection, X: _DataT,
iteration_range: Optional[Tuple[int, int]] = None, iteration_range: Optional[Tuple[int, int]] = None,
) -> Any: ) -> Any:
iteration_range = self._get_iteration_range(iteration_range) iteration_range = self._get_iteration_range(iteration_range)
test_dmatrix = await DaskDMatrix( test_dmatrix = await DaskDMatrix(
self.client, data=X, missing=self.missing, feature_types=self.feature_types, self.client,
data=X,
missing=self.missing,
feature_types=self.feature_types,
) )
predts = await predict( predts = await predict(
self.client, self.client,
@ -1650,7 +1634,7 @@ class DaskScikitLearnBase(XGBModel):
def apply( def apply(
self, self,
X: _DaskCollection, X: _DataT,
ntree_limit: Optional[int] = None, ntree_limit: Optional[int] = None,
iteration_range: Optional[Tuple[int, int]] = None, iteration_range: Optional[Tuple[int, int]] = None,
) -> Any: ) -> Any:
@ -1685,8 +1669,8 @@ class DaskScikitLearnBase(XGBModel):
@client.setter @client.setter
def client(self, clt: "distributed.Client") -> None: def client(self, clt: "distributed.Client") -> None:
# calling `worker_client' doesn't return the correct `asynchronous` attribute, so # calling `worker_client' doesn't return the correct `asynchronous` attribute,
# we have to pass it ourselves. # so we have to pass it ourselves.
self._asynchronous = clt.asynchronous if clt is not None else False self._asynchronous = clt.asynchronous if clt is not None else False
self._client = clt self._client = clt
@ -1720,9 +1704,10 @@ class DaskScikitLearnBase(XGBModel):
) )
class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase): class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
"""dummy doc string to workaround pylint, replaced by the decorator.""" """dummy doc string to workaround pylint, replaced by the decorator."""
async def _fit_async( async def _fit_async(
self, self,
X: _DaskCollection, X: _DataT,
y: _DaskCollection, y: _DaskCollection,
sample_weight: Optional[_DaskCollection], sample_weight: Optional[_DaskCollection],
base_margin: Optional[_DaskCollection], base_margin: Optional[_DaskCollection],
@ -1789,7 +1774,7 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
@_deprecate_positional_args @_deprecate_positional_args
def fit( def fit(
self, self,
X: _DaskCollection, X: _DataT,
y: _DaskCollection, y: _DaskCollection,
*, *,
sample_weight: Optional[_DaskCollection] = None, sample_weight: Optional[_DaskCollection] = None,
@ -1817,7 +1802,7 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
# pylint: disable=missing-class-docstring # pylint: disable=missing-class-docstring
async def _fit_async( async def _fit_async(
self, self,
X: _DaskCollection, X: _DataT,
y: _DaskCollection, y: _DaskCollection,
sample_weight: Optional[_DaskCollection], sample_weight: Optional[_DaskCollection],
base_margin: Optional[_DaskCollection], base_margin: Optional[_DaskCollection],
@ -1898,7 +1883,7 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
# pylint: disable=unused-argument # pylint: disable=unused-argument
def fit( def fit(
self, self,
X: _DaskCollection, X: _DataT,
y: _DaskCollection, y: _DaskCollection,
*, *,
sample_weight: Optional[_DaskCollection] = None, sample_weight: Optional[_DaskCollection] = None,
@ -1919,7 +1904,7 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
async def _predict_proba_async( async def _predict_proba_async(
self, self,
X: _DaskCollection, X: _DataT,
validate_features: bool, validate_features: bool,
base_margin: Optional[_DaskCollection], base_margin: Optional[_DaskCollection],
iteration_range: Optional[Tuple[int, int]], iteration_range: Optional[Tuple[int, int]],
@ -1965,7 +1950,7 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
async def _predict_async( async def _predict_async(
self, self,
data: _DaskCollection, data: _DataT,
output_margin: bool, output_margin: bool,
validate_features: bool, validate_features: bool,
base_margin: Optional[_DaskCollection], base_margin: Optional[_DaskCollection],
@ -2014,7 +1999,7 @@ class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn):
async def _fit_async( async def _fit_async(
self, self,
X: _DaskCollection, X: _DataT,
y: _DaskCollection, y: _DaskCollection,
group: Optional[_DaskCollection], group: Optional[_DaskCollection],
qid: Optional[_DaskCollection], qid: Optional[_DaskCollection],
@ -2090,7 +2075,7 @@ class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn):
@_deprecate_positional_args @_deprecate_positional_args
def fit( def fit(
self, self,
X: _DaskCollection, X: _DataT,
y: _DaskCollection, y: _DaskCollection,
*, *,
group: Optional[_DaskCollection] = None, group: Optional[_DaskCollection] = None,
@ -2113,7 +2098,8 @@ class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn):
args = {k: v for k, v in locals().items() if k not in ("self", "__class__")} args = {k: v for k, v in locals().items() if k not in ("self", "__class__")}
return self._client_sync(self._fit_async, **args) return self._client_sync(self._fit_async, **args)
# FIXME(trivialfis): arguments differ due to additional parameters like group and qid. # FIXME(trivialfis): arguments differ due to additional parameters like group and
# qid.
fit.__doc__ = XGBRanker.fit.__doc__ fit.__doc__ = XGBRanker.fit.__doc__
@ -2159,7 +2145,7 @@ class DaskXGBRFRegressor(DaskXGBRegressor):
# pylint: disable=unused-argument # pylint: disable=unused-argument
def fit( def fit(
self, self,
X: _DaskCollection, X: _DataT,
y: _DaskCollection, y: _DaskCollection,
*, *,
sample_weight: Optional[_DaskCollection] = None, sample_weight: Optional[_DaskCollection] = None,
@ -2223,7 +2209,7 @@ class DaskXGBRFClassifier(DaskXGBClassifier):
# pylint: disable=unused-argument # pylint: disable=unused-argument
def fit( def fit(
self, self,
X: _DaskCollection, X: _DataT,
y: _DaskCollection, y: _DaskCollection,
*, *,
sample_weight: Optional[_DaskCollection] = None, sample_weight: Optional[_DaskCollection] = None,