[dask] Add a 1 line sample to infer output shape. (#6645)
* [dask] Use a 1 line sample to infer output shape. This is for inferring shape with direct prediction (without DaskDMatrix). There are a few things that requires known output shape before carrying out actual prediction, including dask meta data, output dataframe columns. * Infer output shape based on local prediction. * Remove set param in predict function as it's not thread safe nor necessary as we now let dask to decide the parallelism. * Simplify prediction on `DaskDMatrix`.
This commit is contained in:
parent
c3c8e66fc9
commit
d8ec7aad5a
@ -108,8 +108,9 @@ computation a bit faster when meta information like ``base_margin`` is not neede
|
|||||||
prediction = xgb.dask.inplace_predict(client, output, X)
|
prediction = xgb.dask.inplace_predict(client, output, X)
|
||||||
|
|
||||||
Here ``prediction`` is a dask ``Array`` object containing predictions from model if input
|
Here ``prediction`` is a dask ``Array`` object containing predictions from model if input
|
||||||
is a ``DaskDMatrix`` or ``da.Array``. For ``dd.DataFrame``, the return value is a
|
is a ``DaskDMatrix`` or ``da.Array``. When putting dask collection directly into the
|
||||||
``dd.Series``.
|
``predict`` function or using ``inplace_predict``, the output type depends on input data.
|
||||||
|
See next section for details.
|
||||||
|
|
||||||
Alternatively, XGBoost also implements the Scikit-Learn interface with ``DaskXGBClassifier``
|
Alternatively, XGBoost also implements the Scikit-Learn interface with ``DaskXGBClassifier``
|
||||||
and ``DaskXGBRegressor``. See ``xgboost/demo/dask`` for more examples.
|
and ``DaskXGBRegressor``. See ``xgboost/demo/dask`` for more examples.
|
||||||
@ -143,9 +144,23 @@ Also for inplace prediction:
|
|||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
booster.set_param({'predictor': 'gpu_predictor'})
|
booster.set_param({'predictor': 'gpu_predictor'})
|
||||||
# where X is a dask DataFrame or dask Array.
|
# where X is a dask DataFrame or dask Array containing cupy or cuDF backed data.
|
||||||
prediction = xgb.dask.inplace_predict(client, booster, X)
|
prediction = xgb.dask.inplace_predict(client, booster, X)
|
||||||
|
|
||||||
|
When input is ``da.Array`` object, output is always ``da.Array``. However, if the input
|
||||||
|
type is ``dd.DataFrame``, output can be ``dd.Series``, ``dd.DataFrame`` or ``da.Array``,
|
||||||
|
depending on output shape. For example, when shap based prediction is used, the return
|
||||||
|
value can have 3 or 4 dimensions , in such cases an ``Array`` is always returned.
|
||||||
|
|
||||||
|
The performance of running prediction, either using ``predict`` or ``inplace_predict``, is
|
||||||
|
sensitive to number of blocks. Internally, it's implemented using ``da.map_blocks`` or
|
||||||
|
``dd.map_partitions``. When number of partitions is large and each of them have only
|
||||||
|
small amount of data, the overhead of calling predict becomes visible. On the other hand,
|
||||||
|
if not using GPU, the number of threads used for prediction on each block matters. Right
|
||||||
|
now, xgboost uses single thread for each partition. If the number of blocks on each
|
||||||
|
workers is smaller than number of cores, then the CPU workers might not be fully utilized.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
***************************
|
***************************
|
||||||
Working with other clusters
|
Working with other clusters
|
||||||
|
|||||||
@ -112,14 +112,15 @@ def _start_tracker(n_workers: int) -> Dict[str, Any]:
|
|||||||
|
|
||||||
def _assert_dask_support() -> None:
|
def _assert_dask_support() -> None:
|
||||||
try:
|
try:
|
||||||
import dask # pylint: disable=W0621,W0611
|
import dask # pylint: disable=W0621,W0611
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
'Dask needs to be installed in order to use this module') from e
|
"Dask needs to be installed in order to use this module"
|
||||||
|
) from e
|
||||||
|
|
||||||
if platform.system() == 'Windows':
|
if platform.system() == "Windows":
|
||||||
msg = 'Windows is not officially supported for dask/xgboost,'
|
msg = "Windows is not officially supported for dask/xgboost,"
|
||||||
msg += ' contribution are welcomed.'
|
msg += " contribution are welcomed."
|
||||||
LOGGER.warning(msg)
|
LOGGER.warning(msg)
|
||||||
|
|
||||||
|
|
||||||
@ -252,6 +253,7 @@ class DaskDMatrix:
|
|||||||
if not isinstance(label, (dd.DataFrame, da.Array, dd.Series, type(None))):
|
if not isinstance(label, (dd.DataFrame, da.Array, dd.Series, type(None))):
|
||||||
raise TypeError(_expect((dd.DataFrame, da.Array, dd.Series), type(label)))
|
raise TypeError(_expect((dd.DataFrame, da.Array, dd.Series), type(label)))
|
||||||
|
|
||||||
|
self._n_cols = data.shape[1]
|
||||||
self.worker_map: Dict[str, "distributed.Future"] = defaultdict(list)
|
self.worker_map: Dict[str, "distributed.Future"] = defaultdict(list)
|
||||||
self.is_quantile: bool = False
|
self.is_quantile: bool = False
|
||||||
|
|
||||||
@ -403,6 +405,9 @@ class DaskDMatrix:
|
|||||||
'parts': self.worker_map.get(worker_addr, None),
|
'parts': self.worker_map.get(worker_addr, None),
|
||||||
'is_quantile': self.is_quantile}
|
'is_quantile': self.is_quantile}
|
||||||
|
|
||||||
|
def num_col(self) -> int:
|
||||||
|
return self._n_cols
|
||||||
|
|
||||||
|
|
||||||
_DataParts = List[Tuple[Any, Optional[Any], Optional[Any], Optional[Any], Optional[Any],
|
_DataParts = List[Tuple[Any, Optional[Any], Optional[Any], Optional[Any], Optional[Any],
|
||||||
Optional[Any], Optional[Any]]]
|
Optional[Any], Optional[Any]]]
|
||||||
@ -930,27 +935,90 @@ def train(
|
|||||||
callbacks=callbacks)
|
callbacks=callbacks)
|
||||||
|
|
||||||
|
|
||||||
|
def _can_output_df(data: _DaskCollection, output_shape: Tuple) -> bool:
|
||||||
|
return isinstance(data, dd.DataFrame) and len(output_shape) <= 2
|
||||||
|
|
||||||
|
|
||||||
async def _direct_predict_impl(
|
async def _direct_predict_impl(
|
||||||
client: "distributed.Client",
|
client: "distributed.Client",
|
||||||
|
mapped_predict: Callable,
|
||||||
|
booster: Booster,
|
||||||
data: _DaskCollection,
|
data: _DaskCollection,
|
||||||
predict_fn: Callable
|
base_margin: Optional[_DaskCollection],
|
||||||
|
output_shape: Tuple[int, ...],
|
||||||
|
meta: Dict[int, str],
|
||||||
) -> _DaskCollection:
|
) -> _DaskCollection:
|
||||||
if isinstance(data, da.Array):
|
columns = list(meta.keys())
|
||||||
predictions = await client.submit(
|
booster_f = await client.scatter(data=booster, broadcast=True)
|
||||||
da.map_blocks,
|
if _can_output_df(data, output_shape):
|
||||||
predict_fn, data, False, drop_axis=1,
|
if base_margin is not None and isinstance(base_margin, da.Array):
|
||||||
dtype=numpy.float32
|
base_margin_df: Optional[dd.DataFrame] = base_margin.to_dask_dataframe()
|
||||||
).result()
|
else:
|
||||||
return predictions
|
base_margin_df = base_margin
|
||||||
if isinstance(data, dd.DataFrame):
|
predictions = dd.map_partitions(
|
||||||
predictions = await client.submit(
|
mapped_predict,
|
||||||
dd.map_partitions,
|
booster_f,
|
||||||
predict_fn, data, True,
|
data,
|
||||||
meta=dd.utils.make_meta({'prediction': 'f4'})
|
True,
|
||||||
).result()
|
columns,
|
||||||
return predictions.iloc[:, 0]
|
base_margin_df,
|
||||||
raise TypeError('data of type: ' + str(type(data)) +
|
meta=dd.utils.make_meta(meta),
|
||||||
' is not supported by direct prediction')
|
)
|
||||||
|
# classification can return a dataframe, drop 1 dim when it's reg/binary
|
||||||
|
if len(output_shape) == 1:
|
||||||
|
predictions = predictions.iloc[:, 0]
|
||||||
|
else:
|
||||||
|
if base_margin is not None and isinstance(
|
||||||
|
base_margin, (dd.Series, dd.DataFrame)
|
||||||
|
):
|
||||||
|
base_margin_array: Optional[da.Array] = base_margin.to_dask_array()
|
||||||
|
else:
|
||||||
|
base_margin_array = base_margin
|
||||||
|
# Input data is 2-dim array, output can be 1(reg, binary)/2(multi-class,
|
||||||
|
# contrib)/3(contrib)/4(interaction) dims.
|
||||||
|
if len(output_shape) == 1:
|
||||||
|
drop_axis: Union[int, List[int]] = [1] # drop from 2 to 1 dim.
|
||||||
|
new_axis: Union[int, List[int]] = []
|
||||||
|
else:
|
||||||
|
drop_axis = []
|
||||||
|
new_axis = [i + 2 for i in range(len(output_shape) - 2)]
|
||||||
|
predictions = da.map_blocks(
|
||||||
|
mapped_predict,
|
||||||
|
booster_f,
|
||||||
|
data,
|
||||||
|
False,
|
||||||
|
columns,
|
||||||
|
base_margin_array,
|
||||||
|
drop_axis=drop_axis,
|
||||||
|
new_axis=new_axis,
|
||||||
|
dtype=numpy.float32,
|
||||||
|
)
|
||||||
|
return predictions
|
||||||
|
|
||||||
|
|
||||||
|
def _infer_predict_output(
|
||||||
|
booster: Booster, data: _DaskCollection, inplace: bool, **kwargs: Any
|
||||||
|
) -> Tuple[Tuple[int, ...], Dict[int, str]]:
|
||||||
|
"""Create a dummy test sample to infer output shape for prediction."""
|
||||||
|
if isinstance(data, DaskDMatrix):
|
||||||
|
features = data.num_col()
|
||||||
|
else:
|
||||||
|
features = data.shape[1]
|
||||||
|
rng = numpy.random.RandomState(1994)
|
||||||
|
test_sample = rng.randn(1, features)
|
||||||
|
if inplace:
|
||||||
|
# clear the state to avoid gpu_id, gpu_predictor
|
||||||
|
booster = Booster(model_file=booster.save_raw())
|
||||||
|
test_predt = booster.inplace_predict(test_sample, **kwargs)
|
||||||
|
else:
|
||||||
|
m = DMatrix(test_sample)
|
||||||
|
test_predt = booster.predict(m, **kwargs)
|
||||||
|
n_columns = test_predt.shape[1] if len(test_predt.shape) > 1 else 1
|
||||||
|
meta: Dict[int, str] = {}
|
||||||
|
if _can_output_df(data, test_predt.shape):
|
||||||
|
for i in range(n_columns):
|
||||||
|
meta[i] = "f4"
|
||||||
|
return test_predt.shape, meta
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=too-many-statements
|
# pylint: disable=too-many-statements
|
||||||
@ -968,19 +1036,19 @@ async def _predict_async(
|
|||||||
validate_features: bool,
|
validate_features: bool,
|
||||||
) -> _DaskCollection:
|
) -> _DaskCollection:
|
||||||
if isinstance(model, Booster):
|
if isinstance(model, Booster):
|
||||||
booster = model
|
_booster = model
|
||||||
elif isinstance(model, dict):
|
elif isinstance(model, dict):
|
||||||
booster = model["booster"]
|
_booster = model["booster"]
|
||||||
else:
|
else:
|
||||||
raise TypeError(_expect([Booster, dict], type(model)))
|
raise TypeError(_expect([Booster, dict], type(model)))
|
||||||
if not isinstance(data, (DaskDMatrix, da.Array, dd.DataFrame)):
|
if not isinstance(data, (DaskDMatrix, da.Array, dd.DataFrame)):
|
||||||
raise TypeError(_expect([DaskDMatrix, da.Array, dd.DataFrame], type(data)))
|
raise TypeError(_expect([DaskDMatrix, da.Array, dd.DataFrame], type(data)))
|
||||||
|
|
||||||
def mapped_predict(partition: Any, is_df: bool) -> Any:
|
def mapped_predict(
|
||||||
worker = distributed.get_worker()
|
booster: Booster, partition: Any, is_df: bool, columns: List[int], _: Any
|
||||||
|
) -> Any:
|
||||||
with config.config_context(**global_config):
|
with config.config_context(**global_config):
|
||||||
booster.set_param({"nthread": worker.nthreads})
|
m = DMatrix(data=partition, missing=missing)
|
||||||
m = DMatrix(data=partition, missing=missing, nthread=worker.nthreads)
|
|
||||||
predt = booster.predict(
|
predt = booster.predict(
|
||||||
data=m,
|
data=m,
|
||||||
output_margin=output_margin,
|
output_margin=output_margin,
|
||||||
@ -990,167 +1058,115 @@ async def _predict_async(
|
|||||||
pred_interactions=pred_interactions,
|
pred_interactions=pred_interactions,
|
||||||
validate_features=validate_features,
|
validate_features=validate_features,
|
||||||
)
|
)
|
||||||
if is_df:
|
if is_df and len(predt.shape) <= 2:
|
||||||
if lazy_isinstance(partition, "cudf", "core.dataframe.DataFrame"):
|
if lazy_isinstance(partition, "cudf", "core.dataframe.DataFrame"):
|
||||||
import cudf
|
import cudf
|
||||||
predt = cudf.DataFrame(predt, columns=["prediction"])
|
|
||||||
|
predt = cudf.DataFrame(predt, columns=columns)
|
||||||
else:
|
else:
|
||||||
predt = DataFrame(predt, columns=["prediction"])
|
predt = DataFrame(predt, columns=columns)
|
||||||
return predt
|
return predt
|
||||||
|
|
||||||
# Predict on dask collection directly.
|
# Predict on dask collection directly.
|
||||||
if isinstance(data, (da.Array, dd.DataFrame)):
|
if isinstance(data, (da.Array, dd.DataFrame)):
|
||||||
return await _direct_predict_impl(client, data, mapped_predict)
|
_output_shape, meta = _infer_predict_output(
|
||||||
|
_booster,
|
||||||
|
data,
|
||||||
|
inplace=False,
|
||||||
|
output_margin=output_margin,
|
||||||
|
pred_leaf=pred_leaf,
|
||||||
|
pred_contribs=pred_contribs,
|
||||||
|
approx_contribs=approx_contribs,
|
||||||
|
pred_interactions=pred_interactions,
|
||||||
|
validate_features=False,
|
||||||
|
)
|
||||||
|
return await _direct_predict_impl(
|
||||||
|
client, mapped_predict, _booster, data, None, _output_shape, meta
|
||||||
|
)
|
||||||
|
output_shape, _ = _infer_predict_output(
|
||||||
|
booster=_booster,
|
||||||
|
data=data,
|
||||||
|
inplace=False,
|
||||||
|
output_margin=output_margin,
|
||||||
|
pred_leaf=pred_leaf,
|
||||||
|
pred_contribs=pred_contribs,
|
||||||
|
approx_contribs=approx_contribs,
|
||||||
|
pred_interactions=pred_interactions,
|
||||||
|
validate_features=False,
|
||||||
|
)
|
||||||
# Prediction on dask DMatrix.
|
# Prediction on dask DMatrix.
|
||||||
worker_map = data.worker_map
|
|
||||||
partition_order = data.partition_order
|
partition_order = data.partition_order
|
||||||
feature_names = data.feature_names
|
feature_names = data.feature_names
|
||||||
feature_types = data.feature_types
|
feature_types = data.feature_types
|
||||||
missing = data.missing
|
missing = data.missing
|
||||||
meta_names = data.meta_names
|
meta_names = data.meta_names
|
||||||
|
|
||||||
def dispatched_predict(
|
def dispatched_predict(booster: Booster, part: Any) -> numpy.ndarray:
|
||||||
worker_id: int, list_of_orders: List[int], list_of_parts: _DataParts
|
data = part[0]
|
||||||
) -> List[Tuple[List[Union["dask.delayed.Delayed", int]], int]]:
|
assert isinstance(part, tuple), type(part)
|
||||||
"""Perform prediction on each worker."""
|
base_margin = None
|
||||||
LOGGER.debug("Predicting on %d", worker_id)
|
for i, blob in enumerate(part[1:]):
|
||||||
|
if meta_names[i] == "base_margin":
|
||||||
|
base_margin = blob
|
||||||
|
worker = distributed.get_worker()
|
||||||
with config.config_context(**global_config):
|
with config.config_context(**global_config):
|
||||||
worker = distributed.get_worker()
|
m = DMatrix(
|
||||||
list_of_parts = _get_worker_parts_ordered(meta_names, list_of_parts)
|
data,
|
||||||
predictions = []
|
nthread=worker.nthreads,
|
||||||
|
missing=missing,
|
||||||
booster.set_param({"nthread": worker.nthreads})
|
base_margin=base_margin,
|
||||||
for i, parts in enumerate(list_of_parts):
|
feature_names=feature_names,
|
||||||
(data, _, _, base_margin, _, _, _) = parts
|
feature_types=feature_types,
|
||||||
order = list_of_orders[i]
|
|
||||||
local_part = DMatrix(
|
|
||||||
data,
|
|
||||||
base_margin=base_margin,
|
|
||||||
feature_names=feature_names,
|
|
||||||
feature_types=feature_types,
|
|
||||||
missing=missing,
|
|
||||||
nthread=worker.nthreads,
|
|
||||||
)
|
|
||||||
predt = booster.predict(
|
|
||||||
data=local_part,
|
|
||||||
output_margin=output_margin,
|
|
||||||
pred_leaf=pred_leaf,
|
|
||||||
pred_contribs=pred_contribs,
|
|
||||||
approx_contribs=approx_contribs,
|
|
||||||
pred_interactions=pred_interactions,
|
|
||||||
validate_features=validate_features,
|
|
||||||
)
|
|
||||||
if pred_contribs and predt.size != local_part.num_row():
|
|
||||||
assert len(predt.shape) in (2, 3)
|
|
||||||
if len(predt.shape) == 2:
|
|
||||||
groups = 1
|
|
||||||
columns = predt.shape[1]
|
|
||||||
else:
|
|
||||||
groups = predt.shape[1]
|
|
||||||
columns = predt.shape[2]
|
|
||||||
# pylint: disable=no-member
|
|
||||||
ret = (
|
|
||||||
[dask.delayed(predt), groups, columns],
|
|
||||||
order,
|
|
||||||
)
|
|
||||||
elif pred_interactions and predt.size != local_part.num_row():
|
|
||||||
assert len(predt.shape) in (3, 4)
|
|
||||||
if len(predt.shape) == 3:
|
|
||||||
groups = 1
|
|
||||||
columns = predt.shape[1]
|
|
||||||
else:
|
|
||||||
groups = predt.shape[1]
|
|
||||||
columns = predt.shape[2]
|
|
||||||
# pylint: disable=no-member
|
|
||||||
ret = (
|
|
||||||
[dask.delayed(predt), groups, columns],
|
|
||||||
order,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
assert len(predt.shape) == 1 or len(predt.shape) == 2
|
|
||||||
columns = 1 if len(predt.shape) == 1 else predt.shape[1]
|
|
||||||
# pylint: disable=no-member
|
|
||||||
ret = (
|
|
||||||
[dask.delayed(predt), columns],
|
|
||||||
order,
|
|
||||||
)
|
|
||||||
predictions.append(ret)
|
|
||||||
|
|
||||||
return predictions
|
|
||||||
|
|
||||||
def dispatched_get_shape(
|
|
||||||
worker_id: int, list_of_orders: List[int], list_of_parts: _DataParts
|
|
||||||
) -> List[Tuple[int, int]]:
|
|
||||||
"""Get shape of data in each worker."""
|
|
||||||
LOGGER.debug("Get shape on %d", worker_id)
|
|
||||||
list_of_parts = _get_worker_parts_ordered(meta_names, list_of_parts)
|
|
||||||
shapes = []
|
|
||||||
for i, parts in enumerate(list_of_parts):
|
|
||||||
(data, _, _, _, _, _, _) = parts
|
|
||||||
shapes.append((data.shape, list_of_orders[i]))
|
|
||||||
return shapes
|
|
||||||
|
|
||||||
async def map_function(
|
|
||||||
func: Callable[[int, List[int], _DataParts], Any]
|
|
||||||
) -> List[Any]:
|
|
||||||
"""Run function for each part of the data."""
|
|
||||||
futures = []
|
|
||||||
workers_address = list(worker_map.keys())
|
|
||||||
for wid, worker_addr in enumerate(workers_address):
|
|
||||||
worker_addr = workers_address[wid]
|
|
||||||
list_of_parts = worker_map[worker_addr]
|
|
||||||
list_of_orders = [partition_order[part.key] for part in list_of_parts]
|
|
||||||
|
|
||||||
f = client.submit(
|
|
||||||
func,
|
|
||||||
worker_id=wid,
|
|
||||||
list_of_orders=list_of_orders,
|
|
||||||
list_of_parts=list_of_parts,
|
|
||||||
pure=True,
|
|
||||||
workers=[worker_addr],
|
|
||||||
)
|
)
|
||||||
assert isinstance(f, distributed.client.Future)
|
predt = booster.predict(
|
||||||
futures.append(f)
|
m,
|
||||||
# Get delayed objects
|
output_margin=output_margin,
|
||||||
results = await client.gather(futures)
|
pred_leaf=pred_leaf,
|
||||||
# flatten into 1 dim list
|
pred_contribs=pred_contribs,
|
||||||
results = [t for list_per_worker in results for t in list_per_worker]
|
approx_contribs=approx_contribs,
|
||||||
# sort by order, l[0] is the delayed object, l[1] is its order
|
pred_interactions=pred_interactions,
|
||||||
results = sorted(results, key=lambda l: l[1])
|
validate_features=validate_features,
|
||||||
results = [predt for predt, order in results] # remove order
|
)
|
||||||
return results
|
return predt
|
||||||
|
|
||||||
results = await map_function(dispatched_predict)
|
all_parts = []
|
||||||
shapes = await map_function(dispatched_get_shape)
|
all_orders = []
|
||||||
|
all_shapes = []
|
||||||
|
workers_address = list(data.worker_map.keys())
|
||||||
|
for worker_addr in workers_address:
|
||||||
|
list_of_parts = data.worker_map[worker_addr]
|
||||||
|
all_parts.extend(list_of_parts)
|
||||||
|
all_orders.extend([partition_order[part.key] for part in list_of_parts])
|
||||||
|
for part in all_parts:
|
||||||
|
s = client.submit(lambda part: part[0].shape[0], part)
|
||||||
|
all_shapes.append(s)
|
||||||
|
all_shapes = await client.gather(all_shapes)
|
||||||
|
|
||||||
|
parts_with_order = list(zip(all_parts, all_shapes, all_orders))
|
||||||
|
parts_with_order = sorted(parts_with_order, key=lambda p: p[2])
|
||||||
|
all_parts = [part for part, shape, order in parts_with_order]
|
||||||
|
all_shapes = [shape for part, shape, order in parts_with_order]
|
||||||
|
|
||||||
|
futures = []
|
||||||
|
booster_f = await client.scatter(data=_booster, broadcast=True)
|
||||||
|
for part in all_parts:
|
||||||
|
f = client.submit(dispatched_predict, booster_f, part)
|
||||||
|
futures.append(f)
|
||||||
|
|
||||||
# Constructing a dask array from list of numpy arrays
|
# Constructing a dask array from list of numpy arrays
|
||||||
# See https://docs.dask.org/en/latest/array-creation.html
|
# See https://docs.dask.org/en/latest/array-creation.html
|
||||||
arrays = []
|
arrays = []
|
||||||
for i, shape in enumerate(shapes):
|
for i, rows in enumerate(all_shapes):
|
||||||
if pred_contribs:
|
|
||||||
out_shape = (
|
|
||||||
(shape[0], results[i][2])
|
|
||||||
if results[i][1] == 1
|
|
||||||
else (shape[0], results[i][1], results[i][2])
|
|
||||||
)
|
|
||||||
elif pred_interactions:
|
|
||||||
out_shape = (
|
|
||||||
(shape[0], results[i][2], results[i][2])
|
|
||||||
if results[i][1] == 1
|
|
||||||
else (shape[0], results[i][1], results[i][2])
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
out_shape = (shape[0],) if results[i][1] == 1 else (shape[0], results[i][1])
|
|
||||||
arrays.append(
|
arrays.append(
|
||||||
da.from_delayed(results[i][0], shape=out_shape, dtype=numpy.float32)
|
da.from_delayed(
|
||||||
|
futures[i], shape=(rows,) + output_shape[1:], dtype=numpy.float32
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
predictions = await da.concatenate(arrays, axis=0)
|
predictions = await da.concatenate(arrays, axis=0)
|
||||||
return predictions
|
return predictions
|
||||||
|
|
||||||
|
|
||||||
def predict(
|
def predict( # pylint: disable=unused-argument
|
||||||
client: "distributed.Client",
|
client: "distributed.Client",
|
||||||
model: Union[TrainReturnT, Booster],
|
model: Union[TrainReturnT, Booster],
|
||||||
data: Union[DaskDMatrix, _DaskCollection],
|
data: Union[DaskDMatrix, _DaskCollection],
|
||||||
@ -1190,22 +1206,15 @@ def predict(
|
|||||||
-------
|
-------
|
||||||
prediction: dask.array.Array/dask.dataframe.Series
|
prediction: dask.array.Array/dask.dataframe.Series
|
||||||
When input data is ``dask.array.Array`` or ``DaskDMatrix``, the return value is an
|
When input data is ``dask.array.Array`` or ``DaskDMatrix``, the return value is an
|
||||||
array, when input data is ``dask.dataframe.DataFrame``, return value is
|
array, when input data is ``dask.dataframe.DataFrame``, return value can be
|
||||||
``dask.dataframe.Series``
|
``dask.dataframe.Series``, ``dask.dataframe.DataFrame`` or ``dask.array.Array``,
|
||||||
|
depending on the output shape.
|
||||||
|
|
||||||
'''
|
'''
|
||||||
_assert_dask_support()
|
_assert_dask_support()
|
||||||
client = _xgb_get_client(client)
|
client = _xgb_get_client(client)
|
||||||
global_config = config.get_config()
|
|
||||||
return client.sync(
|
return client.sync(
|
||||||
_predict_async, client, global_config, model, data,
|
_predict_async, global_config=config.get_config(), **locals()
|
||||||
output_margin=output_margin,
|
|
||||||
missing=missing,
|
|
||||||
pred_leaf=pred_leaf,
|
|
||||||
pred_contribs=pred_contribs,
|
|
||||||
approx_contribs=approx_contribs,
|
|
||||||
pred_interactions=pred_interactions,
|
|
||||||
validate_features=validate_features
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -1228,30 +1237,38 @@ async def _inplace_predict_async(
|
|||||||
if not isinstance(data, (da.Array, dd.DataFrame)):
|
if not isinstance(data, (da.Array, dd.DataFrame)):
|
||||||
raise TypeError(_expect([da.Array, dd.DataFrame], type(data)))
|
raise TypeError(_expect([da.Array, dd.DataFrame], type(data)))
|
||||||
|
|
||||||
def mapped_predict(data: Any, is_df: bool) -> Any:
|
def mapped_predict(
|
||||||
worker = distributed.get_worker()
|
booster: Booster, data: Any, is_df: bool, columns: List[int], _: Any
|
||||||
config.set_config(**global_config)
|
) -> Any:
|
||||||
booster.set_param({'nthread': worker.nthreads})
|
with config.config_context(**global_config):
|
||||||
prediction = booster.inplace_predict(
|
prediction = booster.inplace_predict(
|
||||||
data,
|
data,
|
||||||
iteration_range=iteration_range,
|
iteration_range=iteration_range,
|
||||||
predict_type=predict_type,
|
predict_type=predict_type,
|
||||||
missing=missing)
|
missing=missing
|
||||||
if is_df:
|
)
|
||||||
|
if is_df and len(prediction.shape) <= 2:
|
||||||
if lazy_isinstance(data, 'cudf.core.dataframe', 'DataFrame'):
|
if lazy_isinstance(data, 'cudf.core.dataframe', 'DataFrame'):
|
||||||
import cudf
|
import cudf
|
||||||
prediction = cudf.DataFrame({'prediction': prediction},
|
prediction = cudf.DataFrame(
|
||||||
dtype=numpy.float32)
|
prediction, columns=columns, dtype=numpy.float32
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# If it's from pandas, the partition is a numpy array
|
# If it's from pandas, the partition is a numpy array
|
||||||
prediction = DataFrame(prediction, columns=['prediction'],
|
prediction = DataFrame(
|
||||||
dtype=numpy.float32)
|
prediction, columns=columns, dtype=numpy.float32
|
||||||
|
)
|
||||||
return prediction
|
return prediction
|
||||||
|
|
||||||
return await _direct_predict_impl(client, data, mapped_predict)
|
shape, meta = _infer_predict_output(
|
||||||
|
booster, data, True, predict_type=predict_type, iteration_range=iteration_range
|
||||||
|
)
|
||||||
|
return await _direct_predict_impl(
|
||||||
|
client, mapped_predict, booster, data, None, shape, meta
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def inplace_predict(
|
def inplace_predict( # pylint: disable=unused-argument
|
||||||
client: "distributed.Client",
|
client: "distributed.Client",
|
||||||
model: Union[TrainReturnT, Booster],
|
model: Union[TrainReturnT, Booster],
|
||||||
data: _DaskCollection,
|
data: _DaskCollection,
|
||||||
@ -1281,16 +1298,17 @@ def inplace_predict(
|
|||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
prediction
|
prediction :
|
||||||
|
When input data is ``dask.array.Array`` or ``DaskDMatrix``, the return value is an
|
||||||
|
array, when input data is ``dask.dataframe.DataFrame``, return value can be
|
||||||
|
``dask.dataframe.Series``, ``dask.dataframe.DataFrame`` or ``dask.array.Array``,
|
||||||
|
depending on the output shape.
|
||||||
'''
|
'''
|
||||||
_assert_dask_support()
|
_assert_dask_support()
|
||||||
client = _xgb_get_client(client)
|
client = _xgb_get_client(client)
|
||||||
global_config = config.get_config()
|
return client.sync(
|
||||||
return client.sync(_inplace_predict_async, client, global_config, model=model,
|
_inplace_predict_async, global_config=config.get_config(), **locals()
|
||||||
data=data,
|
)
|
||||||
iteration_range=iteration_range,
|
|
||||||
predict_type=predict_type,
|
|
||||||
missing=missing)
|
|
||||||
|
|
||||||
|
|
||||||
async def _async_wrap_evaluation_matrices(
|
async def _async_wrap_evaluation_matrices(
|
||||||
|
|||||||
@ -24,7 +24,6 @@ if sys.platform.startswith("win"):
|
|||||||
if tm.no_dask()['condition']:
|
if tm.no_dask()['condition']:
|
||||||
pytest.skip(msg=tm.no_dask()['reason'], allow_module_level=True)
|
pytest.skip(msg=tm.no_dask()['reason'], allow_module_level=True)
|
||||||
|
|
||||||
import distributed
|
|
||||||
from distributed import LocalCluster, Client
|
from distributed import LocalCluster, Client
|
||||||
from distributed.utils_test import client, loop, cluster_fixture
|
from distributed.utils_test import client, loop, cluster_fixture
|
||||||
import dask.dataframe as dd
|
import dask.dataframe as dd
|
||||||
@ -130,24 +129,34 @@ def test_from_dask_array() -> None:
|
|||||||
assert np.all(single_node_predt == from_arr.compute())
|
assert np.all(single_node_predt == from_arr.compute())
|
||||||
|
|
||||||
|
|
||||||
def test_dask_predict_shape_infer() -> None:
|
def test_dask_predict_shape_infer(client: "Client") -> None:
|
||||||
with LocalCluster(n_workers=kWorkers) as cluster:
|
X, y = make_classification(n_samples=1000, n_informative=5, n_classes=3)
|
||||||
with Client(cluster) as client:
|
X_ = dd.from_array(X, chunksize=100)
|
||||||
X, y = make_classification(n_samples=1000, n_informative=5,
|
y_ = dd.from_array(y, chunksize=100)
|
||||||
n_classes=3)
|
dtrain = xgb.dask.DaskDMatrix(client, data=X_, label=y_)
|
||||||
X_ = dd.from_array(X, chunksize=100)
|
|
||||||
y_ = dd.from_array(y, chunksize=100)
|
|
||||||
dtrain = xgb.dask.DaskDMatrix(client, data=X_, label=y_)
|
|
||||||
|
|
||||||
model = xgb.dask.train(
|
model = xgb.dask.train(
|
||||||
client,
|
client, {"objective": "multi:softprob", "num_class": 3}, dtrain=dtrain
|
||||||
{"objective": "multi:softprob", "num_class": 3},
|
)
|
||||||
dtrain=dtrain
|
|
||||||
)
|
|
||||||
|
|
||||||
preds = xgb.dask.predict(client, model, dtrain)
|
preds = xgb.dask.predict(client, model, dtrain)
|
||||||
assert preds.shape[0] == preds.compute().shape[0]
|
assert preds.shape[0] == preds.compute().shape[0]
|
||||||
assert preds.shape[1] == preds.compute().shape[1]
|
assert preds.shape[1] == preds.compute().shape[1]
|
||||||
|
|
||||||
|
prediction = xgb.dask.predict(client, model, X_, output_margin=True)
|
||||||
|
assert isinstance(prediction, dd.DataFrame)
|
||||||
|
|
||||||
|
prediction = prediction.compute()
|
||||||
|
assert prediction.ndim == 2
|
||||||
|
assert prediction.shape[0] == kRows
|
||||||
|
assert prediction.shape[1] == 3
|
||||||
|
|
||||||
|
prediction = xgb.dask.inplace_predict(client, model, X_, predict_type="margin")
|
||||||
|
assert isinstance(prediction, dd.DataFrame)
|
||||||
|
prediction = prediction.compute()
|
||||||
|
assert prediction.ndim == 2
|
||||||
|
assert prediction.shape[0] == kRows
|
||||||
|
assert prediction.shape[1] == 3
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("tree_method", ["hist", "approx"])
|
@pytest.mark.parametrize("tree_method", ["hist", "approx"])
|
||||||
@ -340,7 +349,7 @@ def test_dask_classifier(model: str, client: "Client") -> None:
|
|||||||
classifier.fit(X_d, y_d)
|
classifier.fit(X_d, y_d)
|
||||||
|
|
||||||
assert classifier.n_classes_ == 10
|
assert classifier.n_classes_ == 10
|
||||||
prediction = classifier.predict(X_d)
|
prediction = classifier.predict(X_d).compute()
|
||||||
|
|
||||||
assert prediction.ndim == 1
|
assert prediction.ndim == 1
|
||||||
assert prediction.shape[0] == kRows
|
assert prediction.shape[0] == kRows
|
||||||
@ -541,6 +550,9 @@ async def run_dask_regressor_asyncio(scheduler_address: str) -> None:
|
|||||||
assert list(history['validation_0'].keys())[0] == 'rmse'
|
assert list(history['validation_0'].keys())[0] == 'rmse'
|
||||||
assert len(history['validation_0']['rmse']) == 2
|
assert len(history['validation_0']['rmse']) == 2
|
||||||
|
|
||||||
|
awaited = await client.compute(prediction)
|
||||||
|
assert awaited.shape[0] == kRows
|
||||||
|
|
||||||
|
|
||||||
async def run_dask_classifier_asyncio(scheduler_address: str) -> None:
|
async def run_dask_classifier_asyncio(scheduler_address: str) -> None:
|
||||||
async with Client(scheduler_address, asynchronous=True) as client:
|
async with Client(scheduler_address, asynchronous=True) as client:
|
||||||
@ -578,7 +590,7 @@ async def run_dask_classifier_asyncio(scheduler_address: str) -> None:
|
|||||||
await classifier.fit(X_d, y_d)
|
await classifier.fit(X_d, y_d)
|
||||||
|
|
||||||
assert classifier.n_classes_ == 10
|
assert classifier.n_classes_ == 10
|
||||||
prediction = await classifier.predict(X_d)
|
prediction = await client.compute(await classifier.predict(X_d))
|
||||||
|
|
||||||
assert prediction.ndim == 1
|
assert prediction.ndim == 1
|
||||||
assert prediction.shape[0] == kRows
|
assert prediction.shape[0] == kRows
|
||||||
@ -1019,6 +1031,17 @@ class TestWithDask:
|
|||||||
run_data_initialization(xgb.dask.DaskDMatrix, xgb.dask.DaskXGBClassifier, X, y)
|
run_data_initialization(xgb.dask.DaskDMatrix, xgb.dask.DaskXGBClassifier, X, y)
|
||||||
|
|
||||||
def run_shap(self, X: Any, y: Any, params: Dict[str, Any], client: "Client") -> None:
|
def run_shap(self, X: Any, y: Any, params: Dict[str, Any], client: "Client") -> None:
|
||||||
|
rows = X.shape[0]
|
||||||
|
cols = X.shape[1]
|
||||||
|
|
||||||
|
def assert_shape(shape):
|
||||||
|
assert shape[0] == rows
|
||||||
|
if "num_class" in params.keys():
|
||||||
|
assert shape[1] == params["num_class"]
|
||||||
|
assert shape[2] == cols + 1
|
||||||
|
else:
|
||||||
|
assert shape[1] == cols + 1
|
||||||
|
|
||||||
X, y = da.from_array(X, chunks=(32, -1)), da.from_array(y, chunks=32)
|
X, y = da.from_array(X, chunks=(32, -1)), da.from_array(y, chunks=32)
|
||||||
Xy = xgb.dask.DaskDMatrix(client, X, y)
|
Xy = xgb.dask.DaskDMatrix(client, X, y)
|
||||||
booster = xgb.dask.train(client, params, Xy, num_boost_round=10)['booster']
|
booster = xgb.dask.train(client, params, Xy, num_boost_round=10)['booster']
|
||||||
@ -1027,15 +1050,17 @@ class TestWithDask:
|
|||||||
|
|
||||||
shap = xgb.dask.predict(client, booster, test_Xy, pred_contribs=True).compute()
|
shap = xgb.dask.predict(client, booster, test_Xy, pred_contribs=True).compute()
|
||||||
margin = xgb.dask.predict(client, booster, test_Xy, output_margin=True).compute()
|
margin = xgb.dask.predict(client, booster, test_Xy, output_margin=True).compute()
|
||||||
|
assert_shape(shap.shape)
|
||||||
assert np.allclose(np.sum(shap, axis=len(shap.shape) - 1), margin, 1e-5, 1e-5)
|
assert np.allclose(np.sum(shap, axis=len(shap.shape) - 1), margin, 1e-5, 1e-5)
|
||||||
|
|
||||||
shap = xgb.dask.predict(client, booster, X, pred_contribs=True).compute()
|
shap = xgb.dask.predict(client, booster, X, pred_contribs=True).compute()
|
||||||
margin = xgb.dask.predict(client, booster, X, output_margin=True).compute()
|
margin = xgb.dask.predict(client, booster, X, output_margin=True).compute()
|
||||||
|
assert_shape(shap.shape)
|
||||||
assert np.allclose(np.sum(shap, axis=len(shap.shape) - 1), margin, 1e-5, 1e-5)
|
assert np.allclose(np.sum(shap, axis=len(shap.shape) - 1), margin, 1e-5, 1e-5)
|
||||||
|
|
||||||
def run_shap_cls_sklearn(self, X: Any, y: Any, client: "Client") -> None:
|
def run_shap_cls_sklearn(self, X: Any, y: Any, client: "Client") -> None:
|
||||||
X, y = da.from_array(X, chunks=(32, -1)), da.from_array(y, chunks=32)
|
X, y = da.from_array(X, chunks=(32, -1)), da.from_array(y, chunks=32)
|
||||||
cls = xgb.dask.DaskXGBClassifier()
|
cls = xgb.dask.DaskXGBClassifier(n_estimators=4)
|
||||||
cls.client = client
|
cls.client = client
|
||||||
cls.fit(X, y)
|
cls.fit(X, y)
|
||||||
booster = cls.get_booster()
|
booster = cls.get_booster()
|
||||||
@ -1072,6 +1097,8 @@ class TestWithDask:
|
|||||||
params: Dict[str, Any],
|
params: Dict[str, Any],
|
||||||
client: "Client"
|
client: "Client"
|
||||||
) -> None:
|
) -> None:
|
||||||
|
rows = X.shape[0]
|
||||||
|
cols = X.shape[1]
|
||||||
X, y = da.from_array(X, chunks=(32, -1)), da.from_array(y, chunks=32)
|
X, y = da.from_array(X, chunks=(32, -1)), da.from_array(y, chunks=32)
|
||||||
|
|
||||||
Xy = xgb.dask.DaskDMatrix(client, X, y)
|
Xy = xgb.dask.DaskDMatrix(client, X, y)
|
||||||
@ -1082,6 +1109,12 @@ class TestWithDask:
|
|||||||
shap = xgb.dask.predict(
|
shap = xgb.dask.predict(
|
||||||
client, booster, test_Xy, pred_interactions=True
|
client, booster, test_Xy, pred_interactions=True
|
||||||
).compute()
|
).compute()
|
||||||
|
|
||||||
|
assert len(shap.shape) == 3
|
||||||
|
assert shap.shape[0] == rows
|
||||||
|
assert shap.shape[1] == cols + 1
|
||||||
|
assert shap.shape[2] == cols + 1
|
||||||
|
|
||||||
margin = xgb.dask.predict(client, booster, test_Xy, output_margin=True).compute()
|
margin = xgb.dask.predict(client, booster, test_Xy, output_margin=True).compute()
|
||||||
assert np.allclose(np.sum(shap, axis=(len(shap.shape) - 1, len(shap.shape) - 2)),
|
assert np.allclose(np.sum(shap, axis=(len(shap.shape) - 1, len(shap.shape) - 2)),
|
||||||
margin,
|
margin,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user