[dask] Accept Future of model for prediction. (#6650)
This PR changes predict and inplace_predict to accept a Future of model, to avoid sending models to workers repeatably. * Document is updated to reflect functionality additions in recent changes.
This commit is contained in:
parent
a9ec0ea6da
commit
87ab1ad607
@ -112,8 +112,11 @@ is a ``DaskDMatrix`` or ``da.Array``. When putting dask collection directly int
|
|||||||
``predict`` function or using ``inplace_predict``, the output type depends on input data.
|
``predict`` function or using ``inplace_predict``, the output type depends on input data.
|
||||||
See next section for details.
|
See next section for details.
|
||||||
|
|
||||||
Alternatively, XGBoost also implements the Scikit-Learn interface with ``DaskXGBClassifier``
|
Alternatively, XGBoost also implements the Scikit-Learn interface with
|
||||||
and ``DaskXGBRegressor``. See ``xgboost/demo/dask`` for more examples.
|
``DaskXGBClassifier``, ``DaskXGBRegressor``, ``DaskXGBRanker`` and 2 random forest
|
||||||
|
variances. This wrapper is similar to the single node Scikit-Learn interface in xgboost,
|
||||||
|
with dask collection as inputs and has an additional ``client`` attribute. See
|
||||||
|
``xgboost/demo/dask`` for more examples.
|
||||||
|
|
||||||
|
|
||||||
******************
|
******************
|
||||||
@ -160,6 +163,32 @@ if not using GPU, the number of threads used for prediction on each block matter
|
|||||||
now, xgboost uses single thread for each partition. If the number of blocks on each
|
now, xgboost uses single thread for each partition. If the number of blocks on each
|
||||||
workers is smaller than number of cores, then the CPU workers might not be fully utilized.
|
workers is smaller than number of cores, then the CPU workers might not be fully utilized.
|
||||||
|
|
||||||
|
One simple optimization for running consecutive predictions is using
|
||||||
|
``distributed.Future``:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
dataset = [X_0, X_1, X_2]
|
||||||
|
booster_f = client.scatter(booster, broadcast=True)
|
||||||
|
futures = []
|
||||||
|
for X in dataset:
|
||||||
|
# Here we pass in a future instead of concrete booster
|
||||||
|
shap_f = xgb.dask.predict(client, booster_f, X, pred_contribs=True)
|
||||||
|
futures.append(shap_f)
|
||||||
|
|
||||||
|
results = client.gather(futures)
|
||||||
|
|
||||||
|
|
||||||
|
This is only available on functional interface, as the Scikit-Learn wrapper doesn't know
|
||||||
|
how to maintain a valid future for booster. To obtain the booster object from
|
||||||
|
Scikit-Learn wrapper object:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
cls = xgb.dask.DaskXGBClassifier()
|
||||||
|
cls.fit(X, y)
|
||||||
|
|
||||||
|
booster = cls.get_booster()
|
||||||
|
|
||||||
|
|
||||||
***************************
|
***************************
|
||||||
@ -231,17 +260,17 @@ will override the configuration in Dask. For example:
|
|||||||
with dask.distributed.LocalCluster(n_workers=7, threads_per_worker=4) as cluster:
|
with dask.distributed.LocalCluster(n_workers=7, threads_per_worker=4) as cluster:
|
||||||
|
|
||||||
There are 4 threads allocated for each dask worker. Then by default XGBoost will use 4
|
There are 4 threads allocated for each dask worker. Then by default XGBoost will use 4
|
||||||
threads in each process for both training and prediction. But if ``nthread`` parameter is
|
threads in each process for training. But if ``nthread`` parameter is set:
|
||||||
set:
|
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
output = xgb.dask.train(client,
|
output = xgb.dask.train(
|
||||||
{'verbosity': 1,
|
client,
|
||||||
'nthread': 8,
|
{"verbosity": 1, "nthread": 8, "tree_method": "hist"},
|
||||||
'tree_method': 'hist'},
|
dtrain,
|
||||||
dtrain,
|
num_boost_round=4,
|
||||||
num_boost_round=4, evals=[(dtrain, 'train')])
|
evals=[(dtrain, "train")],
|
||||||
|
)
|
||||||
|
|
||||||
XGBoost will use 8 threads in each training process.
|
XGBoost will use 8 threads in each training process.
|
||||||
|
|
||||||
@ -274,12 +303,12 @@ Functional interface:
|
|||||||
with_X = await xgb.dask.predict(client, output, X)
|
with_X = await xgb.dask.predict(client, output, X)
|
||||||
inplace = await xgb.dask.inplace_predict(client, output, X)
|
inplace = await xgb.dask.inplace_predict(client, output, X)
|
||||||
|
|
||||||
# Use `client.compute` instead of the `compute` method from dask collection
|
# Use ``client.compute`` instead of the ``compute`` method from dask collection
|
||||||
print(await client.compute(with_m))
|
print(await client.compute(with_m))
|
||||||
|
|
||||||
|
|
||||||
While for the Scikit-Learn interface, trivial methods like ``set_params`` and accessing class
|
While for the Scikit-Learn interface, trivial methods like ``set_params`` and accessing class
|
||||||
attributes like ``evals_result_`` do not require ``await``. Other methods involving
|
attributes like ``evals_result()`` do not require ``await``. Other methods involving
|
||||||
actual computation will return a coroutine and hence require awaiting:
|
actual computation will return a coroutine and hence require awaiting:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
@ -373,6 +402,46 @@ If early stopping is enabled by also passing ``early_stopping_rounds``, you can
|
|||||||
print(booster.best_iteration)
|
print(booster.best_iteration)
|
||||||
best_model = booster[: booster.best_iteration]
|
best_model = booster[: booster.best_iteration]
|
||||||
|
|
||||||
|
|
||||||
|
*******************
|
||||||
|
Other customization
|
||||||
|
*******************
|
||||||
|
|
||||||
|
XGBoost dask interface accepts other advanced features found in single node Python
|
||||||
|
interface, including callback functions, custom evaluation metric and objective:
|
||||||
|
|
||||||
|
def eval_error_metric(predt, dtrain: xgb.DMatrix):
|
||||||
|
label = dtrain.get_label()
|
||||||
|
r = np.zeros(predt.shape)
|
||||||
|
gt = predt > 0.5
|
||||||
|
r[gt] = 1 - label[gt]
|
||||||
|
le = predt <= 0.5
|
||||||
|
r[le] = label[le]
|
||||||
|
return 'CustomErr', np.sum(r)
|
||||||
|
|
||||||
|
# custom callback
|
||||||
|
early_stop = xgb.callback.EarlyStopping(
|
||||||
|
rounds=early_stopping_rounds,
|
||||||
|
metric_name="CustomErr",
|
||||||
|
data_name="Train",
|
||||||
|
save_best=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
booster = xgb.dask.train(
|
||||||
|
client,
|
||||||
|
params={
|
||||||
|
"objective": "binary:logistic",
|
||||||
|
"eval_metric": ["error", "rmse"],
|
||||||
|
"tree_method": "hist",
|
||||||
|
},
|
||||||
|
dtrain=D_train,
|
||||||
|
evals=[(D_train, "Train"), (D_valid, "Valid")],
|
||||||
|
feval=eval_error_metric, # custom evaluation metric
|
||||||
|
num_boost_round=100,
|
||||||
|
callbacks=[early_stop],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
*****************************************************************************
|
*****************************************************************************
|
||||||
Why is the initialization of ``DaskDMatrix`` so slow and throws weird errors
|
Why is the initialization of ``DaskDMatrix`` so slow and throws weird errors
|
||||||
*****************************************************************************
|
*****************************************************************************
|
||||||
@ -414,15 +483,3 @@ References:
|
|||||||
|
|
||||||
#. https://github.com/dask/dask/issues/6833
|
#. https://github.com/dask/dask/issues/6833
|
||||||
#. https://stackoverflow.com/questions/45941528/how-to-efficiently-send-a-large-numpy-array-to-the-cluster-with-dask-array
|
#. https://stackoverflow.com/questions/45941528/how-to-efficiently-send-a-large-numpy-array-to-the-cluster-with-dask-array
|
||||||
|
|
||||||
***********
|
|
||||||
Limitations
|
|
||||||
***********
|
|
||||||
|
|
||||||
Basic functionality including model training and generating classification and regression predictions
|
|
||||||
have been implemented. However, there are still some other limitations we haven't
|
|
||||||
addressed yet:
|
|
||||||
|
|
||||||
- Label encoding for the ``DaskXGBClassifier`` classifier may not be supported. So users need
|
|
||||||
to encode their training labels into discrete values first.
|
|
||||||
- Ranking is not yet supported.
|
|
||||||
|
|||||||
@ -940,16 +940,14 @@ def _can_output_df(data: _DaskCollection, output_shape: Tuple) -> bool:
|
|||||||
|
|
||||||
|
|
||||||
async def _direct_predict_impl(
|
async def _direct_predict_impl(
|
||||||
client: "distributed.Client",
|
|
||||||
mapped_predict: Callable,
|
mapped_predict: Callable,
|
||||||
booster: Booster,
|
booster: "distributed.Future",
|
||||||
data: _DaskCollection,
|
data: _DaskCollection,
|
||||||
base_margin: Optional[_DaskCollection],
|
base_margin: Optional[_DaskCollection],
|
||||||
output_shape: Tuple[int, ...],
|
output_shape: Tuple[int, ...],
|
||||||
meta: Dict[int, str],
|
meta: Dict[int, str],
|
||||||
) -> _DaskCollection:
|
) -> _DaskCollection:
|
||||||
columns = list(meta.keys())
|
columns = list(meta.keys())
|
||||||
booster_f = await client.scatter(data=booster, broadcast=True)
|
|
||||||
if _can_output_df(data, output_shape):
|
if _can_output_df(data, output_shape):
|
||||||
if base_margin is not None and isinstance(base_margin, da.Array):
|
if base_margin is not None and isinstance(base_margin, da.Array):
|
||||||
base_margin_df: Optional[dd.DataFrame] = base_margin.to_dask_dataframe()
|
base_margin_df: Optional[dd.DataFrame] = base_margin.to_dask_dataframe()
|
||||||
@ -957,7 +955,7 @@ async def _direct_predict_impl(
|
|||||||
base_margin_df = base_margin
|
base_margin_df = base_margin
|
||||||
predictions = dd.map_partitions(
|
predictions = dd.map_partitions(
|
||||||
mapped_predict,
|
mapped_predict,
|
||||||
booster_f,
|
booster,
|
||||||
data,
|
data,
|
||||||
True,
|
True,
|
||||||
columns,
|
columns,
|
||||||
@ -984,7 +982,7 @@ async def _direct_predict_impl(
|
|||||||
new_axis = [i + 2 for i in range(len(output_shape) - 2)]
|
new_axis = [i + 2 for i in range(len(output_shape) - 2)]
|
||||||
predictions = da.map_blocks(
|
predictions = da.map_blocks(
|
||||||
mapped_predict,
|
mapped_predict,
|
||||||
booster_f,
|
booster,
|
||||||
data,
|
data,
|
||||||
False,
|
False,
|
||||||
columns,
|
columns,
|
||||||
@ -997,7 +995,10 @@ async def _direct_predict_impl(
|
|||||||
|
|
||||||
|
|
||||||
def _infer_predict_output(
|
def _infer_predict_output(
|
||||||
booster: Booster, data: _DaskCollection, inplace: bool, **kwargs: Any
|
booster: Booster,
|
||||||
|
data: Union[DaskDMatrix, _DaskCollection],
|
||||||
|
inplace: bool,
|
||||||
|
**kwargs: Any
|
||||||
) -> Tuple[Tuple[int, ...], Dict[int, str]]:
|
) -> Tuple[Tuple[int, ...], Dict[int, str]]:
|
||||||
"""Create a dummy test sample to infer output shape for prediction."""
|
"""Create a dummy test sample to infer output shape for prediction."""
|
||||||
if isinstance(data, DaskDMatrix):
|
if isinstance(data, DaskDMatrix):
|
||||||
@ -1021,11 +1022,29 @@ def _infer_predict_output(
|
|||||||
return test_predt.shape, meta
|
return test_predt.shape, meta
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_model_future(
|
||||||
|
client: "distributed.Client", model: Union[Booster, Dict, "distributed.Future"]
|
||||||
|
) -> "distributed.Future":
|
||||||
|
if isinstance(model, Booster):
|
||||||
|
booster = await client.scatter(model, broadcast=True)
|
||||||
|
elif isinstance(model, dict):
|
||||||
|
booster = await client.scatter(model["booster"])
|
||||||
|
elif isinstance(model, distributed.Future):
|
||||||
|
booster = model
|
||||||
|
if booster.type is not Booster:
|
||||||
|
raise TypeError(
|
||||||
|
f"Underlying type of model future should be `Booster`, got {booster.type}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise TypeError(_expect([Booster, dict, distributed.Future], type(model)))
|
||||||
|
return booster
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=too-many-statements
|
# pylint: disable=too-many-statements
|
||||||
async def _predict_async(
|
async def _predict_async(
|
||||||
client: "distributed.Client",
|
client: "distributed.Client",
|
||||||
global_config: Dict[str, Any],
|
global_config: Dict[str, Any],
|
||||||
model: Union[Booster, Dict],
|
model: Union[Booster, Dict, "distributed.Future"],
|
||||||
data: _DaskCollection,
|
data: _DaskCollection,
|
||||||
output_margin: bool,
|
output_margin: bool,
|
||||||
missing: float,
|
missing: float,
|
||||||
@ -1035,12 +1054,7 @@ async def _predict_async(
|
|||||||
pred_interactions: bool,
|
pred_interactions: bool,
|
||||||
validate_features: bool,
|
validate_features: bool,
|
||||||
) -> _DaskCollection:
|
) -> _DaskCollection:
|
||||||
if isinstance(model, Booster):
|
_booster = await _get_model_future(client, model)
|
||||||
_booster = model
|
|
||||||
elif isinstance(model, dict):
|
|
||||||
_booster = model["booster"]
|
|
||||||
else:
|
|
||||||
raise TypeError(_expect([Booster, dict], type(model)))
|
|
||||||
if not isinstance(data, (DaskDMatrix, da.Array, dd.DataFrame)):
|
if not isinstance(data, (DaskDMatrix, da.Array, dd.DataFrame)):
|
||||||
raise TypeError(_expect([DaskDMatrix, da.Array, dd.DataFrame], type(data)))
|
raise TypeError(_expect([DaskDMatrix, da.Array, dd.DataFrame], type(data)))
|
||||||
|
|
||||||
@ -1070,7 +1084,7 @@ async def _predict_async(
|
|||||||
# Predict on dask collection directly.
|
# Predict on dask collection directly.
|
||||||
if isinstance(data, (da.Array, dd.DataFrame)):
|
if isinstance(data, (da.Array, dd.DataFrame)):
|
||||||
_output_shape, meta = _infer_predict_output(
|
_output_shape, meta = _infer_predict_output(
|
||||||
_booster,
|
await _booster.result(),
|
||||||
data,
|
data,
|
||||||
inplace=False,
|
inplace=False,
|
||||||
output_margin=output_margin,
|
output_margin=output_margin,
|
||||||
@ -1081,10 +1095,11 @@ async def _predict_async(
|
|||||||
validate_features=False,
|
validate_features=False,
|
||||||
)
|
)
|
||||||
return await _direct_predict_impl(
|
return await _direct_predict_impl(
|
||||||
client, mapped_predict, _booster, data, None, _output_shape, meta
|
mapped_predict, _booster, data, None, _output_shape, meta
|
||||||
)
|
)
|
||||||
|
|
||||||
output_shape, _ = _infer_predict_output(
|
output_shape, _ = _infer_predict_output(
|
||||||
booster=_booster,
|
booster=await _booster.result(),
|
||||||
data=data,
|
data=data,
|
||||||
inplace=False,
|
inplace=False,
|
||||||
output_margin=output_margin,
|
output_margin=output_margin,
|
||||||
@ -1108,11 +1123,9 @@ async def _predict_async(
|
|||||||
for i, blob in enumerate(part[1:]):
|
for i, blob in enumerate(part[1:]):
|
||||||
if meta_names[i] == "base_margin":
|
if meta_names[i] == "base_margin":
|
||||||
base_margin = blob
|
base_margin = blob
|
||||||
worker = distributed.get_worker()
|
|
||||||
with config.config_context(**global_config):
|
with config.config_context(**global_config):
|
||||||
m = DMatrix(
|
m = DMatrix(
|
||||||
data,
|
data,
|
||||||
nthread=worker.nthreads,
|
|
||||||
missing=missing,
|
missing=missing,
|
||||||
base_margin=base_margin,
|
base_margin=base_margin,
|
||||||
feature_names=feature_names,
|
feature_names=feature_names,
|
||||||
@ -1148,9 +1161,8 @@ async def _predict_async(
|
|||||||
all_shapes = [shape for part, shape, order in parts_with_order]
|
all_shapes = [shape for part, shape, order in parts_with_order]
|
||||||
|
|
||||||
futures = []
|
futures = []
|
||||||
booster_f = await client.scatter(data=_booster, broadcast=True)
|
|
||||||
for part in all_parts:
|
for part in all_parts:
|
||||||
f = client.submit(dispatched_predict, booster_f, part)
|
f = client.submit(dispatched_predict, _booster, part)
|
||||||
futures.append(f)
|
futures.append(f)
|
||||||
|
|
||||||
# Constructing a dask array from list of numpy arrays
|
# Constructing a dask array from list of numpy arrays
|
||||||
@ -1168,7 +1180,7 @@ async def _predict_async(
|
|||||||
|
|
||||||
def predict( # pylint: disable=unused-argument
|
def predict( # pylint: disable=unused-argument
|
||||||
client: "distributed.Client",
|
client: "distributed.Client",
|
||||||
model: Union[TrainReturnT, Booster],
|
model: Union[TrainReturnT, Booster, "distributed.Future"],
|
||||||
data: Union[DaskDMatrix, _DaskCollection],
|
data: Union[DaskDMatrix, _DaskCollection],
|
||||||
output_margin: bool = False,
|
output_margin: bool = False,
|
||||||
missing: float = numpy.nan,
|
missing: float = numpy.nan,
|
||||||
@ -1194,7 +1206,8 @@ def predict( # pylint: disable=unused-argument
|
|||||||
Specify the dask client used for training. Use default client
|
Specify the dask client used for training. Use default client
|
||||||
returned from dask if it's set to None.
|
returned from dask if it's set to None.
|
||||||
model:
|
model:
|
||||||
The trained model.
|
The trained model. It can be a distributed.Future so user can
|
||||||
|
pre-scatter it onto all workers.
|
||||||
data:
|
data:
|
||||||
Input data used for prediction. When input is a dataframe object,
|
Input data used for prediction. When input is a dataframe object,
|
||||||
prediction output is a series.
|
prediction output is a series.
|
||||||
@ -1221,19 +1234,14 @@ def predict( # pylint: disable=unused-argument
|
|||||||
async def _inplace_predict_async(
|
async def _inplace_predict_async(
|
||||||
client: "distributed.Client",
|
client: "distributed.Client",
|
||||||
global_config: Dict[str, Any],
|
global_config: Dict[str, Any],
|
||||||
model: Union[Booster, Dict],
|
model: Union[Booster, Dict, "distributed.Future"],
|
||||||
data: _DaskCollection,
|
data: _DaskCollection,
|
||||||
iteration_range: Tuple[int, int] = (0, 0),
|
iteration_range: Tuple[int, int] = (0, 0),
|
||||||
predict_type: str = 'value',
|
predict_type: str = 'value',
|
||||||
missing: float = numpy.nan
|
missing: float = numpy.nan
|
||||||
) -> _DaskCollection:
|
) -> _DaskCollection:
|
||||||
client = _xgb_get_client(client)
|
client = _xgb_get_client(client)
|
||||||
if isinstance(model, Booster):
|
booster = await _get_model_future(client, model)
|
||||||
booster = model
|
|
||||||
elif isinstance(model, dict):
|
|
||||||
booster = model['booster']
|
|
||||||
else:
|
|
||||||
raise TypeError(_expect([Booster, dict], type(model)))
|
|
||||||
if not isinstance(data, (da.Array, dd.DataFrame)):
|
if not isinstance(data, (da.Array, dd.DataFrame)):
|
||||||
raise TypeError(_expect([da.Array, dd.DataFrame], type(data)))
|
raise TypeError(_expect([da.Array, dd.DataFrame], type(data)))
|
||||||
|
|
||||||
@ -1261,16 +1269,20 @@ async def _inplace_predict_async(
|
|||||||
return prediction
|
return prediction
|
||||||
|
|
||||||
shape, meta = _infer_predict_output(
|
shape, meta = _infer_predict_output(
|
||||||
booster, data, True, predict_type=predict_type, iteration_range=iteration_range
|
await booster.result(),
|
||||||
|
data,
|
||||||
|
True,
|
||||||
|
predict_type=predict_type,
|
||||||
|
iteration_range=iteration_range
|
||||||
)
|
)
|
||||||
return await _direct_predict_impl(
|
return await _direct_predict_impl(
|
||||||
client, mapped_predict, booster, data, None, shape, meta
|
mapped_predict, booster, data, None, shape, meta
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def inplace_predict( # pylint: disable=unused-argument
|
def inplace_predict( # pylint: disable=unused-argument
|
||||||
client: "distributed.Client",
|
client: "distributed.Client",
|
||||||
model: Union[TrainReturnT, Booster],
|
model: Union[TrainReturnT, Booster, "distributed.Future"],
|
||||||
data: _DaskCollection,
|
data: _DaskCollection,
|
||||||
iteration_range: Tuple[int, int] = (0, 0),
|
iteration_range: Tuple[int, int] = (0, 0),
|
||||||
predict_type: str = 'value',
|
predict_type: str = 'value',
|
||||||
@ -1286,7 +1298,8 @@ def inplace_predict( # pylint: disable=unused-argument
|
|||||||
Specify the dask client used for training. Use default client
|
Specify the dask client used for training. Use default client
|
||||||
returned from dask if it's set to None.
|
returned from dask if it's set to None.
|
||||||
model:
|
model:
|
||||||
The trained model.
|
The trained model. It can be a distributed.Future so user can
|
||||||
|
pre-scatter it onto all workers.
|
||||||
iteration_range:
|
iteration_range:
|
||||||
Specify the range of trees used for prediction.
|
Specify the range of trees used for prediction.
|
||||||
predict_type:
|
predict_type:
|
||||||
|
|||||||
@ -535,7 +535,7 @@ class XGBModel(XGBModelBase):
|
|||||||
json.dumps({k: v})
|
json.dumps({k: v})
|
||||||
meta[k] = v
|
meta[k] = v
|
||||||
except TypeError:
|
except TypeError:
|
||||||
warnings.warn(str(k) + ' is not saved in Scikit-Learn meta.')
|
warnings.warn(str(k) + ' is not saved in Scikit-Learn meta.', UserWarning)
|
||||||
meta['_estimator_type'] = self._get_type()
|
meta['_estimator_type'] = self._get_type()
|
||||||
meta_str = json.dumps(meta)
|
meta_str = json.dumps(meta)
|
||||||
self.get_booster().set_attr(scikit_learn=meta_str)
|
self.get_booster().set_attr(scikit_learn=meta_str)
|
||||||
|
|||||||
@ -608,28 +608,30 @@ def test_with_asyncio() -> None:
|
|||||||
asyncio.run(run_dask_classifier_asyncio(address))
|
asyncio.run(run_dask_classifier_asyncio(address))
|
||||||
|
|
||||||
|
|
||||||
def test_predict() -> None:
|
def test_predict(client: "Client") -> None:
|
||||||
with LocalCluster(n_workers=kWorkers) as cluster:
|
X, y, _ = generate_array()
|
||||||
with Client(cluster) as client:
|
dtrain = DaskDMatrix(client, X, y)
|
||||||
X, y, _ = generate_array()
|
booster = xgb.dask.train(client, {}, dtrain, num_boost_round=2)["booster"]
|
||||||
dtrain = DaskDMatrix(client, X, y)
|
|
||||||
booster = xgb.dask.train(
|
|
||||||
client, {}, dtrain, num_boost_round=2)['booster']
|
|
||||||
|
|
||||||
pred = xgb.dask.predict(client, model=booster, data=dtrain)
|
predt_0 = xgb.dask.predict(client, model=booster, data=dtrain)
|
||||||
assert pred.ndim == 1
|
assert predt_0.ndim == 1
|
||||||
assert pred.shape[0] == kRows
|
assert predt_0.shape[0] == kRows
|
||||||
|
|
||||||
margin = xgb.dask.predict(client, model=booster, data=dtrain,
|
margin = xgb.dask.predict(client, model=booster, data=dtrain, output_margin=True)
|
||||||
output_margin=True)
|
assert margin.ndim == 1
|
||||||
assert margin.ndim == 1
|
assert margin.shape[0] == kRows
|
||||||
assert margin.shape[0] == kRows
|
|
||||||
|
|
||||||
shap = xgb.dask.predict(client, model=booster, data=dtrain,
|
shap = xgb.dask.predict(client, model=booster, data=dtrain, pred_contribs=True)
|
||||||
pred_contribs=True)
|
assert shap.ndim == 2
|
||||||
assert shap.ndim == 2
|
assert shap.shape[0] == kRows
|
||||||
assert shap.shape[0] == kRows
|
assert shap.shape[1] == kCols + 1
|
||||||
assert shap.shape[1] == kCols + 1
|
|
||||||
|
booster_f = client.scatter(booster, broadcast=True)
|
||||||
|
|
||||||
|
predt_1 = xgb.dask.predict(client, booster_f, X).compute()
|
||||||
|
predt_2 = xgb.dask.inplace_predict(client, booster_f, X).compute()
|
||||||
|
np.testing.assert_allclose(predt_0, predt_1)
|
||||||
|
np.testing.assert_allclose(predt_0, predt_2)
|
||||||
|
|
||||||
|
|
||||||
def test_predict_with_meta(client: "Client") -> None:
|
def test_predict_with_meta(client: "Client") -> None:
|
||||||
@ -1034,7 +1036,7 @@ class TestWithDask:
|
|||||||
rows = X.shape[0]
|
rows = X.shape[0]
|
||||||
cols = X.shape[1]
|
cols = X.shape[1]
|
||||||
|
|
||||||
def assert_shape(shape):
|
def assert_shape(shape: Tuple[int, ...]) -> None:
|
||||||
assert shape[0] == rows
|
assert shape[0] == rows
|
||||||
if "num_class" in params.keys():
|
if "num_class" in params.keys():
|
||||||
assert shape[1] == params["num_class"]
|
assert shape[1] == params["num_class"]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user