[breaking] Remove the predictor param, allow fallback to prediction using DMatrix. (#9129)

- A `DeviceOrd` struct is implemented to indicate the device. It will eventually replace the `gpu_id` parameter.
- The `predictor` parameter is removed.
- Fallback to `DMatrix` when `inplace_predict` is not available.
- The heuristic for choosing a predictor is only used during training.
This commit is contained in:
Jiaming Yuan 2023-07-03 19:23:54 +08:00 committed by GitHub
parent 3a0f787703
commit 39390cc2ee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
54 changed files with 1049 additions and 778 deletions

View File

@ -45,7 +45,7 @@ XGBoost makes use of `GPUTreeShap <https://github.com/rapidsai/gputreeshap>`_ as
.. code-block:: python
model.set_param({"predictor": "gpu_predictor"})
model.set_param({"gpu_id": "0", "tree_method": "gpu_hist"})
shap_values = model.predict(dtrain, pred_contribs=True)
shap_interaction_values = model.predict(dtrain, pred_interactions=True)

View File

@ -199,18 +199,6 @@ Parameters for Tree Booster
- Maximum number of discrete bins to bucket continuous features.
- Increasing this number improves the optimality of splits at the cost of higher computation time.
* ``predictor``, [default= ``auto``]
- The type of predictor algorithm to use. Provides the same results but allows the use of GPU or CPU.
- ``auto``: Configure predictor based on heuristics.
- ``cpu_predictor``: Multicore CPU prediction algorithm.
- ``gpu_predictor``: Prediction using GPU. Used when ``tree_method`` is ``gpu_hist``.
When ``predictor`` is set to default value ``auto``, the ``gpu_hist`` tree method is
able to provide GPU based prediction without copying training data to GPU memory.
If ``gpu_predictor`` is explicitly specified, then all data is copied into GPU, only
recommended for performing prediction tasks.
* ``num_parallel_tree``, [default=1]
- Number of parallel trees constructed during each iteration. This option is used to support boosted random forest.

View File

@ -87,15 +87,6 @@ with the native Python interface :py:meth:`xgboost.Booster.predict` and
behavior. Also the ``save_best`` parameter from :py:obj:`xgboost.callback.EarlyStopping`
might be useful.
*********
Predictor
*********
There are 2 predictors in XGBoost (3 if you have the one-api plugin enabled), namely
``cpu_predictor`` and ``gpu_predictor``. The default option is ``auto`` so that XGBoost
can employ some heuristics for saving GPU memory during training. They might have slight
different outputs due to floating point errors.
***********
Base Margin
@ -134,15 +125,6 @@ it. Be aware that the output of in-place prediction depends on input data type,
input is on GPU data output is :py:obj:`cupy.ndarray`, otherwise a :py:obj:`numpy.ndarray`
is returned.
****************
Categorical Data
****************
Other than users performing encoding, XGBoost has experimental support for categorical
data using ``gpu_hist`` and ``gpu_predictor``. No special operation needs to be done on
input test data since the information about categories is encoded into the model during
training.
*************
Thread Safety
*************
@ -159,7 +141,6 @@ instance we might accidentally call ``clf.set_params()`` inside a predict functi
def predict_fn(clf: xgb.XGBClassifier, X):
X = preprocess(X)
clf.set_params(predictor="gpu_predictor") # NOT safe!
clf.set_params(n_jobs=1) # NOT safe!
return clf.predict_proba(X, iteration_range=(0, 10))

View File

@ -148,8 +148,8 @@ Also for inplace prediction:
.. code-block:: python
booster.set_param({'predictor': 'gpu_predictor'})
# where X is a dask DataFrame or dask Array containing cupy or cuDF backed data.
# where X is a dask DataFrame or dask Array backed by cupy or cuDF.
booster.set_param({"gpu_id": "0"})
prediction = xgb.dask.inplace_predict(client, booster, X)
When input is ``da.Array`` object, output is always ``da.Array``. However, if the input

View File

@ -173,7 +173,6 @@ Will print out something similar to (not actual output as it's too long for demo
"gradient_booster": {
"gbtree_train_param": {
"num_parallel_tree": "1",
"predictor": "gpu_predictor",
"process_type": "default",
"tree_method": "gpu_hist",
"updater": "grow_gpu_hist",

View File

@ -10,6 +10,7 @@
#include <dmlc/omp.h>
#include <cmath>
#include <cstdint>
#include <iostream>
#include <string>
#include <utility>
@ -112,7 +113,7 @@ using bst_row_t = std::size_t; // NOLINT
/*! \brief Type for tree node index. */
using bst_node_t = std::int32_t; // NOLINT
/*! \brief Type for ranking group index. */
using bst_group_t = std::uint32_t; // NOLINT
using bst_group_t = std::uint32_t; // NOLINT
/**
* \brief Type for indexing into output targets.
*/
@ -125,6 +126,10 @@ using bst_layer_t = std::int32_t; // NOLINT
* \brief Type for indexing trees.
*/
using bst_tree_t = std::int32_t; // NOLINT
/**
* @brief Ordinal of a CUDA device.
*/
using bst_d_ordinal_t = std::int16_t; // NOLINT
namespace detail {
/*! \brief Implementation of gradient statistics pair. Template specialisation

View File

@ -1067,6 +1067,9 @@ XGB_DLL int XGBoosterPredictFromDMatrix(BoosterHandle handle, DMatrixHandle dmat
/**
* \brief Inplace prediction from CPU dense matrix.
*
* \note If the booster is configured to run on a CUDA device, XGBoost falls back to run
* prediction with DMatrix with a performance warning.
*
* \param handle Booster handle.
* \param values JSON encoded __array_interface__ to values.
* \param config See \ref XGBoosterPredictFromDMatrix for more info.
@ -1091,6 +1094,9 @@ XGB_DLL int XGBoosterPredictFromDense(BoosterHandle handle, char const *values,
/**
* \brief Inplace prediction from CPU CSR matrix.
*
* \note If the booster is configured to run on a CUDA device, XGBoost falls back to run
* prediction with DMatrix with a performance warning.
*
* \param handle Booster handle.
* \param indptr JSON encoded __array_interface__ to row pointer in CSR.
* \param indices JSON encoded __array_interface__ to column indices in CSR.
@ -1116,6 +1122,9 @@ XGB_DLL int XGBoosterPredictFromCSR(BoosterHandle handle, char const *indptr, ch
/**
* \brief Inplace prediction from CUDA Dense matrix (cupy in Python).
*
* \note If the booster is configured to run on a CPU, XGBoost falls back to run
* prediction with DMatrix with a performance warning.
*
* \param handle Booster handle
* \param values JSON encoded __cuda_array_interface__ to values.
* \param config See \ref XGBoosterPredictFromDMatrix for more info.
@ -1137,6 +1146,9 @@ XGB_DLL int XGBoosterPredictFromCudaArray(BoosterHandle handle, char const *valu
/**
* \brief Inplace prediction from CUDA dense dataframe (cuDF in Python).
*
* \note If the booster is configured to run on a CPU, XGBoost falls back to run
* prediction with DMatrix with a performance warning.
*
* \param handle Booster handle
* \param values List of __cuda_array_interface__ for all columns encoded in JSON list.
* \param config See \ref XGBoosterPredictFromDMatrix for more info.

View File

@ -1,20 +1,79 @@
/*!
* Copyright 2014-2022 by Contributors
/**
* Copyright 2014-2023, XGBoost Contributors
* \file context.h
*/
#ifndef XGBOOST_CONTEXT_H_
#define XGBOOST_CONTEXT_H_
#include <xgboost/logging.h>
#include <xgboost/parameter.h>
#include <xgboost/base.h> // for bst_d_ordinal_t
#include <xgboost/logging.h> // for CHECK_GE
#include <xgboost/parameter.h> // for XGBoostParameter
#include <memory> // std::shared_ptr
#include <string>
#include <cstdint> // for int16_t, int32_t, int64_t
#include <memory> // for shared_ptr
#include <string> // for string, to_string
namespace xgboost {
struct CUDAContext;
/**
* @brief A type for device ordinal. The type is packed into 32-bit for efficient use in
* viewing types like `linalg::TensorView`.
*/
struct DeviceOrd {
enum Type : std::int16_t { kCPU = 0, kCUDA = 1 } device{kCPU};
// CUDA device ordinal.
bst_d_ordinal_t ordinal{-1};
[[nodiscard]] bool IsCUDA() const { return device == kCUDA; }
[[nodiscard]] bool IsCPU() const { return device == kCPU; }
DeviceOrd() = default;
constexpr DeviceOrd(Type type, bst_d_ordinal_t ord) : device{type}, ordinal{ord} {}
DeviceOrd(DeviceOrd const& that) = default;
DeviceOrd& operator=(DeviceOrd const& that) = default;
DeviceOrd(DeviceOrd&& that) = default;
DeviceOrd& operator=(DeviceOrd&& that) = default;
/**
* @brief Constructor for CPU.
*/
[[nodiscard]] constexpr static auto CPU() { return DeviceOrd{kCPU, -1}; }
/**
* @brief Constructor for CUDA device.
*
* @param ordinal CUDA device ordinal.
*/
[[nodiscard]] static auto CUDA(bst_d_ordinal_t ordinal) { return DeviceOrd{kCUDA, ordinal}; }
[[nodiscard]] bool operator==(DeviceOrd const& that) const {
return device == that.device && ordinal == that.ordinal;
}
[[nodiscard]] bool operator!=(DeviceOrd const& that) const { return !(*this == that); }
/**
* @brief Get a string representation of the device and the ordinal.
*/
[[nodiscard]] std::string Name() const {
switch (device) {
case DeviceOrd::kCPU:
return "CPU";
case DeviceOrd::kCUDA:
return "CUDA:" + std::to_string(ordinal);
default: {
LOG(FATAL) << "Unknown device.";
return "";
}
}
}
};
static_assert(sizeof(DeviceOrd) == sizeof(std::int32_t));
/**
* @brief Runtime context for XGBoost. Contains information like threads and device.
*/
struct Context : public XGBoostParameter<Context> {
public:
// Constant representing the device ID of CPU.
@ -36,29 +95,59 @@ struct Context : public XGBoostParameter<Context> {
// fail when gpu_id is invalid
bool fail_on_invalid_gpu_id{false};
bool validate_parameters{false};
/*!
* \brief Configure the parameter `gpu_id'.
/**
* @brief Configure the parameter `gpu_id'.
*
* \param require_gpu Whether GPU is explicitly required from user.
* @param require_gpu Whether GPU is explicitly required by the user through other
* configurations.
*/
void ConfigureGpuId(bool require_gpu);
/*!
* Return automatically chosen threads.
/**
* @brief Returns the automatically chosen number of threads based on the `nthread`
* parameter and the system settting.
*/
std::int32_t Threads() const;
bool IsCPU() const { return gpu_id == kCpuId; }
bool IsCUDA() const { return !IsCPU(); }
CUDAContext const* CUDACtx() const;
// Make a CUDA context based on the current context.
Context MakeCUDA(std::int32_t device = 0) const {
[[nodiscard]] std::int32_t Threads() const;
/**
* @brief Is XGBoost running on CPU?
*/
[[nodiscard]] bool IsCPU() const { return gpu_id == kCpuId; }
/**
* @brief Is XGBoost running on a CUDA device?
*/
[[nodiscard]] bool IsCUDA() const { return !IsCPU(); }
/**
* @brief Get the current device and ordinal.
*/
[[nodiscard]] DeviceOrd Device() const {
return IsCPU() ? DeviceOrd::CPU() : DeviceOrd::CUDA(static_cast<bst_d_ordinal_t>(gpu_id));
}
/**
* @brief Get the CUDA device ordinal. -1 if XGBoost is running on CPU.
*/
[[nodiscard]] bst_d_ordinal_t Ordinal() const { return this->gpu_id; }
/**
* @brief Name of the current device.
*/
[[nodiscard]] std::string DeviceName() const { return Device().Name(); }
/**
* @brief Get a CUDA device context for allocator and stream.
*/
[[nodiscard]] CUDAContext const* CUDACtx() const;
/**
* @brief Make a CUDA context based on the current context.
*
* @param ordinal The CUDA device ordinal.
*/
[[nodiscard]] Context MakeCUDA(std::int32_t ordinal = 0) const {
Context ctx = *this;
ctx.gpu_id = device;
CHECK_GE(ordinal, 0);
ctx.gpu_id = ordinal;
return ctx;
}
Context MakeCPU() const {
/**
* @brief Make a CPU context based on the current context.
*/
[[nodiscard]] Context MakeCPU() const {
Context ctx = *this;
ctx.gpu_id = kCpuId;
return ctx;
@ -87,9 +176,9 @@ struct Context : public XGBoostParameter<Context> {
}
private:
// mutable for lazy initialization for cuda context to avoid initializing CUDA at load.
// shared_ptr is used instead of unique_ptr as with unique_ptr it's difficult to define p_impl
// while trying to hide CUDA code from host compiler.
// mutable for lazy cuda context initialization. This avoids initializing CUDA at load.
// shared_ptr is used instead of unique_ptr as with unique_ptr it's difficult to define
// p_impl while trying to hide CUDA code from the host compiler.
mutable std::shared_ptr<CUDAContext> cuctx_;
// cached value for CFS CPU limit. (used in containerized env)
std::int32_t cfs_cpu_count_; // NOLINT

View File

@ -149,18 +149,14 @@ class GradientBooster : public Model, public Configurable {
* \param layer_begin Beginning of boosted tree layer used for prediction.
* \param layer_end End of booster layer. 0 means do not limit trees.
* \param approximate use a faster (inconsistent) approximation of SHAP values
* \param condition condition on the condition_feature (0=no, -1=cond off, 1=cond on).
* \param condition_feature feature to condition on (i.e. fix) during calculations
*/
virtual void PredictContribution(DMatrix* dmat,
HostDeviceVector<bst_float>* out_contribs,
unsigned layer_begin, unsigned layer_end,
bool approximate = false, int condition = 0,
unsigned condition_feature = 0) = 0;
virtual void PredictContribution(DMatrix* dmat, HostDeviceVector<float>* out_contribs,
bst_layer_t layer_begin, bst_layer_t layer_end,
bool approximate = false) = 0;
virtual void PredictInteractionContributions(
DMatrix *dmat, HostDeviceVector<bst_float> *out_contribs,
unsigned layer_begin, unsigned layer_end, bool approximate) = 0;
virtual void PredictInteractionContributions(DMatrix* dmat, HostDeviceVector<float>* out_contribs,
bst_layer_t layer_begin, bst_layer_t layer_end,
bool approximate) = 0;
/*!
* \brief dump the model in the requested format

View File

@ -78,7 +78,6 @@ public class BoosterTest {
put("num_round", round);
put("num_workers", 1);
put("tree_method", "gpu_hist");
put("predictor", "gpu_predictor");
put("max_bin", maxBin);
}
};

View File

@ -281,7 +281,6 @@ object GpuPreXGBoost extends PreXGBoostProvider {
// - predictor: Force to gpu predictor since native doesn't save predictor.
val gpuId = if (!isLocal) XGBoost.getGPUAddrFromResources else 0
booster.setParam("gpu_id", gpuId.toString)
booster.setParam("predictor", "gpu_predictor")
logger.info("GPU transform on device: " + gpuId)
boosterFlag.isGpuParamsSet = true;
}

View File

@ -2187,20 +2187,25 @@ class Booster:
base_margin: Any = None,
strict_shape: bool = False,
) -> NumpyOrCupy:
"""Run prediction in-place, Unlike :py:meth:`predict` method, inplace prediction
does not cache the prediction result.
"""Run prediction in-place when possible, Unlike :py:meth:`predict` method,
inplace prediction does not cache the prediction result.
Calling only ``inplace_predict`` in multiple threads is safe and lock
free. But the safety does not hold when used in conjunction with other
methods. E.g. you can't train the booster in one thread and perform
prediction in the other.
.. note::
If the device ordinal of the input data doesn't match the one configured for
the booster, data will be copied to the booster device.
.. code-block:: python
booster.set_param({"predictor": "gpu_predictor"})
booster.set_param({"gpu_id": "0", "tree_method": "gpu_hist"})
booster.inplace_predict(cupy_array)
booster.set_param({"predictor": "cpu_predictor"})
booster.set_param({"gpu_id": "-1", "tree_method": "hist"})
booster.inplace_predict(numpy_array)
.. versionadded:: 1.1.0
@ -2208,9 +2213,7 @@ class Booster:
Parameters
----------
data :
The input data, must not be a view for numpy array. Set
``predictor`` to ``gpu_predictor`` for running prediction on CuPy
array or CuDF DataFrame.
The input data.
iteration_range :
See :py:meth:`predict` for details.
predict_type :

View File

@ -277,9 +277,6 @@ __model_doc = f"""
Device ordinal.
validate_parameters : Optional[bool]
Give warnings for unknown parameter.
predictor : Optional[str]
Force XGBoost to use specific predictor, available choices are [cpu_predictor,
gpu_predictor].
enable_categorical : bool
.. versionadded:: 1.5.0
@ -652,7 +649,6 @@ class XGBModel(XGBModelBase):
importance_type: Optional[str] = None,
gpu_id: Optional[int] = None,
validate_parameters: Optional[bool] = None,
predictor: Optional[str] = None,
enable_categorical: bool = False,
feature_types: Optional[FeatureTypes] = None,
max_cat_to_onehot: Optional[int] = None,
@ -699,7 +695,6 @@ class XGBModel(XGBModelBase):
self.importance_type = importance_type
self.gpu_id = gpu_id
self.validate_parameters = validate_parameters
self.predictor = predictor
self.enable_categorical = enable_categorical
self.feature_types = feature_types
self.max_cat_to_onehot = max_cat_to_onehot
@ -1093,12 +1088,7 @@ class XGBModel(XGBModelBase):
return self
def _can_use_inplace_predict(self) -> bool:
# When predictor is explicitly set, using `inplace_predict` might result into
# error with incompatible data type.
# Inplace predict doesn't handle as many data types as DMatrix, but it's
# sufficient for dask interface where input is simpiler.
predictor = self.get_xgb_params().get("predictor", None)
if predictor in ("auto", None) and self.booster != "gblinear":
if self.booster != "gblinear":
return True
return False
@ -1124,9 +1114,9 @@ class XGBModel(XGBModelBase):
iteration_range: Optional[Tuple[int, int]] = None,
) -> ArrayLike:
"""Predict with `X`. If the model is trained with early stopping, then
:py:attr:`best_iteration` is used automatically. For tree models, when data is
on GPU, like cupy array or cuDF dataframe and `predictor` is not specified, the
prediction is run on GPU automatically, otherwise it will run on CPU.
:py:attr:`best_iteration` is used automatically. The estimator uses
`inplace_predict` by default and falls back to using :py:class:`DMatrix` if
devices between the data and the estimator don't match.
.. note:: This function is only thread safe for `gbtree` and `dart`.
@ -1588,7 +1578,9 @@ class XGBClassifier(XGBModel, XGBClassifierMixIn, XGBClassifierBase):
) -> np.ndarray:
"""Predict the probability of each `X` example being of a given class. If the
model is trained with early stopping, then :py:attr:`best_iteration` is used
automatically.
automatically. The estimator uses `inplace_predict` by default and falls back to
using :py:class:`DMatrix` if devices between the data and the estimator don't
match.
.. note:: This function is only thread safe for `gbtree` and `dart`.

View File

@ -25,6 +25,7 @@ from typing import (
Set,
Tuple,
TypedDict,
TypeVar,
Union,
)
@ -711,6 +712,27 @@ def predictor_equal(lhs: xgb.DMatrix, rhs: xgb.DMatrix) -> bool:
)
M = TypeVar("M", xgb.Booster, xgb.XGBModel)
def set_ordinal(ordinal: int, booster: M) -> M:
"""Temporary solution for setting the device ordinal until we move away from
`gpu_id`.
"""
if ordinal < 0:
params = {"gpu_id": -1, "tree_method": "hist"}
else:
params = {"gpu_id": ordinal, "tree_method": "gpu_hist"}
if isinstance(booster, xgb.Booster):
booster.set_param(params)
elif isinstance(booster, xgb.XGBModel):
booster.set_params(**params)
return booster
def eval_error_metric(predt: np.ndarray, dtrain: xgb.DMatrix) -> Tuple[str, np.float64]:
"""Evaluation metric for xgb.train"""
label = dtrain.get_label()

View File

@ -1023,7 +1023,6 @@ void InplacePredictImpl(std::shared_ptr<DMatrix> p_m, char const *c_json_config,
const float **out_result) {
xgboost_CHECK_C_ARG_PTR(c_json_config);
auto config = Json::Load(StringView{c_json_config});
CHECK_EQ(get<Integer const>(config["cache_id"]), 0) << "Cache ID is not supported yet";
HostDeviceVector<float> *p_predt{nullptr};
auto type = PredictionType(RequiredArg<Integer>(config, "type", __func__));
@ -1042,6 +1041,7 @@ void InplacePredictImpl(std::shared_ptr<DMatrix> p_m, char const *c_json_config,
xgboost_CHECK_C_ARG_PTR(out_dim);
CalcPredictShape(strict_shape, type, n_samples, n_features, chunksize, learner->Groups(),
learner->BoostedRounds(), &shape, out_dim);
CHECK_GE(p_predt->Size(), n_samples);
xgboost_CHECK_C_ARG_PTR(out_result);
xgboost_CHECK_C_ARG_PTR(out_shape);

View File

@ -92,7 +92,7 @@ XGB_DLL int XGDMatrixCreateFromCudaArrayInterface(char const *data,
API_END();
}
int InplacePreidctCuda(BoosterHandle handle, char const *c_array_interface,
int InplacePreidctCUDA(BoosterHandle handle, char const *c_array_interface,
char const *c_json_config, std::shared_ptr<DMatrix> p_m,
xgboost::bst_ulong const **out_shape, xgboost::bst_ulong *out_dim,
const float **out_result) {
@ -107,7 +107,6 @@ int InplacePreidctCuda(BoosterHandle handle, char const *c_array_interface,
proxy->SetCUDAArray(c_array_interface);
auto config = Json::Load(StringView{c_json_config});
CHECK_EQ(get<Integer const>(config["cache_id"]), 0) << "Cache ID is not supported yet";
auto *learner = static_cast<Learner *>(handle);
HostDeviceVector<float> *p_predt{nullptr};
@ -118,7 +117,13 @@ int InplacePreidctCuda(BoosterHandle handle, char const *c_array_interface,
RequiredArg<Integer>(config, "iteration_begin", __func__),
RequiredArg<Integer>(config, "iteration_end", __func__));
CHECK(p_predt);
CHECK(p_predt->DeviceCanRead() && !p_predt->HostCanRead());
if (learner->Ctx()->IsCPU()) {
// Prediction using DMatrix as fallback.
CHECK(p_predt->HostCanRead() && !p_predt->DeviceCanRead());
} else {
CHECK(p_predt->DeviceCanRead() && !p_predt->HostCanRead());
}
p_predt->SetDevice(proxy->DeviceIdx());
auto &shape = learner->GetThreadLocal().prediction_shape;
size_t n_samples = p_m->Info().num_row_;
@ -146,7 +151,7 @@ XGB_DLL int XGBoosterPredictFromCudaColumnar(BoosterHandle handle, char const *c
if (m) {
p_m = *static_cast<std::shared_ptr<DMatrix> *>(m);
}
return InplacePreidctCuda(handle, c_json_strs, c_json_config, p_m, out_shape, out_dim,
return InplacePreidctCUDA(handle, c_json_strs, c_json_config, p_m, out_shape, out_dim,
out_result);
}
@ -159,6 +164,6 @@ XGB_DLL int XGBoosterPredictFromCudaArray(BoosterHandle handle, char const *c_js
p_m = *static_cast<std::shared_ptr<DMatrix> *>(m);
}
xgboost_CHECK_C_ARG_PTR(out_result);
return InplacePreidctCuda(handle, c_json_strs, c_json_config, p_m, out_shape, out_dim,
return InplacePreidctCUDA(handle, c_json_strs, c_json_config, p_m, out_shape, out_dim,
out_result);
}

View File

@ -6,6 +6,11 @@
#ifndef XGBOOST_COMMON_ERROR_MSG_H_
#define XGBOOST_COMMON_ERROR_MSG_H_
#include <cinttypes> // for uint64_t
#include <limits> // for numeric_limits
#include "xgboost/base.h" // for bst_feature_t
#include "xgboost/logging.h"
#include "xgboost/string_view.h" // for StringView
namespace xgboost::error {
@ -33,5 +38,14 @@ constexpr StringView InconsistentMaxBin() {
return "Inconsistent `max_bin`. `max_bin` should be the same across different QuantileDMatrix, "
"and consistent with the Booster being trained.";
}
constexpr StringView UnknownDevice() { return "Unknown device type."; }
inline void MaxFeatureSize(std::uint64_t n_features) {
auto max_n_features = std::numeric_limits<bst_feature_t>::max();
CHECK_LE(n_features, max_n_features)
<< "Unfortunately, XGBoost does not support data matrices with "
<< std::numeric_limits<bst_feature_t>::max() << " features or greater";
}
} // namespace xgboost::error
#endif // XGBOOST_COMMON_ERROR_MSG_H_

View File

@ -7,7 +7,7 @@
#include <dmlc/data.h>
#include <algorithm>
#include <cstddef> // std::size_t
#include <cstddef> // for size_t
#include <functional>
#include <limits>
#include <map>
@ -17,6 +17,7 @@
#include <vector>
#include "../c_api/c_api_error.h"
#include "../common/error_msg.h" // for MaxFeatureSize
#include "../common/math.h"
#include "array_interface.h"
#include "arrow-cdi.h"
@ -300,9 +301,9 @@ class ArrayAdapter : public detail::SingleBatchDataIter<ArrayAdapterBatch> {
array_interface_ = ArrayInterface<2>(get<Object const>(j));
batch_ = ArrayAdapterBatch{array_interface_};
}
ArrayAdapterBatch const& Value() const override { return batch_; }
size_t NumRows() const { return array_interface_.Shape(0); }
size_t NumColumns() const { return array_interface_.Shape(1); }
[[nodiscard]] ArrayAdapterBatch const& Value() const override { return batch_; }
[[nodiscard]] std::size_t NumRows() const { return array_interface_.Shape(0); }
[[nodiscard]] std::size_t NumColumns() const { return array_interface_.Shape(1); }
private:
ArrayAdapterBatch batch_;

View File

@ -31,10 +31,10 @@ void IterativeDMatrix::InitFromCUDA(Context const* ctx, BatchParam const& p,
dh::XGBCachingDeviceAllocator<char> alloc;
auto num_rows = [&]() {
return Dispatch(proxy, [](auto const& value) { return value.NumRows(); });
return cuda_impl::Dispatch(proxy, [](auto const& value) { return value.NumRows(); });
};
auto num_cols = [&]() {
return Dispatch(proxy, [](auto const& value) { return value.NumCols(); });
return cuda_impl::Dispatch(proxy, [](auto const& value) { return value.NumCols(); });
};
size_t row_stride = 0;
@ -74,7 +74,7 @@ void IterativeDMatrix::InitFromCUDA(Context const* ctx, BatchParam const& p,
get_device());
auto* p_sketch = &sketch_containers.back();
proxy->Info().weights_.SetDevice(get_device());
Dispatch(proxy, [&](auto const& value) {
cuda_impl::Dispatch(proxy, [&](auto const& value) {
common::AdapterDeviceSketch(value, p.max_bin, proxy->Info(), missing, p_sketch);
});
}
@ -82,7 +82,7 @@ void IterativeDMatrix::InitFromCUDA(Context const* ctx, BatchParam const& p,
accumulated_rows += batch_rows;
dh::device_vector<size_t> row_counts(batch_rows + 1, 0);
common::Span<size_t> row_counts_span(row_counts.data().get(), row_counts.size());
row_stride = std::max(row_stride, Dispatch(proxy, [=](auto const& value) {
row_stride = std::max(row_stride, cuda_impl::Dispatch(proxy, [=](auto const& value) {
return GetRowCounts(value, row_counts_span, get_device(), missing);
}));
nnz += thrust::reduce(thrust::cuda::par(alloc), row_counts.begin(), row_counts.end());
@ -136,14 +136,14 @@ void IterativeDMatrix::InitFromCUDA(Context const* ctx, BatchParam const& p,
auto rows = num_rows();
dh::device_vector<size_t> row_counts(rows + 1, 0);
common::Span<size_t> row_counts_span(row_counts.data().get(), row_counts.size());
Dispatch(proxy, [=](auto const& value) {
cuda_impl::Dispatch(proxy, [=](auto const& value) {
return GetRowCounts(value, row_counts_span, get_device(), missing);
});
auto is_dense = this->IsDense();
proxy->Info().feature_types.SetDevice(get_device());
auto d_feature_types = proxy->Info().feature_types.ConstDeviceSpan();
auto new_impl = Dispatch(proxy, [&](auto const& value) {
auto new_impl = cuda_impl::Dispatch(proxy, [&](auto const& value) {
return EllpackPageImpl(value, missing, get_device(), is_dense, row_counts_span,
d_feature_types, row_stride, rows, cuts);
});

View File

@ -1,14 +1,13 @@
/*!
* Copyright 2021 by Contributors
/**
* Copyright 2021-2023, XGBoost Contributors
* \file proxy_dmatrix.cc
*/
#include "proxy_dmatrix.h"
namespace xgboost {
namespace data {
void DMatrixProxy::SetArrayData(char const *c_interface) {
std::shared_ptr<ArrayAdapter> adapter{new ArrayAdapter(StringView{c_interface})};
namespace xgboost::data {
void DMatrixProxy::SetArrayData(StringView interface_str) {
std::shared_ptr<ArrayAdapter> adapter{new ArrayAdapter{interface_str}};
this->batch_ = adapter;
this->Info().num_col_ = adapter->NumColumns();
this->Info().num_row_ = adapter->NumRows();
@ -25,5 +24,36 @@ void DMatrixProxy::SetCSRData(char const *c_indptr, char const *c_indices,
this->Info().num_row_ = adapter->NumRows();
this->ctx_.gpu_id = Context::kCpuId;
}
} // namespace data
} // namespace xgboost
namespace cuda_impl {
std::shared_ptr<DMatrix> CreateDMatrixFromProxy(Context const *ctx,
std::shared_ptr<DMatrixProxy> proxy, float missing);
#if !defined(XGBOOST_USE_CUDA)
std::shared_ptr<DMatrix> CreateDMatrixFromProxy(Context const *, std::shared_ptr<DMatrixProxy>,
float) {
return nullptr;
}
#endif // XGBOOST_USE_CUDA
} // namespace cuda_impl
std::shared_ptr<DMatrix> CreateDMatrixFromProxy(Context const *ctx,
std::shared_ptr<DMatrixProxy> proxy,
float missing) {
bool type_error{false};
std::shared_ptr<DMatrix> p_fmat{nullptr};
if (proxy->Ctx()->IsCPU()) {
p_fmat = data::HostAdapterDispatch<false>(
proxy.get(),
[&](auto const &adapter) {
auto p_fmat =
std::shared_ptr<DMatrix>(DMatrix::Create(adapter.get(), missing, ctx->Threads()));
return p_fmat;
},
&type_error);
} else {
p_fmat = cuda_impl::CreateDMatrixFromProxy(ctx, proxy, missing);
}
return p_fmat;
}
} // namespace xgboost::data

View File

@ -1,12 +1,11 @@
/*!
* Copyright 2020-2022, XGBoost contributors
/**
* Copyright 2020-2023, XGBoost contributors
*/
#include "proxy_dmatrix.h"
#include "device_adapter.cuh"
#include "proxy_dmatrix.cuh"
#include "proxy_dmatrix.h"
namespace xgboost {
namespace data {
namespace xgboost::data {
void DMatrixProxy::FromCudaColumnar(StringView interface_str) {
std::shared_ptr<data::CudfAdapter> adapter{new CudfAdapter{interface_str}};
auto const& value = adapter->Value();
@ -31,5 +30,15 @@ void DMatrixProxy::FromCudaArray(StringView interface_str) {
ctx_.gpu_id = dh::CurrentDevice();
}
}
} // namespace data
} // namespace xgboost
namespace cuda_impl {
std::shared_ptr<DMatrix> CreateDMatrixFromProxy(Context const* ctx,
std::shared_ptr<DMatrixProxy> proxy,
float missing) {
return Dispatch<false>(proxy.get(), [&](auto const& adapter) {
auto p_fmat = std::shared_ptr<DMatrix>{DMatrix::Create(adapter.get(), missing, ctx->Threads())};
return p_fmat;
});
}
} // namespace cuda_impl
} // namespace xgboost::data

View File

@ -6,19 +6,34 @@
#include "device_adapter.cuh"
#include "proxy_dmatrix.h"
namespace xgboost::data {
template <typename Fn>
namespace xgboost::data::cuda_impl {
template <bool get_value = true, typename Fn>
decltype(auto) Dispatch(DMatrixProxy const* proxy, Fn fn) {
if (proxy->Adapter().type() == typeid(std::shared_ptr<CupyAdapter>)) {
auto value = std::any_cast<std::shared_ptr<CupyAdapter>>(proxy->Adapter())->Value();
return fn(value);
if constexpr (get_value) {
auto value = std::any_cast<std::shared_ptr<CupyAdapter>>(proxy->Adapter())->Value();
return fn(value);
} else {
auto value = std::any_cast<std::shared_ptr<CupyAdapter>>(proxy->Adapter());
return fn(value);
}
} else if (proxy->Adapter().type() == typeid(std::shared_ptr<CudfAdapter>)) {
auto value = std::any_cast<std::shared_ptr<CudfAdapter>>(proxy->Adapter())->Value();
return fn(value);
if constexpr (get_value) {
auto value = std::any_cast<std::shared_ptr<CudfAdapter>>(proxy->Adapter())->Value();
return fn(value);
} else {
auto value = std::any_cast<std::shared_ptr<CudfAdapter>>(proxy->Adapter());
return fn(value);
}
} else {
LOG(FATAL) << "Unknown type: " << proxy->Adapter().type().name();
auto value = std::any_cast<std::shared_ptr<CudfAdapter>>(proxy->Adapter())->Value();
return fn(value);
if constexpr (get_value) {
auto value = std::any_cast<std::shared_ptr<CudfAdapter>>(proxy->Adapter())->Value();
return fn(value);
} else {
auto value = std::any_cast<std::shared_ptr<CudfAdapter>>(proxy->Adapter());
return fn(value);
}
}
}
} // namespace xgboost::data
} // namespace xgboost::data::cuda_impl

View File

@ -62,7 +62,7 @@ class DMatrixProxy : public DMatrix {
#endif // defined(XGBOOST_USE_CUDA)
}
void SetArrayData(char const* c_interface);
void SetArrayData(StringView interface_str);
void SetCSRData(char const* c_indptr, char const* c_indices, char const* c_values,
bst_feature_t n_features, bool on_host);
@ -114,28 +114,62 @@ inline DMatrixProxy* MakeProxy(DMatrixHandle proxy) {
return typed;
}
template <typename Fn>
/**
* @brief Dispatch function call based on input type.
*
* @tparam get_value Whether the funciton Fn accept an adapter batch or the adapter itself.
* @tparam Fn The type of the function to be dispatched.
*
* @param proxy The proxy object holding the reference to the input.
* @param fn The function to be dispatched.
* @param type_error[out] Set to ture if it's not null and the input data is not recognized by
* the host.
*
* @return The return value of the function being dispatched.
*/
template <bool get_value = true, typename Fn>
decltype(auto) HostAdapterDispatch(DMatrixProxy const* proxy, Fn fn, bool* type_error = nullptr) {
if (proxy->Adapter().type() == typeid(std::shared_ptr<CSRArrayAdapter>)) {
auto value = std::any_cast<std::shared_ptr<CSRArrayAdapter>>(proxy->Adapter())->Value();
if constexpr (get_value) {
auto value = std::any_cast<std::shared_ptr<CSRArrayAdapter>>(proxy->Adapter())->Value();
return fn(value);
} else {
auto value = std::any_cast<std::shared_ptr<CSRArrayAdapter>>(proxy->Adapter());
return fn(value);
}
if (type_error) {
*type_error = false;
}
return fn(value);
} else if (proxy->Adapter().type() == typeid(std::shared_ptr<ArrayAdapter>)) {
auto value = std::any_cast<std::shared_ptr<ArrayAdapter>>(proxy->Adapter())->Value();
if constexpr (get_value) {
auto value = std::any_cast<std::shared_ptr<ArrayAdapter>>(proxy->Adapter())->Value();
return fn(value);
} else {
auto value = std::any_cast<std::shared_ptr<ArrayAdapter>>(proxy->Adapter());
return fn(value);
}
if (type_error) {
*type_error = false;
}
return fn(value);
} else {
if (type_error) {
*type_error = true;
} else {
LOG(FATAL) << "Unknown type: " << proxy->Adapter().type().name();
}
return std::result_of_t<Fn(decltype(std::declval<std::shared_ptr<ArrayAdapter>>()->Value()))>();
if constexpr (get_value) {
return std::result_of_t<Fn(
decltype(std::declval<std::shared_ptr<ArrayAdapter>>()->Value()))>();
} else {
return std::result_of_t<Fn(decltype(std::declval<std::shared_ptr<ArrayAdapter>>()))>();
}
}
}
/**
* @brief Create a `SimpleDMatrix` instance from a `DMatrixProxy`.
*/
std::shared_ptr<DMatrix> CreateDMatrixFromProxy(Context const* ctx,
std::shared_ptr<DMatrixProxy> proxy, float missing);
} // namespace xgboost::data
#endif // XGBOOST_DATA_PROXY_DMATRIX_H_

View File

@ -1,33 +1,31 @@
/*!
* Copyright 2021 XGBoost contributors
/**
* Copyright 2021-2023, XGBoost contributors
*/
#include "../common/device_helpers.cuh" // for CurrentDevice
#include "proxy_dmatrix.cuh" // for Dispatch, DMatrixProxy
#include "simple_dmatrix.cuh" // for CopyToSparsePage
#include "sparse_page_source.h"
#include "proxy_dmatrix.cuh"
#include "simple_dmatrix.cuh"
namespace xgboost {
namespace data {
#include "xgboost/data.h" // for SparsePage
namespace xgboost::data {
namespace detail {
std::size_t NSamplesDevice(DMatrixProxy *proxy) {
return Dispatch(proxy, [](auto const &value) { return value.NumRows(); });
return cuda_impl::Dispatch(proxy, [](auto const &value) { return value.NumRows(); });
}
std::size_t NFeaturesDevice(DMatrixProxy *proxy) {
return Dispatch(proxy, [](auto const &value) { return value.NumCols(); });
return cuda_impl::Dispatch(proxy, [](auto const &value) { return value.NumCols(); });
}
} // namespace detail
void DevicePush(DMatrixProxy* proxy, float missing, SparsePage* page) {
void DevicePush(DMatrixProxy *proxy, float missing, SparsePage *page) {
auto device = proxy->DeviceIdx();
if (device < 0) {
device = dh::CurrentDevice();
}
CHECK_GE(device, 0);
Dispatch(proxy, [&](auto const &value) {
CopyToSparsePage(value, device, missing, page);
});
cuda_impl::Dispatch(proxy,
[&](auto const &value) { CopyToSparsePage(value, device, missing, page); });
}
} // namespace data
} // namespace xgboost
} // namespace xgboost::data

View File

@ -172,8 +172,7 @@ class GBLinear : public GradientBooster {
}
void PredictContribution(DMatrix* p_fmat, HostDeviceVector<bst_float>* out_contribs,
uint32_t layer_begin, uint32_t /*layer_end*/, bool, int,
unsigned) override {
bst_layer_t layer_begin, bst_layer_t /*layer_end*/, bool) override {
model_.LazyInitModel();
LinearCheckLayer(layer_begin);
auto base_margin = p_fmat->Info().base_margin_.View(Context::kCpuId);
@ -210,8 +209,8 @@ class GBLinear : public GradientBooster {
}
}
void PredictInteractionContributions(DMatrix* p_fmat, HostDeviceVector<bst_float>* out_contribs,
unsigned layer_begin, unsigned /*layer_end*/,
void PredictInteractionContributions(DMatrix* p_fmat, HostDeviceVector<float>* out_contribs,
bst_layer_t layer_begin, bst_layer_t /*layer_end*/,
bool) override {
LinearCheckLayer(layer_begin);
std::vector<bst_float>& contribs = out_contribs->HostVector();

View File

@ -18,9 +18,11 @@
#include <vector>
#include "../common/common.h"
#include "../common/error_msg.h" // for UnknownDevice
#include "../common/random.h"
#include "../common/threading_utils.h"
#include "../common/timer.h"
#include "../data/proxy_dmatrix.h" // for DMatrixProxy, HostAdapterDispatch
#include "gbtree_model.h"
#include "xgboost/base.h"
#include "xgboost/data.h"
@ -58,9 +60,8 @@ void GBTree::Configure(Args const& cfg) {
cpu_predictor_->Configure(cfg);
#if defined(XGBOOST_USE_CUDA)
auto n_gpus = common::AllVisibleGPUs();
if (!gpu_predictor_ && n_gpus != 0) {
gpu_predictor_ = std::unique_ptr<Predictor>(
Predictor::Create("gpu_predictor", this->ctx_));
if (!gpu_predictor_) {
gpu_predictor_ = std::unique_ptr<Predictor>(Predictor::Create("gpu_predictor", this->ctx_));
}
if (n_gpus != 0) {
gpu_predictor_->Configure(cfg);
@ -374,12 +375,7 @@ void GBTree::LoadConfig(Json const& in) {
// This would cause all trees to be pushed to trees_to_update
// e.g. updating a model, then saving and loading it would result in an empty model
tparam_.process_type = TreeProcessType::kDefault;
int32_t const n_gpus = xgboost::common::AllVisibleGPUs();
if (n_gpus == 0 && tparam_.predictor == PredictorType::kGPUPredictor) {
LOG(WARNING) << "Loading from a raw memory buffer on CPU only machine. "
"Changing predictor to auto.";
tparam_.UpdateAllowUnknown(Args{{"predictor", "auto"}});
}
std::int32_t const n_gpus = xgboost::common::AllVisibleGPUs();
auto msg = StringView{
R"(
@ -505,8 +501,8 @@ void GBTree::Slice(bst_layer_t begin, bst_layer_t end, bst_layer_t step, Gradien
out_model.param.num_parallel_tree = model_.param.num_parallel_tree;
}
void GBTree::PredictBatch(DMatrix* p_fmat, PredictionCacheEntry* out_preds, bool,
bst_layer_t layer_begin, bst_layer_t layer_end) {
void GBTree::PredictBatchImpl(DMatrix* p_fmat, PredictionCacheEntry* out_preds, bool is_training,
bst_layer_t layer_begin, bst_layer_t layer_end) const {
CHECK(configured_);
if (layer_end == 0) {
layer_end = this->BoostedRounds();
@ -526,7 +522,7 @@ void GBTree::PredictBatch(DMatrix* p_fmat, PredictionCacheEntry* out_preds, bool
CHECK_EQ(out_preds->version, 0);
}
auto const& predictor = GetPredictor(&out_preds->predictions, p_fmat);
auto const& predictor = GetPredictor(is_training, &out_preds->predictions, p_fmat);
if (out_preds->version == 0) {
// out_preds->Size() can be non-zero as it's initialized here before any
// tree is built at the 0^th iterator.
@ -546,68 +542,69 @@ void GBTree::PredictBatch(DMatrix* p_fmat, PredictionCacheEntry* out_preds, bool
}
}
std::unique_ptr<Predictor> const &
GBTree::GetPredictor(HostDeviceVector<float> const *out_pred,
DMatrix *f_dmat) const {
void GBTree::PredictBatch(DMatrix* p_fmat, PredictionCacheEntry* out_preds, bool is_training,
bst_layer_t layer_begin, bst_layer_t layer_end) {
// dispatch to const function.
this->PredictBatchImpl(p_fmat, out_preds, is_training, layer_begin, layer_end);
}
void GBTree::InplacePredict(std::shared_ptr<DMatrix> p_m, float missing,
PredictionCacheEntry* out_preds, bst_layer_t layer_begin,
bst_layer_t layer_end) const {
CHECK(configured_);
if (tparam_.predictor != PredictorType::kAuto) {
if (tparam_.predictor == PredictorType::kGPUPredictor) {
#if defined(XGBOOST_USE_CUDA)
CHECK_GE(common::AllVisibleGPUs(), 1) << "No visible GPU is found for XGBoost.";
CHECK(gpu_predictor_);
return gpu_predictor_;
#else
common::AssertGPUSupport();
#endif // defined(XGBOOST_USE_CUDA)
}
if (tparam_.predictor == PredictorType::kOneAPIPredictor) {
#if defined(XGBOOST_USE_ONEAPI)
CHECK(oneapi_predictor_);
return oneapi_predictor_;
#else
common::AssertOneAPISupport();
#endif // defined(XGBOOST_USE_ONEAPI)
}
CHECK(cpu_predictor_);
return cpu_predictor_;
auto [tree_begin, tree_end] = detail::LayerToTree(model_, layer_begin, layer_end);
CHECK_LE(tree_end, model_.trees.size()) << "Invalid number of trees.";
if (p_m->Ctx()->Device() != this->ctx_->Device()) {
LOG(WARNING) << "Falling back to prediction using DMatrix due to mismatched devices. XGBoost "
<< "is running on: " << this->ctx_->DeviceName()
<< ", while the input data is on: " << p_m->Ctx()->DeviceName() << ".";
CHECK_EQ(out_preds->version, 0);
auto proxy = std::dynamic_pointer_cast<data::DMatrixProxy>(p_m);
auto any_adapter = proxy->Adapter();
auto p_fmat = data::CreateDMatrixFromProxy(ctx_, proxy, missing);
this->PredictBatchImpl(p_fmat.get(), out_preds, false, layer_begin, layer_end);
return;
}
if (this->ctx_->IsCPU()) {
this->cpu_predictor_->InplacePredict(p_m, model_, missing, out_preds, tree_begin, tree_end);
} else if (p_m->Ctx()->IsCUDA()) {
CHECK(this->gpu_predictor_);
this->gpu_predictor_->InplacePredict(p_m, model_, missing, out_preds, tree_begin, tree_end);
} else {
LOG(FATAL) << error::UnknownDevice();
}
}
[[nodiscard]] std::unique_ptr<Predictor> const& GBTree::GetPredictor(
bool is_training, HostDeviceVector<float> const* out_pred, DMatrix* f_dmat) const {
CHECK(configured_);
// Data comes from SparsePageDMatrix. Since we are loading data in pages, no need to
// prevent data copy.
if (f_dmat && !f_dmat->SingleColBlock()) {
if (ctx_->IsCPU()) {
return cpu_predictor_;
} else {
#if defined(XGBOOST_USE_CUDA)
CHECK_GE(common::AllVisibleGPUs(), 1) << "No visible GPU is found for XGBoost.";
return gpu_predictor_;
#else
common::AssertGPUSupport();
return cpu_predictor_;
#endif // defined(XGBOOST_USE_CUDA)
CHECK(gpu_predictor_);
return gpu_predictor_;
}
}
// Data comes from Device DMatrix.
auto is_ellpack = f_dmat && f_dmat->PageExists<EllpackPage>() &&
!f_dmat->PageExists<SparsePage>();
auto is_ellpack =
f_dmat && f_dmat->PageExists<EllpackPage>() && !f_dmat->PageExists<SparsePage>();
// Data comes from device memory, like CuDF or CuPy.
auto is_from_device =
f_dmat && f_dmat->PageExists<SparsePage>() &&
(*(f_dmat->GetBatches<SparsePage>().begin())).data.DeviceCanRead();
auto is_from_device = f_dmat && f_dmat->PageExists<SparsePage>() &&
(*(f_dmat->GetBatches<SparsePage>().begin())).data.DeviceCanRead();
auto on_device = is_ellpack || is_from_device;
// Use GPU Predictor if data is already on device and gpu_id is set.
if (on_device && ctx_->gpu_id >= 0) {
#if defined(XGBOOST_USE_CUDA)
CHECK_GE(common::AllVisibleGPUs(), 1) << "No visible GPU is found for XGBoost.";
if (on_device && ctx_->IsCUDA()) {
common::AssertGPUSupport();
CHECK(gpu_predictor_);
return gpu_predictor_;
#else
LOG(FATAL) << "Data is on CUDA device, but XGBoost is not compiled with "
"CUDA support.";
return cpu_predictor_;
#endif // defined(XGBOOST_USE_CUDA)
}
// GPU_Hist by default has prediction cache calculated from quantile values,
@ -619,23 +616,19 @@ GBTree::GetPredictor(HostDeviceVector<float> const *out_pred,
if ((out_pred && out_pred->Size() == 0) && (model_.param.num_trees != 0) &&
// FIXME(trivialfis): Implement a better method for testing whether data
// is on device after DMatrix refactoring is done.
!on_device) {
!on_device && is_training) {
CHECK(cpu_predictor_);
return cpu_predictor_;
}
if (tparam_.tree_method == TreeMethod::kGPUHist) {
#if defined(XGBOOST_USE_CUDA)
CHECK_GE(common::AllVisibleGPUs(), 1) << "No visible GPU is found for XGBoost.";
if (ctx_->IsCPU()) {
return cpu_predictor_;
} else {
common::AssertGPUSupport();
CHECK(gpu_predictor_);
return gpu_predictor_;
#else
common::AssertGPUSupport();
return cpu_predictor_;
#endif // defined(XGBOOST_USE_CUDA)
}
CHECK(cpu_predictor_);
return cpu_predictor_;
}
@ -750,7 +743,7 @@ class Dart : public GBTree {
bool training, unsigned layer_begin,
unsigned layer_end) const {
CHECK(!this->model_.learner_model_param->IsVectorLeaf()) << "dart" << MTNotImplemented();
auto &predictor = this->GetPredictor(&p_out_preds->predictions, p_fmat);
auto& predictor = this->GetPredictor(training, &p_out_preds->predictions, p_fmat);
CHECK(predictor);
predictor->InitOutPredictions(p_fmat->Info(), &p_out_preds->predictions,
model_);
@ -814,49 +807,46 @@ class Dart : public GBTree {
auto [tree_begin, tree_end] = detail::LayerToTree(model_, layer_begin, layer_end);
auto n_groups = model_.learner_model_param->num_output_group;
std::vector<Predictor const*> predictors {
cpu_predictor_.get(),
#if defined(XGBOOST_USE_CUDA)
gpu_predictor_.get()
#endif // defined(XGBOOST_USE_CUDA)
};
Predictor const* predictor{nullptr};
StringView msg{"Unsupported data type for inplace predict."};
if (ctx_->Device() != p_fmat->Ctx()->Device()) {
LOG(WARNING) << "Falling back to prediction using DMatrix due to mismatched devices. XGBoost "
<< "is running on: " << this->ctx_->DeviceName()
<< ", while the input data is on: " << p_fmat->Ctx()->DeviceName() << ".";
auto proxy = std::dynamic_pointer_cast<data::DMatrixProxy>(p_fmat);
auto any_adapter = proxy->Adapter();
auto p_fmat = data::CreateDMatrixFromProxy(ctx_, proxy, missing);
this->PredictBatchImpl(p_fmat.get(), p_out_preds, false, layer_begin, layer_end);
return;
}
StringView msg{"Unsupported data type for inplace predict."};
PredictionCacheEntry predts;
if (ctx_->gpu_id != Context::kCpuId) {
predts.predictions.SetDevice(ctx_->gpu_id);
}
predts.predictions.Resize(p_fmat->Info().num_row_ * n_groups, 0);
auto get_predictor = [&]() -> Predictor const* {
if (ctx_->IsCPU()) {
return cpu_predictor_.get();
} else if (ctx_->IsCUDA()) {
CHECK(this->gpu_predictor_);
return gpu_predictor_.get();
} else {
LOG(FATAL) << error::UnknownDevice();
return nullptr;
}
};
auto predict_impl = [&](size_t i) {
predts.predictions.Fill(0);
if (tparam_.predictor == PredictorType::kAuto) {
// Try both predictor implementations
bool success = false;
for (auto const& p : predictors) {
if (p && p->InplacePredict(p_fmat, model_, missing, &predts, i, i + 1)) {
success = true;
predictor = p;
break;
}
}
CHECK(success) << msg;
} else {
predictor = this->GetPredictor().get();
bool success = predictor->InplacePredict(p_fmat, model_, missing, &predts, i, i + 1);
CHECK(success) << msg << std::endl
<< "Current Predictor: "
<< (tparam_.predictor == PredictorType::kCPUPredictor ? "cpu_predictor"
: "gpu_predictor");
}
bool success{get_predictor()->InplacePredict(p_fmat, model_, missing, &predts, i, i + 1)};
CHECK(success) << msg;
};
// Inplace predict is not used for training, so no need to drop tree.
for (bst_tree_t i = tree_begin; i < tree_end; ++i) {
predict_impl(i);
if (i == tree_begin) {
predictor->InitOutPredictions(p_fmat->Info(), &p_out_preds->predictions, model_);
get_predictor()->InitOutPredictions(p_fmat->Info(), &p_out_preds->predictions, model_);
}
// Multiple the tree weight
auto w = this->weight_drop_.at(i);
@ -886,25 +876,24 @@ class Dart : public GBTree {
std::vector<bst_float> *out_preds,
unsigned layer_begin, unsigned layer_end) override {
DropTrees(false);
auto &predictor = this->GetPredictor();
auto &predictor = this->GetPredictor(false);
uint32_t _, tree_end;
std::tie(_, tree_end) = detail::LayerToTree(model_, layer_begin, layer_end);
predictor->PredictInstance(inst, out_preds, model_, tree_end);
}
void PredictContribution(DMatrix* p_fmat,
HostDeviceVector<bst_float>* out_contribs,
unsigned layer_begin, unsigned layer_end, bool approximate, int,
unsigned) override {
void PredictContribution(DMatrix* p_fmat, HostDeviceVector<bst_float>* out_contribs,
bst_layer_t layer_begin, bst_layer_t layer_end,
bool approximate) override {
CHECK(configured_);
auto [tree_begin, tree_end] = detail::LayerToTree(model_, layer_begin, layer_end);
cpu_predictor_->PredictContribution(p_fmat, out_contribs, model_, tree_end, &weight_drop_,
approximate);
}
void PredictInteractionContributions(
DMatrix *p_fmat, HostDeviceVector<bst_float> *out_contribs,
unsigned layer_begin, unsigned layer_end, bool approximate) override {
void PredictInteractionContributions(DMatrix* p_fmat, HostDeviceVector<float>* out_contribs,
bst_layer_t layer_begin, bst_layer_t layer_end,
bool approximate) override {
CHECK(configured_);
auto [tree_begin, tree_end] = detail::LayerToTree(model_, layer_begin, layer_end);
cpu_predictor_->PredictInteractionContributions(p_fmat, out_contribs, model_, tree_end,

View File

@ -1,14 +1,11 @@
/*!
* Copyright 2021 by Contributors
/**
* Copyright 2021-2023, XGBoost Contributors
*/
#include "../common/device_helpers.cuh"
#include "xgboost/context.h"
#include "xgboost/linalg.h"
#include "xgboost/span.h"
namespace xgboost {
namespace gbm {
namespace xgboost::gbm {
void GPUCopyGradient(HostDeviceVector<GradientPair> const *in_gpair,
bst_group_t n_groups, bst_group_t group_id,
HostDeviceVector<GradientPair> *out_gpair) {
@ -41,5 +38,4 @@ void GPUDartInplacePredictInc(common::Span<float> out_predts, common::Span<float
out_predts[offset] += (predts[offset] - base_score(0)) * tree_w;
});
}
} // namespace gbm
} // namespace xgboost
} // namespace xgboost::gbm

View File

@ -43,18 +43,10 @@ enum class TreeProcessType : int {
kDefault = 0,
kUpdate = 1
};
enum class PredictorType : int {
kAuto = 0,
kCPUPredictor,
kGPUPredictor,
kOneAPIPredictor
};
} // namespace xgboost
DECLARE_FIELD_ENUM_CLASS(xgboost::TreeMethod);
DECLARE_FIELD_ENUM_CLASS(xgboost::TreeProcessType);
DECLARE_FIELD_ENUM_CLASS(xgboost::PredictorType);
namespace xgboost::gbm {
/*! \brief training parameters */
@ -63,8 +55,6 @@ struct GBTreeTrainParam : public XGBoostParameter<GBTreeTrainParam> {
std::string updater_seq;
/*! \brief type of boosting process to run */
TreeProcessType process_type;
// predictor type
PredictorType predictor;
// tree construction method
TreeMethod tree_method;
// declare parameters
@ -79,13 +69,6 @@ struct GBTreeTrainParam : public XGBoostParameter<GBTreeTrainParam> {
.describe("Whether to run the normal boosting process that creates new trees,"\
" or to update the trees in an existing model.");
DMLC_DECLARE_ALIAS(updater_seq, updater);
DMLC_DECLARE_FIELD(predictor)
.set_default(PredictorType::kAuto)
.add_enum("auto", PredictorType::kAuto)
.add_enum("cpu_predictor", PredictorType::kCPUPredictor)
.add_enum("gpu_predictor", PredictorType::kGPUPredictor)
.add_enum("oneapi_predictor", PredictorType::kOneAPIPredictor)
.describe("Predictor algorithm type");
DMLC_DECLARE_FIELD(tree_method)
.set_default(TreeMethod::kAuto)
.add_enum("auto", TreeMethod::kAuto)
@ -206,15 +189,9 @@ class GBTree : public GradientBooster {
void DoBoost(DMatrix* p_fmat, HostDeviceVector<GradientPair>* in_gpair,
PredictionCacheEntry* predt, ObjFunction const* obj) override;
bool UseGPU() const override {
return
tparam_.predictor == PredictorType::kGPUPredictor ||
tparam_.tree_method == TreeMethod::kGPUHist;
}
[[nodiscard]] bool UseGPU() const override { return tparam_.tree_method == TreeMethod::kGPUHist; }
GBTreeTrainParam const& GetTrainParam() const {
return tparam_;
}
[[nodiscard]] GBTreeTrainParam const& GetTrainParam() const { return tparam_; }
void Load(dmlc::Stream* fi) override { model_.Load(fi); }
void Save(dmlc::Stream* fo) const override {
@ -236,39 +213,14 @@ class GBTree : public GradientBooster {
return !model_.trees.empty() || !model_.trees_to_update.empty();
}
void PredictBatchImpl(DMatrix* p_fmat, PredictionCacheEntry* out_preds, bool is_training,
bst_layer_t layer_begin, bst_layer_t layer_end) const;
void PredictBatch(DMatrix* p_fmat, PredictionCacheEntry* out_preds, bool training,
bst_layer_t layer_begin, bst_layer_t layer_end) override;
void InplacePredict(std::shared_ptr<DMatrix> p_m, float missing, PredictionCacheEntry* out_preds,
bst_layer_t layer_begin, bst_layer_t layer_end) const override {
CHECK(configured_);
auto [tree_begin, tree_end] = detail::LayerToTree(model_, layer_begin, layer_end);
CHECK_LE(tree_end, model_.trees.size()) << "Invalid number of trees.";
std::vector<Predictor const *> predictors{
cpu_predictor_.get(),
#if defined(XGBOOST_USE_CUDA)
gpu_predictor_.get()
#endif // defined(XGBOOST_USE_CUDA)
};
StringView msg{"Unsupported data type for inplace predict."};
if (tparam_.predictor == PredictorType::kAuto) {
// Try both predictor implementations
for (auto const &p : predictors) {
if (p && p->InplacePredict(p_m, model_, missing, out_preds, tree_begin, tree_end)) {
return;
}
}
LOG(FATAL) << msg;
} else {
bool success = this->GetPredictor()->InplacePredict(p_m, model_, missing, out_preds,
tree_begin, tree_end);
CHECK(success) << msg << std::endl
<< "Current Predictor: "
<< (tparam_.predictor == PredictorType::kCPUPredictor
? "cpu_predictor"
: "gpu_predictor");
}
}
bst_layer_t layer_begin, bst_layer_t layer_end) const override;
void FeatureScore(std::string const& importance_type, common::Span<int32_t const> trees,
std::vector<bst_feature_t>* features,
@ -349,32 +301,29 @@ class GBTree : public GradientBooster {
auto [tree_begin, tree_end] = detail::LayerToTree(model_, layer_begin, layer_end);
CHECK_EQ(tree_begin, 0) << "Predict leaf supports only iteration end: (0, "
"n_iteration), use model slicing instead.";
this->GetPredictor()->PredictLeaf(p_fmat, out_preds, model_, tree_end);
this->GetPredictor(false)->PredictLeaf(p_fmat, out_preds, model_, tree_end);
}
void PredictContribution(DMatrix* p_fmat,
HostDeviceVector<bst_float>* out_contribs,
uint32_t layer_begin, uint32_t layer_end, bool approximate,
int, unsigned) override {
void PredictContribution(DMatrix* p_fmat, HostDeviceVector<float>* out_contribs,
bst_layer_t layer_begin, bst_layer_t layer_end,
bool approximate) override {
CHECK(configured_);
auto [tree_begin, tree_end] = detail::LayerToTree(model_, layer_begin, layer_end);
CHECK_EQ(tree_begin, 0)
<< "Predict contribution supports only iteration end: (0, "
"n_iteration), using model slicing instead.";
this->GetPredictor()->PredictContribution(
p_fmat, out_contribs, model_, tree_end, nullptr, approximate);
CHECK_EQ(tree_begin, 0) << "Predict contribution supports only iteration end: (0, "
"n_iteration), using model slicing instead.";
this->GetPredictor(false)->PredictContribution(p_fmat, out_contribs, model_, tree_end, nullptr,
approximate);
}
void PredictInteractionContributions(
DMatrix *p_fmat, HostDeviceVector<bst_float> *out_contribs,
uint32_t layer_begin, uint32_t layer_end, bool approximate) override {
void PredictInteractionContributions(DMatrix* p_fmat, HostDeviceVector<float>* out_contribs,
bst_layer_t layer_begin, bst_layer_t layer_end,
bool approximate) override {
CHECK(configured_);
auto [tree_begin, tree_end] = detail::LayerToTree(model_, layer_begin, layer_end);
CHECK_EQ(tree_begin, 0)
<< "Predict interaction contribution supports only iteration end: (0, "
"n_iteration), using model slicing instead.";
this->GetPredictor()->PredictInteractionContributions(
p_fmat, out_contribs, model_, tree_end, nullptr, approximate);
CHECK_EQ(tree_begin, 0) << "Predict interaction contribution supports only iteration end: (0, "
"n_iteration), using model slicing instead.";
this->GetPredictor(false)->PredictInteractionContributions(p_fmat, out_contribs, model_,
tree_end, nullptr, approximate);
}
[[nodiscard]] std::vector<std::string> DumpModel(const FeatureMap& fmap, bool with_stats,
@ -390,8 +339,9 @@ class GBTree : public GradientBooster {
std::vector<HostDeviceVector<bst_node_t>>* out_position,
std::vector<std::unique_ptr<RegTree>>* ret);
std::unique_ptr<Predictor> const& GetPredictor(HostDeviceVector<float> const* out_pred = nullptr,
DMatrix* f_dmat = nullptr) const;
[[nodiscard]] std::unique_ptr<Predictor> const& GetPredictor(
bool is_training, HostDeviceVector<float> const* out_pred = nullptr,
DMatrix* f_dmat = nullptr) const;
// commit new trees all at once
virtual void CommitModel(TreesOneIter&& new_trees);
@ -410,9 +360,7 @@ class GBTree : public GradientBooster {
std::vector<std::unique_ptr<TreeUpdater>> updaters_;
// Predictors
std::unique_ptr<Predictor> cpu_predictor_;
#if defined(XGBOOST_USE_CUDA)
std::unique_ptr<Predictor> gpu_predictor_;
#endif // defined(XGBOOST_USE_CUDA)
std::unique_ptr<Predictor> gpu_predictor_{nullptr};
#if defined(XGBOOST_USE_ONEAPI)
std::unique_ptr<Predictor> oneapi_predictor_;
#endif // defined(XGBOOST_USE_ONEAPI)

View File

@ -40,6 +40,7 @@
#include "common/api_entry.h" // for XGBAPIThreadLocalEntry
#include "common/charconv.h" // for to_chars, to_chars_result, NumericLimits, from_...
#include "common/common.h" // for ToString, Split
#include "common/error_msg.h" // for MaxFeatureSize
#include "common/io.h" // for PeekableInStream, ReadAll, FixedSizeStream, Mem...
#include "common/observer.h" // for TrainingObserver
#include "common/random.h" // for GlobalRandom
@ -763,9 +764,7 @@ class LearnerConfiguration : public Learner {
CHECK(matrix.first.ptr);
CHECK(!matrix.second.ref.expired());
const uint64_t num_col = matrix.first.ptr->Info().num_col_;
CHECK_LE(num_col, static_cast<uint64_t>(std::numeric_limits<unsigned>::max()))
<< "Unfortunately, XGBoost does not support data matrices with "
<< std::numeric_limits<unsigned>::max() << " features or greater";
error::MaxFeatureSize(num_col);
num_feature = std::max(num_feature, static_cast<uint32_t>(num_col));
}
@ -1413,6 +1412,8 @@ class LearnerImpl : public LearnerIO {
this->CheckModelInitialized();
auto& out_predictions = this->GetThreadLocal().prediction_entry;
out_predictions.version = 0;
this->gbm_->InplacePredict(p_m, missing, &out_predictions, iteration_begin, iteration_end);
if (type == PredictionType::kValue) {
obj_->PredTransform(&out_predictions.predictions);

View File

@ -577,8 +577,8 @@ void LambdaRankUpdatePositionBias(Context const* ctx, linalg::VectorView<double
if (lj(0) >= Eps64()) {
tj_minus(i) = std::pow(lj(i) / lj(0), regularizer);
}
assert(!std::isinf(ti_plus(i)));
assert(!std::isinf(tj_minus(i)));
assert(!isinf(ti_plus(i)));
assert(!isinf(tj_minus(i)));
});
}
} // namespace cuda_impl

View File

@ -883,9 +883,8 @@ class CPUPredictor : public Predictor {
for (const auto &batch : p_fmat->GetBatches<SparsePage>()) {
auto page = batch.GetView();
// parallel over local batch
const auto nsize = static_cast<bst_omp_uint>(batch.Size());
common::ParallelFor(nsize, n_threads, [&](bst_omp_uint i) {
auto row_idx = static_cast<size_t>(batch.base_rowid + i);
common::ParallelFor(batch.Size(), n_threads, [&](auto i) {
auto row_idx = batch.base_rowid + i;
RegTree::FVec &feats = feat_vecs[omp_get_thread_num()];
if (feats.Size() == 0) {
feats.Init(num_feature);

View File

@ -226,9 +226,7 @@ struct GPUHistMakerDevice {
monitor.Init(std::string("GPUHistMakerDevice") + std::to_string(ctx_->gpu_id));
}
~GPUHistMakerDevice() { // NOLINT
dh::safe_cuda(cudaSetDevice(ctx_->gpu_id));
}
~GPUHistMakerDevice() = default;
void InitFeatureGroupsOnce() {
if (!feature_groups) {

View File

@ -25,6 +25,9 @@ class LintersPaths:
"tests/python/test_tree_regularization.py",
"tests/python/test_shap.py",
"tests/python-gpu/test_gpu_data_iterator.py",
"tests/python-gpu/test_gpu_prediction.py",
"tests/python-gpu/load_pickle.py",
"tests/python-gpu/test_gpu_pickling.py",
"tests/test_distributed/test_with_spark/",
"tests/test_distributed/test_gpu_with_spark/",
# demo
@ -68,6 +71,7 @@ class LintersPaths:
"tests/python/test_dt.py",
"tests/python/test_data_iterator.py",
"tests/python-gpu/test_gpu_data_iterator.py",
"tests/python-gpu/load_pickle.py",
"tests/test_distributed/test_with_spark/test_data.py",
"tests/test_distributed/test_gpu_with_spark/test_data.py",
"tests/test_distributed/test_gpu_with_dask/test_gpu_with_dask.py",

View File

@ -41,7 +41,6 @@ std::string GetModelStr() {
"num_class": "0",
"num_feature": "10",
"objective": "reg:linear",
"predictor": "gpu_predictor",
"tree_method": "gpu_hist",
"updater": "grow_gpu_hist"
},

View File

@ -1,17 +1,20 @@
/*!
* Copyright 2019-2022 XGBoost contributors
/**
* Copyright 2019-2023, XGBoost contributors
*/
#include <gtest/gtest.h>
#include <xgboost/context.h>
#include <xgboost/host_device_vector.h> // for HostDeviceVector
#include <xgboost/learner.h> // for Learner
#include "../../../src/data/adapter.h"
#include "../../../src/data/proxy_dmatrix.h"
#include <limits> // for numeric_limits
#include <memory> // for shared_ptr
#include <string> // for string
#include "../../../src/data/proxy_dmatrix.h" // for DMatrixProxy
#include "../../../src/gbm/gbtree.h"
#include "../filesystem.h" // dmlc::TemporaryDirectory
#include "../helpers.h"
#include "xgboost/base.h"
#include "xgboost/host_device_vector.h"
#include "xgboost/learner.h"
#include "xgboost/predictor.h"
namespace xgboost {
@ -113,12 +116,11 @@ TEST(GBTree, WrongUpdater) {
#ifdef XGBOOST_USE_CUDA
TEST(GBTree, ChoosePredictor) {
// The test ensures data don't get pulled into device.
size_t constexpr kRows = 17;
size_t constexpr kCols = 15;
std::size_t constexpr kRows = 17, kCols = 15;
auto p_dmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix();
auto& data = (*(p_dmat->GetBatches<SparsePage>().begin())).data;
auto const& data = (*(p_dmat->GetBatches<SparsePage>().begin())).data;
p_dmat->Info().labels.Reshape(kRows);
auto learner = std::unique_ptr<Learner>(Learner::Create({p_dmat}));
@ -127,14 +129,13 @@ TEST(GBTree, ChoosePredictor) {
learner->UpdateOneIter(i, p_dmat);
}
ASSERT_TRUE(data.HostCanWrite());
dmlc::TemporaryDirectory tempdir;
const std::string fname = tempdir.path + "/model_param.bst";
{
std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(fname.c_str(), "w"));
learner->Save(fo.get());
}
// a new learner
learner = std::unique_ptr<Learner>(Learner::Create({p_dmat}));
{
@ -146,6 +147,8 @@ TEST(GBTree, ChoosePredictor) {
learner->UpdateOneIter(i, p_dmat);
}
ASSERT_TRUE(data.HostCanWrite());
ASSERT_FALSE(data.DeviceCanWrite());
ASSERT_FALSE(data.DeviceCanRead());
// pull data into device.
data.HostVector();
@ -232,14 +235,15 @@ TEST(Dart, JsonIO) {
namespace {
class Dart : public testing::TestWithParam<char const*> {
public:
void Run(std::string predictor) {
void Run(std::string device) {
size_t constexpr kRows = 16, kCols = 10;
HostDeviceVector<float> data;
auto rng = RandomDataGenerator(kRows, kCols, 0);
if (predictor == "gpu_predictor") {
rng.Device(0);
Context ctx;
if (device == "GPU") {
ctx = MakeCUDACtx(0);
}
auto rng = RandomDataGenerator(kRows, kCols, 0).Device(ctx.gpu_id);
auto array_str = rng.GenerateArrayInterface(&data);
auto p_mat = GetDMatrixFromData(data.HostVector(), kRows, kCols);
@ -258,14 +262,14 @@ class Dart : public testing::TestWithParam<char const*> {
learner->UpdateOneIter(i, p_mat);
}
learner->SetParam("predictor", predictor);
ConfigLearnerByCtx(&ctx, learner.get());
HostDeviceVector<float> predts_training;
learner->Predict(p_mat, false, &predts_training, 0, 0, true);
HostDeviceVector<float>* inplace_predts;
std::shared_ptr<data::DMatrixProxy> x{new data::DMatrixProxy{}};
if (predictor == "gpu_predictor") {
if (ctx.IsCUDA()) {
x->SetCUDAArray(array_str.c_str());
} else {
x->SetArrayData(array_str.c_str());
@ -295,10 +299,9 @@ class Dart : public testing::TestWithParam<char const*> {
TEST_P(Dart, Prediction) { this->Run(GetParam()); }
#if defined(XGBOOST_USE_CUDA)
INSTANTIATE_TEST_SUITE_P(PredictorTypes, Dart,
testing::Values("auto", "cpu_predictor", "gpu_predictor"));
INSTANTIATE_TEST_SUITE_P(PredictorTypes, Dart, testing::Values("CPU", "GPU"));
#else
INSTANTIATE_TEST_SUITE_P(PredictorTypes, Dart, testing::Values("auto", "cpu_predictor"));
INSTANTIATE_TEST_SUITE_P(PredictorTypes, Dart, testing::Values("CPU"));
#endif // defined(XGBOOST_USE_CUDA)

View File

@ -0,0 +1,88 @@
/**
* Copyright 2023, XGBoost contributors
*/
#include <xgboost/context.h> // for Context
#include <xgboost/learner.h> // for Learner
#include <xgboost/string_view.h> // for StringView
#include <limits> // for numeric_limits
#include <memory> // for shared_ptr
#include <string> // for string
#include "../../../src/data/adapter.h" // for ArrayAdapter
#include "../../../src/data/device_adapter.cuh" // for CupyAdapter
#include "../../../src/data/proxy_dmatrix.h" // for DMatrixProxy
#include "../helpers.h" // for RandomDataGenerator
namespace xgboost {
void TestInplaceFallback(Context const* ctx) {
// prepare data
bst_row_t n_samples{1024};
bst_feature_t n_features{32};
HostDeviceVector<float> X_storage;
// use a different device than the learner
std::int32_t data_ordinal = ctx->IsCPU() ? 0 : -1;
auto X = RandomDataGenerator{n_samples, n_features, 0.0}
.Device(data_ordinal)
.GenerateArrayInterface(&X_storage);
HostDeviceVector<float> y_storage;
auto y = RandomDataGenerator{n_samples, 1u, 0.0}.GenerateArrayInterface(&y_storage);
std::shared_ptr<DMatrix> Xy;
if (data_ordinal == Context::kCpuId) {
auto X_adapter = data::ArrayAdapter{StringView{X}};
Xy.reset(DMatrix::Create(&X_adapter, std::numeric_limits<float>::quiet_NaN(), ctx->Threads()));
} else {
auto X_adapter = data::CupyAdapter{StringView{X}};
Xy.reset(DMatrix::Create(&X_adapter, std::numeric_limits<float>::quiet_NaN(), ctx->Threads()));
}
Xy->SetInfo("label", y);
// learner is configured to the device specified by ctx
std::unique_ptr<Learner> learner{Learner::Create({Xy})};
ConfigLearnerByCtx(ctx, learner.get());
for (std::int32_t i = 0; i < 3; ++i) {
learner->UpdateOneIter(i, Xy);
}
std::shared_ptr<DMatrix> p_m{new data::DMatrixProxy};
auto proxy = std::dynamic_pointer_cast<data::DMatrixProxy>(p_m);
if (data_ordinal == Context::kCpuId) {
proxy->SetArrayData(StringView{X});
} else {
proxy->SetCUDAArray(X.c_str());
}
HostDeviceVector<float>* out_predt{nullptr};
ConsoleLogger::Configure(Args{{"verbosity", "1"}});
// test whether the warning is raised
::testing::internal::CaptureStderr();
learner->InplacePredict(p_m, PredictionType::kValue, std::numeric_limits<float>::quiet_NaN(),
&out_predt, 0, 0);
auto output = testing::internal::GetCapturedStderr();
std::cout << "output:" << output << std::endl;
ASSERT_NE(output.find("Falling back"), std::string::npos);
// test when the contexts match
Context new_ctx = *proxy->Ctx();
ASSERT_NE(new_ctx.gpu_id, ctx->gpu_id);
ConfigLearnerByCtx(&new_ctx, learner.get());
HostDeviceVector<float>* out_predt_1{nullptr};
// no warning is raised
::testing::internal::CaptureStderr();
learner->InplacePredict(p_m, PredictionType::kValue, std::numeric_limits<float>::quiet_NaN(),
&out_predt_1, 0, 0);
output = testing::internal::GetCapturedStderr();
ASSERT_TRUE(output.empty());
ASSERT_EQ(out_predt->ConstHostVector(), out_predt_1->ConstHostVector());
}
TEST(GBTree, InplacePredictFallback) {
auto ctx = MakeCUDACtx(0);
TestInplaceFallback(&ctx);
}
} // namespace xgboost

View File

@ -395,6 +395,9 @@ std::shared_ptr<DMatrix> RandomDataGenerator::GenerateDMatrix(bool with_label, b
for (auto const& page : out->GetBatches<SparsePage>()) {
page.data.SetDevice(device_);
page.offset.SetDevice(device_);
// pull to device
page.data.ConstDeviceSpan();
page.offset.ConstDeviceSpan();
}
}
if (!ft_.empty()) {

View File

@ -183,7 +183,7 @@ class SimpleRealUniformDistribution {
for (size_t k = m; k != 0; --k) {
sum_value += static_cast<ResultT>((*rng)() - rng->Min()) * r_k;
r_k *= r;
r_k *= static_cast<ResultT>(r);
}
ResultT res = sum_value / r_k;
@ -322,15 +322,14 @@ inline std::shared_ptr<DMatrix> EmptyDMatrix() {
return RandomDataGenerator{0, 0, 0.0}.GenerateDMatrix();
}
inline std::vector<float>
GenerateRandomCategoricalSingleColumn(int n, size_t num_categories) {
inline std::vector<float> GenerateRandomCategoricalSingleColumn(int n, size_t num_categories) {
std::vector<float> x(n);
std::mt19937 rng(0);
std::uniform_int_distribution<size_t> dist(0, num_categories - 1);
std::generate(x.begin(), x.end(), [&]() { return dist(rng); });
// Make sure each category is present
for(size_t i = 0; i < num_categories; i++) {
x[i] = i;
for (size_t i = 0; i < num_categories; i++) {
x[i] = static_cast<decltype(x)::value_type>(i);
}
return x;
}
@ -549,4 +548,15 @@ class DeclareUnifiedDistributedTest(MetricTest) : public ::testing::Test {
}
};
// A temporary solution before we move away from gpu_id.
inline void ConfigLearnerByCtx(Context const* ctx, Learner* learner) {
if (ctx->IsCPU()) {
learner->SetParam("tree_method", "hist");
} else {
learner->SetParam("tree_method", "gpu_hist");
}
learner->SetParam("gpu_id", std::to_string(ctx->gpu_id));
learner->Configure();
ASSERT_EQ(learner->Ctx()->gpu_id, ctx->gpu_id);
}
} // namespace xgboost

View File

@ -122,11 +122,13 @@ TEST(CpuPredictor, BasicColumnSplit) {
}
TEST(CpuPredictor, IterationRange) {
TestIterationRange("cpu_predictor");
Context ctx;
TestIterationRange(&ctx);
}
TEST(CpuPredictor, IterationRangeColmnSplit) {
TestIterationRangeColumnSplit("cpu_predictor");
Context ctx;
TestIterationRangeColumnSplit(&ctx);
}
TEST(CpuPredictor, ExternalMemory) {
@ -139,7 +141,8 @@ TEST(CpuPredictor, ExternalMemory) {
TEST(CpuPredictor, InplacePredict) {
bst_row_t constexpr kRows{128};
bst_feature_t constexpr kCols{64};
auto gen = RandomDataGenerator{kRows, kCols, 0.5}.Device(-1);
Context ctx;
auto gen = RandomDataGenerator{kRows, kCols, 0.5}.Device(ctx.gpu_id);
{
HostDeviceVector<float> data;
gen.GenerateDense(&data);
@ -149,7 +152,7 @@ TEST(CpuPredictor, InplacePredict) {
std::string arr_str;
Json::Dump(array_interface, &arr_str);
x->SetArrayData(arr_str.data());
TestInplacePrediction(x, "cpu_predictor", kRows, kCols, Context::kCpuId);
TestInplacePrediction(&ctx, x, kRows, kCols);
}
{
@ -166,50 +169,50 @@ TEST(CpuPredictor, InplacePredict) {
Json::Dump(col_interface, &col_str);
std::shared_ptr<data::DMatrixProxy> x{new data::DMatrixProxy};
x->SetCSRData(rptr_str.data(), col_str.data(), data_str.data(), kCols, true);
TestInplacePrediction(x, "cpu_predictor", kRows, kCols, Context::kCpuId);
TestInplacePrediction(&ctx, x, kRows, kCols);
}
}
namespace {
void TestUpdatePredictionCache(bool use_subsampling) {
size_t constexpr kRows = 64, kCols = 16, kClasses = 4;
std::size_t constexpr kRows = 64, kCols = 16, kClasses = 4;
LearnerModelParam mparam{MakeMP(kCols, .0, kClasses)};
Context ctx;
std::unique_ptr<gbm::GBTree> gbm;
gbm.reset(static_cast<gbm::GBTree*>(GradientBooster::Create("gbtree", &ctx, &mparam)));
std::map<std::string, std::string> cfg;
cfg["tree_method"] = "hist";
cfg["predictor"] = "cpu_predictor";
Args args{{"tree_method", "hist"}};
if (use_subsampling) {
cfg["subsample"] = "0.5";
args.emplace_back("subsample", "0.5");
}
Args args = {cfg.cbegin(), cfg.cend()};
gbm->Configure(args);
auto dmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix(true, true, kClasses);
HostDeviceVector<GradientPair> gpair;
auto& h_gpair = gpair.HostVector();
h_gpair.resize(kRows*kClasses);
for (size_t i = 0; i < kRows*kClasses; ++i) {
h_gpair.resize(kRows * kClasses);
for (size_t i = 0; i < kRows * kClasses; ++i) {
h_gpair[i] = {static_cast<float>(i), 1};
}
PredictionCacheEntry predtion_cache;
predtion_cache.predictions.Resize(kRows*kClasses, 0);
// after one training iteration predtion_cache is filled with cached in QuantileHistMaker::Builder prediction values
predtion_cache.predictions.Resize(kRows * kClasses, 0);
// after one training iteration predtion_cache is filled with cached in QuantileHistMaker
// prediction values
gbm->DoBoost(dmat.get(), &gpair, &predtion_cache, nullptr);
PredictionCacheEntry out_predictions;
// perform fair prediction on the same input data, should be equal to cached result
// perform prediction from scratch on the same input data, should be equal to cached result
gbm->PredictBatch(dmat.get(), &out_predictions, false, 0, 0);
std::vector<float> &out_predictions_h = out_predictions.predictions.HostVector();
std::vector<float> &predtion_cache_from_train = predtion_cache.predictions.HostVector();
std::vector<float>& out_predictions_h = out_predictions.predictions.HostVector();
std::vector<float>& predtion_cache_from_train = predtion_cache.predictions.HostVector();
for (size_t i = 0; i < out_predictions_h.size(); ++i) {
ASSERT_NEAR(out_predictions_h[i], predtion_cache_from_train[i], kRtEps);
}
}
} // namespace
TEST(CPUPredictor, GHistIndex) {
size_t constexpr kRows{128}, kCols{16}, kBins{64};
@ -223,19 +226,23 @@ TEST(CPUPredictor, GHistIndex) {
}
TEST(CPUPredictor, CategoricalPrediction) {
TestCategoricalPrediction("cpu_predictor");
Context ctx;
TestCategoricalPrediction(&ctx, false);
}
TEST(CPUPredictor, CategoricalPredictionColumnSplit) {
TestCategoricalPredictionColumnSplit("cpu_predictor");
Context ctx;
TestCategoricalPredictionColumnSplit(&ctx);
}
TEST(CPUPredictor, CategoricalPredictLeaf) {
TestCategoricalPredictLeaf(StringView{"cpu_predictor"});
Context ctx;
TestCategoricalPredictLeaf(&ctx, false);
}
TEST(CPUPredictor, CategoricalPredictLeafColumnSplit) {
TestCategoricalPredictLeafColumnSplit(StringView{"cpu_predictor"});
Context ctx;
TestCategoricalPredictLeafColumnSplit(&ctx);
}
TEST(CpuPredictor, UpdatePredictionCache) {
@ -244,21 +251,25 @@ TEST(CpuPredictor, UpdatePredictionCache) {
}
TEST(CpuPredictor, LesserFeatures) {
TestPredictionWithLesserFeatures("cpu_predictor");
Context ctx;
TestPredictionWithLesserFeatures(&ctx);
}
TEST(CpuPredictor, LesserFeaturesColumnSplit) {
TestPredictionWithLesserFeaturesColumnSplit("cpu_predictor");
Context ctx;
TestPredictionWithLesserFeaturesColumnSplit(&ctx);
}
TEST(CpuPredictor, Sparse) {
TestSparsePrediction(0.2, "cpu_predictor");
TestSparsePrediction(0.8, "cpu_predictor");
Context ctx;
TestSparsePrediction(&ctx, 0.2);
TestSparsePrediction(&ctx, 0.8);
}
TEST(CpuPredictor, SparseColumnSplit) {
TestSparsePredictionColumnSplit(0.2, "cpu_predictor");
TestSparsePredictionColumnSplit(0.8, "cpu_predictor");
Context ctx;
TestSparsePredictionColumnSplit(&ctx, 0.2);
TestSparsePredictionColumnSplit(&ctx, 0.8);
}
TEST(CpuPredictor, Multi) {
@ -266,4 +277,6 @@ TEST(CpuPredictor, Multi) {
ctx.nthread = 1;
TestVectorLeafPrediction(&ctx);
}
TEST(CpuPredictor, Access) { TestPredictionDeviceAccess(); }
} // namespace xgboost

View File

@ -15,8 +15,7 @@
#include "../helpers.h"
#include "test_predictor.h"
namespace xgboost {
namespace predictor {
namespace xgboost::predictor {
TEST(GPUPredictor, Basic) {
auto cpu_lparam = MakeCUDACtx(-1);
@ -120,13 +119,14 @@ TEST(GPUPredictor, MGPUBasicColumnSplit) {
}
TEST(GPUPredictor, EllpackBasic) {
size_t constexpr kCols {8};
size_t constexpr kCols{8};
auto ctx = MakeCUDACtx(0);
for (size_t bins = 2; bins < 258; bins += 16) {
size_t rows = bins * 16;
auto p_m = RandomDataGenerator{rows, kCols, 0.0}.Bins(bins).Device(0).GenerateDeviceDMatrix();
ASSERT_FALSE(p_m->PageExists<SparsePage>());
TestPredictionFromGradientIndex<EllpackPage>("gpu_predictor", rows, kCols, p_m);
TestPredictionFromGradientIndex<EllpackPage>("gpu_predictor", bins, kCols, p_m);
TestPredictionFromGradientIndex<EllpackPage>(&ctx, rows, kCols, p_m);
TestPredictionFromGradientIndex<EllpackPage>(&ctx, bins, kCols, p_m);
}
}
@ -181,29 +181,32 @@ TEST(GPUPredictor, ExternalMemoryTest) {
}
TEST(GPUPredictor, InplacePredictCupy) {
auto ctx = MakeCUDACtx(0);
size_t constexpr kRows{128}, kCols{64};
RandomDataGenerator gen(kRows, kCols, 0.5);
gen.Device(0);
gen.Device(ctx.gpu_id);
HostDeviceVector<float> data;
std::string interface_str = gen.GenerateArrayInterface(&data);
std::shared_ptr<DMatrix> p_fmat{new data::DMatrixProxy};
dynamic_cast<data::DMatrixProxy*>(p_fmat.get())->SetCUDAArray(interface_str.c_str());
TestInplacePrediction(p_fmat, "gpu_predictor", kRows, kCols, 0);
TestInplacePrediction(&ctx, p_fmat, kRows, kCols);
}
TEST(GPUPredictor, InplacePredictCuDF) {
auto ctx = MakeCUDACtx(0);
size_t constexpr kRows{128}, kCols{64};
RandomDataGenerator gen(kRows, kCols, 0.5);
gen.Device(0);
gen.Device(ctx.gpu_id);
std::vector<HostDeviceVector<float>> storage(kCols);
auto interface_str = gen.GenerateColumnarArrayInterface(&storage);
std::shared_ptr<DMatrix> p_fmat{new data::DMatrixProxy};
dynamic_cast<data::DMatrixProxy*>(p_fmat.get())->SetCUDAArray(interface_str.c_str());
TestInplacePrediction(p_fmat, "gpu_predictor", kRows, kCols, 0);
TestInplacePrediction(&ctx, p_fmat, kRows, kCols);
}
TEST(GpuPredictor, LesserFeatures) {
TestPredictionWithLesserFeatures("gpu_predictor");
auto ctx = MakeCUDACtx(0);
TestPredictionWithLesserFeatures(&ctx);
}
// Very basic test of empty model
@ -268,15 +271,18 @@ TEST(GPUPredictor, Shap) {
}
TEST(GPUPredictor, IterationRange) {
TestIterationRange("gpu_predictor");
auto ctx = MakeCUDACtx(0);
TestIterationRange(&ctx);
}
TEST(GPUPredictor, CategoricalPrediction) {
TestCategoricalPrediction("gpu_predictor");
auto ctx = MakeCUDACtx(0);
TestCategoricalPrediction(&ctx, false);
}
TEST(GPUPredictor, CategoricalPredictLeaf) {
TestCategoricalPredictLeaf(StringView{"gpu_predictor"});
auto ctx = MakeCUDACtx(0);
TestCategoricalPredictLeaf(&ctx, false);
}
TEST(GPUPredictor, PredictLeafBasic) {
@ -300,8 +306,8 @@ TEST(GPUPredictor, PredictLeafBasic) {
}
TEST(GPUPredictor, Sparse) {
TestSparsePrediction(0.2, "gpu_predictor");
TestSparsePrediction(0.8, "gpu_predictor");
auto ctx = MakeCUDACtx(0);
TestSparsePrediction(&ctx, 0.2);
TestSparsePrediction(&ctx, 0.8);
}
} // namespace predictor
} // namespace xgboost
} // namespace xgboost::predictor

View File

@ -8,9 +8,11 @@
#include <xgboost/data.h> // for DMatrix, BatchIterator, BatchSet, MetaInfo
#include <xgboost/host_device_vector.h> // for HostDeviceVector
#include <xgboost/predictor.h> // for PredictionCacheEntry, Predictor, Predic...
#include <xgboost/string_view.h> // for StringView
#include <algorithm> // for max
#include <limits> // for numeric_limits
#include <memory> // for shared_ptr
#include <unordered_map> // for unordered_map
#include "../../../src/common/bitfield.h" // for LBitField32
@ -51,7 +53,7 @@ void TestTrainingPrediction(size_t rows, size_t bins,
size_t constexpr kIters = 3;
std::unique_ptr<Learner> learner;
auto train = [&](std::string predictor) {
auto train = [&](Context const& ctx) {
p_hist->Info().labels.Reshape(rows, 1);
auto &h_label = p_hist->Info().labels.Data()->HostVector();
@ -65,7 +67,7 @@ void TestTrainingPrediction(size_t rows, size_t bins,
learner->SetParam("num_feature", std::to_string(kCols));
learner->SetParam("num_class", std::to_string(kClasses));
learner->SetParam("max_bin", std::to_string(bins));
learner->SetParam("predictor", predictor);
ConfigLearnerByCtx(&ctx, learner.get());
learner->Configure();
for (size_t i = 0; i < kIters; ++i) {
@ -77,7 +79,7 @@ void TestTrainingPrediction(size_t rows, size_t bins,
learner.reset(Learner::Create({}));
learner->LoadModel(model);
learner->SetParam("predictor", predictor);
ConfigLearnerByCtx(&ctx, learner.get());
learner->Configure();
HostDeviceVector<float> from_full;
@ -93,16 +95,16 @@ void TestTrainingPrediction(size_t rows, size_t bins,
};
if (tree_method == "gpu_hist") {
train("gpu_predictor");
train(MakeCUDACtx(0));
} else {
train("cpu_predictor");
train(Context{});
}
}
void TestInplacePrediction(std::shared_ptr<DMatrix> x, std::string predictor, bst_row_t rows,
bst_feature_t cols, int32_t device) {
size_t constexpr kClasses { 4 };
auto gen = RandomDataGenerator{rows, cols, 0.5}.Device(device);
void TestInplacePrediction(Context const *ctx, std::shared_ptr<DMatrix> x, bst_row_t rows,
bst_feature_t cols) {
std::size_t constexpr kClasses { 4 };
auto gen = RandomDataGenerator{rows, cols, 0.5}.Device(ctx->gpu_id);
std::shared_ptr<DMatrix> m = gen.GenerateDMatrix(true, false, kClasses);
std::unique_ptr<Learner> learner {
@ -113,12 +115,14 @@ void TestInplacePrediction(std::shared_ptr<DMatrix> x, std::string predictor, bs
learner->SetParam("num_class", std::to_string(kClasses));
learner->SetParam("seed", "0");
learner->SetParam("subsample", "0.5");
learner->SetParam("gpu_id", std::to_string(device));
learner->SetParam("predictor", predictor);
learner->SetParam("tree_method", "hist");
for (int32_t it = 0; it < 4; ++it) {
learner->UpdateOneIter(it, m);
}
learner->SetParam("gpu_id", std::to_string(ctx->gpu_id));
learner->Configure();
HostDeviceVector<float> *p_out_predictions_0{nullptr};
learner->InplacePredict(x, PredictionType::kMargin, std::numeric_limits<float>::quiet_NaN(),
&p_out_predictions_0, 0, 2);
@ -154,40 +158,79 @@ void TestInplacePrediction(std::shared_ptr<DMatrix> x, std::string predictor, bs
}
namespace {
std::unique_ptr<Learner> LearnerForTest(std::shared_ptr<DMatrix> dmat, size_t iters,
size_t forest = 1) {
std::unique_ptr<Learner> LearnerForTest(Context const *ctx, std::shared_ptr<DMatrix> dmat,
size_t iters, size_t forest = 1) {
std::unique_ptr<Learner> learner{Learner::Create({dmat})};
learner->SetParams(Args{{"num_parallel_tree", std::to_string(forest)}});
for (size_t i = 0; i < iters; ++i) {
learner->UpdateOneIter(i, dmat);
}
ConfigLearnerByCtx(ctx, learner.get());
return learner;
}
void VerifyPredictionWithLesserFeatures(Learner *learner, std::string const &predictor_name,
size_t rows, std::shared_ptr<DMatrix> const &m_test,
std::shared_ptr<DMatrix> const &m_invalid) {
void VerifyPredictionWithLesserFeatures(Learner *learner, bst_row_t kRows,
std::shared_ptr<DMatrix> m_test,
std::shared_ptr<DMatrix> m_invalid) {
HostDeviceVector<float> prediction;
learner->SetParam("predictor", predictor_name);
learner->Configure();
Json config{Object()};
learner->SaveConfig(&config);
ASSERT_EQ(get<String>(config["learner"]["gradient_booster"]["gbtree_train_param"]["predictor"]),
predictor_name);
learner->Predict(m_test, false, &prediction, 0, 0);
ASSERT_EQ(prediction.Size(), rows);
ASSERT_EQ(prediction.Size(), kRows);
ASSERT_THROW({ learner->Predict(m_invalid, false, &prediction, 0, 0); }, dmlc::Error);
}
void VerifyPredictionWithLesserFeaturesColumnSplit(Learner *learner, size_t rows,
std::shared_ptr<DMatrix> m_test,
std::shared_ptr<DMatrix> m_invalid) {
auto const world_size = collective::GetWorldSize();
auto const rank = collective::GetRank();
std::shared_ptr<DMatrix> sliced_test{m_test->SliceCol(world_size, rank)};
std::shared_ptr<DMatrix> sliced_invalid{m_invalid->SliceCol(world_size, rank)};
VerifyPredictionWithLesserFeatures(learner, rows, sliced_test, sliced_invalid);
}
} // anonymous namespace
void TestPredictionWithLesserFeatures(Context const *ctx) {
size_t constexpr kRows = 256, kTrainCols = 256, kTestCols = 4, kIters = 4;
auto m_train = RandomDataGenerator(kRows, kTrainCols, 0.5).GenerateDMatrix(true);
auto learner = LearnerForTest(ctx, m_train, kIters);
auto m_test = RandomDataGenerator(kRows, kTestCols, 0.5).GenerateDMatrix(false);
auto m_invalid = RandomDataGenerator(kRows, kTrainCols + 1, 0.5).GenerateDMatrix(false);
VerifyPredictionWithLesserFeatures(learner.get(), kRows, m_test, m_invalid);
}
void TestPredictionDeviceAccess() {
Context ctx;
size_t constexpr kRows = 256, kTrainCols = 256, kTestCols = 4, kIters = 4;
auto m_train = RandomDataGenerator(kRows, kTrainCols, 0.5).GenerateDMatrix(true);
auto m_test = RandomDataGenerator(kRows, kTestCols, 0.5).GenerateDMatrix(false);
auto learner = LearnerForTest(&ctx, m_train, kIters);
HostDeviceVector<float> from_cpu;
{
ASSERT_EQ(from_cpu.DeviceIdx(), Context::kCpuId);
Context cpu_ctx;
ConfigLearnerByCtx(&cpu_ctx, learner.get());
learner->Predict(m_test, false, &from_cpu, 0, 0);
ASSERT_TRUE(from_cpu.HostCanWrite());
ASSERT_FALSE(from_cpu.DeviceCanRead());
}
#if defined(XGBOOST_USE_CUDA)
HostDeviceVector<float> from_cpu;
learner->SetParam("predictor", "cpu_predictor");
learner->Predict(m_test, false, &from_cpu, 0, 0);
HostDeviceVector<float> from_cuda;
learner->SetParam("predictor", "gpu_predictor");
learner->Predict(m_test, false, &from_cuda, 0, 0);
{
Context cuda_ctx = MakeCUDACtx(0);
ConfigLearnerByCtx(&cuda_ctx, learner.get());
learner->Predict(m_test, false, &from_cuda, 0, 0);
ASSERT_EQ(from_cuda.DeviceIdx(), 0);
ASSERT_TRUE(from_cuda.DeviceCanWrite());
ASSERT_FALSE(from_cuda.HostCanRead());
}
auto const &h_cpu = from_cpu.ConstHostVector();
auto const &h_gpu = from_cuda.ConstHostVector();
@ -196,41 +239,17 @@ void VerifyPredictionWithLesserFeatures(Learner *learner, std::string const &pre
}
#endif // defined(XGBOOST_USE_CUDA)
}
} // anonymous namespace
void TestPredictionWithLesserFeatures(std::string predictor_name) {
void TestPredictionWithLesserFeaturesColumnSplit(Context const *ctx) {
size_t constexpr kRows = 256, kTrainCols = 256, kTestCols = 4, kIters = 4;
auto m_train = RandomDataGenerator(kRows, kTrainCols, 0.5).GenerateDMatrix(true);
auto learner = LearnerForTest(m_train, kIters);
auto m_test = RandomDataGenerator(kRows, kTestCols, 0.5).GenerateDMatrix(false);
auto m_invalid = RandomDataGenerator(kRows, kTrainCols + 1, 0.5).GenerateDMatrix(false);
VerifyPredictionWithLesserFeatures(learner.get(), predictor_name, kRows, m_test, m_invalid);
}
namespace {
void VerifyPredictionWithLesserFeaturesColumnSplit(Learner *learner,
std::string const &predictor_name, size_t rows,
std::shared_ptr<DMatrix> m_test,
std::shared_ptr<DMatrix> m_invalid) {
auto const world_size = collective::GetWorldSize();
auto const rank = collective::GetRank();
std::shared_ptr<DMatrix> sliced_test{m_test->SliceCol(world_size, rank)};
std::shared_ptr<DMatrix> sliced_invalid{m_invalid->SliceCol(world_size, rank)};
VerifyPredictionWithLesserFeatures(learner, predictor_name, rows, sliced_test, sliced_invalid);
}
} // anonymous namespace
void TestPredictionWithLesserFeaturesColumnSplit(std::string predictor_name) {
size_t constexpr kRows = 256, kTrainCols = 256, kTestCols = 4, kIters = 4;
auto m_train = RandomDataGenerator(kRows, kTrainCols, 0.5).GenerateDMatrix(true);
auto learner = LearnerForTest(m_train, kIters);
auto learner = LearnerForTest(ctx, m_train, kIters);
auto m_test = RandomDataGenerator(kRows, kTestCols, 0.5).GenerateDMatrix(false);
auto m_invalid = RandomDataGenerator(kRows, kTrainCols + 1, 0.5).GenerateDMatrix(false);
auto constexpr kWorldSize = 2;
RunWithInMemoryCommunicator(kWorldSize, VerifyPredictionWithLesserFeaturesColumnSplit,
learner.get(), predictor_name, kRows, m_test, m_invalid);
learner.get(), kRows, m_test, m_invalid);
}
void GBTreeModelForTest(gbm::GBTreeModel *model, uint32_t split_ind,
@ -252,7 +271,7 @@ void GBTreeModelForTest(gbm::GBTreeModel *model, uint32_t split_ind,
model->CommitModelGroup(std::move(trees), 0);
}
void TestCategoricalPrediction(std::string name, bool is_column_split) {
void TestCategoricalPrediction(Context const* ctx, bool is_column_split) {
size_t constexpr kCols = 10;
PredictionCacheEntry out_predictions;
@ -262,13 +281,10 @@ void TestCategoricalPrediction(std::string name, bool is_column_split) {
float left_weight = 1.3f;
float right_weight = 1.7f;
Context ctx;
ctx.UpdateAllowUnknown(Args{});
gbm::GBTreeModel model(&mparam, &ctx);
gbm::GBTreeModel model(&mparam, ctx);
GBTreeModelForTest(&model, split_ind, split_cat, left_weight, right_weight);
ctx.UpdateAllowUnknown(Args{{"gpu_id", "0"}});
std::unique_ptr<Predictor> predictor{Predictor::Create(name.c_str(), &ctx)};
std::unique_ptr<Predictor> predictor{CreatePredictorForTest(ctx)};
std::vector<float> row(kCols);
row[split_ind] = split_cat;
@ -298,12 +314,12 @@ void TestCategoricalPrediction(std::string name, bool is_column_split) {
ASSERT_EQ(out_predictions.predictions.HostVector()[0], left_weight + score);
}
void TestCategoricalPredictionColumnSplit(std::string name) {
void TestCategoricalPredictionColumnSplit(Context const *ctx) {
auto constexpr kWorldSize = 2;
RunWithInMemoryCommunicator(kWorldSize, TestCategoricalPrediction, name, true);
RunWithInMemoryCommunicator(kWorldSize, TestCategoricalPrediction, ctx, true);
}
void TestCategoricalPredictLeaf(StringView name, bool is_column_split) {
void TestCategoricalPredictLeaf(Context const *ctx, bool is_column_split) {
size_t constexpr kCols = 10;
PredictionCacheEntry out_predictions;
@ -314,14 +330,10 @@ void TestCategoricalPredictLeaf(StringView name, bool is_column_split) {
float left_weight = 1.3f;
float right_weight = 1.7f;
Context ctx;
ctx.UpdateAllowUnknown(Args{});
gbm::GBTreeModel model(&mparam, &ctx);
gbm::GBTreeModel model(&mparam, ctx);
GBTreeModelForTest(&model, split_ind, split_cat, left_weight, right_weight);
ctx.gpu_id = 0;
std::unique_ptr<Predictor> predictor{Predictor::Create(name.c_str(), &ctx)};
std::unique_ptr<Predictor> predictor{CreatePredictorForTest(ctx)};
std::vector<float> row(kCols);
row[split_ind] = split_cat;
@ -346,19 +358,21 @@ void TestCategoricalPredictLeaf(StringView name, bool is_column_split) {
ASSERT_EQ(out_predictions.predictions.HostVector()[0], 1);
}
void TestCategoricalPredictLeafColumnSplit(StringView name) {
void TestCategoricalPredictLeafColumnSplit(Context const *ctx) {
auto constexpr kWorldSize = 2;
RunWithInMemoryCommunicator(kWorldSize, TestCategoricalPredictLeaf, name, true);
RunWithInMemoryCommunicator(kWorldSize, TestCategoricalPredictLeaf, ctx, true);
}
void TestIterationRange(std::string name) {
void TestIterationRange(Context const* ctx) {
size_t constexpr kRows = 1000, kCols = 20, kClasses = 4, kForest = 3, kIters = 10;
auto dmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix(true, true, kClasses);
auto learner = LearnerForTest(dmat, kIters, kForest);
learner->SetParams(Args{{"predictor", name}});
auto dmat = RandomDataGenerator(kRows, kCols, 0)
.Device(ctx->gpu_id)
.GenerateDMatrix(true, true, kClasses);
auto learner = LearnerForTest(ctx, dmat, kIters, kForest);
bool bound = false;
std::unique_ptr<Learner> sliced {learner->Slice(0, 3, 1, &bound)};
bst_layer_t lend{3};
std::unique_ptr<Learner> sliced{learner->Slice(0, lend, 1, &bound)};
ASSERT_FALSE(bound);
HostDeviceVector<float> out_predt_sliced;
@ -366,11 +380,8 @@ void TestIterationRange(std::string name) {
// margin
{
sliced->Predict(dmat, true, &out_predt_sliced, 0, 0, false, false, false,
false, false);
learner->Predict(dmat, true, &out_predt_ranged, 0, 3, false, false, false,
false, false);
sliced->Predict(dmat, true, &out_predt_sliced, 0, 0, false, false, false, false, false);
learner->Predict(dmat, true, &out_predt_ranged, 0, lend, false, false, false, false, false);
auto const &h_sliced = out_predt_sliced.HostVector();
auto const &h_range = out_predt_ranged.HostVector();
@ -380,11 +391,8 @@ void TestIterationRange(std::string name) {
// SHAP
{
sliced->Predict(dmat, false, &out_predt_sliced, 0, 0, false, false,
true, false, false);
learner->Predict(dmat, false, &out_predt_ranged, 0, 3, false, false, true,
false, false);
sliced->Predict(dmat, false, &out_predt_sliced, 0, 0, false, false, true, false, false);
learner->Predict(dmat, false, &out_predt_ranged, 0, lend, false, false, true, false, false);
auto const &h_sliced = out_predt_sliced.HostVector();
auto const &h_range = out_predt_ranged.HostVector();
@ -394,10 +402,8 @@ void TestIterationRange(std::string name) {
// SHAP interaction
{
sliced->Predict(dmat, false, &out_predt_sliced, 0, 0, false, false,
false, false, true);
learner->Predict(dmat, false, &out_predt_ranged, 0, 3, false, false, false,
false, true);
sliced->Predict(dmat, false, &out_predt_sliced, 0, 0, false, false, false, false, true);
learner->Predict(dmat, false, &out_predt_ranged, 0, lend, false, false, false, false, true);
auto const &h_sliced = out_predt_sliced.HostVector();
auto const &h_range = out_predt_ranged.HostVector();
ASSERT_EQ(h_sliced.size(), h_range.size());
@ -406,10 +412,8 @@ void TestIterationRange(std::string name) {
// Leaf
{
sliced->Predict(dmat, false, &out_predt_sliced, 0, 0, false, true,
false, false, false);
learner->Predict(dmat, false, &out_predt_ranged, 0, 3, false, true, false,
false, false);
sliced->Predict(dmat, false, &out_predt_sliced, 0, 0, false, true, false, false, false);
learner->Predict(dmat, false, &out_predt_ranged, 0, lend, false, true, false, false, false);
auto const &h_sliced = out_predt_sliced.HostVector();
auto const &h_range = out_predt_ranged.HostVector();
ASSERT_EQ(h_sliced.size(), h_range.size());
@ -456,11 +460,16 @@ void VerifyIterationRangeColumnSplit(DMatrix *dmat, Learner *learner, Learner *s
}
} // anonymous namespace
void TestIterationRangeColumnSplit(std::string name) {
void TestIterationRangeColumnSplit(Context const* ctx) {
size_t constexpr kRows = 1000, kCols = 20, kClasses = 4, kForest = 3, kIters = 10;
auto dmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix(true, true, kClasses);
auto learner = LearnerForTest(dmat, kIters, kForest);
learner->SetParams(Args{{"predictor", name}});
auto learner = LearnerForTest(ctx, dmat, kIters, kForest);
if (ctx->IsCPU()) {
learner->SetParams(Args{{"gpu_id", std::to_string(-1)}});
} else {
learner->SetParams(Args{{"gpu_id", std::to_string(0)}});
}
bool bound = false;
std::unique_ptr<Learner> sliced{learner->Slice(0, 3, 1, &bound)};
@ -488,10 +497,10 @@ void TestIterationRangeColumnSplit(std::string name) {
leaf_ranged, leaf_sliced);
}
void TestSparsePrediction(float sparsity, std::string predictor) {
void TestSparsePrediction(Context const *ctx, float sparsity) {
size_t constexpr kRows = 512, kCols = 128, kIters = 4;
auto Xy = RandomDataGenerator(kRows, kCols, sparsity).GenerateDMatrix(true);
auto learner = LearnerForTest(Xy, kIters);
auto learner = LearnerForTest(ctx, Xy, kIters);
HostDeviceVector<float> sparse_predt;
@ -501,11 +510,14 @@ void TestSparsePrediction(float sparsity, std::string predictor) {
learner.reset(Learner::Create({Xy}));
learner->LoadModel(model);
learner->SetParam("predictor", predictor);
if (ctx->IsCUDA()) {
learner->SetParam("tree_method", "gpu_hist");
learner->SetParam("gpu_id", std::to_string(ctx->gpu_id));
}
learner->Predict(Xy, false, &sparse_predt, 0, 0);
HostDeviceVector<float> with_nan(kRows * kCols, std::numeric_limits<float>::quiet_NaN());
auto& h_with_nan = with_nan.HostVector();
auto &h_with_nan = with_nan.HostVector();
for (auto const &page : Xy->GetBatches<SparsePage>()) {
auto batch = page.GetView();
for (size_t i = 0; i < batch.Size(); ++i) {
@ -516,7 +528,8 @@ void TestSparsePrediction(float sparsity, std::string predictor) {
}
}
learner->SetParam("predictor", "cpu_predictor");
learner->SetParam("tree_method", "hist");
learner->SetParam("gpu_id", "-1");
// Xcode_12.4 doesn't compile with `std::make_shared`.
auto dense = std::shared_ptr<DMatrix>(new data::DMatrixProxy{});
auto array_interface = GetArrayInterface(&with_nan, kRows, kCols);
@ -527,8 +540,8 @@ void TestSparsePrediction(float sparsity, std::string predictor) {
learner->InplacePredict(dense, PredictionType::kValue, std::numeric_limits<float>::quiet_NaN(),
&p_dense_predt, 0, 0);
auto const& dense_predt = *p_dense_predt;
if (predictor == "cpu_predictor") {
auto const &dense_predt = *p_dense_predt;
if (ctx->IsCPU()) {
ASSERT_EQ(dense_predt.HostVector(), sparse_predt.HostVector());
} else {
auto const &h_dense = dense_predt.HostVector();
@ -556,10 +569,10 @@ void VerifySparsePredictionColumnSplit(DMatrix *dmat, Learner *learner,
}
} // anonymous namespace
void TestSparsePredictionColumnSplit(float sparsity, std::string predictor) {
void TestSparsePredictionColumnSplit(Context const* ctx, float sparsity) {
size_t constexpr kRows = 512, kCols = 128, kIters = 4;
auto Xy = RandomDataGenerator(kRows, kCols, sparsity).GenerateDMatrix(true);
auto learner = LearnerForTest(Xy, kIters);
auto learner = LearnerForTest(ctx, Xy, kIters);
HostDeviceVector<float> sparse_predt;
@ -569,7 +582,7 @@ void TestSparsePredictionColumnSplit(float sparsity, std::string predictor) {
learner.reset(Learner::Create({Xy}));
learner->LoadModel(model);
learner->SetParam("predictor", predictor);
ConfigLearnerByCtx(ctx, learner.get());
learner->Predict(Xy, false, &sparse_predt, 0, 0);
auto constexpr kWorldSize = 2;

View File

@ -31,8 +31,17 @@ inline gbm::GBTreeModel CreateTestModel(LearnerModelParam const* param, Context
return model;
}
inline auto CreatePredictorForTest(Context const* ctx) {
if (ctx->IsCPU()) {
return Predictor::Create("cpu_predictor", ctx);
} else {
return Predictor::Create("gpu_predictor", ctx);
}
}
// fixme: cpu test
template <typename Page>
void TestPredictionFromGradientIndex(std::string name, size_t rows, size_t cols,
void TestPredictionFromGradientIndex(Context const* ctx, size_t rows, size_t cols,
std::shared_ptr<DMatrix> p_hist) {
constexpr size_t kClasses { 3 };
@ -40,12 +49,10 @@ void TestPredictionFromGradientIndex(std::string name, size_t rows, size_t cols,
auto cuda_ctx = MakeCUDACtx(0);
std::unique_ptr<Predictor> predictor =
std::unique_ptr<Predictor>(Predictor::Create(name, &cuda_ctx));
std::unique_ptr<Predictor>(CreatePredictorForTest(&cuda_ctx));
predictor->Configure({});
Context ctx;
ctx.UpdateAllowUnknown(Args{});
gbm::GBTreeModel model = CreateTestModel(&mparam, &ctx, kClasses);
gbm::GBTreeModel model = CreateTestModel(&mparam, ctx, kClasses);
{
auto p_precise = RandomDataGenerator(rows, cols, 0).GenerateDMatrix();
@ -81,28 +88,30 @@ void TestTrainingPrediction(size_t rows, size_t bins, std::string tree_method,
std::shared_ptr<DMatrix> p_full,
std::shared_ptr<DMatrix> p_hist);
void TestInplacePrediction(std::shared_ptr<DMatrix> x, std::string predictor, bst_row_t rows,
bst_feature_t cols, int32_t device = -1);
void TestInplacePrediction(Context const* ctx, std::shared_ptr<DMatrix> x, bst_row_t rows,
bst_feature_t cols);
void TestPredictionWithLesserFeatures(std::string preditor_name);
void TestPredictionWithLesserFeatures(Context const* ctx);
void TestPredictionWithLesserFeaturesColumnSplit(std::string preditor_name);
void TestPredictionDeviceAccess();
void TestCategoricalPrediction(std::string name, bool is_column_split = false);
void TestCategoricalPrediction(Context const* ctx, bool is_column_split);
void TestCategoricalPredictionColumnSplit(std::string name);
void TestCategoricalPredictionColumnSplit(Context const* ctx);
void TestCategoricalPredictLeaf(StringView name, bool is_column_split = false);
void TestPredictionWithLesserFeaturesColumnSplit(Context const* ctx);
void TestCategoricalPredictLeafColumnSplit(StringView name);
void TestCategoricalPredictLeaf(Context const* ctx, bool is_column_split);
void TestIterationRange(std::string name);
void TestCategoricalPredictLeafColumnSplit(Context const* ctx);
void TestIterationRangeColumnSplit(std::string name);
void TestIterationRange(Context const* ctx);
void TestSparsePrediction(float sparsity, std::string predictor);
void TestIterationRangeColumnSplit(Context const* ctx);
void TestSparsePredictionColumnSplit(float sparsity, std::string predictor);
void TestSparsePrediction(Context const* ctx, float sparsity);
void TestSparsePredictionColumnSplit(Context const* ctx, float sparsity);
void TestVectorLeafPrediction(Context const* ctx);
} // namespace xgboost

View File

@ -342,16 +342,6 @@ TEST(Learner, GPUConfiguration) {
learner->UpdateOneIter(0, p_dmat);
ASSERT_EQ(learner->Ctx()->gpu_id, 0);
}
{
// With CPU algorithm but GPU Predictor, this is to simulate when
// XGBoost is only used for prediction, so tree method is not
// specified.
std::unique_ptr<Learner> learner {Learner::Create(mat)};
learner->SetParams({Arg{"tree_method", "hist"},
Arg{"predictor", "gpu_predictor"}});
learner->UpdateOneIter(0, p_dmat);
ASSERT_EQ(learner->Ctx()->gpu_id, 0);
}
}
#endif // defined(XGBOOST_USE_CUDA)

View File

@ -698,10 +698,6 @@ TEST_F(MultiClassesSerializationTest, GpuHist) {
{"seed", "0"},
{"nthread", "1"},
{"max_depth", std::to_string(kClasses)},
// Somehow rebuilding the cache can generate slightly
// different result (1e-7) with CPU predictor for some
// entries.
{"predictor", "gpu_predictor"},
// Mitigate the difference caused by hardware fused multiply
// add to tree weight during update prediction cache.
{"learning_rate", "1.0"},

View File

@ -1,5 +1,5 @@
'''Loading a pickled model generated by test_pickling.py, only used by
`test_gpu_with_dask.py`'''
"""Loading a pickled model generated by test_pickling.py, only used by
`test_gpu_with_dask.py`"""
import json
import os
@ -12,9 +12,9 @@ from xgboost import testing as tm
class TestLoadPickle:
def test_load_pkl(self):
'''Test whether prediction is correct.'''
assert os.environ['CUDA_VISIBLE_DEVICES'] == '-1'
def test_load_pkl(self) -> None:
"""Test whether prediction is correct."""
assert os.environ["CUDA_VISIBLE_DEVICES"] == "-1"
bst = load_pickle(model_path)
x, y = build_dataset()
if isinstance(bst, xgb.Booster):
@ -28,46 +28,42 @@ class TestLoadPickle:
assert len(res) == 10
def test_predictor_type_is_auto(self):
'''Under invalid CUDA_VISIBLE_DEVICES, predictor should be set to
auto'''
assert os.environ['CUDA_VISIBLE_DEVICES'] == '-1'
def test_context_is_removed(self) -> None:
"""Under invalid CUDA_VISIBLE_DEVICES, context should reset"""
assert os.environ["CUDA_VISIBLE_DEVICES"] == "-1"
bst = load_pickle(model_path)
config = bst.save_config()
config = json.loads(config)
assert config['learner']['gradient_booster']['gbtree_train_param'][
'predictor'] == 'auto'
assert config["learner"]["generic_param"]["gpu_id"] == "-1"
def test_predictor_type_is_gpu(self):
'''When CUDA_VISIBLE_DEVICES is not specified, keep using
`gpu_predictor`'''
assert 'CUDA_VISIBLE_DEVICES' not in os.environ.keys()
def test_context_is_preserved(self) -> None:
"""Test the device context is preserved after pickling."""
assert "CUDA_VISIBLE_DEVICES" not in os.environ.keys()
bst = load_pickle(model_path)
config = bst.save_config()
config = json.loads(config)
assert config['learner']['gradient_booster']['gbtree_train_param'][
'predictor'] == 'gpu_predictor'
assert config["learner"]["generic_param"]["gpu_id"] == "0"
def test_wrap_gpu_id(self):
assert os.environ['CUDA_VISIBLE_DEVICES'] == '0'
def test_wrap_gpu_id(self) -> None:
assert os.environ["CUDA_VISIBLE_DEVICES"] == "0"
bst = load_pickle(model_path)
config = bst.save_config()
config = json.loads(config)
assert config['learner']['generic_param']['gpu_id'] == '0'
assert config["learner"]["generic_param"]["gpu_id"] == "0"
x, y = build_dataset()
test_x = xgb.DMatrix(x)
res = bst.predict(test_x)
assert len(res) == 10
def test_training_on_cpu_only_env(self):
assert os.environ['CUDA_VISIBLE_DEVICES'] == '-1'
def test_training_on_cpu_only_env(self) -> None:
assert os.environ["CUDA_VISIBLE_DEVICES"] == "-1"
rng = np.random.RandomState(1994)
X = rng.randn(10, 10)
y = rng.randn(10)
with tm.captured_output() as (out, err):
# Test no thrust exception is thrown
with pytest.raises(xgb.core.XGBoostError):
xgb.train({'tree_method': 'gpu_hist'}, xgb.DMatrix(X, y))
xgb.train({"tree_method": "gpu_hist"}, xgb.DMatrix(X, y))
assert out.getvalue().find('No visible GPU is found') != -1
assert out.getvalue().find("No visible GPU is found") != -1

View File

@ -203,7 +203,7 @@ class TestQuantileDMatrix:
np.testing.assert_equal(h_ret.indices, d_ret.indices)
booster = xgb.train(
{"tree_method": "gpu_hist", "predictor": "gpu_predictor"}, dtrain=d_m
{"tree_method": "gpu_hist", "gpu_id": "0"}, dtrain=d_m
)
np.testing.assert_allclose(

View File

@ -221,9 +221,10 @@ Arrow specification.'''
def test_specified_device(self):
import cupy as cp
cp.cuda.runtime.setDevice(0)
dtrain = dmatrix_from_cupy(
np.float32, xgb.QuantileDMatrix, np.nan)
with pytest.raises(xgb.core.XGBoostError):
dtrain = dmatrix_from_cupy(np.float32, xgb.QuantileDMatrix, np.nan)
with pytest.raises(
xgb.core.XGBoostError, match="Data is resided on a different device"
):
xgb.train(
{'tree_method': 'gpu_hist', 'gpu_id': 1}, dtrain, num_boost_round=10
)

View File

@ -1,5 +1,4 @@
'''Test model IO with pickle.'''
import json
"""Test model IO with pickle."""
import os
import pickle
import subprocess
@ -11,49 +10,48 @@ import xgboost as xgb
from xgboost import XGBClassifier
from xgboost import testing as tm
model_path = './model.pkl'
model_path = "./model.pkl"
pytestmark = tm.timeout(30)
def build_dataset():
N = 10
x = np.linspace(0, N*N, N*N)
x = np.linspace(0, N * N, N * N)
x = x.reshape((N, N))
y = np.linspace(0, N, N)
return x, y
def save_pickle(bst, path):
with open(path, 'wb') as fd:
with open(path, "wb") as fd:
pickle.dump(bst, fd)
def load_pickle(path):
with open(path, 'rb') as fd:
with open(path, "rb") as fd:
bst = pickle.load(fd)
return bst
class TestPickling:
args_template = [
"pytest",
"--verbose",
"-s",
"--fulltrace"]
args_template = ["pytest", "--verbose", "-s", "--fulltrace"]
def run_pickling(self, bst) -> None:
save_pickle(bst, model_path)
args = [
"pytest", "--verbose", "-s", "--fulltrace",
"./tests/python-gpu/load_pickle.py::TestLoadPickle::test_load_pkl"
"pytest",
"--verbose",
"-s",
"--fulltrace",
"./tests/python-gpu/load_pickle.py::TestLoadPickle::test_load_pkl",
]
command = ''
command = ""
for arg in args:
command += arg
command += ' '
command += " "
cuda_environment = {'CUDA_VISIBLE_DEVICES': '-1'}
cuda_environment = {"CUDA_VISIBLE_DEVICES": "-1"}
env = os.environ.copy()
# Passing new_environment directly to `env' argument results
# in failure on Windows:
@ -72,7 +70,7 @@ class TestPickling:
x, y = build_dataset()
train_x = xgb.DMatrix(x, label=y)
param = {'tree_method': 'gpu_hist', "gpu_id": 0}
param = {"tree_method": "gpu_hist", "gpu_id": 0}
bst = xgb.train(param, train_x)
self.run_pickling(bst)
@ -91,43 +89,46 @@ class TestPickling:
X, y = build_dataset()
dtrain = xgb.DMatrix(X, y)
bst = xgb.train({'tree_method': 'gpu_hist',
'gpu_id': 1},
dtrain, num_boost_round=6)
bst = xgb.train(
{"tree_method": "gpu_hist", "gpu_id": 1}, dtrain, num_boost_round=6
)
model_path = 'model.pkl'
model_path = "model.pkl"
save_pickle(bst, model_path)
cuda_environment = {'CUDA_VISIBLE_DEVICES': '0'}
cuda_environment = {"CUDA_VISIBLE_DEVICES": "0"}
env = os.environ.copy()
env.update(cuda_environment)
args = self.args_template.copy()
args.append(
"./tests/python-gpu/"
"load_pickle.py::TestLoadPickle::test_wrap_gpu_id"
"./tests/python-gpu/" "load_pickle.py::TestLoadPickle::test_wrap_gpu_id"
)
status = subprocess.call(args, env=env)
assert status == 0
os.remove(model_path)
def test_pickled_predictor(self):
x, y = build_dataset()
def test_pickled_context(self):
x, y = tm.make_sparse_regression(10, 10, sparsity=0.8, as_dense=True)
train_x = xgb.DMatrix(x, label=y)
param = {'tree_method': 'gpu_hist',
'verbosity': 1, 'predictor': 'gpu_predictor'}
param = {"tree_method": "gpu_hist", "verbosity": 1}
bst = xgb.train(param, train_x)
config = json.loads(bst.save_config())
assert config['learner']['gradient_booster']['gbtree_train_param'][
'predictor'] == 'gpu_predictor'
with tm.captured_output() as (out, err):
bst.inplace_predict(x)
# The warning is redirected to Python callback, so it's printed in stdout
# instead of stderr.
stdout = out.getvalue()
assert stdout.find("mismatched devices") != -1
save_pickle(bst, model_path)
args = self.args_template.copy()
args.append(
"./tests/python-gpu/"
"load_pickle.py::TestLoadPickle::test_predictor_type_is_auto")
root = tm.project_root(__file__)
path = os.path.join(root, "tests", "python-gpu", "load_pickle.py")
args.append(path + "::TestLoadPickle::test_context_is_removed")
cuda_environment = {'CUDA_VISIBLE_DEVICES': '-1'}
cuda_environment = {"CUDA_VISIBLE_DEVICES": "-1"}
env = os.environ.copy()
env.update(cuda_environment)
@ -138,25 +139,29 @@ class TestPickling:
args = self.args_template.copy()
args.append(
"./tests/python-gpu/"
"load_pickle.py::TestLoadPickle::test_predictor_type_is_gpu")
"load_pickle.py::TestLoadPickle::test_context_is_preserved"
)
# Load in environment that has GPU.
env = os.environ.copy()
assert 'CUDA_VISIBLE_DEVICES' not in env.keys()
assert "CUDA_VISIBLE_DEVICES" not in env.keys()
status = subprocess.call(args, env=env)
assert status == 0
os.remove(model_path)
@pytest.mark.skipif(**tm.no_sklearn())
def test_predict_sklearn_pickle(self):
def test_predict_sklearn_pickle(self) -> None:
from sklearn.datasets import load_digits
x, y = load_digits(return_X_y=True)
kwargs = {'tree_method': 'gpu_hist',
'predictor': 'gpu_predictor',
'objective': 'binary:logistic',
'n_estimators': 10}
kwargs = {
"tree_method": "gpu_hist",
"objective": "binary:logistic",
"gpu_id": 0,
"n_estimators": 10,
}
model = XGBClassifier(**kwargs)
model.fit(x, y)
@ -165,24 +170,25 @@ class TestPickling:
del model
# load model
model: xgb.XGBClassifier = load_pickle("model.pkl")
model = load_pickle("model.pkl")
os.remove("model.pkl")
gpu_pred = model.predict(x, output_margin=True)
# Switch to CPU predictor
bst = model.get_booster()
bst.set_param({'predictor': 'cpu_predictor'})
tm.set_ordinal(-1, bst)
cpu_pred = model.predict(x, output_margin=True)
np.testing.assert_allclose(cpu_pred, gpu_pred, rtol=1e-5)
def test_training_on_cpu_only_env(self):
cuda_environment = {'CUDA_VISIBLE_DEVICES': '-1'}
cuda_environment = {"CUDA_VISIBLE_DEVICES": "-1"}
env = os.environ.copy()
env.update(cuda_environment)
args = self.args_template.copy()
args.append(
"./tests/python-gpu/"
"load_pickle.py::TestLoadPickle::test_training_on_cpu_only_env")
"load_pickle.py::TestLoadPickle::test_training_on_cpu_only_env"
)
status = subprocess.call(args, env=env)
assert status == 0

View File

@ -1,4 +1,5 @@
import sys
from copy import copy
import numpy as np
import pytest
@ -11,8 +12,10 @@ from xgboost.compat import PANDAS_INSTALLED
if PANDAS_INSTALLED:
from hypothesis.extra.pandas import column, data_frames, range_indexes
else:
def noop(*args, **kwargs):
pass
column, data_frames, range_indexes = noop, noop, noop
sys.path.append("tests/python")
@ -21,16 +24,20 @@ from test_predict import run_threaded_predict # noqa
rng = np.random.RandomState(1994)
shap_parameter_strategy = strategies.fixed_dictionaries({
'max_depth': strategies.integers(1, 11),
'max_leaves': strategies.integers(0, 256),
'num_parallel_tree': strategies.sampled_from([1, 10]),
}).filter(lambda x: x['max_depth'] > 0 or x['max_leaves'] > 0)
shap_parameter_strategy = strategies.fixed_dictionaries(
{
"max_depth": strategies.integers(1, 11),
"max_leaves": strategies.integers(0, 256),
"num_parallel_tree": strategies.sampled_from([1, 10]),
}
).filter(lambda x: x["max_depth"] > 0 or x["max_leaves"] > 0)
predict_parameter_strategy = strategies.fixed_dictionaries({
'max_depth': strategies.integers(1, 8),
'num_parallel_tree': strategies.sampled_from([1, 4]),
})
predict_parameter_strategy = strategies.fixed_dictionaries(
{
"max_depth": strategies.integers(1, 8),
"num_parallel_tree": strategies.sampled_from([1, 4]),
}
)
pytestmark = tm.timeout(20)
@ -47,43 +54,45 @@ class TestGPUPredict:
# with 5000 rows is 0.04.
for num_rows in test_num_rows:
for num_cols in test_num_cols:
dtrain = xgb.DMatrix(np.random.randn(num_rows, num_cols),
label=[0, 1] * int(num_rows / 2))
dval = xgb.DMatrix(np.random.randn(num_rows, num_cols),
label=[0, 1] * int(num_rows / 2))
dtest = xgb.DMatrix(np.random.randn(num_rows, num_cols),
label=[0, 1] * int(num_rows / 2))
watchlist = [(dtrain, 'train'), (dval, 'validation')]
dtrain = xgb.DMatrix(
np.random.randn(num_rows, num_cols),
label=[0, 1] * int(num_rows / 2),
)
dval = xgb.DMatrix(
np.random.randn(num_rows, num_cols),
label=[0, 1] * int(num_rows / 2),
)
dtest = xgb.DMatrix(
np.random.randn(num_rows, num_cols),
label=[0, 1] * int(num_rows / 2),
)
watchlist = [(dtrain, "train"), (dval, "validation")]
res = {}
param = {
"objective": "binary:logistic",
"predictor": "gpu_predictor",
'eval_metric': 'logloss',
'tree_method': 'gpu_hist',
'max_depth': 1
"eval_metric": "logloss",
"tree_method": "gpu_hist",
"gpu_id": 0,
"max_depth": 1,
}
bst = xgb.train(param, dtrain, iterations, evals=watchlist,
evals_result=res)
assert self.non_increasing(res["train"]["logloss"])
bst = xgb.train(
param, dtrain, iterations, evals=watchlist, evals_result=res
)
assert tm.non_increasing(res["train"]["logloss"], tolerance=0.001)
gpu_pred_train = bst.predict(dtrain, output_margin=True)
gpu_pred_test = bst.predict(dtest, output_margin=True)
gpu_pred_val = bst.predict(dval, output_margin=True)
param["predictor"] = "cpu_predictor"
bst_cpu = xgb.train(param, dtrain, iterations, evals=watchlist)
bst.set_param({"gpu_id": -1, "tree_method": "hist"})
bst_cpu = copy(bst)
cpu_pred_train = bst_cpu.predict(dtrain, output_margin=True)
cpu_pred_test = bst_cpu.predict(dtest, output_margin=True)
cpu_pred_val = bst_cpu.predict(dval, output_margin=True)
np.testing.assert_allclose(cpu_pred_train, gpu_pred_train,
rtol=1e-6)
np.testing.assert_allclose(cpu_pred_val, gpu_pred_val,
rtol=1e-6)
np.testing.assert_allclose(cpu_pred_test, gpu_pred_test,
rtol=1e-6)
def non_increasing(self, L):
return all((y - x) < 0.001 for x, y in zip(L, L[1:]))
np.testing.assert_allclose(cpu_pred_train, gpu_pred_train, rtol=1e-6)
np.testing.assert_allclose(cpu_pred_val, gpu_pred_val, rtol=1e-6)
np.testing.assert_allclose(cpu_pred_test, gpu_pred_test, rtol=1e-6)
# Test case for a bug where multiple batch predictions made on a
# test set produce incorrect results
@ -94,26 +103,22 @@ class TestGPUPredict:
n = 1000
X, y = make_regression(n, random_state=rng)
X_train, X_test, y_train, y_test = train_test_split(X, y,
random_state=123)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=123)
dtrain = xgb.DMatrix(X_train, label=y_train)
dtest = xgb.DMatrix(X_test)
params = {}
params["tree_method"] = "gpu_hist"
bst = xgb.train(params, dtrain)
params['predictor'] = "gpu_predictor"
bst_gpu_predict = xgb.train(params, dtrain)
tm.set_ordinal(0, bst)
# Don't reuse the DMatrix for prediction, otherwise the result is cached.
predict_gpu_0 = bst.predict(xgb.DMatrix(X_test))
predict_gpu_1 = bst.predict(xgb.DMatrix(X_test))
tm.set_ordinal(-1, bst)
predict_cpu = bst.predict(xgb.DMatrix(X_test))
params['predictor'] = "cpu_predictor"
bst_cpu_predict = xgb.train(params, dtrain)
predict0 = bst_gpu_predict.predict(dtest)
predict1 = bst_gpu_predict.predict(dtest)
cpu_predict = bst_cpu_predict.predict(dtest)
assert np.allclose(predict0, predict1)
assert np.allclose(predict0, cpu_predict)
assert np.allclose(predict_gpu_0, predict_gpu_1)
assert np.allclose(predict_gpu_0, predict_cpu)
@pytest.mark.skipif(**tm.no_sklearn())
def test_sklearn(self):
@ -121,30 +126,31 @@ class TestGPUPredict:
tr_size = 2500
X = np.random.rand(m, n)
y = 200 * np.matmul(X, np.arange(-3, -3 + n))
y = y.reshape(y.size)
X_train, y_train = X[:tr_size, :], y[:tr_size]
X_test, y_test = X[tr_size:, :], y[tr_size:]
# First with cpu_predictor
params = {'tree_method': 'gpu_hist',
'predictor': 'cpu_predictor',
'n_jobs': -1,
'seed': 123}
m = xgb.XGBRegressor(**params).fit(X_train, y_train)
cpu_train_score = m.score(X_train, y_train)
cpu_test_score = m.score(X_test, y_test)
# Now with gpu_predictor
params['predictor'] = 'gpu_predictor'
params = {
"tree_method": "gpu_hist",
"gpu_id": "0",
"n_jobs": -1,
"seed": 123,
}
m = xgb.XGBRegressor(**params).fit(X_train, y_train)
gpu_train_score = m.score(X_train, y_train)
gpu_test_score = m.score(X_test, y_test)
# Now with cpu
m = tm.set_ordinal(-1, m)
cpu_train_score = m.score(X_train, y_train)
cpu_test_score = m.score(X_test, y_test)
assert np.allclose(cpu_train_score, gpu_train_score)
assert np.allclose(cpu_test_score, gpu_test_score)
def run_inplace_base_margin(self, booster, dtrain, X, base_margin):
import cupy as cp
dtrain.set_info(base_margin=base_margin)
from_inplace = booster.inplace_predict(data=X, base_margin=base_margin)
from_dmatrix = booster.predict(dtrain)
@ -152,10 +158,11 @@ class TestGPUPredict:
def run_inplace_predict_cupy(self, device: int) -> None:
import cupy as cp
cp.cuda.runtime.setDevice(device)
rows = 1000
cols = 10
missing = 11 # set to integer for testing
missing = 11 # set to integer for testing
cp_rng = cp.random.RandomState(1994)
cp.random.set_random_state(cp_rng)
@ -168,7 +175,7 @@ class TestGPUPredict:
dtrain = xgb.DMatrix(X, y)
booster = xgb.train(
{'tree_method': 'gpu_hist', "gpu_id": device}, dtrain, num_boost_round=10
{"tree_method": "gpu_hist", "gpu_id": device}, dtrain, num_boost_round=10
)
test = xgb.DMatrix(X[:10, ...], missing=missing)
@ -186,7 +193,7 @@ class TestGPUPredict:
# Don't do this on Windows, see issue #5793
if sys.platform.startswith("win"):
pytest.skip(
'Multi-threaded in-place prediction with cuPy is not working on Windows'
"Multi-threaded in-place prediction with cuPy is not working on Windows"
)
for i in range(10):
run_threaded_predict(X, rows, predict_dense)
@ -205,9 +212,10 @@ class TestGPUPredict:
)
reg.fit(X, y)
reg = tm.set_ordinal(device, reg)
gpu_predt = reg.predict(X)
reg.set_params(predictor="cpu_predictor")
cpu_predt = reg.predict(X)
reg = tm.set_ordinal(-1, reg)
cpu_predt = reg.predict(cp.asnumpy(X))
np.testing.assert_allclose(gpu_predt, cpu_predt, atol=1e-6)
cp.cuda.runtime.setDevice(0)
@ -215,11 +223,11 @@ class TestGPUPredict:
def test_inplace_predict_cupy(self):
self.run_inplace_predict_cupy(0)
@pytest.mark.xfail
@pytest.mark.skipif(**tm.no_cupy())
@pytest.mark.mgpu
def test_inplace_predict_cupy_specified_device(self):
import cupy as cp
n_devices = cp.cuda.runtime.getDeviceCount()
for d in range(n_devices):
self.run_inplace_predict_cupy(d)
@ -230,6 +238,7 @@ class TestGPUPredict:
import cudf
import cupy as cp
import pandas as pd
rows = 1000
cols = 10
rng = np.random.RandomState(1994)
@ -241,8 +250,7 @@ class TestGPUPredict:
dtrain = xgb.DMatrix(X, y)
booster = xgb.train({'tree_method': 'gpu_hist'},
dtrain, num_boost_round=10)
booster = xgb.train({"tree_method": "gpu_hist"}, dtrain, num_boost_round=10)
test = xgb.DMatrix(X)
predt_from_array = booster.inplace_predict(X)
predt_from_dmatrix = booster.predict(test)
@ -272,11 +280,12 @@ class TestGPUPredict:
def test_shap(self, num_rounds, dataset, param):
if dataset.name.endswith("-l1"): # not supported by the exact tree method
return
param.update({"predictor": "gpu_predictor", "gpu_id": 0})
param.update({"tree_method": "gpu_hist", "gpu_id": 0})
param = dataset.set_params(param)
dmat = dataset.get_dmat()
bst = xgb.train(param, dmat, num_rounds)
test_dmat = xgb.DMatrix(dataset.X, dataset.y, dataset.w, dataset.margin)
bst = tm.set_ordinal(0, bst)
shap = bst.predict(test_dmat, pred_contribs=True)
margin = bst.predict(test_dmat, output_margin=True)
assume(len(dataset.y) > 0)
@ -289,31 +298,35 @@ class TestGPUPredict:
def test_shap_interactions(self, num_rounds, dataset, param):
if dataset.name.endswith("-l1"): # not supported by the exact tree method
return
param.update({"predictor": "gpu_predictor", "gpu_id": 0})
param.update({"tree_method": "hist", "gpu_id": 0})
param = dataset.set_params(param)
dmat = dataset.get_dmat()
bst = xgb.train(param, dmat, num_rounds)
test_dmat = xgb.DMatrix(dataset.X, dataset.y, dataset.w, dataset.margin)
bst = tm.set_ordinal(0, bst)
shap = bst.predict(test_dmat, pred_interactions=True)
margin = bst.predict(test_dmat, output_margin=True)
assume(len(dataset.y) > 0)
assert np.allclose(np.sum(shap, axis=(len(shap.shape) - 1, len(shap.shape) - 2)),
margin,
1e-3, 1e-3)
assert np.allclose(
np.sum(shap, axis=(len(shap.shape) - 1, len(shap.shape) - 2)),
margin,
1e-3,
1e-3,
)
def test_shap_categorical(self):
X, y = tm.make_categorical(100, 20, 7, False)
Xy = xgb.DMatrix(X, y, enable_categorical=True)
booster = xgb.train({"tree_method": "gpu_hist"}, Xy, num_boost_round=10)
booster.set_param({"predictor": "gpu_predictor"})
booster = tm.set_ordinal(0, booster)
shap = booster.predict(Xy, pred_contribs=True)
margin = booster.predict(Xy, output_margin=True)
np.testing.assert_allclose(
np.sum(shap, axis=len(shap.shape) - 1), margin, rtol=1e-3
)
booster.set_param({"predictor": "cpu_predictor"})
booster = tm.set_ordinal(-1, booster)
shap = booster.predict(Xy, pred_contribs=True)
margin = booster.predict(Xy, output_margin=True)
np.testing.assert_allclose(
@ -321,18 +334,20 @@ class TestGPUPredict:
)
def test_predict_leaf_basic(self):
gpu_leaf = run_predict_leaf('gpu_predictor')
cpu_leaf = run_predict_leaf('cpu_predictor')
gpu_leaf = run_predict_leaf(0)
cpu_leaf = run_predict_leaf(-1)
np.testing.assert_equal(gpu_leaf, cpu_leaf)
def run_predict_leaf_booster(self, param, num_rounds, dataset):
param = dataset.set_params(param)
m = dataset.get_dmat()
booster = xgb.train(param, dtrain=dataset.get_dmat(), num_boost_round=num_rounds)
booster.set_param({'predictor': 'cpu_predictor'})
booster = xgb.train(
param, dtrain=dataset.get_dmat(), num_boost_round=num_rounds
)
booster = tm.set_ordinal(-1, booster)
cpu_leaf = booster.predict(m, pred_leaf=True)
booster.set_param({'predictor': 'gpu_predictor'})
booster = tm.set_ordinal(0, booster)
gpu_leaf = booster.predict(m, pred_leaf=True)
np.testing.assert_equal(cpu_leaf, gpu_leaf)
@ -344,8 +359,8 @@ class TestGPUPredict:
if param.get("num_parallel_tree", 1) > 1 and dataset.name.endswith("-l1"):
return
param['booster'] = 'gbtree'
param['tree_method'] = 'gpu_hist'
param["booster"] = "gbtree"
param["tree_method"] = "gpu_hist"
self.run_predict_leaf_booster(param, 10, dataset)
@given(predict_parameter_strategy, tm.make_dataset_strategy())
@ -355,42 +370,61 @@ class TestGPUPredict:
if param.get("num_parallel_tree", 1) > 1 and dataset.name.endswith("-l1"):
return
param['booster'] = 'dart'
param['tree_method'] = 'gpu_hist'
param["booster"] = "dart"
param["tree_method"] = "gpu_hist"
self.run_predict_leaf_booster(param, 10, dataset)
@pytest.mark.skipif(**tm.no_sklearn())
@pytest.mark.skipif(**tm.no_pandas())
@given(df=data_frames([column('x0', elements=strategies.integers(min_value=0, max_value=3)),
column('x1', elements=strategies.integers(min_value=0, max_value=5))],
index=range_indexes(min_size=20, max_size=50)))
@given(
df=data_frames(
[
column("x0", elements=strategies.integers(min_value=0, max_value=3)),
column("x1", elements=strategies.integers(min_value=0, max_value=5)),
],
index=range_indexes(min_size=20, max_size=50),
)
)
@settings(deadline=None, max_examples=20, print_blob=True)
def test_predict_categorical_split(self, df):
from sklearn.metrics import mean_squared_error
df = df.astype('category')
x0, x1 = df['x0'].to_numpy(), df['x1'].to_numpy()
df = df.astype("category")
x0, x1 = df["x0"].to_numpy(), df["x1"].to_numpy()
y = (x0 * 10 - 20) + (x1 - 2)
dtrain = xgb.DMatrix(df, label=y, enable_categorical=True)
params = {
'tree_method': 'gpu_hist', 'predictor': 'gpu_predictor',
'max_depth': 3, 'learning_rate': 1.0, 'base_score': 0.0, 'eval_metric': 'rmse'
"tree_method": "gpu_hist",
"max_depth": 3,
"learning_rate": 1.0,
"base_score": 0.0,
"eval_metric": "rmse",
"gpu_id": "0",
}
eval_history = {}
bst = xgb.train(params, dtrain, num_boost_round=5, evals=[(dtrain, 'train')],
verbose_eval=False, evals_result=eval_history)
bst = xgb.train(
params,
dtrain,
num_boost_round=5,
evals=[(dtrain, "train")],
verbose_eval=False,
evals_result=eval_history,
)
bst = tm.set_ordinal(0, bst)
pred = bst.predict(dtrain)
rmse = mean_squared_error(y_true=y, y_pred=pred, squared=False)
np.testing.assert_almost_equal(rmse, eval_history['train']['rmse'][-1], decimal=5)
np.testing.assert_almost_equal(
rmse, eval_history["train"]["rmse"][-1], decimal=5
)
@pytest.mark.skipif(**tm.no_cupy())
@pytest.mark.parametrize("n_classes", [2, 3])
def test_predict_dart(self, n_classes):
import cupy as cp
from sklearn.datasets import make_classification
n_samples = 1000
X_, y_ = make_classification(
n_samples=n_samples, n_informative=5, n_classes=n_classes
@ -403,7 +437,7 @@ class TestGPUPredict:
"tree_method": "gpu_hist",
"booster": "dart",
"rate_drop": 0.5,
"objective": "binary:logistic"
"objective": "binary:logistic",
}
else:
params = {
@ -411,15 +445,18 @@ class TestGPUPredict:
"booster": "dart",
"rate_drop": 0.5,
"objective": "multi:softprob",
"num_class": n_classes
"num_class": n_classes,
}
booster = xgb.train(params, Xy, num_boost_round=32)
# predictor=auto
# auto (GPU)
inplace = booster.inplace_predict(X)
copied = booster.predict(Xy)
# CPU
booster = tm.set_ordinal(-1, booster)
cpu_inplace = booster.inplace_predict(X_)
booster.set_param({"predictor": "cpu_predictor"})
cpu_copied = booster.predict(Xy)
copied = cp.array(copied)
@ -427,7 +464,8 @@ class TestGPUPredict:
cp.testing.assert_allclose(cpu_copied, copied, atol=1e-6)
cp.testing.assert_allclose(inplace, copied, atol=1e-6)
booster.set_param({"predictor": "gpu_predictor"})
# GPU
booster = tm.set_ordinal(0, booster)
inplace = booster.inplace_predict(X)
copied = booster.predict(Xy)
@ -437,12 +475,11 @@ class TestGPUPredict:
@pytest.mark.skipif(**tm.no_cupy())
def test_dtypes(self):
import cupy as cp
rows = 1000
cols = 10
rng = cp.random.RandomState(1994)
orig = rng.randint(low=0, high=127, size=rows * cols).reshape(
rows, cols
)
orig = rng.randint(low=0, high=127, size=rows * cols).reshape(rows, cols)
y = rng.randint(low=0, high=127, size=rows)
dtrain = xgb.DMatrix(orig, label=y)
booster = xgb.train({"tree_method": "gpu_hist"}, dtrain)
@ -450,19 +487,16 @@ class TestGPUPredict:
predt_orig = booster.inplace_predict(orig)
# all primitive types in numpy
for dtype in [
cp.signedinteger,
cp.byte,
cp.short,
cp.intc,
cp.int_,
cp.longlong,
cp.unsignedinteger,
cp.ubyte,
cp.ushort,
cp.uintc,
cp.uint,
cp.ulonglong,
cp.floating,
cp.half,
cp.single,
cp.double,
@ -472,9 +506,7 @@ class TestGPUPredict:
cp.testing.assert_allclose(predt, predt_orig)
# boolean
orig = cp.random.binomial(1, 0.5, size=rows * cols).reshape(
rows, cols
)
orig = cp.random.binomial(1, 0.5, size=rows * cols).reshape(rows, cols)
predt_orig = booster.inplace_predict(orig)
for dtype in [cp.bool8, cp.bool_]:
X = cp.array(orig, dtype=dtype)

View File

@ -29,7 +29,6 @@ def comp_training_with_rank_objective(
"booster": "gbtree",
"tree_method": "gpu_hist",
"gpu_id": 0,
"predictor": "gpu_predictor",
}
num_trees = 100
@ -54,7 +53,6 @@ def comp_training_with_rank_objective(
"booster": "gbtree",
"tree_method": "hist",
"gpu_id": -1,
"predictor": "cpu_predictor",
}
cpu_params["objective"] = rank_objective
cpu_params["eval_metric"] = metric_name

View File

@ -260,7 +260,6 @@ class TestGPUUpdaters:
"seed": 66,
"subsample": 0.5,
"gamma": 0.2,
"predictor": "auto",
"eval_metric": "auc",
},
num_boost_round=150,

View File

@ -28,7 +28,7 @@ def run_threaded_predict(X, rows, predict_func):
assert f.result()
def run_predict_leaf(predictor):
def run_predict_leaf(gpu_id: int) -> np.ndarray:
rows = 100
cols = 4
classes = 5
@ -42,13 +42,13 @@ def run_predict_leaf(predictor):
{
"num_parallel_tree": num_parallel_tree,
"num_class": classes,
"predictor": predictor,
"tree_method": "hist",
},
m,
num_boost_round=num_boost_round,
)
booster = tm.set_ordinal(gpu_id, booster)
empty = xgb.DMatrix(np.ones(shape=(0, cols)))
empty_leaf = booster.predict(empty, pred_leaf=True)
assert empty_leaf.shape[0] == 0
@ -74,13 +74,14 @@ def run_predict_leaf(predictor):
# When there's only 1 tree, the output is a 1 dim vector
booster = xgb.train({"tree_method": "hist"}, num_boost_round=1, dtrain=m)
booster = tm.set_ordinal(gpu_id, booster)
assert booster.predict(m, pred_leaf=True).shape == (rows,)
return leaf
def test_predict_leaf():
run_predict_leaf("cpu_predictor")
def test_predict_leaf() -> None:
run_predict_leaf(-1)
def test_predict_shape():

View File

@ -274,7 +274,7 @@ class TestTreeMethod:
) -> None:
parameters: Dict[str, Any] = {"tree_method": tree_method}
cat, label = tm.make_categorical(
n_samples=rows, n_features=cols, n_categories=cats, onehot=False, sparsity=0.5
rows, n_features=cols, n_categories=cats, onehot=False, sparsity=0.5
)
Xy = xgb.DMatrix(cat, label, enable_categorical=True)
@ -294,7 +294,9 @@ class TestTreeMethod:
y_predt = booster.predict(Xy)
rmse = tm.root_mean_square(label, y_predt)
np.testing.assert_allclose(rmse, evals_result["Train"]["rmse"][-1])
np.testing.assert_allclose(
rmse, evals_result["Train"]["rmse"][-1], rtol=2e-5
)
# Test with OHE split
run(self.USE_ONEHOT)
@ -311,10 +313,8 @@ class TestTreeMethod:
by_etl_results: Dict[str, Dict[str, List[float]]] = {}
by_builtin_results: Dict[str, Dict[str, List[float]]] = {}
predictor = "gpu_predictor" if tree_method == "gpu_hist" else None
parameters: Dict[str, Any] = {
"tree_method": tree_method,
"predictor": predictor,
# Use one-hot exclusively
"max_cat_to_onehot": self.USE_ONEHOT
}

View File

@ -1418,23 +1418,6 @@ def test_categorical():
np.testing.assert_allclose(predt_cat, predt_enc)
def test_prediction_config():
reg = xgb.XGBRegressor()
assert reg._can_use_inplace_predict() is True
reg.set_params(predictor="cpu_predictor")
assert reg._can_use_inplace_predict() is False
reg.set_params(predictor="auto")
assert reg._can_use_inplace_predict() is True
reg.set_params(predictor=None)
assert reg._can_use_inplace_predict() is True
reg.set_params(booster="gblinear")
assert reg._can_use_inplace_predict() is False
def test_evaluation_metric():
from sklearn.datasets import load_diabetes, load_digits
from sklearn.metrics import mean_absolute_error