Specify shape in prediction contrib and interaction. (#6614)

This commit is contained in:
Jiaming Yuan 2021-01-26 02:08:22 +08:00 committed by GitHub
parent 8942c98054
commit 4bf23c2391
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 155 additions and 87 deletions

View File

@ -95,14 +95,21 @@ For prediction, pass the ``output`` returned by ``train`` into ``xgb.dask.predic
.. code-block:: python
prediction = xgb.dask.predict(client, output, dtrain)
# Or equivalently, pass ``output['booster']``:
prediction = xgb.dask.predict(client, output['booster'], dtrain)
Or equivalently, pass ``output['booster']``:
Eliminating the construction of DaskDMatrix is also possible, this can make the
computation a bit faster when meta information like ``base_margin`` is not needed:
.. code-block:: python
prediction = xgb.dask.predict(client, output['booster'], dtrain)
prediction = xgb.dask.predict(client, output, X)
# Use inplace version.
prediction = xgb.dask.inplace_predict(client, output, X)
Here ``prediction`` is a dask ``Array`` object containing predictions from model.
Here ``prediction`` is a dask ``Array`` object containing predictions from model if input
is a ``DaskDMatrix`` or ``da.Array``. For ``dd.DataFrame``, the return value is a
``dd.Series``.
Alternatively, XGBoost also implements the Scikit-Learn interface with ``DaskXGBClassifier``
and ``DaskXGBRegressor``. See ``xgboost/demo/dask`` for more examples.

View File

@ -190,7 +190,7 @@ def _check_call(ret):
raise XGBoostError(py_str(_LIB.XGBGetLastError()))
def ctypes2numpy(cptr, length, dtype):
def ctypes2numpy(cptr, length, dtype) -> np.ndarray:
"""Convert a ctypes pointer array to a numpy array."""
NUMPY_TO_CTYPES_MAPPING = {
np.float32: ctypes.c_float,
@ -1553,7 +1553,7 @@ class Booster(object):
ctypes.byref(preds)))
preds = ctypes2numpy(preds, length.value, np.float32)
if pred_leaf:
preds = preds.astype(np.int32)
preds = preds.astype(np.int32, copy=False)
nrow = data.num_row()
if preds.size != nrow and preds.size % nrow == 0:
chunk_size = int(preds.size / nrow)

View File

