Fix mypy error with the latest dask. (#8052)

* Fix mypy error with latest dask.

Dask is adding type hints to its codebase and as the result, checks in XGBoost can be
performed more rigorously.

- Remove compatibility with old dask version where multi lock was missing.
- Restrict input of `X` to be non-series.
- Adopt latest definition of `Delayed`.
- Avoid passing optional `host_ip`.
- Avoid deprecated `worker.nthreads`.
This commit is contained in:
Jiaming Yuan 2022-07-09 08:02:42 +08:00 committed by GitHub
parent 210eb471e9
commit a5bc8e2c6a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -68,20 +68,20 @@ from .sklearn import xgboost_model_doc
from .sklearn import _cls_predict_proba
from .sklearn import XGBRanker
if TYPE_CHECKING:
from dask import dataframe as dd
from dask import array as da
from dask import delayed as ddelayed
import dask
import distributed
else:
dd = LazyLoader("dd", globals(), "dask.dataframe")
da = LazyLoader("da", globals(), "dask.array")
ddelayed = LazyLoader("Delayed", globals(), "dask.delayed")
dask = LazyLoader("dask", globals(), "dask")
distributed = LazyLoader("distributed", globals(), "dask.distributed")
_DaskCollection = Union["da.Array", "dd.DataFrame", "dd.Series"]
_DataT = Union["da.Array", "dd.DataFrame"] # do not use series as predictor
try:
from mypy_extensions import TypedDict
@ -123,8 +123,8 @@ __all__ = [
# - Write everything with async, then use distributed Client sync function to do the
# switch.
# - Use Any for type hint when the return value can be union of Awaitable and plain
# value. This is caused by Client.sync can return both types depending on context.
# Right now there's no good way to silent:
# value. This is caused by Client.sync can return both types depending on
# context. Right now there's no good way to silent:
#
# await train(...)
#
@ -134,34 +134,6 @@ __all__ = [
LOGGER = logging.getLogger("[xgboost.dask]")
def _multi_lock() -> Any:
"""MultiLock is only available on latest distributed. See:
https://github.com/dask/distributed/pull/4503
"""
try:
from distributed import MultiLock
except ImportError:
class MultiLock: # type:ignore
def __init__(self, *args: Any, **kwargs: Any) -> None:
pass
def __enter__(self) -> "MultiLock":
return self
def __exit__(self, *args: Any, **kwargs: Any) -> None:
return
async def __aenter__(self) -> "MultiLock":
return self
async def __aexit__(self, *args: Any, **kwargs: Any) -> None:
return
return MultiLock
def _try_start_tracker(
n_workers: int,
addrs: List[Union[Optional[str], Optional[Tuple[str, int]]]],
@ -286,8 +258,8 @@ def _xgb_get_client(client: Optional["distributed.Client"]) -> "distributed.Clie
class DaskDMatrix:
# pylint: disable=missing-docstring, too-many-instance-attributes
"""DMatrix holding on references to Dask DataFrame or Dask Array. Constructing a
`DaskDMatrix` forces all lazy computation to be carried out. Wait for the input data
explicitly if you want to see actual computation of constructing `DaskDMatrix`.
`DaskDMatrix` forces all lazy computation to be carried out. Wait for the input
data explicitly if you want to see actual computation of constructing `DaskDMatrix`.
See doc for :py:obj:`xgboost.DMatrix` constructor for other parameters. DaskDMatrix
accepts only dask collection.
@ -302,8 +274,8 @@ class DaskDMatrix:
Parameters
----------
client :
Specify the dask client used for training. Use default client returned from dask
if it's set to None.
Specify the dask client used for training. Use default client returned from
dask if it's set to None.
"""
@ -311,7 +283,7 @@ class DaskDMatrix:
def __init__(
self,
client: "distributed.Client",
data: _DaskCollection,
data: _DataT,
label: Optional[_DaskCollection] = None,
*,
weight: Optional[_DaskCollection] = None,
@ -352,7 +324,7 @@ class DaskDMatrix:
self._n_cols = data.shape[1]
assert isinstance(self._n_cols, int)
self.worker_map: Dict[str, "distributed.Future"] = defaultdict(list)
self.worker_map: Dict[str, List[distributed.Future]] = defaultdict(list)
self.is_quantile: bool = False
self._init = client.sync(
@ -374,7 +346,7 @@ class DaskDMatrix:
async def _map_local_data(
self,
client: "distributed.Client",
data: _DaskCollection,
data: _DataT,
label: Optional[_DaskCollection] = None,
weights: Optional[_DaskCollection] = None,
base_margin: Optional[_DaskCollection] = None,
@ -384,6 +356,7 @@ class DaskDMatrix:
label_upper_bound: Optional[_DaskCollection] = None,
) -> "DaskDMatrix":
"""Obtain references to local data."""
from dask.delayed import Delayed
def inconsistent(
left: List[Any], left_name: str, right: List[Any], right_name: str
@ -404,7 +377,7 @@ class DaskDMatrix:
" chunks=(partition_size, X.shape[1])"
)
def to_delayed(d: _DaskCollection) -> List[ddelayed.Delayed]:
def to_delayed(d: _DaskCollection) -> List[Delayed]:
"""Breaking data into partitions, a trick borrowed from dask_xgboost. `to_delayed`
downgrades high-level objects into numpy or pandas equivalents .
@ -414,17 +387,15 @@ class DaskDMatrix:
if isinstance(delayed_obj, numpy.ndarray):
# da.Array returns an array to delayed objects
check_columns(delayed_obj)
delayed_list: List[ddelayed.Delayed] = delayed_obj.flatten().tolist()
delayed_list: List[Delayed] = delayed_obj.flatten().tolist()
else:
# dd.DataFrame
delayed_list = delayed_obj
return delayed_list
OpDelayed = TypeVar("OpDelayed", _DaskCollection, None)
def flatten_meta(meta: OpDelayed) -> OpDelayed:
def flatten_meta(meta: Optional[_DaskCollection]) -> Optional[List[Delayed]]:
if meta is not None:
meta_parts: List[ddelayed.Delayed] = to_delayed(meta)
meta_parts: List[Delayed] = to_delayed(meta)
return meta_parts
return None
@ -436,9 +407,9 @@ class DaskDMatrix:
ll_parts = flatten_meta(label_lower_bound)
lu_parts = flatten_meta(label_upper_bound)
parts: Dict[str, List[ddelayed.Delayed]] = {"data": X_parts}
parts: Dict[str, List[Delayed]] = {"data": X_parts}
def append_meta(m_parts: Optional[List[ddelayed.Delayed]], name: str) -> None:
def append_meta(m_parts: Optional[List[Delayed]], name: str) -> None:
if m_parts is not None:
assert len(X_parts) == len(m_parts), inconsistent(
X_parts, "X", m_parts, name
@ -455,16 +426,16 @@ class DaskDMatrix:
# [(x0, x1, ..), (y0, y1, ..), ..] in delayed form
# turn into list of dictionaries.
packed_parts: List[Dict[str, ddelayed.Delayed]] = []
packed_parts: List[Dict[str, Delayed]] = []
for i in range(len(X_parts)):
part_dict: Dict[str, ddelayed.Delayed] = {}
part_dict: Dict[str, Delayed] = {}
for key, value in parts.items():
part_dict[key] = value[i]
packed_parts.append(part_dict)
# delay the zipped result
# pylint: disable=no-member
delayed_parts: List[ddelayed.Delayed] = list(map(dask.delayed, packed_parts))
delayed_parts: List[Delayed] = list(map(dask.delayed, packed_parts))
# At this point, the mental model should look like:
# [(x0, y0, ..), (x1, y1, ..), ..] in delayed form
@ -662,12 +633,12 @@ class DaskPartitionIter(DataIter): # pylint: disable=R0902
class DaskDeviceQuantileDMatrix(DaskDMatrix):
"""Specialized data type for `gpu_hist` tree method. This class is used to reduce the
memory usage by eliminating data copies. Internally the all partitions/chunks of data
are merged by weighted GK sketching. So the number of partitions from dask may affect
training accuracy as GK generates bounded error for each merge. See doc string for
:py:obj:`xgboost.DeviceQuantileDMatrix` and :py:obj:`xgboost.DMatrix` for other
parameters.
"""Specialized data type for `gpu_hist` tree method. This class is used to reduce
the memory usage by eliminating data copies. Internally the all partitions/chunks
of data are merged by weighted GK sketching. So the number of partitions from dask
may affect training accuracy as GK generates bounded error for each merge. See doc
string for :py:obj:`xgboost.DeviceQuantileDMatrix` and :py:obj:`xgboost.DMatrix` for
other parameters.
.. versionadded:: 1.2.0
@ -681,7 +652,7 @@ class DaskDeviceQuantileDMatrix(DaskDMatrix):
def __init__(
self,
client: "distributed.Client",
data: _DaskCollection,
data: _DataT,
label: Optional[_DaskCollection] = None,
*,
weight: Optional[_DaskCollection] = None,
@ -845,6 +816,7 @@ async def _get_rabit_args(
if k not in valid_config:
raise ValueError(f"Unknown configuration: {k}")
host_ip = dconfig.get("scheduler_address", None)
if host_ip is not None:
try:
host_ip, port = distributed.comm.get_address_host_port(host_ip)
except ValueError:
@ -932,16 +904,18 @@ async def _train_async(
worker = distributed.get_worker()
local_param = parameters.copy()
n_threads = 0
# dask worker nthreads, "state" is available in 2022.6.1
dwnt = worker.state.nthreads if hasattr(worker, "state") else worker.nthreads
for p in ["nthread", "n_jobs"]:
if (
local_param.get(p, None) is not None
and local_param.get(p, worker.nthreads) != worker.nthreads
and local_param.get(p, dwnt) != dwnt
):
LOGGER.info("Overriding `nthreads` defined in dask worker.")
n_threads = local_param[p]
break
if n_threads == 0 or n_threads is None:
n_threads = worker.nthreads
n_threads = dwnt
local_param.update({"nthread": n_threads, "n_jobs": n_threads})
local_history: TrainingCallback.EvalsLog = {}
with RabitContext(rabit_args), config.config_context(**global_config):
@ -977,7 +951,7 @@ async def _train_async(
ret = None
return ret
async with _multi_lock()(workers, client):
async with distributed.MultiLock(workers, client):
if evals is not None:
evals_data = [d for d, n in evals]
evals_name = [n for d, n in evals]
@ -1030,8 +1004,8 @@ def train( # pylint: disable=unused-argument
Parameters
----------
client :
Specify the dask client used for training. Use default client returned from dask
if it's set to None.
Specify the dask client used for training. Use default client returned from
dask if it's set to None.
Returns
-------
@ -1068,8 +1042,8 @@ def _maybe_dataframe(
if _can_output_df(is_df, prediction.shape):
# Need to preserve the index for dataframe.
# See issue: https://github.com/dmlc/xgboost/issues/6939
# In older versions of dask, the partition is actually a numpy array when input is
# dataframe.
# In older versions of dask, the partition is actually a numpy array when input
# is dataframe.
index = getattr(data, "index", None)
if lazy_isinstance(data, "cudf.core.dataframe", "DataFrame"):
import cudf
@ -1093,7 +1067,7 @@ def _maybe_dataframe(
async def _direct_predict_impl( # pylint: disable=too-many-branches
mapped_predict: Callable,
booster: "distributed.Future",
data: _DaskCollection,
data: _DataT,
base_margin: Optional[_DaskCollection],
output_shape: Tuple[int, ...],
meta: Dict[int, str],
@ -1111,7 +1085,9 @@ async def _direct_predict_impl( # pylint: disable=too-many-branches
if _can_output_df(isinstance(data, dd.DataFrame), output_shape):
if base_margin is not None and isinstance(base_margin, da.Array):
# Easier for map_partitions
base_margin_df: Optional[dd.DataFrame] = base_margin.to_dask_dataframe()
base_margin_df: Optional[
Union[dd.DataFrame, dd.Series]
] = base_margin.to_dask_dataframe()
else:
base_margin_df = base_margin
predictions = dd.map_partitions(
@ -1149,6 +1125,9 @@ async def _direct_predict_impl( # pylint: disable=too-many-branches
# Somehow dask fail to infer output shape change for 2-dim prediction, and
# `chunks = (None, output_shape[1])` doesn't work due to None is not
# supported in map_blocks.
# data must be an array here as dataframe + 2-dim output predict will return
# a dataframe instead.
chunks: Optional[List[Tuple]] = list(data.chunks)
assert isinstance(chunks, list)
chunks[1] = (output_shape[1],)
@ -1200,9 +1179,10 @@ async def _get_model_future(
booster = await client.scatter(model["booster"], broadcast=True)
elif isinstance(model, distributed.Future):
booster = model
if booster.type is not Booster:
t = booster.type
if t is not Booster:
raise TypeError(
f"Underlying type of model future should be `Booster`, got {booster.type}"
f"Underlying type of model future should be `Booster`, got {t}"
)
else:
raise TypeError(_expect([Booster, dict, distributed.Future], type(model)))
@ -1214,7 +1194,7 @@ async def _predict_async(
client: "distributed.Client",
global_config: Dict[str, Any],
model: Union[Booster, Dict, "distributed.Future"],
data: _DaskCollection,
data: _DataT,
output_margin: bool,
missing: float,
pred_leaf: bool,
@ -1236,7 +1216,7 @@ async def _predict_async(
m = DMatrix(
data=partition,
missing=missing,
enable_categorical=_has_categorical(booster, partition)
enable_categorical=_has_categorical(booster, partition),
)
predt = booster.predict(
data=m,
@ -1358,9 +1338,9 @@ async def _predict_async(
def predict( # pylint: disable=unused-argument
client: "distributed.Client",
client: Optional["distributed.Client"],
model: Union[TrainReturnT, Booster, "distributed.Future"],
data: Union[DaskDMatrix, _DaskCollection],
data: Union[DaskDMatrix, _DataT],
output_margin: bool = False,
missing: float = numpy.nan,
pred_leaf: bool = False,
@ -1375,10 +1355,10 @@ def predict( # pylint: disable=unused-argument
.. note::
Using ``inplace_predict`` might be faster when some features are not needed. See
:py:meth:`xgboost.Booster.predict` for details on various parameters. When output
has more than 2 dimensions (shap value, leaf with strict_shape), input should be
``da.Array`` or ``DaskDMatrix``.
Using ``inplace_predict`` might be faster when some features are not needed.
See :py:meth:`xgboost.Booster.predict` for details on various parameters. When
output has more than 2 dimensions (shap value, leaf with strict_shape), input
should be ``da.Array`` or ``DaskDMatrix``.
.. versionadded:: 1.0.0
@ -1400,8 +1380,8 @@ def predict( # pylint: disable=unused-argument
Returns
-------
prediction: dask.array.Array/dask.dataframe.Series
When input data is ``dask.array.Array`` or ``DaskDMatrix``, the return value is an
array, when input data is ``dask.dataframe.DataFrame``, return value can be
When input data is ``dask.array.Array`` or ``DaskDMatrix``, the return value is
an array, when input data is ``dask.dataframe.DataFrame``, return value can be
``dask.dataframe.Series``, ``dask.dataframe.DataFrame``, depending on the output
shape.
@ -1415,7 +1395,7 @@ async def _inplace_predict_async( # pylint: disable=too-many-branches
client: "distributed.Client",
global_config: Dict[str, Any],
model: Union[Booster, Dict, "distributed.Future"],
data: _DaskCollection,
data: _DataT,
iteration_range: Tuple[int, int],
predict_type: str,
missing: float,
@ -1471,9 +1451,9 @@ async def _inplace_predict_async( # pylint: disable=too-many-branches
def inplace_predict( # pylint: disable=unused-argument
client: "distributed.Client",
client: Optional["distributed.Client"],
model: Union[TrainReturnT, Booster, "distributed.Future"],
data: _DaskCollection,
data: _DataT,
iteration_range: Tuple[int, int] = (0, 0),
predict_type: str = "value",
missing: float = numpy.nan,
@ -1481,7 +1461,8 @@ def inplace_predict( # pylint: disable=unused-argument
base_margin: Optional[_DaskCollection] = None,
strict_shape: bool = False,
) -> Any:
"""Inplace prediction. See doc in :py:meth:`xgboost.Booster.inplace_predict` for details.
"""Inplace prediction. See doc in :py:meth:`xgboost.Booster.inplace_predict` for
details.
.. versionadded:: 1.1.0
@ -1514,8 +1495,8 @@ def inplace_predict( # pylint: disable=unused-argument
Returns
-------
prediction :
When input data is ``dask.array.Array``, the return value is an array, when input
data is ``dask.dataframe.DataFrame``, return value can be
When input data is ``dask.array.Array``, the return value is an array, when
input data is ``dask.dataframe.DataFrame``, return value can be
``dask.dataframe.Series``, ``dask.dataframe.DataFrame``, depending on the output
shape.
@ -1531,7 +1512,7 @@ def inplace_predict( # pylint: disable=unused-argument
async def _async_wrap_evaluation_matrices(
client: "distributed.Client", **kwargs: Any
client: Optional["distributed.Client"], **kwargs: Any
) -> Tuple[DaskDMatrix, Optional[List[Tuple[DaskDMatrix, str]]]]:
"""A switch function for async environment."""
@ -1561,7 +1542,7 @@ def _set_worker_client(
model.client = client
yield model
finally:
model.client = None
model.client = None # type:ignore
class DaskScikitLearnBase(XGBModel):
@ -1571,7 +1552,7 @@ class DaskScikitLearnBase(XGBModel):
async def _predict_async(
self,
data: _DaskCollection,
data: _DataT,
output_margin: bool,
validate_features: bool,
base_margin: Optional[_DaskCollection],
@ -1597,7 +1578,7 @@ class DaskScikitLearnBase(XGBModel):
data=data,
base_margin=base_margin,
missing=self.missing,
feature_types=self.feature_types
feature_types=self.feature_types,
)
predts = await predict(
self.client,
@ -1611,7 +1592,7 @@ class DaskScikitLearnBase(XGBModel):
def predict(
self,
X: _DaskCollection,
X: _DataT,
output_margin: bool = False,
ntree_limit: Optional[int] = None,
validate_features: bool = True,
@ -1632,12 +1613,15 @@ class DaskScikitLearnBase(XGBModel):
async def _apply_async(
self,
X: _DaskCollection,
X: _DataT,
iteration_range: Optional[Tuple[int, int]] = None,
) -> Any:
iteration_range = self._get_iteration_range(iteration_range)
test_dmatrix = await DaskDMatrix(
self.client, data=X, missing=self.missing, feature_types=self.feature_types,
self.client,
data=X,
missing=self.missing,
feature_types=self.feature_types,
)
predts = await predict(
self.client,
@ -1650,7 +1634,7 @@ class DaskScikitLearnBase(XGBModel):
def apply(
self,
X: _DaskCollection,
X: _DataT,
ntree_limit: Optional[int] = None,
iteration_range: Optional[Tuple[int, int]] = None,
) -> Any:
@ -1685,8 +1669,8 @@ class DaskScikitLearnBase(XGBModel):
@client.setter
def client(self, clt: "distributed.Client") -> None:
# calling `worker_client' doesn't return the correct `asynchronous` attribute, so
# we have to pass it ourselves.
# calling `worker_client' doesn't return the correct `asynchronous` attribute,
# so we have to pass it ourselves.
self._asynchronous = clt.asynchronous if clt is not None else False
self._client = clt
@ -1720,9 +1704,10 @@ class DaskScikitLearnBase(XGBModel):
)
class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
"""dummy doc string to workaround pylint, replaced by the decorator."""
async def _fit_async(
self,
X: _DaskCollection,
X: _DataT,
y: _DaskCollection,
sample_weight: Optional[_DaskCollection],
base_margin: Optional[_DaskCollection],
@ -1789,7 +1774,7 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
@_deprecate_positional_args
def fit(
self,
X: _DaskCollection,
X: _DataT,
y: _DaskCollection,
*,
sample_weight: Optional[_DaskCollection] = None,
@ -1817,7 +1802,7 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
# pylint: disable=missing-class-docstring
async def _fit_async(
self,
X: _DaskCollection,
X: _DataT,
y: _DaskCollection,
sample_weight: Optional[_DaskCollection],
base_margin: Optional[_DaskCollection],
@ -1898,7 +1883,7 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
# pylint: disable=unused-argument
def fit(
self,
X: _DaskCollection,
X: _DataT,
y: _DaskCollection,
*,
sample_weight: Optional[_DaskCollection] = None,
@ -1919,7 +1904,7 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
async def _predict_proba_async(
self,
X: _DaskCollection,
X: _DataT,
validate_features: bool,
base_margin: Optional[_DaskCollection],
iteration_range: Optional[Tuple[int, int]],
@ -1965,7 +1950,7 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
async def _predict_async(
self,
data: _DaskCollection,
data: _DataT,
output_margin: bool,
validate_features: bool,
base_margin: Optional[_DaskCollection],
@ -2014,7 +1999,7 @@ class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn):
async def _fit_async(
self,
X: _DaskCollection,
X: _DataT,
y: _DaskCollection,
group: Optional[_DaskCollection],
qid: Optional[_DaskCollection],
@ -2090,7 +2075,7 @@ class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn):
@_deprecate_positional_args
def fit(
self,
X: _DaskCollection,
X: _DataT,
y: _DaskCollection,
*,
group: Optional[_DaskCollection] = None,
@ -2113,7 +2098,8 @@ class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn):
args = {k: v for k, v in locals().items() if k not in ("self", "__class__")}
return self._client_sync(self._fit_async, **args)
# FIXME(trivialfis): arguments differ due to additional parameters like group and qid.
# FIXME(trivialfis): arguments differ due to additional parameters like group and
# qid.
fit.__doc__ = XGBRanker.fit.__doc__
@ -2159,7 +2145,7 @@ class DaskXGBRFRegressor(DaskXGBRegressor):
# pylint: disable=unused-argument
def fit(
self,
X: _DaskCollection,
X: _DataT,
y: _DaskCollection,
*,
sample_weight: Optional[_DaskCollection] = None,
@ -2223,7 +2209,7 @@ class DaskXGBRFClassifier(DaskXGBClassifier):
# pylint: disable=unused-argument
def fit(
self,
X: _DaskCollection,
X: _DataT,
y: _DaskCollection,
*,
sample_weight: Optional[_DaskCollection] = None,