Specify shape in prediction contrib and interaction. (#6614)
This commit is contained in:
parent
8942c98054
commit
4bf23c2391
@ -95,14 +95,21 @@ For prediction, pass the ``output`` returned by ``train`` into ``xgb.dask.predic
|
||||
.. code-block:: python
|
||||
|
||||
prediction = xgb.dask.predict(client, output, dtrain)
|
||||
# Or equivalently, pass ``output['booster']``:
|
||||
prediction = xgb.dask.predict(client, output['booster'], dtrain)
|
||||
|
||||
Or equivalently, pass ``output['booster']``:
|
||||
Eliminating the construction of DaskDMatrix is also possible, this can make the
|
||||
computation a bit faster when meta information like ``base_margin`` is not needed:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
prediction = xgb.dask.predict(client, output['booster'], dtrain)
|
||||
prediction = xgb.dask.predict(client, output, X)
|
||||
# Use inplace version.
|
||||
prediction = xgb.dask.inplace_predict(client, output, X)
|
||||
|
||||
Here ``prediction`` is a dask ``Array`` object containing predictions from model.
|
||||
Here ``prediction`` is a dask ``Array`` object containing predictions from model if input
|
||||
is a ``DaskDMatrix`` or ``da.Array``. For ``dd.DataFrame``, the return value is a
|
||||
``dd.Series``.
|
||||
|
||||
Alternatively, XGBoost also implements the Scikit-Learn interface with ``DaskXGBClassifier``
|
||||
and ``DaskXGBRegressor``. See ``xgboost/demo/dask`` for more examples.
|
||||
|
||||
@ -190,7 +190,7 @@ def _check_call(ret):
|
||||
raise XGBoostError(py_str(_LIB.XGBGetLastError()))
|
||||
|
||||
|
||||
def ctypes2numpy(cptr, length, dtype):
|
||||
def ctypes2numpy(cptr, length, dtype) -> np.ndarray:
|
||||
"""Convert a ctypes pointer array to a numpy array."""
|
||||
NUMPY_TO_CTYPES_MAPPING = {
|
||||
np.float32: ctypes.c_float,
|
||||
@ -1553,7 +1553,7 @@ class Booster(object):
|
||||
ctypes.byref(preds)))
|
||||
preds = ctypes2numpy(preds, length.value, np.float32)
|
||||
if pred_leaf:
|
||||
preds = preds.astype(np.int32)
|
||||
preds = preds.astype(np.int32, copy=False)
|
||||
nrow = data.num_row()
|
||||
if preds.size != nrow and preds.size % nrow == 0:
|
||||
chunk_size = int(preds.size / nrow)
|
||||
|
||||
@ -964,22 +964,21 @@ async def _predict_async(
|
||||
pred_contribs: bool,
|
||||
approx_contribs: bool,
|
||||
pred_interactions: bool,
|
||||
validate_features: bool
|
||||
validate_features: bool,
|
||||
) -> _DaskCollection:
|
||||
if isinstance(model, Booster):
|
||||
booster = model
|
||||
elif isinstance(model, dict):
|
||||
booster = model['booster']
|
||||
booster = model["booster"]
|
||||
else:
|
||||
raise TypeError(_expect([Booster, dict], type(model)))
|
||||
if not isinstance(data, (DaskDMatrix, da.Array, dd.DataFrame)):
|
||||
raise TypeError(_expect([DaskDMatrix, da.Array, dd.DataFrame],
|
||||
type(data)))
|
||||
raise TypeError(_expect([DaskDMatrix, da.Array, dd.DataFrame], type(data)))
|
||||
|
||||
def mapped_predict(partition: Any, is_df: bool) -> Any:
|
||||
worker = distributed.get_worker()
|
||||
with config.config_context(**global_config):
|
||||
booster.set_param({'nthread': worker.nthreads})
|
||||
booster.set_param({"nthread": worker.nthreads})
|
||||
m = DMatrix(data=partition, missing=missing, nthread=worker.nthreads)
|
||||
predt = booster.predict(
|
||||
data=m,
|
||||
@ -988,15 +987,16 @@ async def _predict_async(
|
||||
pred_contribs=pred_contribs,
|
||||
approx_contribs=approx_contribs,
|
||||
pred_interactions=pred_interactions,
|
||||
validate_features=validate_features
|
||||
validate_features=validate_features,
|
||||
)
|
||||
if is_df:
|
||||
if lazy_isinstance(partition, 'cudf', 'core.dataframe.DataFrame'):
|
||||
if lazy_isinstance(partition, "cudf", "core.dataframe.DataFrame"):
|
||||
import cudf
|
||||
predt = cudf.DataFrame(predt, columns=['prediction'])
|
||||
predt = cudf.DataFrame(predt, columns=["prediction"])
|
||||
else:
|
||||
predt = DataFrame(predt, columns=['prediction'])
|
||||
predt = DataFrame(predt, columns=["prediction"])
|
||||
return predt
|
||||
|
||||
# Predict on dask collection directly.
|
||||
if isinstance(data, (da.Array, dd.DataFrame)):
|
||||
return await _direct_predict_impl(client, data, mapped_predict)
|
||||
@ -1011,15 +1011,15 @@ async def _predict_async(
|
||||
|
||||
def dispatched_predict(
|
||||
worker_id: int, list_of_orders: List[int], list_of_parts: _DataParts
|
||||
) -> List[Tuple[Tuple["dask.delayed.Delayed", int], int]]:
|
||||
'''Perform prediction on each worker.'''
|
||||
LOGGER.debug('Predicting on %d', worker_id)
|
||||
) -> List[Tuple[List[Union["dask.delayed.Delayed", int]], int]]:
|
||||
"""Perform prediction on each worker."""
|
||||
LOGGER.debug("Predicting on %d", worker_id)
|
||||
with config.config_context(**global_config):
|
||||
worker = distributed.get_worker()
|
||||
list_of_parts = _get_worker_parts_ordered(meta_names, list_of_parts)
|
||||
predictions = []
|
||||
|
||||
booster.set_param({'nthread': worker.nthreads})
|
||||
booster.set_param({"nthread": worker.nthreads})
|
||||
for i, parts in enumerate(list_of_parts):
|
||||
(data, _, _, base_margin, _, _, _) = parts
|
||||
order = list_of_orders[i]
|
||||
@ -1029,7 +1029,7 @@ async def _predict_async(
|
||||
feature_names=feature_names,
|
||||
feature_types=feature_types,
|
||||
missing=missing,
|
||||
nthread=worker.nthreads
|
||||
nthread=worker.nthreads,
|
||||
)
|
||||
predt = booster.predict(
|
||||
data=local_part,
|
||||
@ -1038,10 +1038,42 @@ async def _predict_async(
|
||||
pred_contribs=pred_contribs,
|
||||
approx_contribs=approx_contribs,
|
||||
pred_interactions=pred_interactions,
|
||||
validate_features=validate_features
|
||||
validate_features=validate_features,
|
||||
)
|
||||
if pred_contribs and predt.size != local_part.num_row():
|
||||
assert len(predt.shape) in (2, 3)
|
||||
if len(predt.shape) == 2:
|
||||
groups = 1
|
||||
columns = predt.shape[1]
|
||||
else:
|
||||
groups = predt.shape[1]
|
||||
columns = predt.shape[2]
|
||||
# pylint: disable=no-member
|
||||
ret = (
|
||||
[dask.delayed(predt), groups, columns],
|
||||
order,
|
||||
)
|
||||
elif pred_interactions and predt.size != local_part.num_row():
|
||||
assert len(predt.shape) in (3, 4)
|
||||
if len(predt.shape) == 3:
|
||||
groups = 1
|
||||
columns = predt.shape[1]
|
||||
else:
|
||||
groups = predt.shape[1]
|
||||
columns = predt.shape[2]
|
||||
# pylint: disable=no-member
|
||||
ret = (
|
||||
[dask.delayed(predt), groups, columns],
|
||||
order,
|
||||
)
|
||||
else:
|
||||
assert len(predt.shape) == 1 or len(predt.shape) == 2
|
||||
columns = 1 if len(predt.shape) == 1 else predt.shape[1]
|
||||
ret = ((dask.delayed(predt), columns), order) # pylint: disable=no-member
|
||||
# pylint: disable=no-member
|
||||
ret = (
|
||||
[dask.delayed(predt), columns],
|
||||
order,
|
||||
)
|
||||
predictions.append(ret)
|
||||
|
||||
return predictions
|
||||
@ -1049,8 +1081,8 @@ async def _predict_async(
|
||||
def dispatched_get_shape(
|
||||
worker_id: int, list_of_orders: List[int], list_of_parts: _DataParts
|
||||
) -> List[Tuple[int, int]]:
|
||||
'''Get shape of data in each worker.'''
|
||||
LOGGER.debug('Get shape on %d', worker_id)
|
||||
"""Get shape of data in each worker."""
|
||||
LOGGER.debug("Get shape on %d", worker_id)
|
||||
list_of_parts = _get_worker_parts_ordered(meta_names, list_of_parts)
|
||||
shapes = []
|
||||
for i, parts in enumerate(list_of_parts):
|
||||
@ -1061,7 +1093,7 @@ async def _predict_async(
|
||||
async def map_function(
|
||||
func: Callable[[int, List[int], _DataParts], Any]
|
||||
) -> List[Any]:
|
||||
'''Run function for each part of the data.'''
|
||||
"""Run function for each part of the data."""
|
||||
futures = []
|
||||
workers_address = list(worker_map.keys())
|
||||
for wid, worker_addr in enumerate(workers_address):
|
||||
@ -1069,10 +1101,14 @@ async def _predict_async(
|
||||
list_of_parts = worker_map[worker_addr]
|
||||
list_of_orders = [partition_order[part.key] for part in list_of_parts]
|
||||
|
||||
f = client.submit(func, worker_id=wid,
|
||||
f = client.submit(
|
||||
func,
|
||||
worker_id=wid,
|
||||
list_of_orders=list_of_orders,
|
||||
list_of_parts=list_of_parts,
|
||||
pure=True, workers=[worker_addr])
|
||||
pure=True,
|
||||
workers=[worker_addr],
|
||||
)
|
||||
assert isinstance(f, distributed.client.Future)
|
||||
futures.append(f)
|
||||
# Get delayed objects
|
||||
@ -1091,10 +1127,24 @@ async def _predict_async(
|
||||
# See https://docs.dask.org/en/latest/array-creation.html
|
||||
arrays = []
|
||||
for i, shape in enumerate(shapes):
|
||||
arrays.append(da.from_delayed(
|
||||
results[i][0], shape=(shape[0],)
|
||||
if results[i][1] == 1 else (shape[0], results[i][1]),
|
||||
dtype=numpy.float32))
|
||||
if pred_contribs:
|
||||
out_shape = (
|
||||
(shape[0], results[i][2])
|
||||
if results[i][1] == 1
|
||||
else (shape[0], results[i][1], results[i][2])
|
||||
)
|
||||
elif pred_interactions:
|
||||
out_shape = (
|
||||
(shape[0], results[i][2], results[i][2])
|
||||
if results[i][1] == 1
|
||||
else (shape[0], results[i][1], results[i][2])
|
||||
)
|
||||
else:
|
||||
out_shape = (shape[0],) if results[i][1] == 1 else (shape[0], results[i][1])
|
||||
arrays.append(
|
||||
da.from_delayed(results[i][0], shape=out_shape, dtype=numpy.float32)
|
||||
)
|
||||
|
||||
predictions = await da.concatenate(arrays, axis=0)
|
||||
return predictions
|
||||
|
||||
@ -1115,7 +1165,9 @@ def predict(
|
||||
|
||||
.. note::
|
||||
|
||||
Only default prediction mode is supported right now.
|
||||
Using ``inplace_predict `` might be faster when meta information like
|
||||
``base_margin`` is not needed. For other parameters, please see
|
||||
``Booster.predict``.
|
||||
|
||||
.. versionadded:: 1.0.0
|
||||
|
||||
@ -1136,6 +1188,9 @@ def predict(
|
||||
Returns
|
||||
-------
|
||||
prediction: dask.array.Array/dask.dataframe.Series
|
||||
When input data is ``dask.array.Array`` or ``DaskDMatrix``, the return value is an
|
||||
array, when input data is ``dask.dataframe.DataFrame``, return value is
|
||||
``dask.dataframe.Series``
|
||||
|
||||
'''
|
||||
_assert_dask_support()
|
||||
|
||||
@ -24,6 +24,7 @@ if sys.platform.startswith("win"):
|
||||
if tm.no_dask()['condition']:
|
||||
pytest.skip(msg=tm.no_dask()['reason'], allow_module_level=True)
|
||||
|
||||
import distributed
|
||||
from distributed import LocalCluster, Client
|
||||
from distributed.utils_test import client, loop, cluster_fixture
|
||||
import dask.dataframe as dd
|
||||
@ -51,11 +52,12 @@ def generate_array(
|
||||
with_weights: bool = False
|
||||
) -> Tuple[xgb.dask._DaskCollection, xgb.dask._DaskCollection,
|
||||
Optional[xgb.dask._DaskCollection]]:
|
||||
partition_size = 20
|
||||
X = da.random.random((kRows, kCols), partition_size)
|
||||
y = da.random.random(kRows, partition_size)
|
||||
chunk_size = 20
|
||||
rng = da.random.RandomState(1994)
|
||||
X = rng.random_sample((kRows, kCols), chunks=(chunk_size, -1))
|
||||
y = rng.random_sample(kRows, chunks=chunk_size)
|
||||
if with_weights:
|
||||
w = da.random.random(kRows, partition_size)
|
||||
w = rng.random_sample(kRows, chunks=chunk_size)
|
||||
return X, y, w
|
||||
return X, y, None
|
||||
|
||||
@ -175,9 +177,7 @@ def test_boost_from_prediction(tree_method: str, client: "Client") -> None:
|
||||
assert np.all(predictions_1.compute() == predictions_2.compute())
|
||||
|
||||
|
||||
def test_dask_missing_value_reg() -> None:
|
||||
with LocalCluster(n_workers=kWorkers) as cluster:
|
||||
with Client(cluster) as client:
|
||||
def test_dask_missing_value_reg(client: "Client") -> None:
|
||||
X_0 = np.ones((20 // 2, kCols))
|
||||
X_1 = np.zeros((20 // 2, kCols))
|
||||
X = np.concatenate([X_0, X_1], axis=0)
|
||||
@ -199,9 +199,7 @@ def test_dask_missing_value_reg() -> None:
|
||||
np.testing.assert_allclose(np_predt, dd_predt)
|
||||
|
||||
|
||||
def test_dask_missing_value_cls() -> None:
|
||||
with LocalCluster() as cluster:
|
||||
with Client(cluster) as client:
|
||||
def test_dask_missing_value_cls(client: "Client") -> None:
|
||||
X_0 = np.ones((kRows // 2, kCols))
|
||||
X_1 = np.zeros((kRows // 2, kCols))
|
||||
X = np.concatenate([X_0, X_1], axis=0)
|
||||
@ -998,8 +996,7 @@ class TestWithDask:
|
||||
assert cnt - n_workers == n_partitions
|
||||
|
||||
def run_shap(self, X: Any, y: Any, params: Dict[str, Any], client: "Client") -> None:
|
||||
X, y = da.from_array(X), da.from_array(y)
|
||||
|
||||
X, y = da.from_array(X, chunks=(32, -1)), da.from_array(y, chunks=32)
|
||||
Xy = xgb.dask.DaskDMatrix(client, X, y)
|
||||
booster = xgb.dask.train(client, params, Xy, num_boost_round=10)['booster']
|
||||
|
||||
@ -1009,8 +1006,12 @@ class TestWithDask:
|
||||
margin = xgb.dask.predict(client, booster, test_Xy, output_margin=True).compute()
|
||||
assert np.allclose(np.sum(shap, axis=len(shap.shape) - 1), margin, 1e-5, 1e-5)
|
||||
|
||||
shap = xgb.dask.predict(client, booster, X, pred_contribs=True).compute()
|
||||
margin = xgb.dask.predict(client, booster, X, output_margin=True).compute()
|
||||
assert np.allclose(np.sum(shap, axis=len(shap.shape) - 1), margin, 1e-5, 1e-5)
|
||||
|
||||
def run_shap_cls_sklearn(self, X: Any, y: Any, client: "Client") -> None:
|
||||
X, y = da.from_array(X), da.from_array(y)
|
||||
X, y = da.from_array(X, chunks=(32, -1)), da.from_array(y, chunks=32)
|
||||
cls = xgb.dask.DaskXGBClassifier()
|
||||
cls.client = client
|
||||
cls.fit(X, y)
|
||||
@ -1022,6 +1023,10 @@ class TestWithDask:
|
||||
margin = xgb.dask.predict(client, booster, test_Xy, output_margin=True).compute()
|
||||
assert np.allclose(np.sum(shap, axis=len(shap.shape) - 1), margin, 1e-5, 1e-5)
|
||||
|
||||
shap = xgb.dask.predict(client, booster, X, pred_contribs=True).compute()
|
||||
margin = xgb.dask.predict(client, booster, X, output_margin=True).compute()
|
||||
assert np.allclose(np.sum(shap, axis=len(shap.shape) - 1), margin, 1e-5, 1e-5)
|
||||
|
||||
def test_shap(self, client: "Client") -> None:
|
||||
from sklearn.datasets import load_boston, load_digits
|
||||
X, y = load_boston(return_X_y=True)
|
||||
@ -1031,6 +1036,7 @@ class TestWithDask:
|
||||
X, y = load_digits(return_X_y=True)
|
||||
params = {'objective': 'multi:softmax', 'num_class': 10}
|
||||
self.run_shap(X, y, params, client)
|
||||
|
||||
params = {'objective': 'multi:softprob', 'num_class': 10}
|
||||
self.run_shap(X, y, params, client)
|
||||
|
||||
@ -1043,7 +1049,7 @@ class TestWithDask:
|
||||
params: Dict[str, Any],
|
||||
client: "Client"
|
||||
) -> None:
|
||||
X, y = da.from_array(X), da.from_array(y)
|
||||
X, y = da.from_array(X, chunks=(32, -1)), da.from_array(y, chunks=32)
|
||||
|
||||
Xy = xgb.dask.DaskDMatrix(client, X, y)
|
||||
booster = xgb.dask.train(client, params, Xy, num_boost_round=10)['booster']
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user