Fix mypy error with the latest dask. (#8052)
* Fix mypy error with latest dask. Dask is adding type hints to its codebase and as the result, checks in XGBoost can be performed more rigorously. - Remove compatibility with old dask version where multi lock was missing. - Restrict input of `X` to be non-series. - Adopt latest definition of `Delayed`. - Avoid passing optional `host_ip`. - Avoid deprecated `worker.nthreads`.
This commit is contained in:
parent
210eb471e9
commit
a5bc8e2c6a
@ -68,20 +68,20 @@ from .sklearn import xgboost_model_doc
|
|||||||
from .sklearn import _cls_predict_proba
|
from .sklearn import _cls_predict_proba
|
||||||
from .sklearn import XGBRanker
|
from .sklearn import XGBRanker
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from dask import dataframe as dd
|
from dask import dataframe as dd
|
||||||
from dask import array as da
|
from dask import array as da
|
||||||
from dask import delayed as ddelayed
|
|
||||||
import dask
|
import dask
|
||||||
import distributed
|
import distributed
|
||||||
else:
|
else:
|
||||||
dd = LazyLoader("dd", globals(), "dask.dataframe")
|
dd = LazyLoader("dd", globals(), "dask.dataframe")
|
||||||
da = LazyLoader("da", globals(), "dask.array")
|
da = LazyLoader("da", globals(), "dask.array")
|
||||||
ddelayed = LazyLoader("Delayed", globals(), "dask.delayed")
|
|
||||||
dask = LazyLoader("dask", globals(), "dask")
|
dask = LazyLoader("dask", globals(), "dask")
|
||||||
distributed = LazyLoader("distributed", globals(), "dask.distributed")
|
distributed = LazyLoader("distributed", globals(), "dask.distributed")
|
||||||
|
|
||||||
_DaskCollection = Union["da.Array", "dd.DataFrame", "dd.Series"]
|
_DaskCollection = Union["da.Array", "dd.DataFrame", "dd.Series"]
|
||||||
|
_DataT = Union["da.Array", "dd.DataFrame"] # do not use series as predictor
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from mypy_extensions import TypedDict
|
from mypy_extensions import TypedDict
|
||||||
@ -123,8 +123,8 @@ __all__ = [
|
|||||||
# - Write everything with async, then use distributed Client sync function to do the
|
# - Write everything with async, then use distributed Client sync function to do the
|
||||||
# switch.
|
# switch.
|
||||||
# - Use Any for type hint when the return value can be union of Awaitable and plain
|
# - Use Any for type hint when the return value can be union of Awaitable and plain
|
||||||
# value. This is caused by Client.sync can return both types depending on context.
|
# value. This is caused by Client.sync can return both types depending on
|
||||||
# Right now there's no good way to silent:
|
# context. Right now there's no good way to silent:
|
||||||
#
|
#
|
||||||
# await train(...)
|
# await train(...)
|
||||||
#
|
#
|
||||||
@ -134,34 +134,6 @@ __all__ = [
|
|||||||
LOGGER = logging.getLogger("[xgboost.dask]")
|
LOGGER = logging.getLogger("[xgboost.dask]")
|
||||||
|
|
||||||
|
|
||||||
def _multi_lock() -> Any:
|
|
||||||
"""MultiLock is only available on latest distributed. See:
|
|
||||||
|
|
||||||
https://github.com/dask/distributed/pull/4503
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
from distributed import MultiLock
|
|
||||||
except ImportError:
|
|
||||||
|
|
||||||
class MultiLock: # type:ignore
|
|
||||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def __enter__(self) -> "MultiLock":
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __exit__(self, *args: Any, **kwargs: Any) -> None:
|
|
||||||
return
|
|
||||||
|
|
||||||
async def __aenter__(self) -> "MultiLock":
|
|
||||||
return self
|
|
||||||
|
|
||||||
async def __aexit__(self, *args: Any, **kwargs: Any) -> None:
|
|
||||||
return
|
|
||||||
|
|
||||||
return MultiLock
|
|
||||||
|
|
||||||
|
|
||||||
def _try_start_tracker(
|
def _try_start_tracker(
|
||||||
n_workers: int,
|
n_workers: int,
|
||||||
addrs: List[Union[Optional[str], Optional[Tuple[str, int]]]],
|
addrs: List[Union[Optional[str], Optional[Tuple[str, int]]]],
|
||||||
@ -286,8 +258,8 @@ def _xgb_get_client(client: Optional["distributed.Client"]) -> "distributed.Clie
|
|||||||
class DaskDMatrix:
|
class DaskDMatrix:
|
||||||
# pylint: disable=missing-docstring, too-many-instance-attributes
|
# pylint: disable=missing-docstring, too-many-instance-attributes
|
||||||
"""DMatrix holding on references to Dask DataFrame or Dask Array. Constructing a
|
"""DMatrix holding on references to Dask DataFrame or Dask Array. Constructing a
|
||||||
`DaskDMatrix` forces all lazy computation to be carried out. Wait for the input data
|
`DaskDMatrix` forces all lazy computation to be carried out. Wait for the input
|
||||||
explicitly if you want to see actual computation of constructing `DaskDMatrix`.
|
data explicitly if you want to see actual computation of constructing `DaskDMatrix`.
|
||||||
|
|
||||||
See doc for :py:obj:`xgboost.DMatrix` constructor for other parameters. DaskDMatrix
|
See doc for :py:obj:`xgboost.DMatrix` constructor for other parameters. DaskDMatrix
|
||||||
accepts only dask collection.
|
accepts only dask collection.
|
||||||
@ -302,8 +274,8 @@ class DaskDMatrix:
|
|||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
client :
|
client :
|
||||||
Specify the dask client used for training. Use default client returned from dask
|
Specify the dask client used for training. Use default client returned from
|
||||||
if it's set to None.
|
dask if it's set to None.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -311,7 +283,7 @@ class DaskDMatrix:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
client: "distributed.Client",
|
client: "distributed.Client",
|
||||||
data: _DaskCollection,
|
data: _DataT,
|
||||||
label: Optional[_DaskCollection] = None,
|
label: Optional[_DaskCollection] = None,
|
||||||
*,
|
*,
|
||||||
weight: Optional[_DaskCollection] = None,
|
weight: Optional[_DaskCollection] = None,
|
||||||
@ -352,7 +324,7 @@ class DaskDMatrix:
|
|||||||
|
|
||||||
self._n_cols = data.shape[1]
|
self._n_cols = data.shape[1]
|
||||||
assert isinstance(self._n_cols, int)
|
assert isinstance(self._n_cols, int)
|
||||||
self.worker_map: Dict[str, "distributed.Future"] = defaultdict(list)
|
self.worker_map: Dict[str, List[distributed.Future]] = defaultdict(list)
|
||||||
self.is_quantile: bool = False
|
self.is_quantile: bool = False
|
||||||
|
|
||||||
self._init = client.sync(
|
self._init = client.sync(
|
||||||
@ -374,7 +346,7 @@ class DaskDMatrix:
|
|||||||
async def _map_local_data(
|
async def _map_local_data(
|
||||||
self,
|
self,
|
||||||
client: "distributed.Client",
|
client: "distributed.Client",
|
||||||
data: _DaskCollection,
|
data: _DataT,
|
||||||
label: Optional[_DaskCollection] = None,
|
label: Optional[_DaskCollection] = None,
|
||||||
weights: Optional[_DaskCollection] = None,
|
weights: Optional[_DaskCollection] = None,
|
||||||
base_margin: Optional[_DaskCollection] = None,
|
base_margin: Optional[_DaskCollection] = None,
|
||||||
@ -384,6 +356,7 @@ class DaskDMatrix:
|
|||||||
label_upper_bound: Optional[_DaskCollection] = None,
|
label_upper_bound: Optional[_DaskCollection] = None,
|
||||||
) -> "DaskDMatrix":
|
) -> "DaskDMatrix":
|
||||||
"""Obtain references to local data."""
|
"""Obtain references to local data."""
|
||||||
|
from dask.delayed import Delayed
|
||||||
|
|
||||||
def inconsistent(
|
def inconsistent(
|
||||||
left: List[Any], left_name: str, right: List[Any], right_name: str
|
left: List[Any], left_name: str, right: List[Any], right_name: str
|
||||||
@ -404,7 +377,7 @@ class DaskDMatrix:
|
|||||||
" chunks=(partition_size, X.shape[1])"
|
" chunks=(partition_size, X.shape[1])"
|
||||||
)
|
)
|
||||||
|
|
||||||
def to_delayed(d: _DaskCollection) -> List[ddelayed.Delayed]:
|
def to_delayed(d: _DaskCollection) -> List[Delayed]:
|
||||||
"""Breaking data into partitions, a trick borrowed from dask_xgboost. `to_delayed`
|
"""Breaking data into partitions, a trick borrowed from dask_xgboost. `to_delayed`
|
||||||
downgrades high-level objects into numpy or pandas equivalents .
|
downgrades high-level objects into numpy or pandas equivalents .
|
||||||
|
|
||||||
@ -414,17 +387,15 @@ class DaskDMatrix:
|
|||||||
if isinstance(delayed_obj, numpy.ndarray):
|
if isinstance(delayed_obj, numpy.ndarray):
|
||||||
# da.Array returns an array to delayed objects
|
# da.Array returns an array to delayed objects
|
||||||
check_columns(delayed_obj)
|
check_columns(delayed_obj)
|
||||||
delayed_list: List[ddelayed.Delayed] = delayed_obj.flatten().tolist()
|
delayed_list: List[Delayed] = delayed_obj.flatten().tolist()
|
||||||
else:
|
else:
|
||||||
# dd.DataFrame
|
# dd.DataFrame
|
||||||
delayed_list = delayed_obj
|
delayed_list = delayed_obj
|
||||||
return delayed_list
|
return delayed_list
|
||||||
|
|
||||||
OpDelayed = TypeVar("OpDelayed", _DaskCollection, None)
|
def flatten_meta(meta: Optional[_DaskCollection]) -> Optional[List[Delayed]]:
|
||||||
|
|
||||||
def flatten_meta(meta: OpDelayed) -> OpDelayed:
|
|
||||||
if meta is not None:
|
if meta is not None:
|
||||||
meta_parts: List[ddelayed.Delayed] = to_delayed(meta)
|
meta_parts: List[Delayed] = to_delayed(meta)
|
||||||
return meta_parts
|
return meta_parts
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -436,9 +407,9 @@ class DaskDMatrix:
|
|||||||
ll_parts = flatten_meta(label_lower_bound)
|
ll_parts = flatten_meta(label_lower_bound)
|
||||||
lu_parts = flatten_meta(label_upper_bound)
|
lu_parts = flatten_meta(label_upper_bound)
|
||||||
|
|
||||||
parts: Dict[str, List[ddelayed.Delayed]] = {"data": X_parts}
|
parts: Dict[str, List[Delayed]] = {"data": X_parts}
|
||||||
|
|
||||||
def append_meta(m_parts: Optional[List[ddelayed.Delayed]], name: str) -> None:
|
def append_meta(m_parts: Optional[List[Delayed]], name: str) -> None:
|
||||||
if m_parts is not None:
|
if m_parts is not None:
|
||||||
assert len(X_parts) == len(m_parts), inconsistent(
|
assert len(X_parts) == len(m_parts), inconsistent(
|
||||||
X_parts, "X", m_parts, name
|
X_parts, "X", m_parts, name
|
||||||
@ -455,16 +426,16 @@ class DaskDMatrix:
|
|||||||
# [(x0, x1, ..), (y0, y1, ..), ..] in delayed form
|
# [(x0, x1, ..), (y0, y1, ..), ..] in delayed form
|
||||||
|
|
||||||
# turn into list of dictionaries.
|
# turn into list of dictionaries.
|
||||||
packed_parts: List[Dict[str, ddelayed.Delayed]] = []
|
packed_parts: List[Dict[str, Delayed]] = []
|
||||||
for i in range(len(X_parts)):
|
for i in range(len(X_parts)):
|
||||||
part_dict: Dict[str, ddelayed.Delayed] = {}
|
part_dict: Dict[str, Delayed] = {}
|
||||||
for key, value in parts.items():
|
for key, value in parts.items():
|
||||||
part_dict[key] = value[i]
|
part_dict[key] = value[i]
|
||||||
packed_parts.append(part_dict)
|
packed_parts.append(part_dict)
|
||||||
|
|
||||||
# delay the zipped result
|
# delay the zipped result
|
||||||
# pylint: disable=no-member
|
# pylint: disable=no-member
|
||||||
delayed_parts: List[ddelayed.Delayed] = list(map(dask.delayed, packed_parts))
|
delayed_parts: List[Delayed] = list(map(dask.delayed, packed_parts))
|
||||||
# At this point, the mental model should look like:
|
# At this point, the mental model should look like:
|
||||||
# [(x0, y0, ..), (x1, y1, ..), ..] in delayed form
|
# [(x0, y0, ..), (x1, y1, ..), ..] in delayed form
|
||||||
|
|
||||||
@ -662,12 +633,12 @@ class DaskPartitionIter(DataIter): # pylint: disable=R0902
|
|||||||
|
|
||||||
|
|
||||||
class DaskDeviceQuantileDMatrix(DaskDMatrix):
|
class DaskDeviceQuantileDMatrix(DaskDMatrix):
|
||||||
"""Specialized data type for `gpu_hist` tree method. This class is used to reduce the
|
"""Specialized data type for `gpu_hist` tree method. This class is used to reduce
|
||||||
memory usage by eliminating data copies. Internally the all partitions/chunks of data
|
the memory usage by eliminating data copies. Internally the all partitions/chunks
|
||||||
are merged by weighted GK sketching. So the number of partitions from dask may affect
|
of data are merged by weighted GK sketching. So the number of partitions from dask
|
||||||
training accuracy as GK generates bounded error for each merge. See doc string for
|
may affect training accuracy as GK generates bounded error for each merge. See doc
|
||||||
:py:obj:`xgboost.DeviceQuantileDMatrix` and :py:obj:`xgboost.DMatrix` for other
|
string for :py:obj:`xgboost.DeviceQuantileDMatrix` and :py:obj:`xgboost.DMatrix` for
|
||||||
parameters.
|
other parameters.
|
||||||
|
|
||||||
.. versionadded:: 1.2.0
|
.. versionadded:: 1.2.0
|
||||||
|
|
||||||
@ -681,7 +652,7 @@ class DaskDeviceQuantileDMatrix(DaskDMatrix):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
client: "distributed.Client",
|
client: "distributed.Client",
|
||||||
data: _DaskCollection,
|
data: _DataT,
|
||||||
label: Optional[_DaskCollection] = None,
|
label: Optional[_DaskCollection] = None,
|
||||||
*,
|
*,
|
||||||
weight: Optional[_DaskCollection] = None,
|
weight: Optional[_DaskCollection] = None,
|
||||||
@ -845,10 +816,11 @@ async def _get_rabit_args(
|
|||||||
if k not in valid_config:
|
if k not in valid_config:
|
||||||
raise ValueError(f"Unknown configuration: {k}")
|
raise ValueError(f"Unknown configuration: {k}")
|
||||||
host_ip = dconfig.get("scheduler_address", None)
|
host_ip = dconfig.get("scheduler_address", None)
|
||||||
try:
|
if host_ip is not None:
|
||||||
host_ip, port = distributed.comm.get_address_host_port(host_ip)
|
try:
|
||||||
except ValueError:
|
host_ip, port = distributed.comm.get_address_host_port(host_ip)
|
||||||
pass
|
except ValueError:
|
||||||
|
pass
|
||||||
if host_ip is not None:
|
if host_ip is not None:
|
||||||
user_addr = (host_ip, port)
|
user_addr = (host_ip, port)
|
||||||
else:
|
else:
|
||||||
@ -932,16 +904,18 @@ async def _train_async(
|
|||||||
worker = distributed.get_worker()
|
worker = distributed.get_worker()
|
||||||
local_param = parameters.copy()
|
local_param = parameters.copy()
|
||||||
n_threads = 0
|
n_threads = 0
|
||||||
|
# dask worker nthreads, "state" is available in 2022.6.1
|
||||||
|
dwnt = worker.state.nthreads if hasattr(worker, "state") else worker.nthreads
|
||||||
for p in ["nthread", "n_jobs"]:
|
for p in ["nthread", "n_jobs"]:
|
||||||
if (
|
if (
|
||||||
local_param.get(p, None) is not None
|
local_param.get(p, None) is not None
|
||||||
and local_param.get(p, worker.nthreads) != worker.nthreads
|
and local_param.get(p, dwnt) != dwnt
|
||||||
):
|
):
|
||||||
LOGGER.info("Overriding `nthreads` defined in dask worker.")
|
LOGGER.info("Overriding `nthreads` defined in dask worker.")
|
||||||
n_threads = local_param[p]
|
n_threads = local_param[p]
|
||||||
break
|
break
|
||||||
if n_threads == 0 or n_threads is None:
|
if n_threads == 0 or n_threads is None:
|
||||||
n_threads = worker.nthreads
|
n_threads = dwnt
|
||||||
local_param.update({"nthread": n_threads, "n_jobs": n_threads})
|
local_param.update({"nthread": n_threads, "n_jobs": n_threads})
|
||||||
local_history: TrainingCallback.EvalsLog = {}
|
local_history: TrainingCallback.EvalsLog = {}
|
||||||
with RabitContext(rabit_args), config.config_context(**global_config):
|
with RabitContext(rabit_args), config.config_context(**global_config):
|
||||||
@ -977,7 +951,7 @@ async def _train_async(
|
|||||||
ret = None
|
ret = None
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
async with _multi_lock()(workers, client):
|
async with distributed.MultiLock(workers, client):
|
||||||
if evals is not None:
|
if evals is not None:
|
||||||
evals_data = [d for d, n in evals]
|
evals_data = [d for d, n in evals]
|
||||||
evals_name = [n for d, n in evals]
|
evals_name = [n for d, n in evals]
|
||||||
@ -1030,8 +1004,8 @@ def train( # pylint: disable=unused-argument
|
|||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
client :
|
client :
|
||||||
Specify the dask client used for training. Use default client returned from dask
|
Specify the dask client used for training. Use default client returned from
|
||||||
if it's set to None.
|
dask if it's set to None.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
@ -1068,8 +1042,8 @@ def _maybe_dataframe(
|
|||||||
if _can_output_df(is_df, prediction.shape):
|
if _can_output_df(is_df, prediction.shape):
|
||||||
# Need to preserve the index for dataframe.
|
# Need to preserve the index for dataframe.
|
||||||
# See issue: https://github.com/dmlc/xgboost/issues/6939
|
# See issue: https://github.com/dmlc/xgboost/issues/6939
|
||||||
# In older versions of dask, the partition is actually a numpy array when input is
|
# In older versions of dask, the partition is actually a numpy array when input
|
||||||
# dataframe.
|
# is dataframe.
|
||||||
index = getattr(data, "index", None)
|
index = getattr(data, "index", None)
|
||||||
if lazy_isinstance(data, "cudf.core.dataframe", "DataFrame"):
|
if lazy_isinstance(data, "cudf.core.dataframe", "DataFrame"):
|
||||||
import cudf
|
import cudf
|
||||||
@ -1093,7 +1067,7 @@ def _maybe_dataframe(
|
|||||||
async def _direct_predict_impl( # pylint: disable=too-many-branches
|
async def _direct_predict_impl( # pylint: disable=too-many-branches
|
||||||
mapped_predict: Callable,
|
mapped_predict: Callable,
|
||||||
booster: "distributed.Future",
|
booster: "distributed.Future",
|
||||||
data: _DaskCollection,
|
data: _DataT,
|
||||||
base_margin: Optional[_DaskCollection],
|
base_margin: Optional[_DaskCollection],
|
||||||
output_shape: Tuple[int, ...],
|
output_shape: Tuple[int, ...],
|
||||||
meta: Dict[int, str],
|
meta: Dict[int, str],
|
||||||
@ -1111,7 +1085,9 @@ async def _direct_predict_impl( # pylint: disable=too-many-branches
|
|||||||
if _can_output_df(isinstance(data, dd.DataFrame), output_shape):
|
if _can_output_df(isinstance(data, dd.DataFrame), output_shape):
|
||||||
if base_margin is not None and isinstance(base_margin, da.Array):
|
if base_margin is not None and isinstance(base_margin, da.Array):
|
||||||
# Easier for map_partitions
|
# Easier for map_partitions
|
||||||
base_margin_df: Optional[dd.DataFrame] = base_margin.to_dask_dataframe()
|
base_margin_df: Optional[
|
||||||
|
Union[dd.DataFrame, dd.Series]
|
||||||
|
] = base_margin.to_dask_dataframe()
|
||||||
else:
|
else:
|
||||||
base_margin_df = base_margin
|
base_margin_df = base_margin
|
||||||
predictions = dd.map_partitions(
|
predictions = dd.map_partitions(
|
||||||
@ -1149,6 +1125,9 @@ async def _direct_predict_impl( # pylint: disable=too-many-branches
|
|||||||
# Somehow dask fail to infer output shape change for 2-dim prediction, and
|
# Somehow dask fail to infer output shape change for 2-dim prediction, and
|
||||||
# `chunks = (None, output_shape[1])` doesn't work due to None is not
|
# `chunks = (None, output_shape[1])` doesn't work due to None is not
|
||||||
# supported in map_blocks.
|
# supported in map_blocks.
|
||||||
|
|
||||||
|
# data must be an array here as dataframe + 2-dim output predict will return
|
||||||
|
# a dataframe instead.
|
||||||
chunks: Optional[List[Tuple]] = list(data.chunks)
|
chunks: Optional[List[Tuple]] = list(data.chunks)
|
||||||
assert isinstance(chunks, list)
|
assert isinstance(chunks, list)
|
||||||
chunks[1] = (output_shape[1],)
|
chunks[1] = (output_shape[1],)
|
||||||
@ -1200,9 +1179,10 @@ async def _get_model_future(
|
|||||||
booster = await client.scatter(model["booster"], broadcast=True)
|
booster = await client.scatter(model["booster"], broadcast=True)
|
||||||
elif isinstance(model, distributed.Future):
|
elif isinstance(model, distributed.Future):
|
||||||
booster = model
|
booster = model
|
||||||
if booster.type is not Booster:
|
t = booster.type
|
||||||
|
if t is not Booster:
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
f"Underlying type of model future should be `Booster`, got {booster.type}"
|
f"Underlying type of model future should be `Booster`, got {t}"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise TypeError(_expect([Booster, dict, distributed.Future], type(model)))
|
raise TypeError(_expect([Booster, dict, distributed.Future], type(model)))
|
||||||
@ -1214,7 +1194,7 @@ async def _predict_async(
|
|||||||
client: "distributed.Client",
|
client: "distributed.Client",
|
||||||
global_config: Dict[str, Any],
|
global_config: Dict[str, Any],
|
||||||
model: Union[Booster, Dict, "distributed.Future"],
|
model: Union[Booster, Dict, "distributed.Future"],
|
||||||
data: _DaskCollection,
|
data: _DataT,
|
||||||
output_margin: bool,
|
output_margin: bool,
|
||||||
missing: float,
|
missing: float,
|
||||||
pred_leaf: bool,
|
pred_leaf: bool,
|
||||||
@ -1236,7 +1216,7 @@ async def _predict_async(
|
|||||||
m = DMatrix(
|
m = DMatrix(
|
||||||
data=partition,
|
data=partition,
|
||||||
missing=missing,
|
missing=missing,
|
||||||
enable_categorical=_has_categorical(booster, partition)
|
enable_categorical=_has_categorical(booster, partition),
|
||||||
)
|
)
|
||||||
predt = booster.predict(
|
predt = booster.predict(
|
||||||
data=m,
|
data=m,
|
||||||
@ -1358,9 +1338,9 @@ async def _predict_async(
|
|||||||
|
|
||||||
|
|
||||||
def predict( # pylint: disable=unused-argument
|
def predict( # pylint: disable=unused-argument
|
||||||
client: "distributed.Client",
|
client: Optional["distributed.Client"],
|
||||||
model: Union[TrainReturnT, Booster, "distributed.Future"],
|
model: Union[TrainReturnT, Booster, "distributed.Future"],
|
||||||
data: Union[DaskDMatrix, _DaskCollection],
|
data: Union[DaskDMatrix, _DataT],
|
||||||
output_margin: bool = False,
|
output_margin: bool = False,
|
||||||
missing: float = numpy.nan,
|
missing: float = numpy.nan,
|
||||||
pred_leaf: bool = False,
|
pred_leaf: bool = False,
|
||||||
@ -1375,10 +1355,10 @@ def predict( # pylint: disable=unused-argument
|
|||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
|
|
||||||
Using ``inplace_predict`` might be faster when some features are not needed. See
|
Using ``inplace_predict`` might be faster when some features are not needed.
|
||||||
:py:meth:`xgboost.Booster.predict` for details on various parameters. When output
|
See :py:meth:`xgboost.Booster.predict` for details on various parameters. When
|
||||||
has more than 2 dimensions (shap value, leaf with strict_shape), input should be
|
output has more than 2 dimensions (shap value, leaf with strict_shape), input
|
||||||
``da.Array`` or ``DaskDMatrix``.
|
should be ``da.Array`` or ``DaskDMatrix``.
|
||||||
|
|
||||||
.. versionadded:: 1.0.0
|
.. versionadded:: 1.0.0
|
||||||
|
|
||||||
@ -1400,8 +1380,8 @@ def predict( # pylint: disable=unused-argument
|
|||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
prediction: dask.array.Array/dask.dataframe.Series
|
prediction: dask.array.Array/dask.dataframe.Series
|
||||||
When input data is ``dask.array.Array`` or ``DaskDMatrix``, the return value is an
|
When input data is ``dask.array.Array`` or ``DaskDMatrix``, the return value is
|
||||||
array, when input data is ``dask.dataframe.DataFrame``, return value can be
|
an array, when input data is ``dask.dataframe.DataFrame``, return value can be
|
||||||
``dask.dataframe.Series``, ``dask.dataframe.DataFrame``, depending on the output
|
``dask.dataframe.Series``, ``dask.dataframe.DataFrame``, depending on the output
|
||||||
shape.
|
shape.
|
||||||
|
|
||||||
@ -1415,7 +1395,7 @@ async def _inplace_predict_async( # pylint: disable=too-many-branches
|
|||||||
client: "distributed.Client",
|
client: "distributed.Client",
|
||||||
global_config: Dict[str, Any],
|
global_config: Dict[str, Any],
|
||||||
model: Union[Booster, Dict, "distributed.Future"],
|
model: Union[Booster, Dict, "distributed.Future"],
|
||||||
data: _DaskCollection,
|
data: _DataT,
|
||||||
iteration_range: Tuple[int, int],
|
iteration_range: Tuple[int, int],
|
||||||
predict_type: str,
|
predict_type: str,
|
||||||
missing: float,
|
missing: float,
|
||||||
@ -1471,9 +1451,9 @@ async def _inplace_predict_async( # pylint: disable=too-many-branches
|
|||||||
|
|
||||||
|
|
||||||
def inplace_predict( # pylint: disable=unused-argument
|
def inplace_predict( # pylint: disable=unused-argument
|
||||||
client: "distributed.Client",
|
client: Optional["distributed.Client"],
|
||||||
model: Union[TrainReturnT, Booster, "distributed.Future"],
|
model: Union[TrainReturnT, Booster, "distributed.Future"],
|
||||||
data: _DaskCollection,
|
data: _DataT,
|
||||||
iteration_range: Tuple[int, int] = (0, 0),
|
iteration_range: Tuple[int, int] = (0, 0),
|
||||||
predict_type: str = "value",
|
predict_type: str = "value",
|
||||||
missing: float = numpy.nan,
|
missing: float = numpy.nan,
|
||||||
@ -1481,7 +1461,8 @@ def inplace_predict( # pylint: disable=unused-argument
|
|||||||
base_margin: Optional[_DaskCollection] = None,
|
base_margin: Optional[_DaskCollection] = None,
|
||||||
strict_shape: bool = False,
|
strict_shape: bool = False,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Inplace prediction. See doc in :py:meth:`xgboost.Booster.inplace_predict` for details.
|
"""Inplace prediction. See doc in :py:meth:`xgboost.Booster.inplace_predict` for
|
||||||
|
details.
|
||||||
|
|
||||||
.. versionadded:: 1.1.0
|
.. versionadded:: 1.1.0
|
||||||
|
|
||||||
@ -1514,8 +1495,8 @@ def inplace_predict( # pylint: disable=unused-argument
|
|||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
prediction :
|
prediction :
|
||||||
When input data is ``dask.array.Array``, the return value is an array, when input
|
When input data is ``dask.array.Array``, the return value is an array, when
|
||||||
data is ``dask.dataframe.DataFrame``, return value can be
|
input data is ``dask.dataframe.DataFrame``, return value can be
|
||||||
``dask.dataframe.Series``, ``dask.dataframe.DataFrame``, depending on the output
|
``dask.dataframe.Series``, ``dask.dataframe.DataFrame``, depending on the output
|
||||||
shape.
|
shape.
|
||||||
|
|
||||||
@ -1531,7 +1512,7 @@ def inplace_predict( # pylint: disable=unused-argument
|
|||||||
|
|
||||||
|
|
||||||
async def _async_wrap_evaluation_matrices(
|
async def _async_wrap_evaluation_matrices(
|
||||||
client: "distributed.Client", **kwargs: Any
|
client: Optional["distributed.Client"], **kwargs: Any
|
||||||
) -> Tuple[DaskDMatrix, Optional[List[Tuple[DaskDMatrix, str]]]]:
|
) -> Tuple[DaskDMatrix, Optional[List[Tuple[DaskDMatrix, str]]]]:
|
||||||
"""A switch function for async environment."""
|
"""A switch function for async environment."""
|
||||||
|
|
||||||
@ -1561,7 +1542,7 @@ def _set_worker_client(
|
|||||||
model.client = client
|
model.client = client
|
||||||
yield model
|
yield model
|
||||||
finally:
|
finally:
|
||||||
model.client = None
|
model.client = None # type:ignore
|
||||||
|
|
||||||
|
|
||||||
class DaskScikitLearnBase(XGBModel):
|
class DaskScikitLearnBase(XGBModel):
|
||||||
@ -1571,7 +1552,7 @@ class DaskScikitLearnBase(XGBModel):
|
|||||||
|
|
||||||
async def _predict_async(
|
async def _predict_async(
|
||||||
self,
|
self,
|
||||||
data: _DaskCollection,
|
data: _DataT,
|
||||||
output_margin: bool,
|
output_margin: bool,
|
||||||
validate_features: bool,
|
validate_features: bool,
|
||||||
base_margin: Optional[_DaskCollection],
|
base_margin: Optional[_DaskCollection],
|
||||||
@ -1597,7 +1578,7 @@ class DaskScikitLearnBase(XGBModel):
|
|||||||
data=data,
|
data=data,
|
||||||
base_margin=base_margin,
|
base_margin=base_margin,
|
||||||
missing=self.missing,
|
missing=self.missing,
|
||||||
feature_types=self.feature_types
|
feature_types=self.feature_types,
|
||||||
)
|
)
|
||||||
predts = await predict(
|
predts = await predict(
|
||||||
self.client,
|
self.client,
|
||||||
@ -1611,7 +1592,7 @@ class DaskScikitLearnBase(XGBModel):
|
|||||||
|
|
||||||
def predict(
|
def predict(
|
||||||
self,
|
self,
|
||||||
X: _DaskCollection,
|
X: _DataT,
|
||||||
output_margin: bool = False,
|
output_margin: bool = False,
|
||||||
ntree_limit: Optional[int] = None,
|
ntree_limit: Optional[int] = None,
|
||||||
validate_features: bool = True,
|
validate_features: bool = True,
|
||||||
@ -1632,12 +1613,15 @@ class DaskScikitLearnBase(XGBModel):
|
|||||||
|
|
||||||
async def _apply_async(
|
async def _apply_async(
|
||||||
self,
|
self,
|
||||||
X: _DaskCollection,
|
X: _DataT,
|
||||||
iteration_range: Optional[Tuple[int, int]] = None,
|
iteration_range: Optional[Tuple[int, int]] = None,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
iteration_range = self._get_iteration_range(iteration_range)
|
iteration_range = self._get_iteration_range(iteration_range)
|
||||||
test_dmatrix = await DaskDMatrix(
|
test_dmatrix = await DaskDMatrix(
|
||||||
self.client, data=X, missing=self.missing, feature_types=self.feature_types,
|
self.client,
|
||||||
|
data=X,
|
||||||
|
missing=self.missing,
|
||||||
|
feature_types=self.feature_types,
|
||||||
)
|
)
|
||||||
predts = await predict(
|
predts = await predict(
|
||||||
self.client,
|
self.client,
|
||||||
@ -1650,7 +1634,7 @@ class DaskScikitLearnBase(XGBModel):
|
|||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
self,
|
self,
|
||||||
X: _DaskCollection,
|
X: _DataT,
|
||||||
ntree_limit: Optional[int] = None,
|
ntree_limit: Optional[int] = None,
|
||||||
iteration_range: Optional[Tuple[int, int]] = None,
|
iteration_range: Optional[Tuple[int, int]] = None,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
@ -1685,8 +1669,8 @@ class DaskScikitLearnBase(XGBModel):
|
|||||||
|
|
||||||
@client.setter
|
@client.setter
|
||||||
def client(self, clt: "distributed.Client") -> None:
|
def client(self, clt: "distributed.Client") -> None:
|
||||||
# calling `worker_client' doesn't return the correct `asynchronous` attribute, so
|
# calling `worker_client' doesn't return the correct `asynchronous` attribute,
|
||||||
# we have to pass it ourselves.
|
# so we have to pass it ourselves.
|
||||||
self._asynchronous = clt.asynchronous if clt is not None else False
|
self._asynchronous = clt.asynchronous if clt is not None else False
|
||||||
self._client = clt
|
self._client = clt
|
||||||
|
|
||||||
@ -1720,9 +1704,10 @@ class DaskScikitLearnBase(XGBModel):
|
|||||||
)
|
)
|
||||||
class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
|
class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
|
||||||
"""dummy doc string to workaround pylint, replaced by the decorator."""
|
"""dummy doc string to workaround pylint, replaced by the decorator."""
|
||||||
|
|
||||||
async def _fit_async(
|
async def _fit_async(
|
||||||
self,
|
self,
|
||||||
X: _DaskCollection,
|
X: _DataT,
|
||||||
y: _DaskCollection,
|
y: _DaskCollection,
|
||||||
sample_weight: Optional[_DaskCollection],
|
sample_weight: Optional[_DaskCollection],
|
||||||
base_margin: Optional[_DaskCollection],
|
base_margin: Optional[_DaskCollection],
|
||||||
@ -1789,7 +1774,7 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
|
|||||||
@_deprecate_positional_args
|
@_deprecate_positional_args
|
||||||
def fit(
|
def fit(
|
||||||
self,
|
self,
|
||||||
X: _DaskCollection,
|
X: _DataT,
|
||||||
y: _DaskCollection,
|
y: _DaskCollection,
|
||||||
*,
|
*,
|
||||||
sample_weight: Optional[_DaskCollection] = None,
|
sample_weight: Optional[_DaskCollection] = None,
|
||||||
@ -1817,7 +1802,7 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
|||||||
# pylint: disable=missing-class-docstring
|
# pylint: disable=missing-class-docstring
|
||||||
async def _fit_async(
|
async def _fit_async(
|
||||||
self,
|
self,
|
||||||
X: _DaskCollection,
|
X: _DataT,
|
||||||
y: _DaskCollection,
|
y: _DaskCollection,
|
||||||
sample_weight: Optional[_DaskCollection],
|
sample_weight: Optional[_DaskCollection],
|
||||||
base_margin: Optional[_DaskCollection],
|
base_margin: Optional[_DaskCollection],
|
||||||
@ -1898,7 +1883,7 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
|||||||
# pylint: disable=unused-argument
|
# pylint: disable=unused-argument
|
||||||
def fit(
|
def fit(
|
||||||
self,
|
self,
|
||||||
X: _DaskCollection,
|
X: _DataT,
|
||||||
y: _DaskCollection,
|
y: _DaskCollection,
|
||||||
*,
|
*,
|
||||||
sample_weight: Optional[_DaskCollection] = None,
|
sample_weight: Optional[_DaskCollection] = None,
|
||||||
@ -1919,7 +1904,7 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
|||||||
|
|
||||||
async def _predict_proba_async(
|
async def _predict_proba_async(
|
||||||
self,
|
self,
|
||||||
X: _DaskCollection,
|
X: _DataT,
|
||||||
validate_features: bool,
|
validate_features: bool,
|
||||||
base_margin: Optional[_DaskCollection],
|
base_margin: Optional[_DaskCollection],
|
||||||
iteration_range: Optional[Tuple[int, int]],
|
iteration_range: Optional[Tuple[int, int]],
|
||||||
@ -1965,7 +1950,7 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
|||||||
|
|
||||||
async def _predict_async(
|
async def _predict_async(
|
||||||
self,
|
self,
|
||||||
data: _DaskCollection,
|
data: _DataT,
|
||||||
output_margin: bool,
|
output_margin: bool,
|
||||||
validate_features: bool,
|
validate_features: bool,
|
||||||
base_margin: Optional[_DaskCollection],
|
base_margin: Optional[_DaskCollection],
|
||||||
@ -2014,7 +1999,7 @@ class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn):
|
|||||||
|
|
||||||
async def _fit_async(
|
async def _fit_async(
|
||||||
self,
|
self,
|
||||||
X: _DaskCollection,
|
X: _DataT,
|
||||||
y: _DaskCollection,
|
y: _DaskCollection,
|
||||||
group: Optional[_DaskCollection],
|
group: Optional[_DaskCollection],
|
||||||
qid: Optional[_DaskCollection],
|
qid: Optional[_DaskCollection],
|
||||||
@ -2090,7 +2075,7 @@ class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn):
|
|||||||
@_deprecate_positional_args
|
@_deprecate_positional_args
|
||||||
def fit(
|
def fit(
|
||||||
self,
|
self,
|
||||||
X: _DaskCollection,
|
X: _DataT,
|
||||||
y: _DaskCollection,
|
y: _DaskCollection,
|
||||||
*,
|
*,
|
||||||
group: Optional[_DaskCollection] = None,
|
group: Optional[_DaskCollection] = None,
|
||||||
@ -2113,7 +2098,8 @@ class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn):
|
|||||||
args = {k: v for k, v in locals().items() if k not in ("self", "__class__")}
|
args = {k: v for k, v in locals().items() if k not in ("self", "__class__")}
|
||||||
return self._client_sync(self._fit_async, **args)
|
return self._client_sync(self._fit_async, **args)
|
||||||
|
|
||||||
# FIXME(trivialfis): arguments differ due to additional parameters like group and qid.
|
# FIXME(trivialfis): arguments differ due to additional parameters like group and
|
||||||
|
# qid.
|
||||||
fit.__doc__ = XGBRanker.fit.__doc__
|
fit.__doc__ = XGBRanker.fit.__doc__
|
||||||
|
|
||||||
|
|
||||||
@ -2159,7 +2145,7 @@ class DaskXGBRFRegressor(DaskXGBRegressor):
|
|||||||
# pylint: disable=unused-argument
|
# pylint: disable=unused-argument
|
||||||
def fit(
|
def fit(
|
||||||
self,
|
self,
|
||||||
X: _DaskCollection,
|
X: _DataT,
|
||||||
y: _DaskCollection,
|
y: _DaskCollection,
|
||||||
*,
|
*,
|
||||||
sample_weight: Optional[_DaskCollection] = None,
|
sample_weight: Optional[_DaskCollection] = None,
|
||||||
@ -2223,7 +2209,7 @@ class DaskXGBRFClassifier(DaskXGBClassifier):
|
|||||||
# pylint: disable=unused-argument
|
# pylint: disable=unused-argument
|
||||||
def fit(
|
def fit(
|
||||||
self,
|
self,
|
||||||
X: _DaskCollection,
|
X: _DataT,
|
||||||
y: _DaskCollection,
|
y: _DaskCollection,
|
||||||
*,
|
*,
|
||||||
sample_weight: Optional[_DaskCollection] = None,
|
sample_weight: Optional[_DaskCollection] = None,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user