Specify shape in prediction contrib and interaction. (#6614)
This commit is contained in:
parent
8942c98054
commit
4bf23c2391
@ -95,14 +95,21 @@ For prediction, pass the ``output`` returned by ``train`` into ``xgb.dask.predic
|
|||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
prediction = xgb.dask.predict(client, output, dtrain)
|
prediction = xgb.dask.predict(client, output, dtrain)
|
||||||
|
# Or equivalently, pass ``output['booster']``:
|
||||||
|
prediction = xgb.dask.predict(client, output['booster'], dtrain)
|
||||||
|
|
||||||
Or equivalently, pass ``output['booster']``:
|
Eliminating the construction of DaskDMatrix is also possible, this can make the
|
||||||
|
computation a bit faster when meta information like ``base_margin`` is not needed:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
prediction = xgb.dask.predict(client, output['booster'], dtrain)
|
prediction = xgb.dask.predict(client, output, X)
|
||||||
|
# Use inplace version.
|
||||||
|
prediction = xgb.dask.inplace_predict(client, output, X)
|
||||||
|
|
||||||
Here ``prediction`` is a dask ``Array`` object containing predictions from model.
|
Here ``prediction`` is a dask ``Array`` object containing predictions from model if input
|
||||||
|
is a ``DaskDMatrix`` or ``da.Array``. For ``dd.DataFrame``, the return value is a
|
||||||
|
``dd.Series``.
|
||||||
|
|
||||||
Alternatively, XGBoost also implements the Scikit-Learn interface with ``DaskXGBClassifier``
|
Alternatively, XGBoost also implements the Scikit-Learn interface with ``DaskXGBClassifier``
|
||||||
and ``DaskXGBRegressor``. See ``xgboost/demo/dask`` for more examples.
|
and ``DaskXGBRegressor``. See ``xgboost/demo/dask`` for more examples.
|
||||||
|
|||||||
@ -190,7 +190,7 @@ def _check_call(ret):
|
|||||||
raise XGBoostError(py_str(_LIB.XGBGetLastError()))
|
raise XGBoostError(py_str(_LIB.XGBGetLastError()))
|
||||||
|
|
||||||
|
|
||||||
def ctypes2numpy(cptr, length, dtype):
|
def ctypes2numpy(cptr, length, dtype) -> np.ndarray:
|
||||||
"""Convert a ctypes pointer array to a numpy array."""
|
"""Convert a ctypes pointer array to a numpy array."""
|
||||||
NUMPY_TO_CTYPES_MAPPING = {
|
NUMPY_TO_CTYPES_MAPPING = {
|
||||||
np.float32: ctypes.c_float,
|
np.float32: ctypes.c_float,
|
||||||
@ -1553,7 +1553,7 @@ class Booster(object):
|
|||||||
ctypes.byref(preds)))
|
ctypes.byref(preds)))
|
||||||
preds = ctypes2numpy(preds, length.value, np.float32)
|
preds = ctypes2numpy(preds, length.value, np.float32)
|
||||||
if pred_leaf:
|
if pred_leaf:
|
||||||
preds = preds.astype(np.int32)
|
preds = preds.astype(np.int32, copy=False)
|
||||||
nrow = data.num_row()
|
nrow = data.num_row()
|
||||||
if preds.size != nrow and preds.size % nrow == 0:
|
if preds.size != nrow and preds.size % nrow == 0:
|
||||||
chunk_size = int(preds.size / nrow)
|
chunk_size = int(preds.size / nrow)
|
||||||
|
|||||||
@ -964,22 +964,21 @@ async def _predict_async(
|
|||||||
pred_contribs: bool,
|
pred_contribs: bool,
|
||||||
approx_contribs: bool,
|
approx_contribs: bool,
|
||||||
pred_interactions: bool,
|
pred_interactions: bool,
|
||||||
validate_features: bool
|
validate_features: bool,
|
||||||
) -> _DaskCollection:
|
) -> _DaskCollection:
|
||||||
if isinstance(model, Booster):
|
if isinstance(model, Booster):
|
||||||
booster = model
|
booster = model
|
||||||
elif isinstance(model, dict):
|
elif isinstance(model, dict):
|
||||||
booster = model['booster']
|
booster = model["booster"]
|
||||||
else:
|
else:
|
||||||
raise TypeError(_expect([Booster, dict], type(model)))
|
raise TypeError(_expect([Booster, dict], type(model)))
|
||||||
if not isinstance(data, (DaskDMatrix, da.Array, dd.DataFrame)):
|
if not isinstance(data, (DaskDMatrix, da.Array, dd.DataFrame)):
|
||||||
raise TypeError(_expect([DaskDMatrix, da.Array, dd.DataFrame],
|
raise TypeError(_expect([DaskDMatrix, da.Array, dd.DataFrame], type(data)))
|
||||||
type(data)))
|
|
||||||
|
|
||||||
def mapped_predict(partition: Any, is_df: bool) -> Any:
|
def mapped_predict(partition: Any, is_df: bool) -> Any:
|
||||||
worker = distributed.get_worker()
|
worker = distributed.get_worker()
|
||||||
with config.config_context(**global_config):
|
with config.config_context(**global_config):
|
||||||
booster.set_param({'nthread': worker.nthreads})
|
booster.set_param({"nthread": worker.nthreads})
|
||||||
m = DMatrix(data=partition, missing=missing, nthread=worker.nthreads)
|
m = DMatrix(data=partition, missing=missing, nthread=worker.nthreads)
|
||||||
predt = booster.predict(
|
predt = booster.predict(
|
||||||
data=m,
|
data=m,
|
||||||
@ -988,15 +987,16 @@ async def _predict_async(
|
|||||||
pred_contribs=pred_contribs,
|
pred_contribs=pred_contribs,
|
||||||
approx_contribs=approx_contribs,
|
approx_contribs=approx_contribs,
|
||||||
pred_interactions=pred_interactions,
|
pred_interactions=pred_interactions,
|
||||||
validate_features=validate_features
|
validate_features=validate_features,
|
||||||
)
|
)
|
||||||
if is_df:
|
if is_df:
|
||||||
if lazy_isinstance(partition, 'cudf', 'core.dataframe.DataFrame'):
|
if lazy_isinstance(partition, "cudf", "core.dataframe.DataFrame"):
|
||||||
import cudf
|
import cudf
|
||||||
predt = cudf.DataFrame(predt, columns=['prediction'])
|
predt = cudf.DataFrame(predt, columns=["prediction"])
|
||||||
else:
|
else:
|
||||||
predt = DataFrame(predt, columns=['prediction'])
|
predt = DataFrame(predt, columns=["prediction"])
|
||||||
return predt
|
return predt
|
||||||
|
|
||||||
# Predict on dask collection directly.
|
# Predict on dask collection directly.
|
||||||
if isinstance(data, (da.Array, dd.DataFrame)):
|
if isinstance(data, (da.Array, dd.DataFrame)):
|
||||||
return await _direct_predict_impl(client, data, mapped_predict)
|
return await _direct_predict_impl(client, data, mapped_predict)
|
||||||
@ -1011,15 +1011,15 @@ async def _predict_async(
|
|||||||
|
|
||||||
def dispatched_predict(
|
def dispatched_predict(
|
||||||
worker_id: int, list_of_orders: List[int], list_of_parts: _DataParts
|
worker_id: int, list_of_orders: List[int], list_of_parts: _DataParts
|
||||||
) -> List[Tuple[Tuple["dask.delayed.Delayed", int], int]]:
|
) -> List[Tuple[List[Union["dask.delayed.Delayed", int]], int]]:
|
||||||
'''Perform prediction on each worker.'''
|
"""Perform prediction on each worker."""
|
||||||
LOGGER.debug('Predicting on %d', worker_id)
|
LOGGER.debug("Predicting on %d", worker_id)
|
||||||
with config.config_context(**global_config):
|
with config.config_context(**global_config):
|
||||||
worker = distributed.get_worker()
|
worker = distributed.get_worker()
|
||||||
list_of_parts = _get_worker_parts_ordered(meta_names, list_of_parts)
|
list_of_parts = _get_worker_parts_ordered(meta_names, list_of_parts)
|
||||||
predictions = []
|
predictions = []
|
||||||
|
|
||||||
booster.set_param({'nthread': worker.nthreads})
|
booster.set_param({"nthread": worker.nthreads})
|
||||||
for i, parts in enumerate(list_of_parts):
|
for i, parts in enumerate(list_of_parts):
|
||||||
(data, _, _, base_margin, _, _, _) = parts
|
(data, _, _, base_margin, _, _, _) = parts
|
||||||
order = list_of_orders[i]
|
order = list_of_orders[i]
|
||||||
@ -1029,7 +1029,7 @@ async def _predict_async(
|
|||||||
feature_names=feature_names,
|
feature_names=feature_names,
|
||||||
feature_types=feature_types,
|
feature_types=feature_types,
|
||||||
missing=missing,
|
missing=missing,
|
||||||
nthread=worker.nthreads
|
nthread=worker.nthreads,
|
||||||
)
|
)
|
||||||
predt = booster.predict(
|
predt = booster.predict(
|
||||||
data=local_part,
|
data=local_part,
|
||||||
@ -1038,10 +1038,42 @@ async def _predict_async(
|
|||||||
pred_contribs=pred_contribs,
|
pred_contribs=pred_contribs,
|
||||||
approx_contribs=approx_contribs,
|
approx_contribs=approx_contribs,
|
||||||
pred_interactions=pred_interactions,
|
pred_interactions=pred_interactions,
|
||||||
validate_features=validate_features
|
validate_features=validate_features,
|
||||||
)
|
)
|
||||||
|
if pred_contribs and predt.size != local_part.num_row():
|
||||||
|
assert len(predt.shape) in (2, 3)
|
||||||
|
if len(predt.shape) == 2:
|
||||||
|
groups = 1
|
||||||
|
columns = predt.shape[1]
|
||||||
|
else:
|
||||||
|
groups = predt.shape[1]
|
||||||
|
columns = predt.shape[2]
|
||||||
|
# pylint: disable=no-member
|
||||||
|
ret = (
|
||||||
|
[dask.delayed(predt), groups, columns],
|
||||||
|
order,
|
||||||
|
)
|
||||||
|
elif pred_interactions and predt.size != local_part.num_row():
|
||||||
|
assert len(predt.shape) in (3, 4)
|
||||||
|
if len(predt.shape) == 3:
|
||||||
|
groups = 1
|
||||||
|
columns = predt.shape[1]
|
||||||
|
else:
|
||||||
|
groups = predt.shape[1]
|
||||||
|
columns = predt.shape[2]
|
||||||
|
# pylint: disable=no-member
|
||||||
|
ret = (
|
||||||
|
[dask.delayed(predt), groups, columns],
|
||||||
|
order,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert len(predt.shape) == 1 or len(predt.shape) == 2
|
||||||
columns = 1 if len(predt.shape) == 1 else predt.shape[1]
|
columns = 1 if len(predt.shape) == 1 else predt.shape[1]
|
||||||
ret = ((dask.delayed(predt), columns), order) # pylint: disable=no-member
|
# pylint: disable=no-member
|
||||||
|
ret = (
|
||||||
|
[dask.delayed(predt), columns],
|
||||||
|
order,
|
||||||
|
)
|
||||||
predictions.append(ret)
|
predictions.append(ret)
|
||||||
|
|
||||||
return predictions
|
return predictions
|
||||||
@ -1049,8 +1081,8 @@ async def _predict_async(
|
|||||||
def dispatched_get_shape(
|
def dispatched_get_shape(
|
||||||
worker_id: int, list_of_orders: List[int], list_of_parts: _DataParts
|
worker_id: int, list_of_orders: List[int], list_of_parts: _DataParts
|
||||||
) -> List[Tuple[int, int]]:
|
) -> List[Tuple[int, int]]:
|
||||||
'''Get shape of data in each worker.'''
|
"""Get shape of data in each worker."""
|
||||||
LOGGER.debug('Get shape on %d', worker_id)
|
LOGGER.debug("Get shape on %d", worker_id)
|
||||||
list_of_parts = _get_worker_parts_ordered(meta_names, list_of_parts)
|
list_of_parts = _get_worker_parts_ordered(meta_names, list_of_parts)
|
||||||
shapes = []
|
shapes = []
|
||||||
for i, parts in enumerate(list_of_parts):
|
for i, parts in enumerate(list_of_parts):
|
||||||
@ -1061,7 +1093,7 @@ async def _predict_async(
|
|||||||
async def map_function(
|
async def map_function(
|
||||||
func: Callable[[int, List[int], _DataParts], Any]
|
func: Callable[[int, List[int], _DataParts], Any]
|
||||||
) -> List[Any]:
|
) -> List[Any]:
|
||||||
'''Run function for each part of the data.'''
|
"""Run function for each part of the data."""
|
||||||
futures = []
|
futures = []
|
||||||
workers_address = list(worker_map.keys())
|
workers_address = list(worker_map.keys())
|
||||||
for wid, worker_addr in enumerate(workers_address):
|
for wid, worker_addr in enumerate(workers_address):
|
||||||
@ -1069,10 +1101,14 @@ async def _predict_async(
|
|||||||
list_of_parts = worker_map[worker_addr]
|
list_of_parts = worker_map[worker_addr]
|
||||||
list_of_orders = [partition_order[part.key] for part in list_of_parts]
|
list_of_orders = [partition_order[part.key] for part in list_of_parts]
|
||||||
|
|
||||||
f = client.submit(func, worker_id=wid,
|
f = client.submit(
|
||||||
|
func,
|
||||||
|
worker_id=wid,
|
||||||
list_of_orders=list_of_orders,
|
list_of_orders=list_of_orders,
|
||||||
list_of_parts=list_of_parts,
|
list_of_parts=list_of_parts,
|
||||||
pure=True, workers=[worker_addr])
|
pure=True,
|
||||||
|
workers=[worker_addr],
|
||||||
|
)
|
||||||
assert isinstance(f, distributed.client.Future)
|
assert isinstance(f, distributed.client.Future)
|
||||||
futures.append(f)
|
futures.append(f)
|
||||||
# Get delayed objects
|
# Get delayed objects
|
||||||
@ -1091,10 +1127,24 @@ async def _predict_async(
|
|||||||
# See https://docs.dask.org/en/latest/array-creation.html
|
# See https://docs.dask.org/en/latest/array-creation.html
|
||||||
arrays = []
|
arrays = []
|
||||||
for i, shape in enumerate(shapes):
|
for i, shape in enumerate(shapes):
|
||||||
arrays.append(da.from_delayed(
|
if pred_contribs:
|
||||||
results[i][0], shape=(shape[0],)
|
out_shape = (
|
||||||
if results[i][1] == 1 else (shape[0], results[i][1]),
|
(shape[0], results[i][2])
|
||||||
dtype=numpy.float32))
|
if results[i][1] == 1
|
||||||
|
else (shape[0], results[i][1], results[i][2])
|
||||||
|
)
|
||||||
|
elif pred_interactions:
|
||||||
|
out_shape = (
|
||||||
|
(shape[0], results[i][2], results[i][2])
|
||||||
|
if results[i][1] == 1
|
||||||
|
else (shape[0], results[i][1], results[i][2])
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
out_shape = (shape[0],) if results[i][1] == 1 else (shape[0], results[i][1])
|
||||||
|
arrays.append(
|
||||||
|
da.from_delayed(results[i][0], shape=out_shape, dtype=numpy.float32)
|
||||||
|
)
|
||||||
|
|
||||||
predictions = await da.concatenate(arrays, axis=0)
|
predictions = await da.concatenate(arrays, axis=0)
|
||||||
return predictions
|
return predictions
|
||||||
|
|
||||||
@ -1115,7 +1165,9 @@ def predict(
|
|||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
|
|
||||||
Only default prediction mode is supported right now.
|
Using ``inplace_predict `` might be faster when meta information like
|
||||||
|
``base_margin`` is not needed. For other parameters, please see
|
||||||
|
``Booster.predict``.
|
||||||
|
|
||||||
.. versionadded:: 1.0.0
|
.. versionadded:: 1.0.0
|
||||||
|
|
||||||
@ -1136,6 +1188,9 @@ def predict(
|
|||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
prediction: dask.array.Array/dask.dataframe.Series
|
prediction: dask.array.Array/dask.dataframe.Series
|
||||||
|
When input data is ``dask.array.Array`` or ``DaskDMatrix``, the return value is an
|
||||||
|
array, when input data is ``dask.dataframe.DataFrame``, return value is
|
||||||
|
``dask.dataframe.Series``
|
||||||
|
|
||||||
'''
|
'''
|
||||||
_assert_dask_support()
|
_assert_dask_support()
|
||||||
|
|||||||
@ -24,6 +24,7 @@ if sys.platform.startswith("win"):
|
|||||||
if tm.no_dask()['condition']:
|
if tm.no_dask()['condition']:
|
||||||
pytest.skip(msg=tm.no_dask()['reason'], allow_module_level=True)
|
pytest.skip(msg=tm.no_dask()['reason'], allow_module_level=True)
|
||||||
|
|
||||||
|
import distributed
|
||||||
from distributed import LocalCluster, Client
|
from distributed import LocalCluster, Client
|
||||||
from distributed.utils_test import client, loop, cluster_fixture
|
from distributed.utils_test import client, loop, cluster_fixture
|
||||||
import dask.dataframe as dd
|
import dask.dataframe as dd
|
||||||
@ -51,11 +52,12 @@ def generate_array(
|
|||||||
with_weights: bool = False
|
with_weights: bool = False
|
||||||
) -> Tuple[xgb.dask._DaskCollection, xgb.dask._DaskCollection,
|
) -> Tuple[xgb.dask._DaskCollection, xgb.dask._DaskCollection,
|
||||||
Optional[xgb.dask._DaskCollection]]:
|
Optional[xgb.dask._DaskCollection]]:
|
||||||
partition_size = 20
|
chunk_size = 20
|
||||||
X = da.random.random((kRows, kCols), partition_size)
|
rng = da.random.RandomState(1994)
|
||||||
y = da.random.random(kRows, partition_size)
|
X = rng.random_sample((kRows, kCols), chunks=(chunk_size, -1))
|
||||||
|
y = rng.random_sample(kRows, chunks=chunk_size)
|
||||||
if with_weights:
|
if with_weights:
|
||||||
w = da.random.random(kRows, partition_size)
|
w = rng.random_sample(kRows, chunks=chunk_size)
|
||||||
return X, y, w
|
return X, y, w
|
||||||
return X, y, None
|
return X, y, None
|
||||||
|
|
||||||
@ -175,9 +177,7 @@ def test_boost_from_prediction(tree_method: str, client: "Client") -> None:
|
|||||||
assert np.all(predictions_1.compute() == predictions_2.compute())
|
assert np.all(predictions_1.compute() == predictions_2.compute())
|
||||||
|
|
||||||
|
|
||||||
def test_dask_missing_value_reg() -> None:
|
def test_dask_missing_value_reg(client: "Client") -> None:
|
||||||
with LocalCluster(n_workers=kWorkers) as cluster:
|
|
||||||
with Client(cluster) as client:
|
|
||||||
X_0 = np.ones((20 // 2, kCols))
|
X_0 = np.ones((20 // 2, kCols))
|
||||||
X_1 = np.zeros((20 // 2, kCols))
|
X_1 = np.zeros((20 // 2, kCols))
|
||||||
X = np.concatenate([X_0, X_1], axis=0)
|
X = np.concatenate([X_0, X_1], axis=0)
|
||||||
@ -199,9 +199,7 @@ def test_dask_missing_value_reg() -> None:
|
|||||||
np.testing.assert_allclose(np_predt, dd_predt)
|
np.testing.assert_allclose(np_predt, dd_predt)
|
||||||
|
|
||||||
|
|
||||||
def test_dask_missing_value_cls() -> None:
|
def test_dask_missing_value_cls(client: "Client") -> None:
|
||||||
with LocalCluster() as cluster:
|
|
||||||
with Client(cluster) as client:
|
|
||||||
X_0 = np.ones((kRows // 2, kCols))
|
X_0 = np.ones((kRows // 2, kCols))
|
||||||
X_1 = np.zeros((kRows // 2, kCols))
|
X_1 = np.zeros((kRows // 2, kCols))
|
||||||
X = np.concatenate([X_0, X_1], axis=0)
|
X = np.concatenate([X_0, X_1], axis=0)
|
||||||
@ -998,8 +996,7 @@ class TestWithDask:
|
|||||||
assert cnt - n_workers == n_partitions
|
assert cnt - n_workers == n_partitions
|
||||||
|
|
||||||
def run_shap(self, X: Any, y: Any, params: Dict[str, Any], client: "Client") -> None:
|
def run_shap(self, X: Any, y: Any, params: Dict[str, Any], client: "Client") -> None:
|
||||||
X, y = da.from_array(X), da.from_array(y)
|
X, y = da.from_array(X, chunks=(32, -1)), da.from_array(y, chunks=32)
|
||||||
|
|
||||||
Xy = xgb.dask.DaskDMatrix(client, X, y)
|
Xy = xgb.dask.DaskDMatrix(client, X, y)
|
||||||
booster = xgb.dask.train(client, params, Xy, num_boost_round=10)['booster']
|
booster = xgb.dask.train(client, params, Xy, num_boost_round=10)['booster']
|
||||||
|
|
||||||
@ -1009,8 +1006,12 @@ class TestWithDask:
|
|||||||
margin = xgb.dask.predict(client, booster, test_Xy, output_margin=True).compute()
|
margin = xgb.dask.predict(client, booster, test_Xy, output_margin=True).compute()
|
||||||
assert np.allclose(np.sum(shap, axis=len(shap.shape) - 1), margin, 1e-5, 1e-5)
|
assert np.allclose(np.sum(shap, axis=len(shap.shape) - 1), margin, 1e-5, 1e-5)
|
||||||
|
|
||||||
|
shap = xgb.dask.predict(client, booster, X, pred_contribs=True).compute()
|
||||||
|
margin = xgb.dask.predict(client, booster, X, output_margin=True).compute()
|
||||||
|
assert np.allclose(np.sum(shap, axis=len(shap.shape) - 1), margin, 1e-5, 1e-5)
|
||||||
|
|
||||||
def run_shap_cls_sklearn(self, X: Any, y: Any, client: "Client") -> None:
|
def run_shap_cls_sklearn(self, X: Any, y: Any, client: "Client") -> None:
|
||||||
X, y = da.from_array(X), da.from_array(y)
|
X, y = da.from_array(X, chunks=(32, -1)), da.from_array(y, chunks=32)
|
||||||
cls = xgb.dask.DaskXGBClassifier()
|
cls = xgb.dask.DaskXGBClassifier()
|
||||||
cls.client = client
|
cls.client = client
|
||||||
cls.fit(X, y)
|
cls.fit(X, y)
|
||||||
@ -1022,6 +1023,10 @@ class TestWithDask:
|
|||||||
margin = xgb.dask.predict(client, booster, test_Xy, output_margin=True).compute()
|
margin = xgb.dask.predict(client, booster, test_Xy, output_margin=True).compute()
|
||||||
assert np.allclose(np.sum(shap, axis=len(shap.shape) - 1), margin, 1e-5, 1e-5)
|
assert np.allclose(np.sum(shap, axis=len(shap.shape) - 1), margin, 1e-5, 1e-5)
|
||||||
|
|
||||||
|
shap = xgb.dask.predict(client, booster, X, pred_contribs=True).compute()
|
||||||
|
margin = xgb.dask.predict(client, booster, X, output_margin=True).compute()
|
||||||
|
assert np.allclose(np.sum(shap, axis=len(shap.shape) - 1), margin, 1e-5, 1e-5)
|
||||||
|
|
||||||
def test_shap(self, client: "Client") -> None:
|
def test_shap(self, client: "Client") -> None:
|
||||||
from sklearn.datasets import load_boston, load_digits
|
from sklearn.datasets import load_boston, load_digits
|
||||||
X, y = load_boston(return_X_y=True)
|
X, y = load_boston(return_X_y=True)
|
||||||
@ -1031,6 +1036,7 @@ class TestWithDask:
|
|||||||
X, y = load_digits(return_X_y=True)
|
X, y = load_digits(return_X_y=True)
|
||||||
params = {'objective': 'multi:softmax', 'num_class': 10}
|
params = {'objective': 'multi:softmax', 'num_class': 10}
|
||||||
self.run_shap(X, y, params, client)
|
self.run_shap(X, y, params, client)
|
||||||
|
|
||||||
params = {'objective': 'multi:softprob', 'num_class': 10}
|
params = {'objective': 'multi:softprob', 'num_class': 10}
|
||||||
self.run_shap(X, y, params, client)
|
self.run_shap(X, y, params, client)
|
||||||
|
|
||||||
@ -1043,7 +1049,7 @@ class TestWithDask:
|
|||||||
params: Dict[str, Any],
|
params: Dict[str, Any],
|
||||||
client: "Client"
|
client: "Client"
|
||||||
) -> None:
|
) -> None:
|
||||||
X, y = da.from_array(X), da.from_array(y)
|
X, y = da.from_array(X, chunks=(32, -1)), da.from_array(y, chunks=32)
|
||||||
|
|
||||||
Xy = xgb.dask.DaskDMatrix(client, X, y)
|
Xy = xgb.dask.DaskDMatrix(client, X, y)
|
||||||
booster = xgb.dask.train(client, params, Xy, num_boost_round=10)['booster']
|
booster = xgb.dask.train(client, params, Xy, num_boost_round=10)['booster']
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user