@ -964,22 +964,21 @@ async def _predict_async(
pred_contribs: bool,
approx_contribs: bool,
pred_interactions: bool,
validate_features: bool
validate_features: bool,
) -> _DaskCollection:
if isinstance(model, Booster):
booster = model
elif isinstance(model, dict):
booster = model['booster']
booster = model["booster"]
else:
raise TypeError(_expect([Booster, dict], type(model)))
if not isinstance(data, (DaskDMatrix, da.Array, dd.DataFrame)):
raise TypeError(_expect([DaskDMatrix, da.Array, dd.DataFrame],
type(data)))
raise TypeError(_expect([DaskDMatrix, da.Array, dd.DataFrame], type(data)))
def mapped_predict(partition: Any, is_df: bool) -> Any:
worker = distributed.get_worker()
with config.config_context(**global_config):
booster.set_param({'nthread': worker.nthreads})
booster.set_param({"nthread": worker.nthreads})
m = DMatrix(data=partition, missing=missing, nthread=worker.nthreads)
predt = booster.predict(
data=m,
@ -988,15 +987,16 @@ async def _predict_async(
pred_contribs=pred_contribs,
approx_contribs=approx_contribs,
pred_interactions=pred_interactions,
validate_features=validate_features
validate_features=validate_features,
)
if is_df:
if lazy_isinstance(partition, 'cudf', 'core.dataframe.DataFrame'):
if lazy_isinstance(partition, "cudf", "core.dataframe.DataFrame"):
import cudf
predt = cudf.DataFrame(predt, columns=['prediction'])
predt = cudf.DataFrame(predt, columns=["prediction"])
else:
predt = DataFrame(predt, columns=['prediction'])
predt = DataFrame(predt, columns=["prediction"])
return predt
# Predict on dask collection directly.
if isinstance(data, (da.Array, dd.DataFrame)):
return await _direct_predict_impl(client, data, mapped_predict)
@ -1011,15 +1011,15 @@ async def _predict_async(
def dispatched_predict(
worker_id: int, list_of_orders: List[int], list_of_parts: _DataParts
) -> List[Tuple[Tuple["dask.delayed.Delayed", int], int]]:
'''Perform prediction on each worker.'''
LOGGER.debug('Predicting on %d', worker_id)
) -> List[Tuple[List[Union["dask.delayed.Delayed", int]], int]]:
"""Perform prediction on each worker."""
LOGGER.debug("Predicting on %d", worker_id)
with config.config_context(**global_config):
worker = distributed.get_worker()
list_of_parts = _get_worker_parts_ordered(meta_names, list_of_parts)
predictions = []
booster.set_param({'nthread': worker.nthreads})
booster.set_param({"nthread": worker.nthreads})
for i, parts in enumerate(list_of_parts):
(data, _, _, base_margin, _, _, _) = parts
order = list_of_orders[i]
@ -1029,7 +1029,7 @@ async def _predict_async(
feature_names=feature_names,
feature_types=feature_types,
missing=missing,
nthread=worker.nthreads
nthread=worker.nthreads,
)
predt = booster.predict(
data=local_part,
@ -1038,10 +1038,42 @@ async def _predict_async(
pred_contribs=pred_contribs,
approx_contribs=approx_contribs,
pred_interactions=pred_interactions,
validate_features=validate_features
validate_features=validate_features,
)
if pred_contribs and predt.size != local_part.num_row():
assert len(predt.shape) in (2, 3)
if len(predt.shape) == 2:
groups = 1
columns = predt.shape[1]
else:
groups = predt.shape[1]
columns = predt.shape[2]
# pylint: disable=no-member
ret = (
[dask.delayed(predt), groups, columns],
order,
)
elif pred_interactions and predt.size != local_part.num_row():
assert len(predt.shape) in (3, 4)
if len(predt.shape) == 3:
groups = 1
columns = predt.shape[1]
else:
groups = predt.shape[1]
columns = predt.shape[2]
# pylint: disable=no-member
ret = (
[dask.delayed(predt), groups, columns],
order,
)
else:
assert len(predt.shape) == 1 or len(predt.shape) == 2
columns = 1 if len(predt.shape) == 1 else predt.shape[1]
ret = ((dask.delayed(predt), columns), order) # pylint: disable=no-member
# pylint: disable=no-member
ret = (
[dask.delayed(predt), columns],
order,
)
predictions.append(ret)
return predictions
@ -1049,8 +1081,8 @@ async def _predict_async(
def dispatched_get_shape(
worker_id: int, list_of_orders: List[int], list_of_parts: _DataParts
) -> List[Tuple[int, int]]:
'''Get shape of data in each worker.'''
LOGGER.debug('Get shape on %d', worker_id)
"""Get shape of data in each worker."""
LOGGER.debug("Get shape on %d", worker_id)
list_of_parts = _get_worker_parts_ordered(meta_names, list_of_parts)
shapes = []
for i, parts in enumerate(list_of_parts):
@ -1061,7 +1093,7 @@ async def _predict_async(
async def map_function(
func: Callable[[int, List[int], _DataParts], Any]
) -> List[Any]:
'''Run function for each part of the data.'''
"""Run function for each part of the data."""
futures = []
workers_address = list(worker_map.keys())
for wid, worker_addr in enumerate(workers_address):
@ -1069,10 +1101,14 @@ async def _predict_async(
list_of_parts = worker_map[worker_addr]
list_of_orders = [partition_order[part.key] for part in list_of_parts]
f = client.submit(func, worker_id=wid,
f = client.submit(
func,
worker_id=wid,
list_of_orders=list_of_orders,
list_of_parts=list_of_parts,
pure=True, workers=[worker_addr])
pure=True,
workers=[worker_addr],
)
assert isinstance(f, distributed.client.Future)
futures.append(f)
# Get delayed objects
@ -1091,10 +1127,24 @@ async def _predict_async(
# See https://docs.dask.org/en/latest/array-creation.html
arrays = []
for i, shape in enumerate(shapes):
arrays.append(da.from_delayed(
results[i][0], shape=(shape[0],)
if results[i][1] == 1 else (shape[0], results[i][1]),
dtype=numpy.float32))
if pred_contribs:
out_shape = (
(shape[0], results[i][2])
if results[i][1] == 1
else (shape[0], results[i][1], results[i][2])
)
elif pred_interactions:
out_shape = (
(shape[0], results[i][2], results[i][2])
if results[i][1] == 1
else (shape[0], results[i][1], results[i][2])
)
else:
out_shape = (shape[0],) if results[i][1] == 1 else (shape[0], results[i][1])
arrays.append(
da.from_delayed(results[i][0], shape=out_shape, dtype=numpy.float32)
)
predictions = await da.concatenate(arrays, axis=0)
return predictions
@ -1115,7 +1165,9 @@ def predict(
.. note::
Only default prediction mode is supported right now.
Using ``inplace_predict `` might be faster when meta information like
``base_margin`` is not needed. For other parameters, please see
``Booster.predict``.
.. versionadded:: 1.0.0
@ -1136,6 +1188,9 @@ def predict(
Returns
-------
prediction: dask.array.Array/dask.dataframe.Series
When input data is ``dask.array.Array`` or ``DaskDMatrix``, the return value is an
array, when input data is ``dask.dataframe.DataFrame``, return value is
``dask.dataframe.Series``
'''
_assert_dask_support()

View File

@ -24,6 +24,7 @@ if sys.platform.startswith("win"):
if tm.no_dask()['condition']:
pytest.skip(msg=tm.no_dask()['reason'], allow_module_level=True)
import distributed
from distributed import LocalCluster, Client
from distributed.utils_test import client, loop, cluster_fixture
import dask.dataframe as dd
@ -51,11 +52,12 @@ def generate_array(
with_weights: bool = False
) -> Tuple[xgb.dask._DaskCollection, xgb.dask._DaskCollection,
Optional[xgb.dask._DaskCollection]]:
partition_size = 20
X = da.random.random((kRows, kCols), partition_size)
y = da.random.random(kRows, partition_size)
chunk_size = 20
rng = da.random.RandomState(1994)
X = rng.random_sample((kRows, kCols), chunks=(chunk_size, -1))
y = rng.random_sample(kRows, chunks=chunk_size)
if with_weights:
w = da.random.random(kRows, partition_size)
w = rng.random_sample(kRows, chunks=chunk_size)
return X, y, w
return X, y, None
@ -175,9 +177,7 @@ def test_boost_from_prediction(tree_method: str, client: "Client") -> None:
assert np.all(predictions_1.compute() == predictions_2.compute())
def test_dask_missing_value_reg() -> None:
with LocalCluster(n_workers=kWorkers) as cluster:
with Client(cluster) as client:
def test_dask_missing_value_reg(client: "Client") -> None:
X_0 = np.ones((20 // 2, kCols))
X_1 = np.zeros((20 // 2, kCols))
X = np.concatenate([X_0, X_1], axis=0)
@ -199,9 +199,7 @@ def test_dask_missing_value_reg() -> None:
np.testing.assert_allclose(np_predt, dd_predt)
def test_dask_missing_value_cls() -> None:
with LocalCluster() as cluster:
with Client(cluster) as client:
def test_dask_missing_value_cls(client: "Client") -> None:
X_0 = np.ones((kRows // 2, kCols))
X_1 = np.zeros((kRows // 2, kCols))
X = np.concatenate([X_0, X_1], axis=0)
@ -998,8 +996,7 @@ class TestWithDask:
assert cnt - n_workers == n_partitions
def run_shap(self, X: Any, y: Any, params: Dict[str, Any], client: "Client") -> None:
X, y = da.from_array(X), da.from_array(y)
X, y = da.from_array(X, chunks=(32, -1)), da.from_array(y, chunks=32)
Xy = xgb.dask.DaskDMatrix(client, X, y)
booster = xgb.dask.train(client, params, Xy, num_boost_round=10)['booster']
@ -1009,8 +1006,12 @@ class TestWithDask:
margin = xgb.dask.predict(client, booster, test_Xy, output_margin=True).compute()
assert np.allclose(np.sum(shap, axis=len(shap.shape) - 1), margin, 1e-5, 1e-5)
shap = xgb.dask.predict(client, booster, X, pred_contribs=True).compute()
margin = xgb.dask.predict(client, booster, X, output_margin=True).compute()
assert np.allclose(np.sum(shap, axis=len(shap.shape) - 1), margin, 1e-5, 1e-5)
def run_shap_cls_sklearn(self, X: Any, y: Any, client: "Client") -> None:
X, y = da.from_array(X), da.from_array(y)
X, y = da.from_array(X, chunks=(32, -1)), da.from_array(y, chunks=32)
cls = xgb.dask.DaskXGBClassifier()
cls.client = client
cls.fit(X, y)
@ -1022,6 +1023,10 @@ class TestWithDask:
margin = xgb.dask.predict(client, booster, test_Xy, output_margin=True).compute()
assert np.allclose(np.sum(shap, axis=len(shap.shape) - 1), margin, 1e-5, 1e-5)
shap = xgb.dask.predict(client, booster, X, pred_contribs=True).compute()
margin = xgb.dask.predict(client, booster, X, output_margin=True).compute()
assert np.allclose(np.sum(shap, axis=len(shap.shape) - 1), margin, 1e-5, 1e-5)
def test_shap(self, client: "Client") -> None:
from sklearn.datasets import load_boston, load_digits
X, y = load_boston(return_X_y=True)
@ -1031,6 +1036,7 @@ class TestWithDask:
X, y = load_digits(return_X_y=True)
params = {'objective': 'multi:softmax', 'num_class': 10}
self.run_shap(X, y, params, client)
params = {'objective': 'multi:softprob', 'num_class': 10}
self.run_shap(X, y, params, client)
@ -1043,7 +1049,7 @@ class TestWithDask:
params: Dict[str, Any],
client: "Client"
) -> None:
X, y = da.from_array(X), da.from_array(y)
X, y = da.from_array(X, chunks=(32, -1)), da.from_array(y, chunks=32)
Xy = xgb.dask.DaskDMatrix(client, X, y)
booster = xgb.dask.train(client, params, Xy, num_boost_round=10)['booster']