[dask] Add a 1 line sample to infer output shape. (#6645)

* [dask] Use a 1 line sample to infer output shape.

This is for inferring shape with direct prediction (without DaskDMatrix).
There are a few things that requires known output shape before carrying out
actual prediction, including dask meta data, output dataframe columns.

* Infer output shape based on local prediction.
* Remove set param in predict function as it's not thread safe nor necessary as
we now let dask to decide the parallelism.
* Simplify prediction on `DaskDMatrix`.
This commit is contained in:
Jiaming Yuan 2021-01-30 18:55:50 +08:00 committed by GitHub
parent c3c8e66fc9
commit d8ec7aad5a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 285 additions and 219 deletions

View File

@ -108,8 +108,9 @@ computation a bit faster when meta information like ``base_margin`` is not neede
prediction = xgb.dask.inplace_predict(client, output, X)
Here ``prediction`` is a dask ``Array`` object containing predictions from model if input
is a ``DaskDMatrix`` or ``da.Array``. For ``dd.DataFrame``, the return value is a
``dd.Series``.
is a ``DaskDMatrix`` or ``da.Array``. When putting dask collection directly into the
``predict`` function or using ``inplace_predict``, the output type depends on input data.
See next section for details.
Alternatively, XGBoost also implements the Scikit-Learn interface with ``DaskXGBClassifier``
and ``DaskXGBRegressor``. See ``xgboost/demo/dask`` for more examples.
@ -143,9 +144,23 @@ Also for inplace prediction:
.. code-block:: python
booster.set_param({'predictor': 'gpu_predictor'})
# where X is a dask DataFrame or dask Array.
# where X is a dask DataFrame or dask Array containing cupy or cuDF backed data.
prediction = xgb.dask.inplace_predict(client, booster, X)
When input is ``da.Array`` object, output is always ``da.Array``. However, if the input
type is ``dd.DataFrame``, output can be ``dd.Series``, ``dd.DataFrame`` or ``da.Array``,
depending on output shape. For example, when shap based prediction is used, the return
value can have 3 or 4 dimensions , in such cases an ``Array`` is always returned.
The performance of running prediction, either using ``predict`` or ``inplace_predict``, is
sensitive to number of blocks. Internally, it's implemented using ``da.map_blocks`` or
``dd.map_partitions``. When number of partitions is large and each of them have only
small amount of data, the overhead of calling predict becomes visible. On the other hand,
if not using GPU, the number of threads used for prediction on each block matters. Right
now, xgboost uses single thread for each partition. If the number of blocks on each
workers is smaller than number of cores, then the CPU workers might not be fully utilized.
***************************
Working with other clusters

View File

@ -112,14 +112,15 @@ def _start_tracker(n_workers: int) -> Dict[str, Any]:
def _assert_dask_support() -> None:
try:
import dask # pylint: disable=W0621,W0611
import dask # pylint: disable=W0621,W0611
except ImportError as e:
raise ImportError(
'Dask needs to be installed in order to use this module') from e
"Dask needs to be installed in order to use this module"
) from e
if platform.system() == 'Windows':
msg = 'Windows is not officially supported for dask/xgboost,'
msg += ' contribution are welcomed.'
if platform.system() == "Windows":
msg = "Windows is not officially supported for dask/xgboost,"
msg += " contribution are welcomed."
LOGGER.warning(msg)
@ -252,6 +253,7 @@ class DaskDMatrix:
if not isinstance(label, (dd.DataFrame, da.Array, dd.Series, type(None))):
raise TypeError(_expect((dd.DataFrame, da.Array, dd.Series), type(label)))
self._n_cols = data.shape[1]
self.worker_map: Dict[str, "distributed.Future"] = defaultdict(list)
self.is_quantile: bool = False
@ -403,6 +405,9 @@ class DaskDMatrix:
'parts': self.worker_map.get(worker_addr, None),
'is_quantile': self.is_quantile}
def num_col(self) -> int:
return self._n_cols
_DataParts = List[Tuple[Any, Optional[Any], Optional[Any], Optional[Any], Optional[Any],
Optional[Any], Optional[Any]]]
@ -930,27 +935,90 @@ def train(
callbacks=callbacks)
def _can_output_df(data: _DaskCollection, output_shape: Tuple) -> bool:
return isinstance(data, dd.DataFrame) and len(output_shape) <= 2
async def _direct_predict_impl(
client: "distributed.Client",
mapped_predict: Callable,
booster: Booster,
data: _DaskCollection,
predict_fn: Callable
base_margin: Optional[_DaskCollection],
output_shape: Tuple[int, ...],
meta: Dict[int, str],
) -> _DaskCollection:
if isinstance(data, da.Array):
predictions = await client.submit(
da.map_blocks,
predict_fn, data, False, drop_axis=1,
dtype=numpy.float32
).result()
return predictions
if isinstance(data, dd.DataFrame):
predictions = await client.submit(
dd.map_partitions,
predict_fn, data, True,
meta=dd.utils.make_meta({'prediction': 'f4'})
).result()
return predictions.iloc[:, 0]
raise TypeError('data of type: ' + str(type(data)) +
' is not supported by direct prediction')
columns = list(meta.keys())
booster_f = await client.scatter(data=booster, broadcast=True)
if _can_output_df(data, output_shape):
if base_margin is not None and isinstance(base_margin, da.Array):
base_margin_df: Optional[dd.DataFrame] = base_margin.to_dask_dataframe()
else:
base_margin_df = base_margin
predictions = dd.map_partitions(
mapped_predict,
booster_f,
data,
True,
columns,
base_margin_df,
meta=dd.utils.make_meta(meta),
)
# classification can return a dataframe, drop 1 dim when it's reg/binary
if len(output_shape) == 1:
predictions = predictions.iloc[:, 0]
else:
if base_margin is not None and isinstance(
base_margin, (dd.Series, dd.DataFrame)
):
base_margin_array: Optional[da.Array] = base_margin.to_dask_array()
else:
base_margin_array = base_margin
# Input data is 2-dim array, output can be 1(reg, binary)/2(multi-class,
# contrib)/3(contrib)/4(interaction) dims.
if len(output_shape) == 1:
drop_axis: Union[int, List[int]] = [1] # drop from 2 to 1 dim.
new_axis: Union[int, List[int]] = []
else:
drop_axis = []
new_axis = [i + 2 for i in range(len(output_shape) - 2)]
predictions = da.map_blocks(
mapped_predict,
booster_f,
data,
False,
columns,
base_margin_array,
drop_axis=drop_axis,
new_axis=new_axis,
dtype=numpy.float32,
)
return predictions
def _infer_predict_output(
booster: Booster, data: _DaskCollection, inplace: bool, **kwargs: Any
) -> Tuple[Tuple[int, ...], Dict[int, str]]:
"""Create a dummy test sample to infer output shape for prediction."""
if isinstance(data, DaskDMatrix):
features = data.num_col()
else:
features = data.shape[1]
rng = numpy.random.RandomState(1994)
test_sample = rng.randn(1, features)
if inplace:
# clear the state to avoid gpu_id, gpu_predictor
booster = Booster(model_file=booster.save_raw())
test_predt = booster.inplace_predict(test_sample, **kwargs)
else:
m = DMatrix(test_sample)
test_predt = booster.predict(m, **kwargs)
n_columns = test_predt.shape[1] if len(test_predt.shape) > 1 else 1
meta: Dict[int, str] = {}
if _can_output_df(data, test_predt.shape):
for i in range(n_columns):
meta[i] = "f4"
return test_predt.shape, meta
# pylint: disable=too-many-statements
@ -968,19 +1036,19 @@ async def _predict_async(
validate_features: bool,
) -> _DaskCollection:
if isinstance(model, Booster):
booster = model
_booster = model
elif isinstance(model, dict):
booster = model["booster"]
_booster = model["booster"]
else:
raise TypeError(_expect([Booster, dict], type(model)))
if not isinstance(data, (DaskDMatrix, da.Array, dd.DataFrame)):
raise TypeError(_expect([DaskDMatrix, da.Array, dd.DataFrame], type(data)))
def mapped_predict(partition: Any, is_df: bool) -> Any:
worker = distributed.get_worker()
def mapped_predict(
booster: Booster, partition: Any, is_df: bool, columns: List[int], _: Any
) -> Any:
with config.config_context(**global_config):
booster.set_param({"nthread": worker.nthreads})
m = DMatrix(data=partition, missing=missing, nthread=worker.nthreads)
m = DMatrix(data=partition, missing=missing)
predt = booster.predict(
data=m,
output_margin=output_margin,
@ -990,167 +1058,115 @@ async def _predict_async(
pred_interactions=pred_interactions,
validate_features=validate_features,
)
if is_df:
if is_df and len(predt.shape) <= 2:
if lazy_isinstance(partition, "cudf", "core.dataframe.DataFrame"):
import cudf
predt = cudf.DataFrame(predt, columns=["prediction"])
predt = cudf.DataFrame(predt, columns=columns)
else:
predt = DataFrame(predt, columns=["prediction"])
predt = DataFrame(predt, columns=columns)
return predt
# Predict on dask collection directly.
if isinstance(data, (da.Array, dd.DataFrame)):
return await _direct_predict_impl(client, data, mapped_predict)
_output_shape, meta = _infer_predict_output(
_booster,
data,
inplace=False,
output_margin=output_margin,
pred_leaf=pred_leaf,
pred_contribs=pred_contribs,
approx_contribs=approx_contribs,
pred_interactions=pred_interactions,
validate_features=False,
)
return await _direct_predict_impl(
client, mapped_predict, _booster, data, None, _output_shape, meta
)
output_shape, _ = _infer_predict_output(
booster=_booster,
data=data,
inplace=False,
output_margin=output_margin,
pred_leaf=pred_leaf,
pred_contribs=pred_contribs,
approx_contribs=approx_contribs,
pred_interactions=pred_interactions,
validate_features=False,
)
# Prediction on dask DMatrix.
worker_map = data.worker_map
partition_order = data.partition_order
feature_names = data.feature_names
feature_types = data.feature_types
missing = data.missing
meta_names = data.meta_names
def dispatched_predict(
worker_id: int, list_of_orders: List[int], list_of_parts: _DataParts
) -> List[Tuple[List[Union["dask.delayed.Delayed", int]], int]]:
"""Perform prediction on each worker."""
LOGGER.debug("Predicting on %d", worker_id)
def dispatched_predict(booster: Booster, part: Any) -> numpy.ndarray:
data = part[0]
assert isinstance(part, tuple), type(part)
base_margin = None
for i, blob in enumerate(part[1:]):
if meta_names[i] == "base_margin":
base_margin = blob
worker = distributed.get_worker()
with config.config_context(**global_config):
worker = distributed.get_worker()
list_of_parts = _get_worker_parts_ordered(meta_names, list_of_parts)
predictions = []
booster.set_param({"nthread": worker.nthreads})
for i, parts in enumerate(list_of_parts):
(data, _, _, base_margin, _, _, _) = parts
order = list_of_orders[i]
local_part = DMatrix(
data,
base_margin=base_margin,
feature_names=feature_names,
feature_types=feature_types,
missing=missing,
nthread=worker.nthreads,
)
predt = booster.predict(
data=local_part,
output_margin=output_margin,
pred_leaf=pred_leaf,
pred_contribs=pred_contribs,
approx_contribs=approx_contribs,
pred_interactions=pred_interactions,
validate_features=validate_features,
)
if pred_contribs and predt.size != local_part.num_row():
assert len(predt.shape) in (2, 3)
if len(predt.shape) == 2:
groups = 1
columns = predt.shape[1]
else:
groups = predt.shape[1]
columns = predt.shape[2]
# pylint: disable=no-member
ret = (
[dask.delayed(predt), groups, columns],
order,
)
elif pred_interactions and predt.size != local_part.num_row():
assert len(predt.shape) in (3, 4)
if len(predt.shape) == 3:
groups = 1
columns = predt.shape[1]
else:
groups = predt.shape[1]
columns = predt.shape[2]
# pylint: disable=no-member
ret = (
[dask.delayed(predt), groups, columns],
order,
)
else:
assert len(predt.shape) == 1 or len(predt.shape) == 2
columns = 1 if len(predt.shape) == 1 else predt.shape[1]
# pylint: disable=no-member
ret = (
[dask.delayed(predt), columns],
order,
)
predictions.append(ret)
return predictions
def dispatched_get_shape(
worker_id: int, list_of_orders: List[int], list_of_parts: _DataParts
) -> List[Tuple[int, int]]:
"""Get shape of data in each worker."""
LOGGER.debug("Get shape on %d", worker_id)
list_of_parts = _get_worker_parts_ordered(meta_names, list_of_parts)
shapes = []
for i, parts in enumerate(list_of_parts):
(data, _, _, _, _, _, _) = parts
shapes.append((data.shape, list_of_orders[i]))
return shapes
async def map_function(
func: Callable[[int, List[int], _DataParts], Any]
) -> List[Any]:
"""Run function for each part of the data."""
futures = []
workers_address = list(worker_map.keys())
for wid, worker_addr in enumerate(workers_address):
worker_addr = workers_address[wid]
list_of_parts = worker_map[worker_addr]
list_of_orders = [partition_order[part.key] for part in list_of_parts]
f = client.submit(
func,
worker_id=wid,
list_of_orders=list_of_orders,
list_of_parts=list_of_parts,
pure=True,
workers=[worker_addr],
m = DMatrix(
data,
nthread=worker.nthreads,
missing=missing,
base_margin=base_margin,
feature_names=feature_names,
feature_types=feature_types,
)
assert isinstance(f, distributed.client.Future)
futures.append(f)
# Get delayed objects
results = await client.gather(futures)
# flatten into 1 dim list
results = [t for list_per_worker in results for t in list_per_worker]
# sort by order, l[0] is the delayed object, l[1] is its order
results = sorted(results, key=lambda l: l[1])
results = [predt for predt, order in results] # remove order
return results
predt = booster.predict(
m,
output_margin=output_margin,
pred_leaf=pred_leaf,
pred_contribs=pred_contribs,
approx_contribs=approx_contribs,
pred_interactions=pred_interactions,
validate_features=validate_features,
)
return predt
results = await map_function(dispatched_predict)
shapes = await map_function(dispatched_get_shape)
all_parts = []
all_orders = []
all_shapes = []
workers_address = list(data.worker_map.keys())
for worker_addr in workers_address:
list_of_parts = data.worker_map[worker_addr]
all_parts.extend(list_of_parts)
all_orders.extend([partition_order[part.key] for part in list_of_parts])
for part in all_parts:
s = client.submit(lambda part: part[0].shape[0], part)
all_shapes.append(s)
all_shapes = await client.gather(all_shapes)
parts_with_order = list(zip(all_parts, all_shapes, all_orders))
parts_with_order = sorted(parts_with_order, key=lambda p: p[2])
all_parts = [part for part, shape, order in parts_with_order]
all_shapes = [shape for part, shape, order in parts_with_order]
futures = []
booster_f = await client.scatter(data=_booster, broadcast=True)
for part in all_parts:
f = client.submit(dispatched_predict, booster_f, part)
futures.append(f)
# Constructing a dask array from list of numpy arrays
# See https://docs.dask.org/en/latest/array-creation.html
arrays = []
for i, shape in enumerate(shapes):
if pred_contribs:
out_shape = (
(shape[0], results[i][2])
if results[i][1] == 1
else (shape[0], results[i][1], results[i][2])
)
elif pred_interactions:
out_shape = (
(shape[0], results[i][2], results[i][2])
if results[i][1] == 1
else (shape[0], results[i][1], results[i][2])
)
else:
out_shape = (shape[0],) if results[i][1] == 1 else (shape[0], results[i][1])
for i, rows in enumerate(all_shapes):
arrays.append(
da.from_delayed(results[i][0], shape=out_shape, dtype=numpy.float32)
da.from_delayed(
futures[i], shape=(rows,) + output_shape[1:], dtype=numpy.float32
)
)
predictions = await da.concatenate(arrays, axis=0)
return predictions
def predict(
def predict( # pylint: disable=unused-argument
client: "distributed.Client",
model: Union[TrainReturnT, Booster],
data: Union[DaskDMatrix, _DaskCollection],
@ -1190,22 +1206,15 @@ def predict(
-------
prediction: dask.array.Array/dask.dataframe.Series
When input data is ``dask.array.Array`` or ``DaskDMatrix``, the return value is an
array, when input data is ``dask.dataframe.DataFrame``, return value is
``dask.dataframe.Series``
array, when input data is ``dask.dataframe.DataFrame``, return value can be
``dask.dataframe.Series``, ``dask.dataframe.DataFrame`` or ``dask.array.Array``,
depending on the output shape.
'''
_assert_dask_support()
client = _xgb_get_client(client)
global_config = config.get_config()
return client.sync(
_predict_async, client, global_config, model, data,
output_margin=output_margin,
missing=missing,
pred_leaf=pred_leaf,
pred_contribs=pred_contribs,
approx_contribs=approx_contribs,
pred_interactions=pred_interactions,
validate_features=validate_features
_predict_async, global_config=config.get_config(), **locals()
)
@ -1228,30 +1237,38 @@ async def _inplace_predict_async(
if not isinstance(data, (da.Array, dd.DataFrame)):
raise TypeError(_expect([da.Array, dd.DataFrame], type(data)))
def mapped_predict(data: Any, is_df: bool) -> Any:
worker = distributed.get_worker()
config.set_config(**global_config)
booster.set_param({'nthread': worker.nthreads})
prediction = booster.inplace_predict(
data,
iteration_range=iteration_range,
predict_type=predict_type,
missing=missing)
if is_df:
def mapped_predict(
booster: Booster, data: Any, is_df: bool, columns: List[int], _: Any
) -> Any:
with config.config_context(**global_config):
prediction = booster.inplace_predict(
data,
iteration_range=iteration_range,
predict_type=predict_type,
missing=missing
)
if is_df and len(prediction.shape) <= 2:
if lazy_isinstance(data, 'cudf.core.dataframe', 'DataFrame'):
import cudf
prediction = cudf.DataFrame({'prediction': prediction},
dtype=numpy.float32)
prediction = cudf.DataFrame(
prediction, columns=columns, dtype=numpy.float32
)
else:
# If it's from pandas, the partition is a numpy array
prediction = DataFrame(prediction, columns=['prediction'],
dtype=numpy.float32)
prediction = DataFrame(
prediction, columns=columns, dtype=numpy.float32
)
return prediction
return await _direct_predict_impl(client, data, mapped_predict)
shape, meta = _infer_predict_output(
booster, data, True, predict_type=predict_type, iteration_range=iteration_range
)
return await _direct_predict_impl(
client, mapped_predict, booster, data, None, shape, meta
)
def inplace_predict(
def inplace_predict( # pylint: disable=unused-argument
client: "distributed.Client",
model: Union[TrainReturnT, Booster],
data: _DaskCollection,
@ -1281,16 +1298,17 @@ def inplace_predict(
Returns
-------
prediction
prediction :
When input data is ``dask.array.Array`` or ``DaskDMatrix``, the return value is an
array, when input data is ``dask.dataframe.DataFrame``, return value can be
``dask.dataframe.Series``, ``dask.dataframe.DataFrame`` or ``dask.array.Array``,
depending on the output shape.
'''
_assert_dask_support()
client = _xgb_get_client(client)
global_config = config.get_config()
return client.sync(_inplace_predict_async, client, global_config, model=model,
data=data,
iteration_range=iteration_range,
predict_type=predict_type,
missing=missing)
return client.sync(
_inplace_predict_async, global_config=config.get_config(), **locals()
)
async def _async_wrap_evaluation_matrices(

View File

@ -24,7 +24,6 @@ if sys.platform.startswith("win"):
if tm.no_dask()['condition']:
pytest.skip(msg=tm.no_dask()['reason'], allow_module_level=True)
import distributed
from distributed import LocalCluster, Client
from distributed.utils_test import client, loop, cluster_fixture
import dask.dataframe as dd
@ -130,24 +129,34 @@ def test_from_dask_array() -> None:
assert np.all(single_node_predt == from_arr.compute())
def test_dask_predict_shape_infer() -> None:
with LocalCluster(n_workers=kWorkers) as cluster:
with Client(cluster) as client:
X, y = make_classification(n_samples=1000, n_informative=5,
n_classes=3)
X_ = dd.from_array(X, chunksize=100)
y_ = dd.from_array(y, chunksize=100)
dtrain = xgb.dask.DaskDMatrix(client, data=X_, label=y_)
def test_dask_predict_shape_infer(client: "Client") -> None:
X, y = make_classification(n_samples=1000, n_informative=5, n_classes=3)
X_ = dd.from_array(X, chunksize=100)
y_ = dd.from_array(y, chunksize=100)
dtrain = xgb.dask.DaskDMatrix(client, data=X_, label=y_)
model = xgb.dask.train(
client,
{"objective": "multi:softprob", "num_class": 3},
dtrain=dtrain
)
model = xgb.dask.train(
client, {"objective": "multi:softprob", "num_class": 3}, dtrain=dtrain
)
preds = xgb.dask.predict(client, model, dtrain)
assert preds.shape[0] == preds.compute().shape[0]
assert preds.shape[1] == preds.compute().shape[1]
preds = xgb.dask.predict(client, model, dtrain)
assert preds.shape[0] == preds.compute().shape[0]
assert preds.shape[1] == preds.compute().shape[1]
prediction = xgb.dask.predict(client, model, X_, output_margin=True)
assert isinstance(prediction, dd.DataFrame)
prediction = prediction.compute()
assert prediction.ndim == 2
assert prediction.shape[0] == kRows
assert prediction.shape[1] == 3
prediction = xgb.dask.inplace_predict(client, model, X_, predict_type="margin")
assert isinstance(prediction, dd.DataFrame)
prediction = prediction.compute()
assert prediction.ndim == 2
assert prediction.shape[0] == kRows
assert prediction.shape[1] == 3
@pytest.mark.parametrize("tree_method", ["hist", "approx"])
@ -340,7 +349,7 @@ def test_dask_classifier(model: str, client: "Client") -> None:
classifier.fit(X_d, y_d)
assert classifier.n_classes_ == 10
prediction = classifier.predict(X_d)
prediction = classifier.predict(X_d).compute()
assert prediction.ndim == 1
assert prediction.shape[0] == kRows
@ -541,6 +550,9 @@ async def run_dask_regressor_asyncio(scheduler_address: str) -> None:
assert list(history['validation_0'].keys())[0] == 'rmse'
assert len(history['validation_0']['rmse']) == 2
awaited = await client.compute(prediction)
assert awaited.shape[0] == kRows
async def run_dask_classifier_asyncio(scheduler_address: str) -> None:
async with Client(scheduler_address, asynchronous=True) as client:
@ -578,7 +590,7 @@ async def run_dask_classifier_asyncio(scheduler_address: str) -> None:
await classifier.fit(X_d, y_d)
assert classifier.n_classes_ == 10
prediction = await classifier.predict(X_d)
prediction = await client.compute(await classifier.predict(X_d))
assert prediction.ndim == 1
assert prediction.shape[0] == kRows
@ -1019,6 +1031,17 @@ class TestWithDask:
run_data_initialization(xgb.dask.DaskDMatrix, xgb.dask.DaskXGBClassifier, X, y)
def run_shap(self, X: Any, y: Any, params: Dict[str, Any], client: "Client") -> None:
rows = X.shape[0]
cols = X.shape[1]
def assert_shape(shape):
assert shape[0] == rows
if "num_class" in params.keys():
assert shape[1] == params["num_class"]
assert shape[2] == cols + 1
else:
assert shape[1] == cols + 1
X, y = da.from_array(X, chunks=(32, -1)), da.from_array(y, chunks=32)
Xy = xgb.dask.DaskDMatrix(client, X, y)
booster = xgb.dask.train(client, params, Xy, num_boost_round=10)['booster']
@ -1027,15 +1050,17 @@ class TestWithDask:
shap = xgb.dask.predict(client, booster, test_Xy, pred_contribs=True).compute()
margin = xgb.dask.predict(client, booster, test_Xy, output_margin=True).compute()
assert_shape(shap.shape)
assert np.allclose(np.sum(shap, axis=len(shap.shape) - 1), margin, 1e-5, 1e-5)
shap = xgb.dask.predict(client, booster, X, pred_contribs=True).compute()
margin = xgb.dask.predict(client, booster, X, output_margin=True).compute()
assert_shape(shap.shape)
assert np.allclose(np.sum(shap, axis=len(shap.shape) - 1), margin, 1e-5, 1e-5)
def run_shap_cls_sklearn(self, X: Any, y: Any, client: "Client") -> None:
X, y = da.from_array(X, chunks=(32, -1)), da.from_array(y, chunks=32)
cls = xgb.dask.DaskXGBClassifier()
cls = xgb.dask.DaskXGBClassifier(n_estimators=4)
cls.client = client
cls.fit(X, y)
booster = cls.get_booster()
@ -1072,6 +1097,8 @@ class TestWithDask:
params: Dict[str, Any],
client: "Client"
) -> None:
rows = X.shape[0]
cols = X.shape[1]
X, y = da.from_array(X, chunks=(32, -1)), da.from_array(y, chunks=32)
Xy = xgb.dask.DaskDMatrix(client, X, y)
@ -1082,6 +1109,12 @@ class TestWithDask:
shap = xgb.dask.predict(
client, booster, test_Xy, pred_interactions=True
).compute()
assert len(shap.shape) == 3
assert shap.shape[0] == rows
assert shap.shape[1] == cols + 1
assert shap.shape[2] == cols + 1
margin = xgb.dask.predict(client, booster, test_Xy, output_margin=True).compute()
assert np.allclose(np.sum(shap, axis=(len(shap.shape) - 1, len(shap.shape) - 2)),
margin,