Fix mypy error with the latest dask. (#8052)
* Fix mypy error with latest dask. Dask is adding type hints to its codebase and as the result, checks in XGBoost can be performed more rigorously. - Remove compatibility with old dask version where multi lock was missing. - Restrict input of `X` to be non-series. - Adopt latest definition of `Delayed`. - Avoid passing optional `host_ip`. - Avoid deprecated `worker.nthreads`.
This commit is contained in:
parent
210eb471e9
commit
a5bc8e2c6a
@ -68,20 +68,20 @@ from .sklearn import xgboost_model_doc
|
||||
from .sklearn import _cls_predict_proba
|
||||
from .sklearn import XGBRanker
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from dask import dataframe as dd
|
||||
from dask import array as da
|
||||
from dask import delayed as ddelayed
|
||||
import dask
|
||||
import distributed
|
||||
else:
|
||||
dd = LazyLoader("dd", globals(), "dask.dataframe")
|
||||
da = LazyLoader("da", globals(), "dask.array")
|
||||
ddelayed = LazyLoader("Delayed", globals(), "dask.delayed")
|
||||
dask = LazyLoader("dask", globals(), "dask")
|
||||
distributed = LazyLoader("distributed", globals(), "dask.distributed")
|
||||
|
||||
_DaskCollection = Union["da.Array", "dd.DataFrame", "dd.Series"]
|
||||
_DataT = Union["da.Array", "dd.DataFrame"] # do not use series as predictor
|
||||
|
||||
try:
|
||||
from mypy_extensions import TypedDict
|
||||
@ -123,8 +123,8 @@ __all__ = [
|
||||
# - Write everything with async, then use distributed Client sync function to do the
|
||||
# switch.
|
||||
# - Use Any for type hint when the return value can be union of Awaitable and plain
|
||||
# value. This is caused by Client.sync can return both types depending on context.
|
||||
# Right now there's no good way to silent:
|
||||
# value. This is caused by Client.sync can return both types depending on
|
||||
# context. Right now there's no good way to silent:
|
||||
#
|
||||
# await train(...)
|
||||
#
|
||||
@ -134,34 +134,6 @@ __all__ = [
|
||||
LOGGER = logging.getLogger("[xgboost.dask]")
|
||||
|
||||
|
||||
def _multi_lock() -> Any:
|
||||
"""MultiLock is only available on latest distributed. See:
|
||||
|
||||
https://github.com/dask/distributed/pull/4503
|
||||
"""
|
||||
try:
|
||||
from distributed import MultiLock
|
||||
except ImportError:
|
||||
|
||||
class MultiLock: # type:ignore
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
pass
|
||||
|
||||
def __enter__(self) -> "MultiLock":
|
||||
return self
|
||||
|
||||
def __exit__(self, *args: Any, **kwargs: Any) -> None:
|
||||
return
|
||||
|
||||
async def __aenter__(self) -> "MultiLock":
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *args: Any, **kwargs: Any) -> None:
|
||||
return
|
||||
|
||||
return MultiLock
|
||||
|
||||
|
||||
def _try_start_tracker(
|
||||
n_workers: int,
|
||||
addrs: List[Union[Optional[str], Optional[Tuple[str, int]]]],
|
||||
@ -286,8 +258,8 @@ def _xgb_get_client(client: Optional["distributed.Client"]) -> "distributed.Clie
|
||||
class DaskDMatrix:
|
||||
# pylint: disable=missing-docstring, too-many-instance-attributes
|
||||
"""DMatrix holding on references to Dask DataFrame or Dask Array. Constructing a
|
||||
`DaskDMatrix` forces all lazy computation to be carried out. Wait for the input data
|
||||
explicitly if you want to see actual computation of constructing `DaskDMatrix`.
|
||||
`DaskDMatrix` forces all lazy computation to be carried out. Wait for the input
|
||||
data explicitly if you want to see actual computation of constructing `DaskDMatrix`.
|
||||
|
||||
See doc for :py:obj:`xgboost.DMatrix` constructor for other parameters. DaskDMatrix
|
||||
accepts only dask collection.
|
||||
@ -302,8 +274,8 @@ class DaskDMatrix:
|
||||
Parameters
|
||||
----------
|
||||
client :
|
||||
Specify the dask client used for training. Use default client returned from dask
|
||||
if it's set to None.
|
||||
Specify the dask client used for training. Use default client returned from
|
||||
dask if it's set to None.
|
||||
|
||||
"""
|
||||
|
||||
@ -311,7 +283,7 @@ class DaskDMatrix:
|
||||
def __init__(
|
||||
self,
|
||||
client: "distributed.Client",
|
||||
data: _DaskCollection,
|
||||
data: _DataT,
|
||||
label: Optional[_DaskCollection] = None,
|
||||
*,
|
||||
weight: Optional[_DaskCollection] = None,
|
||||
@ -352,7 +324,7 @@ class DaskDMatrix:
|
||||
|
||||
self._n_cols = data.shape[1]
|
||||
assert isinstance(self._n_cols, int)
|
||||
self.worker_map: Dict[str, "distributed.Future"] = defaultdict(list)
|
||||
self.worker_map: Dict[str, List[distributed.Future]] = defaultdict(list)
|
||||
self.is_quantile: bool = False
|
||||
|
||||
self._init = client.sync(
|
||||
@ -374,7 +346,7 @@ class DaskDMatrix:
|
||||
async def _map_local_data(
|
||||
self,
|
||||
client: "distributed.Client",
|
||||
data: _DaskCollection,
|
||||
data: _DataT,
|
||||
label: Optional[_DaskCollection] = None,
|
||||
weights: Optional[_DaskCollection] = None,
|
||||
base_margin: Optional[_DaskCollection] = None,
|
||||
@ -384,6 +356,7 @@ class DaskDMatrix:
|
||||
label_upper_bound: Optional[_DaskCollection] = None,
|
||||
) -> "DaskDMatrix":
|
||||
"""Obtain references to local data."""
|
||||
from dask.delayed import Delayed
|
||||
|
||||
def inconsistent(
|
||||
left: List[Any], left_name: str, right: List[Any], right_name: str
|
||||
@ -404,7 +377,7 @@ class DaskDMatrix:
|
||||
" chunks=(partition_size, X.shape[1])"
|
||||
)
|
||||
|
||||
def to_delayed(d: _DaskCollection) -> List[ddelayed.Delayed]:
|
||||
def to_delayed(d: _DaskCollection) -> List[Delayed]:
|
||||
"""Breaking data into partitions, a trick borrowed from dask_xgboost. `to_delayed`
|
||||
downgrades high-level objects into numpy or pandas equivalents .
|
||||
|
||||
@ -414,17 +387,15 @@ class DaskDMatrix:
|
||||
if isinstance(delayed_obj, numpy.ndarray):
|
||||
# da.Array returns an array to delayed objects
|
||||
check_columns(delayed_obj)
|
||||
delayed_list: List[ddelayed.Delayed] = delayed_obj.flatten().tolist()
|
||||
delayed_list: List[Delayed] = delayed_obj.flatten().tolist()
|
||||
else:
|
||||
# dd.DataFrame
|
||||
delayed_list = delayed_obj
|
||||
return delayed_list
|
||||
|
||||
OpDelayed = TypeVar("OpDelayed", _DaskCollection, None)
|
||||
|
||||
def flatten_meta(meta: OpDelayed) -> OpDelayed:
|
||||
def flatten_meta(meta: Optional[_DaskCollection]) -> Optional[List[Delayed]]:
|
||||
if meta is not None:
|
||||
meta_parts: List[ddelayed.Delayed] = to_delayed(meta)
|
||||
meta_parts: List[Delayed] = to_delayed(meta)
|
||||
return meta_parts
|
||||
return None
|
||||
|
||||
@ -436,9 +407,9 @@ class DaskDMatrix:
|
||||
ll_parts = flatten_meta(label_lower_bound)
|
||||
lu_parts = flatten_meta(label_upper_bound)
|
||||
|
||||
parts: Dict[str, List[ddelayed.Delayed]] = {"data": X_parts}
|
||||
parts: Dict[str, List[Delayed]] = {"data": X_parts}
|
||||
|
||||
def append_meta(m_parts: Optional[List[ddelayed.Delayed]], name: str) -> None:
|
||||
def append_meta(m_parts: Optional[List[Delayed]], name: str) -> None:
|
||||
if m_parts is not None:
|
||||
assert len(X_parts) == len(m_parts), inconsistent(
|
||||
X_parts, "X", m_parts, name
|
||||
@ -455,16 +426,16 @@ class DaskDMatrix:
|
||||
# [(x0, x1, ..), (y0, y1, ..), ..] in delayed form
|
||||
|
||||
# turn into list of dictionaries.
|
||||
packed_parts: List[Dict[str, ddelayed.Delayed]] = []
|
||||
packed_parts: List[Dict[str, Delayed]] = []
|
||||
for i in range(len(X_parts)):
|
||||
part_dict: Dict[str, ddelayed.Delayed] = {}
|
||||
part_dict: Dict[str, Delayed] = {}
|
||||
for key, value in parts.items():
|
||||
part_dict[key] = value[i]
|
||||
packed_parts.append(part_dict)
|
||||
|
||||
# delay the zipped result
|
||||
# pylint: disable=no-member
|
||||
delayed_parts: List[ddelayed.Delayed] = list(map(dask.delayed, packed_parts))
|
||||
delayed_parts: List[Delayed] = list(map(dask.delayed, packed_parts))
|
||||
# At this point, the mental model should look like:
|
||||
# [(x0, y0, ..), (x1, y1, ..), ..] in delayed form
|
||||
|
||||
@ -662,12 +633,12 @@ class DaskPartitionIter(DataIter): # pylint: disable=R0902
|
||||
|
||||
|
||||
class DaskDeviceQuantileDMatrix(DaskDMatrix):
|
||||
"""Specialized data type for `gpu_hist` tree method. This class is used to reduce the
|
||||
memory usage by eliminating data copies. Internally the all partitions/chunks of data
|
||||
are merged by weighted GK sketching. So the number of partitions from dask may affect
|
||||
training accuracy as GK generates bounded error for each merge. See doc string for
|
||||
:py:obj:`xgboost.DeviceQuantileDMatrix` and :py:obj:`xgboost.DMatrix` for other
|
||||
parameters.
|
||||
"""Specialized data type for `gpu_hist` tree method. This class is used to reduce
|
||||
the memory usage by eliminating data copies. Internally the all partitions/chunks
|
||||
of data are merged by weighted GK sketching. So the number of partitions from dask
|
||||
may affect training accuracy as GK generates bounded error for each merge. See doc
|
||||
string for :py:obj:`xgboost.DeviceQuantileDMatrix` and :py:obj:`xgboost.DMatrix` for
|
||||
other parameters.
|
||||
|
||||
.. versionadded:: 1.2.0
|
||||
|
||||
@ -681,7 +652,7 @@ class DaskDeviceQuantileDMatrix(DaskDMatrix):
|
||||
def __init__(
|
||||
self,
|
||||
client: "distributed.Client",
|
||||
data: _DaskCollection,
|
||||
data: _DataT,
|
||||
label: Optional[_DaskCollection] = None,
|
||||
*,
|
||||
weight: Optional[_DaskCollection] = None,
|
||||
@ -845,10 +816,11 @@ async def _get_rabit_args(
|
||||
if k not in valid_config:
|
||||
raise ValueError(f"Unknown configuration: {k}")
|
||||
host_ip = dconfig.get("scheduler_address", None)
|
||||
try:
|
||||
host_ip, port = distributed.comm.get_address_host_port(host_ip)
|
||||
except ValueError:
|
||||
pass
|
||||
if host_ip is not None:
|
||||
try:
|
||||
host_ip, port = distributed.comm.get_address_host_port(host_ip)
|
||||
except ValueError:
|
||||
pass
|
||||
if host_ip is not None:
|
||||
user_addr = (host_ip, port)
|
||||
else:
|
||||
@ -932,16 +904,18 @@ async def _train_async(
|
||||
worker = distributed.get_worker()
|
||||
local_param = parameters.copy()
|
||||
n_threads = 0
|
||||
# dask worker nthreads, "state" is available in 2022.6.1
|
||||
dwnt = worker.state.nthreads if hasattr(worker, "state") else worker.nthreads
|
||||
for p in ["nthread", "n_jobs"]:
|
||||
if (
|
||||
local_param.get(p, None) is not None
|
||||
and local_param.get(p, worker.nthreads) != worker.nthreads
|
||||
and local_param.get(p, dwnt) != dwnt
|
||||
):
|
||||
LOGGER.info("Overriding `nthreads` defined in dask worker.")
|
||||
n_threads = local_param[p]
|
||||
break
|
||||
if n_threads == 0 or n_threads is None:
|
||||
n_threads = worker.nthreads
|
||||
n_threads = dwnt
|
||||
local_param.update({"nthread": n_threads, "n_jobs": n_threads})
|
||||
local_history: TrainingCallback.EvalsLog = {}
|
||||
with RabitContext(rabit_args), config.config_context(**global_config):
|
||||
@ -977,7 +951,7 @@ async def _train_async(
|
||||
ret = None
|
||||
return ret
|
||||
|
||||
async with _multi_lock()(workers, client):
|
||||
async with distributed.MultiLock(workers, client):
|
||||
if evals is not None:
|
||||
evals_data = [d for d, n in evals]
|
||||
evals_name = [n for d, n in evals]
|
||||
@ -1030,8 +1004,8 @@ def train( # pylint: disable=unused-argument
|
||||
Parameters
|
||||
----------
|
||||
client :
|
||||
Specify the dask client used for training. Use default client returned from dask
|
||||
if it's set to None.
|
||||
Specify the dask client used for training. Use default client returned from
|
||||
dask if it's set to None.
|
||||
|
||||
Returns
|
||||
-------
|
||||
@ -1068,8 +1042,8 @@ def _maybe_dataframe(
|
||||
if _can_output_df(is_df, prediction.shape):
|
||||
# Need to preserve the index for dataframe.
|
||||
# See issue: https://github.com/dmlc/xgboost/issues/6939
|
||||
# In older versions of dask, the partition is actually a numpy array when input is
|
||||
# dataframe.
|
||||
# In older versions of dask, the partition is actually a numpy array when input
|
||||
# is dataframe.
|
||||
index = getattr(data, "index", None)
|
||||
if lazy_isinstance(data, "cudf.core.dataframe", "DataFrame"):
|
||||
import cudf
|
||||
@ -1093,7 +1067,7 @@ def _maybe_dataframe(
|
||||
async def _direct_predict_impl( # pylint: disable=too-many-branches
|
||||
mapped_predict: Callable,
|
||||
booster: "distributed.Future",
|
||||
data: _DaskCollection,
|
||||
data: _DataT,
|
||||
base_margin: Optional[_DaskCollection],
|
||||
output_shape: Tuple[int, ...],
|
||||
meta: Dict[int, str],
|
||||
@ -1111,7 +1085,9 @@ async def _direct_predict_impl( # pylint: disable=too-many-branches
|
||||
if _can_output_df(isinstance(data, dd.DataFrame), output_shape):
|
||||
if base_margin is not None and isinstance(base_margin, da.Array):
|
||||
# Easier for map_partitions
|
||||
base_margin_df: Optional[dd.DataFrame] = base_margin.to_dask_dataframe()
|
||||
base_margin_df: Optional[
|
||||
Union[dd.DataFrame, dd.Series]
|
||||
] = base_margin.to_dask_dataframe()
|
||||
else:
|
||||
base_margin_df = base_margin
|
||||
predictions = dd.map_partitions(
|
||||
@ -1149,6 +1125,9 @@ async def _direct_predict_impl( # pylint: disable=too-many-branches
|
||||
# Somehow dask fail to infer output shape change for 2-dim prediction, and
|
||||
# `chunks = (None, output_shape[1])` doesn't work due to None is not
|
||||
# supported in map_blocks.
|
||||
|
||||
# data must be an array here as dataframe + 2-dim output predict will return
|
||||
# a dataframe instead.
|
||||
chunks: Optional[List[Tuple]] = list(data.chunks)
|
||||
assert isinstance(chunks, list)
|
||||
chunks[1] = (output_shape[1],)
|
||||
@ -1200,9 +1179,10 @@ async def _get_model_future(
|
||||
booster = await client.scatter(model["booster"], broadcast=True)
|
||||
elif isinstance(model, distributed.Future):
|
||||
booster = model
|
||||
if booster.type is not Booster:
|
||||
t = booster.type
|
||||
if t is not Booster:
|
||||
raise TypeError(
|
||||
f"Underlying type of model future should be `Booster`, got {booster.type}"
|
||||
f"Underlying type of model future should be `Booster`, got {t}"
|
||||
)
|
||||
else:
|
||||
raise TypeError(_expect([Booster, dict, distributed.Future], type(model)))
|
||||
@ -1214,7 +1194,7 @@ async def _predict_async(
|
||||
client: "distributed.Client",
|
||||
global_config: Dict[str, Any],
|
||||
model: Union[Booster, Dict, "distributed.Future"],
|
||||
data: _DaskCollection,
|
||||
data: _DataT,
|
||||
output_margin: bool,
|
||||
missing: float,
|
||||
pred_leaf: bool,
|
||||
@ -1236,7 +1216,7 @@ async def _predict_async(
|
||||
m = DMatrix(
|
||||
data=partition,
|
||||
missing=missing,
|
||||
enable_categorical=_has_categorical(booster, partition)
|
||||
enable_categorical=_has_categorical(booster, partition),
|
||||
)
|
||||
predt = booster.predict(
|
||||
data=m,
|
||||
@ -1358,9 +1338,9 @@ async def _predict_async(
|
||||
|
||||
|
||||
def predict( # pylint: disable=unused-argument
|
||||
client: "distributed.Client",
|
||||
client: Optional["distributed.Client"],
|
||||
model: Union[TrainReturnT, Booster, "distributed.Future"],
|
||||
data: Union[DaskDMatrix, _DaskCollection],
|
||||
data: Union[DaskDMatrix, _DataT],
|
||||
output_margin: bool = False,
|
||||
missing: float = numpy.nan,
|
||||
pred_leaf: bool = False,
|
||||
@ -1375,10 +1355,10 @@ def predict( # pylint: disable=unused-argument
|
||||
|
||||
.. note::
|
||||
|
||||
Using ``inplace_predict`` might be faster when some features are not needed. See
|
||||
:py:meth:`xgboost.Booster.predict` for details on various parameters. When output
|
||||
has more than 2 dimensions (shap value, leaf with strict_shape), input should be
|
||||
``da.Array`` or ``DaskDMatrix``.
|
||||
Using ``inplace_predict`` might be faster when some features are not needed.
|
||||
See :py:meth:`xgboost.Booster.predict` for details on various parameters. When
|
||||
output has more than 2 dimensions (shap value, leaf with strict_shape), input
|
||||
should be ``da.Array`` or ``DaskDMatrix``.
|
||||
|
||||
.. versionadded:: 1.0.0
|
||||
|
||||
@ -1400,8 +1380,8 @@ def predict( # pylint: disable=unused-argument
|
||||
Returns
|
||||
-------
|
||||
prediction: dask.array.Array/dask.dataframe.Series
|
||||
When input data is ``dask.array.Array`` or ``DaskDMatrix``, the return value is an
|
||||
array, when input data is ``dask.dataframe.DataFrame``, return value can be
|
||||
When input data is ``dask.array.Array`` or ``DaskDMatrix``, the return value is
|
||||
an array, when input data is ``dask.dataframe.DataFrame``, return value can be
|
||||
``dask.dataframe.Series``, ``dask.dataframe.DataFrame``, depending on the output
|
||||
shape.
|
||||
|
||||
@ -1415,7 +1395,7 @@ async def _inplace_predict_async( # pylint: disable=too-many-branches
|
||||
client: "distributed.Client",
|
||||
global_config: Dict[str, Any],
|
||||
model: Union[Booster, Dict, "distributed.Future"],
|
||||
data: _DaskCollection,
|
||||
data: _DataT,
|
||||
iteration_range: Tuple[int, int],
|
||||
predict_type: str,
|
||||
missing: float,
|
||||
@ -1471,9 +1451,9 @@ async def _inplace_predict_async( # pylint: disable=too-many-branches
|
||||
|
||||
|
||||
def inplace_predict( # pylint: disable=unused-argument
|
||||
client: "distributed.Client",
|
||||
client: Optional["distributed.Client"],
|
||||
model: Union[TrainReturnT, Booster, "distributed.Future"],
|
||||
data: _DaskCollection,
|
||||
data: _DataT,
|
||||
iteration_range: Tuple[int, int] = (0, 0),
|
||||
predict_type: str = "value",
|
||||
missing: float = numpy.nan,
|
||||
@ -1481,7 +1461,8 @@ def inplace_predict( # pylint: disable=unused-argument
|
||||
base_margin: Optional[_DaskCollection] = None,
|
||||
strict_shape: bool = False,
|
||||
) -> Any:
|
||||
"""Inplace prediction. See doc in :py:meth:`xgboost.Booster.inplace_predict` for details.
|
||||
"""Inplace prediction. See doc in :py:meth:`xgboost.Booster.inplace_predict` for
|
||||
details.
|
||||
|
||||
.. versionadded:: 1.1.0
|
||||
|
||||
@ -1514,8 +1495,8 @@ def inplace_predict( # pylint: disable=unused-argument
|
||||
Returns
|
||||
-------
|
||||
prediction :
|
||||
When input data is ``dask.array.Array``, the return value is an array, when input
|
||||
data is ``dask.dataframe.DataFrame``, return value can be
|
||||
When input data is ``dask.array.Array``, the return value is an array, when
|
||||
input data is ``dask.dataframe.DataFrame``, return value can be
|
||||
``dask.dataframe.Series``, ``dask.dataframe.DataFrame``, depending on the output
|
||||
shape.
|
||||
|
||||
@ -1531,7 +1512,7 @@ def inplace_predict( # pylint: disable=unused-argument
|
||||
|
||||
|
||||
async def _async_wrap_evaluation_matrices(
|
||||
client: "distributed.Client", **kwargs: Any
|
||||
client: Optional["distributed.Client"], **kwargs: Any
|
||||
) -> Tuple[DaskDMatrix, Optional[List[Tuple[DaskDMatrix, str]]]]:
|
||||
"""A switch function for async environment."""
|
||||
|
||||
@ -1561,7 +1542,7 @@ def _set_worker_client(
|
||||
model.client = client
|
||||
yield model
|
||||
finally:
|
||||
model.client = None
|
||||
model.client = None # type:ignore
|
||||
|
||||
|
||||
class DaskScikitLearnBase(XGBModel):
|
||||
@ -1571,7 +1552,7 @@ class DaskScikitLearnBase(XGBModel):
|
||||
|
||||
async def _predict_async(
|
||||
self,
|
||||
data: _DaskCollection,
|
||||
data: _DataT,
|
||||
output_margin: bool,
|
||||
validate_features: bool,
|
||||
base_margin: Optional[_DaskCollection],
|
||||
@ -1597,7 +1578,7 @@ class DaskScikitLearnBase(XGBModel):
|
||||
data=data,
|
||||
base_margin=base_margin,
|
||||
missing=self.missing,
|
||||
feature_types=self.feature_types
|
||||
feature_types=self.feature_types,
|
||||
)
|
||||
predts = await predict(
|
||||
self.client,
|
||||
@ -1611,7 +1592,7 @@ class DaskScikitLearnBase(XGBModel):
|
||||
|
||||
def predict(
|
||||
self,
|
||||
X: _DaskCollection,
|
||||
X: _DataT,
|
||||
output_margin: bool = False,
|
||||
ntree_limit: Optional[int] = None,
|
||||
validate_features: bool = True,
|
||||
@ -1632,12 +1613,15 @@ class DaskScikitLearnBase(XGBModel):
|
||||
|
||||
async def _apply_async(
|
||||
self,
|
||||
X: _DaskCollection,
|
||||
X: _DataT,
|
||||
iteration_range: Optional[Tuple[int, int]] = None,
|
||||
) -> Any:
|
||||
iteration_range = self._get_iteration_range(iteration_range)
|
||||
test_dmatrix = await DaskDMatrix(
|
||||
self.client, data=X, missing=self.missing, feature_types=self.feature_types,
|
||||
self.client,
|
||||
data=X,
|
||||
missing=self.missing,
|
||||
feature_types=self.feature_types,
|
||||
)
|
||||
predts = await predict(
|
||||
self.client,
|
||||
@ -1650,7 +1634,7 @@ class DaskScikitLearnBase(XGBModel):
|
||||
|
||||
def apply(
|
||||
self,
|
||||
X: _DaskCollection,
|
||||
X: _DataT,
|
||||
ntree_limit: Optional[int] = None,
|
||||
iteration_range: Optional[Tuple[int, int]] = None,
|
||||
) -> Any:
|
||||
@ -1685,8 +1669,8 @@ class DaskScikitLearnBase(XGBModel):
|
||||
|
||||
@client.setter
|
||||
def client(self, clt: "distributed.Client") -> None:
|
||||
# calling `worker_client' doesn't return the correct `asynchronous` attribute, so
|
||||
# we have to pass it ourselves.
|
||||
# calling `worker_client' doesn't return the correct `asynchronous` attribute,
|
||||
# so we have to pass it ourselves.
|
||||
self._asynchronous = clt.asynchronous if clt is not None else False
|
||||
self._client = clt
|
||||
|
||||
@ -1720,9 +1704,10 @@ class DaskScikitLearnBase(XGBModel):
|
||||
)
|
||||
class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
|
||||
"""dummy doc string to workaround pylint, replaced by the decorator."""
|
||||
|
||||
async def _fit_async(
|
||||
self,
|
||||
X: _DaskCollection,
|
||||
X: _DataT,
|
||||
y: _DaskCollection,
|
||||
sample_weight: Optional[_DaskCollection],
|
||||
base_margin: Optional[_DaskCollection],
|
||||
@ -1789,7 +1774,7 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
|
||||
@_deprecate_positional_args
|
||||
def fit(
|
||||
self,
|
||||
X: _DaskCollection,
|
||||
X: _DataT,
|
||||
y: _DaskCollection,
|
||||
*,
|
||||
sample_weight: Optional[_DaskCollection] = None,
|
||||
@ -1817,7 +1802,7 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
||||
# pylint: disable=missing-class-docstring
|
||||
async def _fit_async(
|
||||
self,
|
||||
X: _DaskCollection,
|
||||
X: _DataT,
|
||||
y: _DaskCollection,
|
||||
sample_weight: Optional[_DaskCollection],
|
||||
base_margin: Optional[_DaskCollection],
|
||||
@ -1898,7 +1883,7 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
||||
# pylint: disable=unused-argument
|
||||
def fit(
|
||||
self,
|
||||
X: _DaskCollection,
|
||||
X: _DataT,
|
||||
y: _DaskCollection,
|
||||
*,
|
||||
sample_weight: Optional[_DaskCollection] = None,
|
||||
@ -1919,7 +1904,7 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
||||
|
||||
async def _predict_proba_async(
|
||||
self,
|
||||
X: _DaskCollection,
|
||||
X: _DataT,
|
||||
validate_features: bool,
|
||||
base_margin: Optional[_DaskCollection],
|
||||
iteration_range: Optional[Tuple[int, int]],
|
||||
@ -1965,7 +1950,7 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
||||
|
||||
async def _predict_async(
|
||||
self,
|
||||
data: _DaskCollection,
|
||||
data: _DataT,
|
||||
output_margin: bool,
|
||||
validate_features: bool,
|
||||
base_margin: Optional[_DaskCollection],
|
||||
@ -2014,7 +1999,7 @@ class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn):
|
||||
|
||||
async def _fit_async(
|
||||
self,
|
||||
X: _DaskCollection,
|
||||
X: _DataT,
|
||||
y: _DaskCollection,
|
||||
group: Optional[_DaskCollection],
|
||||
qid: Optional[_DaskCollection],
|
||||
@ -2090,7 +2075,7 @@ class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn):
|
||||
@_deprecate_positional_args
|
||||
def fit(
|
||||
self,
|
||||
X: _DaskCollection,
|
||||
X: _DataT,
|
||||
y: _DaskCollection,
|
||||
*,
|
||||
group: Optional[_DaskCollection] = None,
|
||||
@ -2113,7 +2098,8 @@ class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn):
|
||||
args = {k: v for k, v in locals().items() if k not in ("self", "__class__")}
|
||||
return self._client_sync(self._fit_async, **args)
|
||||
|
||||
# FIXME(trivialfis): arguments differ due to additional parameters like group and qid.
|
||||
# FIXME(trivialfis): arguments differ due to additional parameters like group and
|
||||
# qid.
|
||||
fit.__doc__ = XGBRanker.fit.__doc__
|
||||
|
||||
|
||||
@ -2159,7 +2145,7 @@ class DaskXGBRFRegressor(DaskXGBRegressor):
|
||||
# pylint: disable=unused-argument
|
||||
def fit(
|
||||
self,
|
||||
X: _DaskCollection,
|
||||
X: _DataT,
|
||||
y: _DaskCollection,
|
||||
*,
|
||||
sample_weight: Optional[_DaskCollection] = None,
|
||||
@ -2223,7 +2209,7 @@ class DaskXGBRFClassifier(DaskXGBClassifier):
|
||||
# pylint: disable=unused-argument
|
||||
def fit(
|
||||
self,
|
||||
X: _DaskCollection,
|
||||
X: _DataT,
|
||||
y: _DaskCollection,
|
||||
*,
|
||||
sample_weight: Optional[_DaskCollection] = None,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user