[dask] Add a 1 line sample to infer output shape. (#6645)

* [dask] Use a 1 line sample to infer output shape.

This is for inferring shape with direct prediction (without DaskDMatrix).
There are a few things that requires known output shape before carrying out
actual prediction, including dask meta data, output dataframe columns.

* Infer output shape based on local prediction.
* Remove set param in predict function as it's not thread safe nor necessary as
we now let dask to decide the parallelism.
* Simplify prediction on `DaskDMatrix`.
This commit is contained in:
Jiaming Yuan 2021-01-30 18:55:50 +08:00 committed by GitHub
parent c3c8e66fc9
commit d8ec7aad5a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 285 additions and 219 deletions

View File

@ -108,8 +108,9 @@ computation a bit faster when meta information like ``base_margin`` is not neede
prediction = xgb.dask.inplace_predict(client, output, X) prediction = xgb.dask.inplace_predict(client, output, X)
Here ``prediction`` is a dask ``Array`` object containing predictions from model if input Here ``prediction`` is a dask ``Array`` object containing predictions from model if input
is a ``DaskDMatrix`` or ``da.Array``. For ``dd.DataFrame``, the return value is a is a ``DaskDMatrix`` or ``da.Array``. When putting dask collection directly into the
``dd.Series``. ``predict`` function or using ``inplace_predict``, the output type depends on input data.
See next section for details.
Alternatively, XGBoost also implements the Scikit-Learn interface with ``DaskXGBClassifier`` Alternatively, XGBoost also implements the Scikit-Learn interface with ``DaskXGBClassifier``
and ``DaskXGBRegressor``. See ``xgboost/demo/dask`` for more examples. and ``DaskXGBRegressor``. See ``xgboost/demo/dask`` for more examples.
@ -143,9 +144,23 @@ Also for inplace prediction:
.. code-block:: python .. code-block:: python
booster.set_param({'predictor': 'gpu_predictor'}) booster.set_param({'predictor': 'gpu_predictor'})
# where X is a dask DataFrame or dask Array. # where X is a dask DataFrame or dask Array containing cupy or cuDF backed data.
prediction = xgb.dask.inplace_predict(client, booster, X) prediction = xgb.dask.inplace_predict(client, booster, X)
When input is ``da.Array`` object, output is always ``da.Array``. However, if the input
type is ``dd.DataFrame``, output can be ``dd.Series``, ``dd.DataFrame`` or ``da.Array``,
depending on output shape. For example, when shap based prediction is used, the return
value can have 3 or 4 dimensions , in such cases an ``Array`` is always returned.
The performance of running prediction, either using ``predict`` or ``inplace_predict``, is
sensitive to number of blocks. Internally, it's implemented using ``da.map_blocks`` or
``dd.map_partitions``. When number of partitions is large and each of them have only
small amount of data, the overhead of calling predict becomes visible. On the other hand,
if not using GPU, the number of threads used for prediction on each block matters. Right
now, xgboost uses single thread for each partition. If the number of blocks on each
workers is smaller than number of cores, then the CPU workers might not be fully utilized.
*************************** ***************************
Working with other clusters Working with other clusters

View File

@ -115,11 +115,12 @@ def _assert_dask_support() -> None:
import dask # pylint: disable=W0621,W0611 import dask # pylint: disable=W0621,W0611
except ImportError as e: except ImportError as e:
raise ImportError( raise ImportError(
'Dask needs to be installed in order to use this module') from e "Dask needs to be installed in order to use this module"
) from e
if platform.system() == 'Windows': if platform.system() == "Windows":
msg = 'Windows is not officially supported for dask/xgboost,' msg = "Windows is not officially supported for dask/xgboost,"
msg += ' contribution are welcomed.' msg += " contribution are welcomed."
LOGGER.warning(msg) LOGGER.warning(msg)
@ -252,6 +253,7 @@ class DaskDMatrix:
if not isinstance(label, (dd.DataFrame, da.Array, dd.Series, type(None))): if not isinstance(label, (dd.DataFrame, da.Array, dd.Series, type(None))):
raise TypeError(_expect((dd.DataFrame, da.Array, dd.Series), type(label))) raise TypeError(_expect((dd.DataFrame, da.Array, dd.Series), type(label)))
self._n_cols = data.shape[1]
self.worker_map: Dict[str, "distributed.Future"] = defaultdict(list) self.worker_map: Dict[str, "distributed.Future"] = defaultdict(list)
self.is_quantile: bool = False self.is_quantile: bool = False
@ -403,6 +405,9 @@ class DaskDMatrix:
'parts': self.worker_map.get(worker_addr, None), 'parts': self.worker_map.get(worker_addr, None),
'is_quantile': self.is_quantile} 'is_quantile': self.is_quantile}
def num_col(self) -> int:
return self._n_cols
_DataParts = List[Tuple[Any, Optional[Any], Optional[Any], Optional[Any], Optional[Any], _DataParts = List[Tuple[Any, Optional[Any], Optional[Any], Optional[Any], Optional[Any],
Optional[Any], Optional[Any]]] Optional[Any], Optional[Any]]]
@ -930,27 +935,90 @@ def train(
callbacks=callbacks) callbacks=callbacks)
def _can_output_df(data: _DaskCollection, output_shape: Tuple) -> bool:
return isinstance(data, dd.DataFrame) and len(output_shape) <= 2
async def _direct_predict_impl( async def _direct_predict_impl(
client: "distributed.Client", client: "distributed.Client",
mapped_predict: Callable,
booster: Booster,
data: _DaskCollection, data: _DaskCollection,
predict_fn: Callable base_margin: Optional[_DaskCollection],
output_shape: Tuple[int, ...],
meta: Dict[int, str],
) -> _DaskCollection: ) -> _DaskCollection:
if isinstance(data, da.Array): columns = list(meta.keys())
predictions = await client.submit( booster_f = await client.scatter(data=booster, broadcast=True)
da.map_blocks, if _can_output_df(data, output_shape):
predict_fn, data, False, drop_axis=1, if base_margin is not None and isinstance(base_margin, da.Array):
dtype=numpy.float32 base_margin_df: Optional[dd.DataFrame] = base_margin.to_dask_dataframe()
).result() else:
base_margin_df = base_margin
predictions = dd.map_partitions(
mapped_predict,
booster_f,
data,
True,
columns,
base_margin_df,
meta=dd.utils.make_meta(meta),
)
# classification can return a dataframe, drop 1 dim when it's reg/binary
if len(output_shape) == 1:
predictions = predictions.iloc[:, 0]
else:
if base_margin is not None and isinstance(
base_margin, (dd.Series, dd.DataFrame)
):
base_margin_array: Optional[da.Array] = base_margin.to_dask_array()
else:
base_margin_array = base_margin
# Input data is 2-dim array, output can be 1(reg, binary)/2(multi-class,
# contrib)/3(contrib)/4(interaction) dims.
if len(output_shape) == 1:
drop_axis: Union[int, List[int]] = [1] # drop from 2 to 1 dim.
new_axis: Union[int, List[int]] = []
else:
drop_axis = []
new_axis = [i + 2 for i in range(len(output_shape) - 2)]
predictions = da.map_blocks(
mapped_predict,
booster_f,
data,
False,
columns,
base_margin_array,
drop_axis=drop_axis,
new_axis=new_axis,
dtype=numpy.float32,
)
return predictions return predictions
if isinstance(data, dd.DataFrame):
predictions = await client.submit(
dd.map_partitions, def _infer_predict_output(
predict_fn, data, True, booster: Booster, data: _DaskCollection, inplace: bool, **kwargs: Any
meta=dd.utils.make_meta({'prediction': 'f4'}) ) -> Tuple[Tuple[int, ...], Dict[int, str]]:
).result() """Create a dummy test sample to infer output shape for prediction."""
return predictions.iloc[:, 0] if isinstance(data, DaskDMatrix):
raise TypeError('data of type: ' + str(type(data)) + features = data.num_col()
' is not supported by direct prediction') else:
features = data.shape[1]
rng = numpy.random.RandomState(1994)
test_sample = rng.randn(1, features)
if inplace:
# clear the state to avoid gpu_id, gpu_predictor
booster = Booster(model_file=booster.save_raw())
test_predt = booster.inplace_predict(test_sample, **kwargs)
else:
m = DMatrix(test_sample)
test_predt = booster.predict(m, **kwargs)
n_columns = test_predt.shape[1] if len(test_predt.shape) > 1 else 1
meta: Dict[int, str] = {}
if _can_output_df(data, test_predt.shape):
for i in range(n_columns):
meta[i] = "f4"
return test_predt.shape, meta
# pylint: disable=too-many-statements # pylint: disable=too-many-statements
@ -968,19 +1036,19 @@ async def _predict_async(
validate_features: bool, validate_features: bool,
) -> _DaskCollection: ) -> _DaskCollection:
if isinstance(model, Booster): if isinstance(model, Booster):
booster = model _booster = model
elif isinstance(model, dict): elif isinstance(model, dict):
booster = model["booster"] _booster = model["booster"]
else: else:
raise TypeError(_expect([Booster, dict], type(model))) raise TypeError(_expect([Booster, dict], type(model)))
if not isinstance(data, (DaskDMatrix, da.Array, dd.DataFrame)): if not isinstance(data, (DaskDMatrix, da.Array, dd.DataFrame)):
raise TypeError(_expect([DaskDMatrix, da.Array, dd.DataFrame], type(data))) raise TypeError(_expect([DaskDMatrix, da.Array, dd.DataFrame], type(data)))
def mapped_predict(partition: Any, is_df: bool) -> Any: def mapped_predict(
worker = distributed.get_worker() booster: Booster, partition: Any, is_df: bool, columns: List[int], _: Any
) -> Any:
with config.config_context(**global_config): with config.config_context(**global_config):
booster.set_param({"nthread": worker.nthreads}) m = DMatrix(data=partition, missing=missing)
m = DMatrix(data=partition, missing=missing, nthread=worker.nthreads)
predt = booster.predict( predt = booster.predict(
data=m, data=m,
output_margin=output_margin, output_margin=output_margin,
@ -990,50 +1058,68 @@ async def _predict_async(
pred_interactions=pred_interactions, pred_interactions=pred_interactions,
validate_features=validate_features, validate_features=validate_features,
) )
if is_df: if is_df and len(predt.shape) <= 2:
if lazy_isinstance(partition, "cudf", "core.dataframe.DataFrame"): if lazy_isinstance(partition, "cudf", "core.dataframe.DataFrame"):
import cudf import cudf
predt = cudf.DataFrame(predt, columns=["prediction"])
predt = cudf.DataFrame(predt, columns=columns)
else: else:
predt = DataFrame(predt, columns=["prediction"]) predt = DataFrame(predt, columns=columns)
return predt return predt
# Predict on dask collection directly. # Predict on dask collection directly.
if isinstance(data, (da.Array, dd.DataFrame)): if isinstance(data, (da.Array, dd.DataFrame)):
return await _direct_predict_impl(client, data, mapped_predict) _output_shape, meta = _infer_predict_output(
_booster,
data,
inplace=False,
output_margin=output_margin,
pred_leaf=pred_leaf,
pred_contribs=pred_contribs,
approx_contribs=approx_contribs,
pred_interactions=pred_interactions,
validate_features=False,
)
return await _direct_predict_impl(
client, mapped_predict, _booster, data, None, _output_shape, meta
)
output_shape, _ = _infer_predict_output(
booster=_booster,
data=data,
inplace=False,
output_margin=output_margin,
pred_leaf=pred_leaf,
pred_contribs=pred_contribs,
approx_contribs=approx_contribs,
pred_interactions=pred_interactions,
validate_features=False,
)
# Prediction on dask DMatrix. # Prediction on dask DMatrix.
worker_map = data.worker_map
partition_order = data.partition_order partition_order = data.partition_order
feature_names = data.feature_names feature_names = data.feature_names
feature_types = data.feature_types feature_types = data.feature_types
missing = data.missing missing = data.missing
meta_names = data.meta_names meta_names = data.meta_names
def dispatched_predict( def dispatched_predict(booster: Booster, part: Any) -> numpy.ndarray:
worker_id: int, list_of_orders: List[int], list_of_parts: _DataParts data = part[0]
) -> List[Tuple[List[Union["dask.delayed.Delayed", int]], int]]: assert isinstance(part, tuple), type(part)
"""Perform prediction on each worker.""" base_margin = None
LOGGER.debug("Predicting on %d", worker_id) for i, blob in enumerate(part[1:]):
with config.config_context(**global_config): if meta_names[i] == "base_margin":
base_margin = blob
worker = distributed.get_worker() worker = distributed.get_worker()
list_of_parts = _get_worker_parts_ordered(meta_names, list_of_parts) with config.config_context(**global_config):
predictions = [] m = DMatrix(
booster.set_param({"nthread": worker.nthreads})
for i, parts in enumerate(list_of_parts):
(data, _, _, base_margin, _, _, _) = parts
order = list_of_orders[i]
local_part = DMatrix(
data, data,
nthread=worker.nthreads,
missing=missing,
base_margin=base_margin, base_margin=base_margin,
feature_names=feature_names, feature_names=feature_names,
feature_types=feature_types, feature_types=feature_types,
missing=missing,
nthread=worker.nthreads,
) )
predt = booster.predict( predt = booster.predict(
data=local_part, m,
output_margin=output_margin, output_margin=output_margin,
pred_leaf=pred_leaf, pred_leaf=pred_leaf,
pred_contribs=pred_contribs, pred_contribs=pred_contribs,
@ -1041,116 +1127,46 @@ async def _predict_async(
pred_interactions=pred_interactions, pred_interactions=pred_interactions,
validate_features=validate_features, validate_features=validate_features,
) )
if pred_contribs and predt.size != local_part.num_row(): return predt
assert len(predt.shape) in (2, 3)
if len(predt.shape) == 2:
groups = 1
columns = predt.shape[1]
else:
groups = predt.shape[1]
columns = predt.shape[2]
# pylint: disable=no-member
ret = (
[dask.delayed(predt), groups, columns],
order,
)
elif pred_interactions and predt.size != local_part.num_row():
assert len(predt.shape) in (3, 4)
if len(predt.shape) == 3:
groups = 1
columns = predt.shape[1]
else:
groups = predt.shape[1]
columns = predt.shape[2]
# pylint: disable=no-member
ret = (
[dask.delayed(predt), groups, columns],
order,
)
else:
assert len(predt.shape) == 1 or len(predt.shape) == 2
columns = 1 if len(predt.shape) == 1 else predt.shape[1]
# pylint: disable=no-member
ret = (
[dask.delayed(predt), columns],
order,
)
predictions.append(ret)
return predictions all_parts = []
all_orders = []
all_shapes = []
workers_address = list(data.worker_map.keys())
for worker_addr in workers_address:
list_of_parts = data.worker_map[worker_addr]
all_parts.extend(list_of_parts)
all_orders.extend([partition_order[part.key] for part in list_of_parts])
for part in all_parts:
s = client.submit(lambda part: part[0].shape[0], part)
all_shapes.append(s)
all_shapes = await client.gather(all_shapes)
def dispatched_get_shape( parts_with_order = list(zip(all_parts, all_shapes, all_orders))
worker_id: int, list_of_orders: List[int], list_of_parts: _DataParts parts_with_order = sorted(parts_with_order, key=lambda p: p[2])
) -> List[Tuple[int, int]]: all_parts = [part for part, shape, order in parts_with_order]
"""Get shape of data in each worker.""" all_shapes = [shape for part, shape, order in parts_with_order]
LOGGER.debug("Get shape on %d", worker_id)
list_of_parts = _get_worker_parts_ordered(meta_names, list_of_parts)
shapes = []
for i, parts in enumerate(list_of_parts):
(data, _, _, _, _, _, _) = parts
shapes.append((data.shape, list_of_orders[i]))
return shapes
async def map_function(
func: Callable[[int, List[int], _DataParts], Any]
) -> List[Any]:
"""Run function for each part of the data."""
futures = [] futures = []
workers_address = list(worker_map.keys()) booster_f = await client.scatter(data=_booster, broadcast=True)
for wid, worker_addr in enumerate(workers_address): for part in all_parts:
worker_addr = workers_address[wid] f = client.submit(dispatched_predict, booster_f, part)
list_of_parts = worker_map[worker_addr]
list_of_orders = [partition_order[part.key] for part in list_of_parts]
f = client.submit(
func,
worker_id=wid,
list_of_orders=list_of_orders,
list_of_parts=list_of_parts,
pure=True,
workers=[worker_addr],
)
assert isinstance(f, distributed.client.Future)
futures.append(f) futures.append(f)
# Get delayed objects
results = await client.gather(futures)
# flatten into 1 dim list
results = [t for list_per_worker in results for t in list_per_worker]
# sort by order, l[0] is the delayed object, l[1] is its order
results = sorted(results, key=lambda l: l[1])
results = [predt for predt, order in results] # remove order
return results
results = await map_function(dispatched_predict)
shapes = await map_function(dispatched_get_shape)
# Constructing a dask array from list of numpy arrays # Constructing a dask array from list of numpy arrays
# See https://docs.dask.org/en/latest/array-creation.html # See https://docs.dask.org/en/latest/array-creation.html
arrays = [] arrays = []
for i, shape in enumerate(shapes): for i, rows in enumerate(all_shapes):
if pred_contribs:
out_shape = (
(shape[0], results[i][2])
if results[i][1] == 1
else (shape[0], results[i][1], results[i][2])
)
elif pred_interactions:
out_shape = (
(shape[0], results[i][2], results[i][2])
if results[i][1] == 1
else (shape[0], results[i][1], results[i][2])
)
else:
out_shape = (shape[0],) if results[i][1] == 1 else (shape[0], results[i][1])
arrays.append( arrays.append(
da.from_delayed(results[i][0], shape=out_shape, dtype=numpy.float32) da.from_delayed(
futures[i], shape=(rows,) + output_shape[1:], dtype=numpy.float32
)
) )
predictions = await da.concatenate(arrays, axis=0) predictions = await da.concatenate(arrays, axis=0)
return predictions return predictions
def predict( def predict( # pylint: disable=unused-argument
client: "distributed.Client", client: "distributed.Client",
model: Union[TrainReturnT, Booster], model: Union[TrainReturnT, Booster],
data: Union[DaskDMatrix, _DaskCollection], data: Union[DaskDMatrix, _DaskCollection],
@ -1190,22 +1206,15 @@ def predict(
------- -------
prediction: dask.array.Array/dask.dataframe.Series prediction: dask.array.Array/dask.dataframe.Series
When input data is ``dask.array.Array`` or ``DaskDMatrix``, the return value is an When input data is ``dask.array.Array`` or ``DaskDMatrix``, the return value is an
array, when input data is ``dask.dataframe.DataFrame``, return value is array, when input data is ``dask.dataframe.DataFrame``, return value can be
``dask.dataframe.Series`` ``dask.dataframe.Series``, ``dask.dataframe.DataFrame`` or ``dask.array.Array``,
depending on the output shape.
''' '''
_assert_dask_support() _assert_dask_support()
client = _xgb_get_client(client) client = _xgb_get_client(client)
global_config = config.get_config()
return client.sync( return client.sync(
_predict_async, client, global_config, model, data, _predict_async, global_config=config.get_config(), **locals()
output_margin=output_margin,
missing=missing,
pred_leaf=pred_leaf,
pred_contribs=pred_contribs,
approx_contribs=approx_contribs,
pred_interactions=pred_interactions,
validate_features=validate_features
) )
@ -1228,30 +1237,38 @@ async def _inplace_predict_async(
if not isinstance(data, (da.Array, dd.DataFrame)): if not isinstance(data, (da.Array, dd.DataFrame)):
raise TypeError(_expect([da.Array, dd.DataFrame], type(data))) raise TypeError(_expect([da.Array, dd.DataFrame], type(data)))
def mapped_predict(data: Any, is_df: bool) -> Any: def mapped_predict(
worker = distributed.get_worker() booster: Booster, data: Any, is_df: bool, columns: List[int], _: Any
config.set_config(**global_config) ) -> Any:
booster.set_param({'nthread': worker.nthreads}) with config.config_context(**global_config):
prediction = booster.inplace_predict( prediction = booster.inplace_predict(
data, data,
iteration_range=iteration_range, iteration_range=iteration_range,
predict_type=predict_type, predict_type=predict_type,
missing=missing) missing=missing
if is_df: )
if is_df and len(prediction.shape) <= 2:
if lazy_isinstance(data, 'cudf.core.dataframe', 'DataFrame'): if lazy_isinstance(data, 'cudf.core.dataframe', 'DataFrame'):
import cudf import cudf
prediction = cudf.DataFrame({'prediction': prediction}, prediction = cudf.DataFrame(
dtype=numpy.float32) prediction, columns=columns, dtype=numpy.float32
)
else: else:
# If it's from pandas, the partition is a numpy array # If it's from pandas, the partition is a numpy array
prediction = DataFrame(prediction, columns=['prediction'], prediction = DataFrame(
dtype=numpy.float32) prediction, columns=columns, dtype=numpy.float32
)
return prediction return prediction
return await _direct_predict_impl(client, data, mapped_predict) shape, meta = _infer_predict_output(
booster, data, True, predict_type=predict_type, iteration_range=iteration_range
)
return await _direct_predict_impl(
client, mapped_predict, booster, data, None, shape, meta
)
def inplace_predict( def inplace_predict( # pylint: disable=unused-argument
client: "distributed.Client", client: "distributed.Client",
model: Union[TrainReturnT, Booster], model: Union[TrainReturnT, Booster],
data: _DaskCollection, data: _DaskCollection,
@ -1281,16 +1298,17 @@ def inplace_predict(
Returns Returns
------- -------
prediction prediction :
When input data is ``dask.array.Array`` or ``DaskDMatrix``, the return value is an
array, when input data is ``dask.dataframe.DataFrame``, return value can be
``dask.dataframe.Series``, ``dask.dataframe.DataFrame`` or ``dask.array.Array``,
depending on the output shape.
''' '''
_assert_dask_support() _assert_dask_support()
client = _xgb_get_client(client) client = _xgb_get_client(client)
global_config = config.get_config() return client.sync(
return client.sync(_inplace_predict_async, client, global_config, model=model, _inplace_predict_async, global_config=config.get_config(), **locals()
data=data, )
iteration_range=iteration_range,
predict_type=predict_type,
missing=missing)
async def _async_wrap_evaluation_matrices( async def _async_wrap_evaluation_matrices(

View File

@ -24,7 +24,6 @@ if sys.platform.startswith("win"):
if tm.no_dask()['condition']: if tm.no_dask()['condition']:
pytest.skip(msg=tm.no_dask()['reason'], allow_module_level=True) pytest.skip(msg=tm.no_dask()['reason'], allow_module_level=True)
import distributed
from distributed import LocalCluster, Client from distributed import LocalCluster, Client
from distributed.utils_test import client, loop, cluster_fixture from distributed.utils_test import client, loop, cluster_fixture
import dask.dataframe as dd import dask.dataframe as dd
@ -130,25 +129,35 @@ def test_from_dask_array() -> None:
assert np.all(single_node_predt == from_arr.compute()) assert np.all(single_node_predt == from_arr.compute())
def test_dask_predict_shape_infer() -> None: def test_dask_predict_shape_infer(client: "Client") -> None:
with LocalCluster(n_workers=kWorkers) as cluster: X, y = make_classification(n_samples=1000, n_informative=5, n_classes=3)
with Client(cluster) as client:
X, y = make_classification(n_samples=1000, n_informative=5,
n_classes=3)
X_ = dd.from_array(X, chunksize=100) X_ = dd.from_array(X, chunksize=100)
y_ = dd.from_array(y, chunksize=100) y_ = dd.from_array(y, chunksize=100)
dtrain = xgb.dask.DaskDMatrix(client, data=X_, label=y_) dtrain = xgb.dask.DaskDMatrix(client, data=X_, label=y_)
model = xgb.dask.train( model = xgb.dask.train(
client, client, {"objective": "multi:softprob", "num_class": 3}, dtrain=dtrain
{"objective": "multi:softprob", "num_class": 3},
dtrain=dtrain
) )
preds = xgb.dask.predict(client, model, dtrain) preds = xgb.dask.predict(client, model, dtrain)
assert preds.shape[0] == preds.compute().shape[0] assert preds.shape[0] == preds.compute().shape[0]
assert preds.shape[1] == preds.compute().shape[1] assert preds.shape[1] == preds.compute().shape[1]
prediction = xgb.dask.predict(client, model, X_, output_margin=True)
assert isinstance(prediction, dd.DataFrame)
prediction = prediction.compute()
assert prediction.ndim == 2
assert prediction.shape[0] == kRows
assert prediction.shape[1] == 3
prediction = xgb.dask.inplace_predict(client, model, X_, predict_type="margin")
assert isinstance(prediction, dd.DataFrame)
prediction = prediction.compute()
assert prediction.ndim == 2
assert prediction.shape[0] == kRows
assert prediction.shape[1] == 3
@pytest.mark.parametrize("tree_method", ["hist", "approx"]) @pytest.mark.parametrize("tree_method", ["hist", "approx"])
def test_boost_from_prediction(tree_method: str, client: "Client") -> None: def test_boost_from_prediction(tree_method: str, client: "Client") -> None:
@ -340,7 +349,7 @@ def test_dask_classifier(model: str, client: "Client") -> None:
classifier.fit(X_d, y_d) classifier.fit(X_d, y_d)
assert classifier.n_classes_ == 10 assert classifier.n_classes_ == 10
prediction = classifier.predict(X_d) prediction = classifier.predict(X_d).compute()
assert prediction.ndim == 1 assert prediction.ndim == 1
assert prediction.shape[0] == kRows assert prediction.shape[0] == kRows
@ -541,6 +550,9 @@ async def run_dask_regressor_asyncio(scheduler_address: str) -> None:
assert list(history['validation_0'].keys())[0] == 'rmse' assert list(history['validation_0'].keys())[0] == 'rmse'
assert len(history['validation_0']['rmse']) == 2 assert len(history['validation_0']['rmse']) == 2
awaited = await client.compute(prediction)
assert awaited.shape[0] == kRows
async def run_dask_classifier_asyncio(scheduler_address: str) -> None: async def run_dask_classifier_asyncio(scheduler_address: str) -> None:
async with Client(scheduler_address, asynchronous=True) as client: async with Client(scheduler_address, asynchronous=True) as client:
@ -578,7 +590,7 @@ async def run_dask_classifier_asyncio(scheduler_address: str) -> None:
await classifier.fit(X_d, y_d) await classifier.fit(X_d, y_d)
assert classifier.n_classes_ == 10 assert classifier.n_classes_ == 10
prediction = await classifier.predict(X_d) prediction = await client.compute(await classifier.predict(X_d))
assert prediction.ndim == 1 assert prediction.ndim == 1
assert prediction.shape[0] == kRows assert prediction.shape[0] == kRows
@ -1019,6 +1031,17 @@ class TestWithDask:
run_data_initialization(xgb.dask.DaskDMatrix, xgb.dask.DaskXGBClassifier, X, y) run_data_initialization(xgb.dask.DaskDMatrix, xgb.dask.DaskXGBClassifier, X, y)
def run_shap(self, X: Any, y: Any, params: Dict[str, Any], client: "Client") -> None: def run_shap(self, X: Any, y: Any, params: Dict[str, Any], client: "Client") -> None:
rows = X.shape[0]
cols = X.shape[1]
def assert_shape(shape):
assert shape[0] == rows
if "num_class" in params.keys():
assert shape[1] == params["num_class"]
assert shape[2] == cols + 1
else:
assert shape[1] == cols + 1
X, y = da.from_array(X, chunks=(32, -1)), da.from_array(y, chunks=32) X, y = da.from_array(X, chunks=(32, -1)), da.from_array(y, chunks=32)
Xy = xgb.dask.DaskDMatrix(client, X, y) Xy = xgb.dask.DaskDMatrix(client, X, y)
booster = xgb.dask.train(client, params, Xy, num_boost_round=10)['booster'] booster = xgb.dask.train(client, params, Xy, num_boost_round=10)['booster']
@ -1027,15 +1050,17 @@ class TestWithDask:
shap = xgb.dask.predict(client, booster, test_Xy, pred_contribs=True).compute() shap = xgb.dask.predict(client, booster, test_Xy, pred_contribs=True).compute()
margin = xgb.dask.predict(client, booster, test_Xy, output_margin=True).compute() margin = xgb.dask.predict(client, booster, test_Xy, output_margin=True).compute()
assert_shape(shap.shape)
assert np.allclose(np.sum(shap, axis=len(shap.shape) - 1), margin, 1e-5, 1e-5) assert np.allclose(np.sum(shap, axis=len(shap.shape) - 1), margin, 1e-5, 1e-5)
shap = xgb.dask.predict(client, booster, X, pred_contribs=True).compute() shap = xgb.dask.predict(client, booster, X, pred_contribs=True).compute()
margin = xgb.dask.predict(client, booster, X, output_margin=True).compute() margin = xgb.dask.predict(client, booster, X, output_margin=True).compute()
assert_shape(shap.shape)
assert np.allclose(np.sum(shap, axis=len(shap.shape) - 1), margin, 1e-5, 1e-5) assert np.allclose(np.sum(shap, axis=len(shap.shape) - 1), margin, 1e-5, 1e-5)
def run_shap_cls_sklearn(self, X: Any, y: Any, client: "Client") -> None: def run_shap_cls_sklearn(self, X: Any, y: Any, client: "Client") -> None:
X, y = da.from_array(X, chunks=(32, -1)), da.from_array(y, chunks=32) X, y = da.from_array(X, chunks=(32, -1)), da.from_array(y, chunks=32)
cls = xgb.dask.DaskXGBClassifier() cls = xgb.dask.DaskXGBClassifier(n_estimators=4)
cls.client = client cls.client = client
cls.fit(X, y) cls.fit(X, y)
booster = cls.get_booster() booster = cls.get_booster()
@ -1072,6 +1097,8 @@ class TestWithDask:
params: Dict[str, Any], params: Dict[str, Any],
client: "Client" client: "Client"
) -> None: ) -> None:
rows = X.shape[0]
cols = X.shape[1]
X, y = da.from_array(X, chunks=(32, -1)), da.from_array(y, chunks=32) X, y = da.from_array(X, chunks=(32, -1)), da.from_array(y, chunks=32)
Xy = xgb.dask.DaskDMatrix(client, X, y) Xy = xgb.dask.DaskDMatrix(client, X, y)
@ -1082,6 +1109,12 @@ class TestWithDask:
shap = xgb.dask.predict( shap = xgb.dask.predict(
client, booster, test_Xy, pred_interactions=True client, booster, test_Xy, pred_interactions=True
).compute() ).compute()
assert len(shap.shape) == 3
assert shap.shape[0] == rows
assert shap.shape[1] == cols + 1
assert shap.shape[2] == cols + 1
margin = xgb.dask.predict(client, booster, test_Xy, output_margin=True).compute() margin = xgb.dask.predict(client, booster, test_Xy, output_margin=True).compute()
assert np.allclose(np.sum(shap, axis=(len(shap.shape) - 1, len(shap.shape) - 2)), assert np.allclose(np.sum(shap, axis=(len(shap.shape) - 1, len(shap.shape) - 2)),
margin, margin,