[breaking] Remove the predictor param, allow fallback to prediction using DMatrix. (#9129)
- A `DeviceOrd` struct is implemented to indicate the device. It will eventually replace the `gpu_id` parameter. - The `predictor` parameter is removed. - Fallback to `DMatrix` when `inplace_predict` is not available. - The heuristic for choosing a predictor is only used during training.
This commit is contained in:
parent
3a0f787703
commit
39390cc2ee
@ -45,7 +45,7 @@ XGBoost makes use of `GPUTreeShap <https://github.com/rapidsai/gputreeshap>`_ as
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
model.set_param({"predictor": "gpu_predictor"})
|
||||
model.set_param({"gpu_id": "0", "tree_method": "gpu_hist"})
|
||||
shap_values = model.predict(dtrain, pred_contribs=True)
|
||||
shap_interaction_values = model.predict(dtrain, pred_interactions=True)
|
||||
|
||||
|
||||
@ -199,18 +199,6 @@ Parameters for Tree Booster
|
||||
- Maximum number of discrete bins to bucket continuous features.
|
||||
- Increasing this number improves the optimality of splits at the cost of higher computation time.
|
||||
|
||||
* ``predictor``, [default= ``auto``]
|
||||
|
||||
- The type of predictor algorithm to use. Provides the same results but allows the use of GPU or CPU.
|
||||
|
||||
- ``auto``: Configure predictor based on heuristics.
|
||||
- ``cpu_predictor``: Multicore CPU prediction algorithm.
|
||||
- ``gpu_predictor``: Prediction using GPU. Used when ``tree_method`` is ``gpu_hist``.
|
||||
When ``predictor`` is set to default value ``auto``, the ``gpu_hist`` tree method is
|
||||
able to provide GPU based prediction without copying training data to GPU memory.
|
||||
If ``gpu_predictor`` is explicitly specified, then all data is copied into GPU, only
|
||||
recommended for performing prediction tasks.
|
||||
|
||||
* ``num_parallel_tree``, [default=1]
|
||||
|
||||
- Number of parallel trees constructed during each iteration. This option is used to support boosted random forest.
|
||||
|
||||
@ -87,15 +87,6 @@ with the native Python interface :py:meth:`xgboost.Booster.predict` and
|
||||
behavior. Also the ``save_best`` parameter from :py:obj:`xgboost.callback.EarlyStopping`
|
||||
might be useful.
|
||||
|
||||
*********
|
||||
Predictor
|
||||
*********
|
||||
|
||||
There are 2 predictors in XGBoost (3 if you have the one-api plugin enabled), namely
|
||||
``cpu_predictor`` and ``gpu_predictor``. The default option is ``auto`` so that XGBoost
|
||||
can employ some heuristics for saving GPU memory during training. They might have slight
|
||||
different outputs due to floating point errors.
|
||||
|
||||
|
||||
***********
|
||||
Base Margin
|
||||
@ -134,15 +125,6 @@ it. Be aware that the output of in-place prediction depends on input data type,
|
||||
input is on GPU data output is :py:obj:`cupy.ndarray`, otherwise a :py:obj:`numpy.ndarray`
|
||||
is returned.
|
||||
|
||||
****************
|
||||
Categorical Data
|
||||
****************
|
||||
|
||||
Other than users performing encoding, XGBoost has experimental support for categorical
|
||||
data using ``gpu_hist`` and ``gpu_predictor``. No special operation needs to be done on
|
||||
input test data since the information about categories is encoded into the model during
|
||||
training.
|
||||
|
||||
*************
|
||||
Thread Safety
|
||||
*************
|
||||
@ -159,7 +141,6 @@ instance we might accidentally call ``clf.set_params()`` inside a predict functi
|
||||
|
||||
def predict_fn(clf: xgb.XGBClassifier, X):
|
||||
X = preprocess(X)
|
||||
clf.set_params(predictor="gpu_predictor") # NOT safe!
|
||||
clf.set_params(n_jobs=1) # NOT safe!
|
||||
return clf.predict_proba(X, iteration_range=(0, 10))
|
||||
|
||||
|
||||
@ -148,8 +148,8 @@ Also for inplace prediction:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
booster.set_param({'predictor': 'gpu_predictor'})
|
||||
# where X is a dask DataFrame or dask Array containing cupy or cuDF backed data.
|
||||
# where X is a dask DataFrame or dask Array backed by cupy or cuDF.
|
||||
booster.set_param({"gpu_id": "0"})
|
||||
prediction = xgb.dask.inplace_predict(client, booster, X)
|
||||
|
||||
When input is ``da.Array`` object, output is always ``da.Array``. However, if the input
|
||||
|
||||
@ -173,7 +173,6 @@ Will print out something similar to (not actual output as it's too long for demo
|
||||
"gradient_booster": {
|
||||
"gbtree_train_param": {
|
||||
"num_parallel_tree": "1",
|
||||
"predictor": "gpu_predictor",
|
||||
"process_type": "default",
|
||||
"tree_method": "gpu_hist",
|
||||
"updater": "grow_gpu_hist",
|
||||
|
||||
@ -10,6 +10,7 @@
|
||||
#include <dmlc/omp.h>
|
||||
|
||||
#include <cmath>
|
||||
#include <cstdint>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
@ -112,7 +113,7 @@ using bst_row_t = std::size_t; // NOLINT
|
||||
/*! \brief Type for tree node index. */
|
||||
using bst_node_t = std::int32_t; // NOLINT
|
||||
/*! \brief Type for ranking group index. */
|
||||
using bst_group_t = std::uint32_t; // NOLINT
|
||||
using bst_group_t = std::uint32_t; // NOLINT
|
||||
/**
|
||||
* \brief Type for indexing into output targets.
|
||||
*/
|
||||
@ -125,6 +126,10 @@ using bst_layer_t = std::int32_t; // NOLINT
|
||||
* \brief Type for indexing trees.
|
||||
*/
|
||||
using bst_tree_t = std::int32_t; // NOLINT
|
||||
/**
|
||||
* @brief Ordinal of a CUDA device.
|
||||
*/
|
||||
using bst_d_ordinal_t = std::int16_t; // NOLINT
|
||||
|
||||
namespace detail {
|
||||
/*! \brief Implementation of gradient statistics pair. Template specialisation
|
||||
|
||||
@ -1067,6 +1067,9 @@ XGB_DLL int XGBoosterPredictFromDMatrix(BoosterHandle handle, DMatrixHandle dmat
|
||||
/**
|
||||
* \brief Inplace prediction from CPU dense matrix.
|
||||
*
|
||||
* \note If the booster is configured to run on a CUDA device, XGBoost falls back to run
|
||||
* prediction with DMatrix with a performance warning.
|
||||
*
|
||||
* \param handle Booster handle.
|
||||
* \param values JSON encoded __array_interface__ to values.
|
||||
* \param config See \ref XGBoosterPredictFromDMatrix for more info.
|
||||
@ -1091,6 +1094,9 @@ XGB_DLL int XGBoosterPredictFromDense(BoosterHandle handle, char const *values,
|
||||
/**
|
||||
* \brief Inplace prediction from CPU CSR matrix.
|
||||
*
|
||||
* \note If the booster is configured to run on a CUDA device, XGBoost falls back to run
|
||||
* prediction with DMatrix with a performance warning.
|
||||
*
|
||||
* \param handle Booster handle.
|
||||
* \param indptr JSON encoded __array_interface__ to row pointer in CSR.
|
||||
* \param indices JSON encoded __array_interface__ to column indices in CSR.
|
||||
@ -1116,6 +1122,9 @@ XGB_DLL int XGBoosterPredictFromCSR(BoosterHandle handle, char const *indptr, ch
|
||||
/**
|
||||
* \brief Inplace prediction from CUDA Dense matrix (cupy in Python).
|
||||
*
|
||||
* \note If the booster is configured to run on a CPU, XGBoost falls back to run
|
||||
* prediction with DMatrix with a performance warning.
|
||||
*
|
||||
* \param handle Booster handle
|
||||
* \param values JSON encoded __cuda_array_interface__ to values.
|
||||
* \param config See \ref XGBoosterPredictFromDMatrix for more info.
|
||||
@ -1137,6 +1146,9 @@ XGB_DLL int XGBoosterPredictFromCudaArray(BoosterHandle handle, char const *valu
|
||||
/**
|
||||
* \brief Inplace prediction from CUDA dense dataframe (cuDF in Python).
|
||||
*
|
||||
* \note If the booster is configured to run on a CPU, XGBoost falls back to run
|
||||
* prediction with DMatrix with a performance warning.
|
||||
*
|
||||
* \param handle Booster handle
|
||||
* \param values List of __cuda_array_interface__ for all columns encoded in JSON list.
|
||||
* \param config See \ref XGBoosterPredictFromDMatrix for more info.
|
||||
|
||||
@ -1,20 +1,79 @@
|
||||
/*!
|
||||
* Copyright 2014-2022 by Contributors
|
||||
/**
|
||||
* Copyright 2014-2023, XGBoost Contributors
|
||||
* \file context.h
|
||||
*/
|
||||
#ifndef XGBOOST_CONTEXT_H_
|
||||
#define XGBOOST_CONTEXT_H_
|
||||
|
||||
#include <xgboost/logging.h>
|
||||
#include <xgboost/parameter.h>
|
||||
#include <xgboost/base.h> // for bst_d_ordinal_t
|
||||
#include <xgboost/logging.h> // for CHECK_GE
|
||||
#include <xgboost/parameter.h> // for XGBoostParameter
|
||||
|
||||
#include <memory> // std::shared_ptr
|
||||
#include <string>
|
||||
#include <cstdint> // for int16_t, int32_t, int64_t
|
||||
#include <memory> // for shared_ptr
|
||||
#include <string> // for string, to_string
|
||||
|
||||
namespace xgboost {
|
||||
|
||||
struct CUDAContext;
|
||||
|
||||
/**
|
||||
* @brief A type for device ordinal. The type is packed into 32-bit for efficient use in
|
||||
* viewing types like `linalg::TensorView`.
|
||||
*/
|
||||
struct DeviceOrd {
|
||||
enum Type : std::int16_t { kCPU = 0, kCUDA = 1 } device{kCPU};
|
||||
// CUDA device ordinal.
|
||||
bst_d_ordinal_t ordinal{-1};
|
||||
|
||||
[[nodiscard]] bool IsCUDA() const { return device == kCUDA; }
|
||||
[[nodiscard]] bool IsCPU() const { return device == kCPU; }
|
||||
|
||||
DeviceOrd() = default;
|
||||
constexpr DeviceOrd(Type type, bst_d_ordinal_t ord) : device{type}, ordinal{ord} {}
|
||||
|
||||
DeviceOrd(DeviceOrd const& that) = default;
|
||||
DeviceOrd& operator=(DeviceOrd const& that) = default;
|
||||
DeviceOrd(DeviceOrd&& that) = default;
|
||||
DeviceOrd& operator=(DeviceOrd&& that) = default;
|
||||
|
||||
/**
|
||||
* @brief Constructor for CPU.
|
||||
*/
|
||||
[[nodiscard]] constexpr static auto CPU() { return DeviceOrd{kCPU, -1}; }
|
||||
/**
|
||||
* @brief Constructor for CUDA device.
|
||||
*
|
||||
* @param ordinal CUDA device ordinal.
|
||||
*/
|
||||
[[nodiscard]] static auto CUDA(bst_d_ordinal_t ordinal) { return DeviceOrd{kCUDA, ordinal}; }
|
||||
|
||||
[[nodiscard]] bool operator==(DeviceOrd const& that) const {
|
||||
return device == that.device && ordinal == that.ordinal;
|
||||
}
|
||||
[[nodiscard]] bool operator!=(DeviceOrd const& that) const { return !(*this == that); }
|
||||
/**
|
||||
* @brief Get a string representation of the device and the ordinal.
|
||||
*/
|
||||
[[nodiscard]] std::string Name() const {
|
||||
switch (device) {
|
||||
case DeviceOrd::kCPU:
|
||||
return "CPU";
|
||||
case DeviceOrd::kCUDA:
|
||||
return "CUDA:" + std::to_string(ordinal);
|
||||
default: {
|
||||
LOG(FATAL) << "Unknown device.";
|
||||
return "";
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
static_assert(sizeof(DeviceOrd) == sizeof(std::int32_t));
|
||||
|
||||
/**
|
||||
* @brief Runtime context for XGBoost. Contains information like threads and device.
|
||||
*/
|
||||
struct Context : public XGBoostParameter<Context> {
|
||||
public:
|
||||
// Constant representing the device ID of CPU.
|
||||
@ -36,29 +95,59 @@ struct Context : public XGBoostParameter<Context> {
|
||||
// fail when gpu_id is invalid
|
||||
bool fail_on_invalid_gpu_id{false};
|
||||
bool validate_parameters{false};
|
||||
|
||||
/*!
|
||||
* \brief Configure the parameter `gpu_id'.
|
||||
/**
|
||||
* @brief Configure the parameter `gpu_id'.
|
||||
*
|
||||
* \param require_gpu Whether GPU is explicitly required from user.
|
||||
* @param require_gpu Whether GPU is explicitly required by the user through other
|
||||
* configurations.
|
||||
*/
|
||||
void ConfigureGpuId(bool require_gpu);
|
||||
/*!
|
||||
* Return automatically chosen threads.
|
||||
/**
|
||||
* @brief Returns the automatically chosen number of threads based on the `nthread`
|
||||
* parameter and the system settting.
|
||||
*/
|
||||
std::int32_t Threads() const;
|
||||
|
||||
bool IsCPU() const { return gpu_id == kCpuId; }
|
||||
bool IsCUDA() const { return !IsCPU(); }
|
||||
|
||||
CUDAContext const* CUDACtx() const;
|
||||
// Make a CUDA context based on the current context.
|
||||
Context MakeCUDA(std::int32_t device = 0) const {
|
||||
[[nodiscard]] std::int32_t Threads() const;
|
||||
/**
|
||||
* @brief Is XGBoost running on CPU?
|
||||
*/
|
||||
[[nodiscard]] bool IsCPU() const { return gpu_id == kCpuId; }
|
||||
/**
|
||||
* @brief Is XGBoost running on a CUDA device?
|
||||
*/
|
||||
[[nodiscard]] bool IsCUDA() const { return !IsCPU(); }
|
||||
/**
|
||||
* @brief Get the current device and ordinal.
|
||||
*/
|
||||
[[nodiscard]] DeviceOrd Device() const {
|
||||
return IsCPU() ? DeviceOrd::CPU() : DeviceOrd::CUDA(static_cast<bst_d_ordinal_t>(gpu_id));
|
||||
}
|
||||
/**
|
||||
* @brief Get the CUDA device ordinal. -1 if XGBoost is running on CPU.
|
||||
*/
|
||||
[[nodiscard]] bst_d_ordinal_t Ordinal() const { return this->gpu_id; }
|
||||
/**
|
||||
* @brief Name of the current device.
|
||||
*/
|
||||
[[nodiscard]] std::string DeviceName() const { return Device().Name(); }
|
||||
/**
|
||||
* @brief Get a CUDA device context for allocator and stream.
|
||||
*/
|
||||
[[nodiscard]] CUDAContext const* CUDACtx() const;
|
||||
/**
|
||||
* @brief Make a CUDA context based on the current context.
|
||||
*
|
||||
* @param ordinal The CUDA device ordinal.
|
||||
*/
|
||||
[[nodiscard]] Context MakeCUDA(std::int32_t ordinal = 0) const {
|
||||
Context ctx = *this;
|
||||
ctx.gpu_id = device;
|
||||
CHECK_GE(ordinal, 0);
|
||||
ctx.gpu_id = ordinal;
|
||||
return ctx;
|
||||
}
|
||||
Context MakeCPU() const {
|
||||
/**
|
||||
* @brief Make a CPU context based on the current context.
|
||||
*/
|
||||
[[nodiscard]] Context MakeCPU() const {
|
||||
Context ctx = *this;
|
||||
ctx.gpu_id = kCpuId;
|
||||
return ctx;
|
||||
@ -87,9 +176,9 @@ struct Context : public XGBoostParameter<Context> {
|
||||
}
|
||||
|
||||
private:
|
||||
// mutable for lazy initialization for cuda context to avoid initializing CUDA at load.
|
||||
// shared_ptr is used instead of unique_ptr as with unique_ptr it's difficult to define p_impl
|
||||
// while trying to hide CUDA code from host compiler.
|
||||
// mutable for lazy cuda context initialization. This avoids initializing CUDA at load.
|
||||
// shared_ptr is used instead of unique_ptr as with unique_ptr it's difficult to define
|
||||
// p_impl while trying to hide CUDA code from the host compiler.
|
||||
mutable std::shared_ptr<CUDAContext> cuctx_;
|
||||
// cached value for CFS CPU limit. (used in containerized env)
|
||||
std::int32_t cfs_cpu_count_; // NOLINT
|
||||
|
||||
@ -149,18 +149,14 @@ class GradientBooster : public Model, public Configurable {
|
||||
* \param layer_begin Beginning of boosted tree layer used for prediction.
|
||||
* \param layer_end End of booster layer. 0 means do not limit trees.
|
||||
* \param approximate use a faster (inconsistent) approximation of SHAP values
|
||||
* \param condition condition on the condition_feature (0=no, -1=cond off, 1=cond on).
|
||||
* \param condition_feature feature to condition on (i.e. fix) during calculations
|
||||
*/
|
||||
virtual void PredictContribution(DMatrix* dmat,
|
||||
HostDeviceVector<bst_float>* out_contribs,
|
||||
unsigned layer_begin, unsigned layer_end,
|
||||
bool approximate = false, int condition = 0,
|
||||
unsigned condition_feature = 0) = 0;
|
||||
virtual void PredictContribution(DMatrix* dmat, HostDeviceVector<float>* out_contribs,
|
||||
bst_layer_t layer_begin, bst_layer_t layer_end,
|
||||
bool approximate = false) = 0;
|
||||
|
||||
virtual void PredictInteractionContributions(
|
||||
DMatrix *dmat, HostDeviceVector<bst_float> *out_contribs,
|
||||
unsigned layer_begin, unsigned layer_end, bool approximate) = 0;
|
||||
virtual void PredictInteractionContributions(DMatrix* dmat, HostDeviceVector<float>* out_contribs,
|
||||
bst_layer_t layer_begin, bst_layer_t layer_end,
|
||||
bool approximate) = 0;
|
||||
|
||||
/*!
|
||||
* \brief dump the model in the requested format
|
||||
|
||||
@ -78,7 +78,6 @@ public class BoosterTest {
|
||||
put("num_round", round);
|
||||
put("num_workers", 1);
|
||||
put("tree_method", "gpu_hist");
|
||||
put("predictor", "gpu_predictor");
|
||||
put("max_bin", maxBin);
|
||||
}
|
||||
};
|
||||
|
||||
@ -281,7 +281,6 @@ object GpuPreXGBoost extends PreXGBoostProvider {
|
||||
// - predictor: Force to gpu predictor since native doesn't save predictor.
|
||||
val gpuId = if (!isLocal) XGBoost.getGPUAddrFromResources else 0
|
||||
booster.setParam("gpu_id", gpuId.toString)
|
||||
booster.setParam("predictor", "gpu_predictor")
|
||||
logger.info("GPU transform on device: " + gpuId)
|
||||
boosterFlag.isGpuParamsSet = true;
|
||||
}
|
||||
|
||||
@ -2187,20 +2187,25 @@ class Booster:
|
||||
base_margin: Any = None,
|
||||
strict_shape: bool = False,
|
||||
) -> NumpyOrCupy:
|
||||
"""Run prediction in-place, Unlike :py:meth:`predict` method, inplace prediction
|
||||
does not cache the prediction result.
|
||||
"""Run prediction in-place when possible, Unlike :py:meth:`predict` method,
|
||||
inplace prediction does not cache the prediction result.
|
||||
|
||||
Calling only ``inplace_predict`` in multiple threads is safe and lock
|
||||
free. But the safety does not hold when used in conjunction with other
|
||||
methods. E.g. you can't train the booster in one thread and perform
|
||||
prediction in the other.
|
||||
|
||||
.. note::
|
||||
|
||||
If the device ordinal of the input data doesn't match the one configured for
|
||||
the booster, data will be copied to the booster device.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
booster.set_param({"predictor": "gpu_predictor"})
|
||||
booster.set_param({"gpu_id": "0", "tree_method": "gpu_hist"})
|
||||
booster.inplace_predict(cupy_array)
|
||||
|
||||
booster.set_param({"predictor": "cpu_predictor"})
|
||||
booster.set_param({"gpu_id": "-1", "tree_method": "hist"})
|
||||
booster.inplace_predict(numpy_array)
|
||||
|
||||
.. versionadded:: 1.1.0
|
||||
@ -2208,9 +2213,7 @@ class Booster:
|
||||
Parameters
|
||||
----------
|
||||
data :
|
||||
The input data, must not be a view for numpy array. Set
|
||||
``predictor`` to ``gpu_predictor`` for running prediction on CuPy
|
||||
array or CuDF DataFrame.
|
||||
The input data.
|
||||
iteration_range :
|
||||
See :py:meth:`predict` for details.
|
||||
predict_type :
|
||||
|
||||
@ -277,9 +277,6 @@ __model_doc = f"""
|
||||
Device ordinal.
|
||||
validate_parameters : Optional[bool]
|
||||
Give warnings for unknown parameter.
|
||||
predictor : Optional[str]
|
||||
Force XGBoost to use specific predictor, available choices are [cpu_predictor,
|
||||
gpu_predictor].
|
||||
enable_categorical : bool
|
||||
|
||||
.. versionadded:: 1.5.0
|
||||
@ -652,7 +649,6 @@ class XGBModel(XGBModelBase):
|
||||
importance_type: Optional[str] = None,
|
||||
gpu_id: Optional[int] = None,
|
||||
validate_parameters: Optional[bool] = None,
|
||||
predictor: Optional[str] = None,
|
||||
enable_categorical: bool = False,
|
||||
feature_types: Optional[FeatureTypes] = None,
|
||||
max_cat_to_onehot: Optional[int] = None,
|
||||
@ -699,7 +695,6 @@ class XGBModel(XGBModelBase):
|
||||
self.importance_type = importance_type
|
||||
self.gpu_id = gpu_id
|
||||
self.validate_parameters = validate_parameters
|
||||
self.predictor = predictor
|
||||
self.enable_categorical = enable_categorical
|
||||
self.feature_types = feature_types
|
||||
self.max_cat_to_onehot = max_cat_to_onehot
|
||||
@ -1093,12 +1088,7 @@ class XGBModel(XGBModelBase):
|
||||
return self
|
||||
|
||||
def _can_use_inplace_predict(self) -> bool:
|
||||
# When predictor is explicitly set, using `inplace_predict` might result into
|
||||
# error with incompatible data type.
|
||||
# Inplace predict doesn't handle as many data types as DMatrix, but it's
|
||||
# sufficient for dask interface where input is simpiler.
|
||||
predictor = self.get_xgb_params().get("predictor", None)
|
||||
if predictor in ("auto", None) and self.booster != "gblinear":
|
||||
if self.booster != "gblinear":
|
||||
return True
|
||||
return False
|
||||
|
||||
@ -1124,9 +1114,9 @@ class XGBModel(XGBModelBase):
|
||||
iteration_range: Optional[Tuple[int, int]] = None,
|
||||
) -> ArrayLike:
|
||||
"""Predict with `X`. If the model is trained with early stopping, then
|
||||
:py:attr:`best_iteration` is used automatically. For tree models, when data is
|
||||
on GPU, like cupy array or cuDF dataframe and `predictor` is not specified, the
|
||||
prediction is run on GPU automatically, otherwise it will run on CPU.
|
||||
:py:attr:`best_iteration` is used automatically. The estimator uses
|
||||
`inplace_predict` by default and falls back to using :py:class:`DMatrix` if
|
||||
devices between the data and the estimator don't match.
|
||||
|
||||
.. note:: This function is only thread safe for `gbtree` and `dart`.
|
||||
|
||||
@ -1588,7 +1578,9 @@ class XGBClassifier(XGBModel, XGBClassifierMixIn, XGBClassifierBase):
|
||||
) -> np.ndarray:
|
||||
"""Predict the probability of each `X` example being of a given class. If the
|
||||
model is trained with early stopping, then :py:attr:`best_iteration` is used
|
||||
automatically.
|
||||
automatically. The estimator uses `inplace_predict` by default and falls back to
|
||||
using :py:class:`DMatrix` if devices between the data and the estimator don't
|
||||
match.
|
||||
|
||||
.. note:: This function is only thread safe for `gbtree` and `dart`.
|
||||
|
||||
|
||||
@ -25,6 +25,7 @@ from typing import (
|
||||
Set,
|
||||
Tuple,
|
||||
TypedDict,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
@ -711,6 +712,27 @@ def predictor_equal(lhs: xgb.DMatrix, rhs: xgb.DMatrix) -> bool:
|
||||
)
|
||||
|
||||
|
||||
M = TypeVar("M", xgb.Booster, xgb.XGBModel)
|
||||
|
||||
|
||||
def set_ordinal(ordinal: int, booster: M) -> M:
|
||||
"""Temporary solution for setting the device ordinal until we move away from
|
||||
`gpu_id`.
|
||||
|
||||
"""
|
||||
if ordinal < 0:
|
||||
params = {"gpu_id": -1, "tree_method": "hist"}
|
||||
else:
|
||||
params = {"gpu_id": ordinal, "tree_method": "gpu_hist"}
|
||||
|
||||
if isinstance(booster, xgb.Booster):
|
||||
booster.set_param(params)
|
||||
elif isinstance(booster, xgb.XGBModel):
|
||||
booster.set_params(**params)
|
||||
|
||||
return booster
|
||||
|
||||
|
||||
def eval_error_metric(predt: np.ndarray, dtrain: xgb.DMatrix) -> Tuple[str, np.float64]:
|
||||
"""Evaluation metric for xgb.train"""
|
||||
label = dtrain.get_label()
|
||||
|
||||
@ -1023,7 +1023,6 @@ void InplacePredictImpl(std::shared_ptr<DMatrix> p_m, char const *c_json_config,
|
||||
const float **out_result) {
|
||||
xgboost_CHECK_C_ARG_PTR(c_json_config);
|
||||
auto config = Json::Load(StringView{c_json_config});
|
||||
CHECK_EQ(get<Integer const>(config["cache_id"]), 0) << "Cache ID is not supported yet";
|
||||
|
||||
HostDeviceVector<float> *p_predt{nullptr};
|
||||
auto type = PredictionType(RequiredArg<Integer>(config, "type", __func__));
|
||||
@ -1042,6 +1041,7 @@ void InplacePredictImpl(std::shared_ptr<DMatrix> p_m, char const *c_json_config,
|
||||
xgboost_CHECK_C_ARG_PTR(out_dim);
|
||||
CalcPredictShape(strict_shape, type, n_samples, n_features, chunksize, learner->Groups(),
|
||||
learner->BoostedRounds(), &shape, out_dim);
|
||||
CHECK_GE(p_predt->Size(), n_samples);
|
||||
|
||||
xgboost_CHECK_C_ARG_PTR(out_result);
|
||||
xgboost_CHECK_C_ARG_PTR(out_shape);
|
||||
|
||||
@ -92,7 +92,7 @@ XGB_DLL int XGDMatrixCreateFromCudaArrayInterface(char const *data,
|
||||
API_END();
|
||||
}
|
||||
|
||||
int InplacePreidctCuda(BoosterHandle handle, char const *c_array_interface,
|
||||
int InplacePreidctCUDA(BoosterHandle handle, char const *c_array_interface,
|
||||
char const *c_json_config, std::shared_ptr<DMatrix> p_m,
|
||||
xgboost::bst_ulong const **out_shape, xgboost::bst_ulong *out_dim,
|
||||
const float **out_result) {
|
||||
@ -107,7 +107,6 @@ int InplacePreidctCuda(BoosterHandle handle, char const *c_array_interface,
|
||||
proxy->SetCUDAArray(c_array_interface);
|
||||
|
||||
auto config = Json::Load(StringView{c_json_config});
|
||||
CHECK_EQ(get<Integer const>(config["cache_id"]), 0) << "Cache ID is not supported yet";
|
||||
auto *learner = static_cast<Learner *>(handle);
|
||||
|
||||
HostDeviceVector<float> *p_predt{nullptr};
|
||||
@ -118,7 +117,13 @@ int InplacePreidctCuda(BoosterHandle handle, char const *c_array_interface,
|
||||
RequiredArg<Integer>(config, "iteration_begin", __func__),
|
||||
RequiredArg<Integer>(config, "iteration_end", __func__));
|
||||
CHECK(p_predt);
|
||||
CHECK(p_predt->DeviceCanRead() && !p_predt->HostCanRead());
|
||||
if (learner->Ctx()->IsCPU()) {
|
||||
// Prediction using DMatrix as fallback.
|
||||
CHECK(p_predt->HostCanRead() && !p_predt->DeviceCanRead());
|
||||
} else {
|
||||
CHECK(p_predt->DeviceCanRead() && !p_predt->HostCanRead());
|
||||
}
|
||||
p_predt->SetDevice(proxy->DeviceIdx());
|
||||
|
||||
auto &shape = learner->GetThreadLocal().prediction_shape;
|
||||
size_t n_samples = p_m->Info().num_row_;
|
||||
@ -146,7 +151,7 @@ XGB_DLL int XGBoosterPredictFromCudaColumnar(BoosterHandle handle, char const *c
|
||||
if (m) {
|
||||
p_m = *static_cast<std::shared_ptr<DMatrix> *>(m);
|
||||
}
|
||||
return InplacePreidctCuda(handle, c_json_strs, c_json_config, p_m, out_shape, out_dim,
|
||||
return InplacePreidctCUDA(handle, c_json_strs, c_json_config, p_m, out_shape, out_dim,
|
||||
out_result);
|
||||
}
|
||||
|
||||
@ -159,6 +164,6 @@ XGB_DLL int XGBoosterPredictFromCudaArray(BoosterHandle handle, char const *c_js
|
||||
p_m = *static_cast<std::shared_ptr<DMatrix> *>(m);
|
||||
}
|
||||
xgboost_CHECK_C_ARG_PTR(out_result);
|
||||
return InplacePreidctCuda(handle, c_json_strs, c_json_config, p_m, out_shape, out_dim,
|
||||
return InplacePreidctCUDA(handle, c_json_strs, c_json_config, p_m, out_shape, out_dim,
|
||||
out_result);
|
||||
}
|
||||
|
||||
@ -6,6 +6,11 @@
|
||||
#ifndef XGBOOST_COMMON_ERROR_MSG_H_
|
||||
#define XGBOOST_COMMON_ERROR_MSG_H_
|
||||
|
||||
#include <cinttypes> // for uint64_t
|
||||
#include <limits> // for numeric_limits
|
||||
|
||||
#include "xgboost/base.h" // for bst_feature_t
|
||||
#include "xgboost/logging.h"
|
||||
#include "xgboost/string_view.h" // for StringView
|
||||
|
||||
namespace xgboost::error {
|
||||
@ -33,5 +38,14 @@ constexpr StringView InconsistentMaxBin() {
|
||||
return "Inconsistent `max_bin`. `max_bin` should be the same across different QuantileDMatrix, "
|
||||
"and consistent with the Booster being trained.";
|
||||
}
|
||||
|
||||
constexpr StringView UnknownDevice() { return "Unknown device type."; }
|
||||
|
||||
inline void MaxFeatureSize(std::uint64_t n_features) {
|
||||
auto max_n_features = std::numeric_limits<bst_feature_t>::max();
|
||||
CHECK_LE(n_features, max_n_features)
|
||||
<< "Unfortunately, XGBoost does not support data matrices with "
|
||||
<< std::numeric_limits<bst_feature_t>::max() << " features or greater";
|
||||
}
|
||||
} // namespace xgboost::error
|
||||
#endif // XGBOOST_COMMON_ERROR_MSG_H_
|
||||
|
||||
@ -7,7 +7,7 @@
|
||||
#include <dmlc/data.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstddef> // std::size_t
|
||||
#include <cstddef> // for size_t
|
||||
#include <functional>
|
||||
#include <limits>
|
||||
#include <map>
|
||||
@ -17,6 +17,7 @@
|
||||
#include <vector>
|
||||
|
||||
#include "../c_api/c_api_error.h"
|
||||
#include "../common/error_msg.h" // for MaxFeatureSize
|
||||
#include "../common/math.h"
|
||||
#include "array_interface.h"
|
||||
#include "arrow-cdi.h"
|
||||
@ -300,9 +301,9 @@ class ArrayAdapter : public detail::SingleBatchDataIter<ArrayAdapterBatch> {
|
||||
array_interface_ = ArrayInterface<2>(get<Object const>(j));
|
||||
batch_ = ArrayAdapterBatch{array_interface_};
|
||||
}
|
||||
ArrayAdapterBatch const& Value() const override { return batch_; }
|
||||
size_t NumRows() const { return array_interface_.Shape(0); }
|
||||
size_t NumColumns() const { return array_interface_.Shape(1); }
|
||||
[[nodiscard]] ArrayAdapterBatch const& Value() const override { return batch_; }
|
||||
[[nodiscard]] std::size_t NumRows() const { return array_interface_.Shape(0); }
|
||||
[[nodiscard]] std::size_t NumColumns() const { return array_interface_.Shape(1); }
|
||||
|
||||
private:
|
||||
ArrayAdapterBatch batch_;
|
||||
|
||||
@ -31,10 +31,10 @@ void IterativeDMatrix::InitFromCUDA(Context const* ctx, BatchParam const& p,
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
|
||||
auto num_rows = [&]() {
|
||||
return Dispatch(proxy, [](auto const& value) { return value.NumRows(); });
|
||||
return cuda_impl::Dispatch(proxy, [](auto const& value) { return value.NumRows(); });
|
||||
};
|
||||
auto num_cols = [&]() {
|
||||
return Dispatch(proxy, [](auto const& value) { return value.NumCols(); });
|
||||
return cuda_impl::Dispatch(proxy, [](auto const& value) { return value.NumCols(); });
|
||||
};
|
||||
|
||||
size_t row_stride = 0;
|
||||
@ -74,7 +74,7 @@ void IterativeDMatrix::InitFromCUDA(Context const* ctx, BatchParam const& p,
|
||||
get_device());
|
||||
auto* p_sketch = &sketch_containers.back();
|
||||
proxy->Info().weights_.SetDevice(get_device());
|
||||
Dispatch(proxy, [&](auto const& value) {
|
||||
cuda_impl::Dispatch(proxy, [&](auto const& value) {
|
||||
common::AdapterDeviceSketch(value, p.max_bin, proxy->Info(), missing, p_sketch);
|
||||
});
|
||||
}
|
||||
@ -82,7 +82,7 @@ void IterativeDMatrix::InitFromCUDA(Context const* ctx, BatchParam const& p,
|
||||
accumulated_rows += batch_rows;
|
||||
dh::device_vector<size_t> row_counts(batch_rows + 1, 0);
|
||||
common::Span<size_t> row_counts_span(row_counts.data().get(), row_counts.size());
|
||||
row_stride = std::max(row_stride, Dispatch(proxy, [=](auto const& value) {
|
||||
row_stride = std::max(row_stride, cuda_impl::Dispatch(proxy, [=](auto const& value) {
|
||||
return GetRowCounts(value, row_counts_span, get_device(), missing);
|
||||
}));
|
||||
nnz += thrust::reduce(thrust::cuda::par(alloc), row_counts.begin(), row_counts.end());
|
||||
@ -136,14 +136,14 @@ void IterativeDMatrix::InitFromCUDA(Context const* ctx, BatchParam const& p,
|
||||
auto rows = num_rows();
|
||||
dh::device_vector<size_t> row_counts(rows + 1, 0);
|
||||
common::Span<size_t> row_counts_span(row_counts.data().get(), row_counts.size());
|
||||
Dispatch(proxy, [=](auto const& value) {
|
||||
cuda_impl::Dispatch(proxy, [=](auto const& value) {
|
||||
return GetRowCounts(value, row_counts_span, get_device(), missing);
|
||||
});
|
||||
auto is_dense = this->IsDense();
|
||||
|
||||
proxy->Info().feature_types.SetDevice(get_device());
|
||||
auto d_feature_types = proxy->Info().feature_types.ConstDeviceSpan();
|
||||
auto new_impl = Dispatch(proxy, [&](auto const& value) {
|
||||
auto new_impl = cuda_impl::Dispatch(proxy, [&](auto const& value) {
|
||||
return EllpackPageImpl(value, missing, get_device(), is_dense, row_counts_span,
|
||||
d_feature_types, row_stride, rows, cuts);
|
||||
});
|
||||
|
||||
@ -1,14 +1,13 @@
|
||||
/*!
|
||||
* Copyright 2021 by Contributors
|
||||
/**
|
||||
* Copyright 2021-2023, XGBoost Contributors
|
||||
* \file proxy_dmatrix.cc
|
||||
*/
|
||||
|
||||
#include "proxy_dmatrix.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace data {
|
||||
void DMatrixProxy::SetArrayData(char const *c_interface) {
|
||||
std::shared_ptr<ArrayAdapter> adapter{new ArrayAdapter(StringView{c_interface})};
|
||||
namespace xgboost::data {
|
||||
void DMatrixProxy::SetArrayData(StringView interface_str) {
|
||||
std::shared_ptr<ArrayAdapter> adapter{new ArrayAdapter{interface_str}};
|
||||
this->batch_ = adapter;
|
||||
this->Info().num_col_ = adapter->NumColumns();
|
||||
this->Info().num_row_ = adapter->NumRows();
|
||||
@ -25,5 +24,36 @@ void DMatrixProxy::SetCSRData(char const *c_indptr, char const *c_indices,
|
||||
this->Info().num_row_ = adapter->NumRows();
|
||||
this->ctx_.gpu_id = Context::kCpuId;
|
||||
}
|
||||
} // namespace data
|
||||
} // namespace xgboost
|
||||
|
||||
namespace cuda_impl {
|
||||
std::shared_ptr<DMatrix> CreateDMatrixFromProxy(Context const *ctx,
|
||||
std::shared_ptr<DMatrixProxy> proxy, float missing);
|
||||
#if !defined(XGBOOST_USE_CUDA)
|
||||
std::shared_ptr<DMatrix> CreateDMatrixFromProxy(Context const *, std::shared_ptr<DMatrixProxy>,
|
||||
float) {
|
||||
return nullptr;
|
||||
}
|
||||
#endif // XGBOOST_USE_CUDA
|
||||
} // namespace cuda_impl
|
||||
|
||||
std::shared_ptr<DMatrix> CreateDMatrixFromProxy(Context const *ctx,
|
||||
std::shared_ptr<DMatrixProxy> proxy,
|
||||
float missing) {
|
||||
bool type_error{false};
|
||||
std::shared_ptr<DMatrix> p_fmat{nullptr};
|
||||
if (proxy->Ctx()->IsCPU()) {
|
||||
p_fmat = data::HostAdapterDispatch<false>(
|
||||
proxy.get(),
|
||||
[&](auto const &adapter) {
|
||||
auto p_fmat =
|
||||
std::shared_ptr<DMatrix>(DMatrix::Create(adapter.get(), missing, ctx->Threads()));
|
||||
return p_fmat;
|
||||
},
|
||||
&type_error);
|
||||
} else {
|
||||
p_fmat = cuda_impl::CreateDMatrixFromProxy(ctx, proxy, missing);
|
||||
}
|
||||
|
||||
return p_fmat;
|
||||
}
|
||||
} // namespace xgboost::data
|
||||
|
||||
@ -1,12 +1,11 @@
|
||||
/*!
|
||||
* Copyright 2020-2022, XGBoost contributors
|
||||
/**
|
||||
* Copyright 2020-2023, XGBoost contributors
|
||||
*/
|
||||
#include "proxy_dmatrix.h"
|
||||
#include "device_adapter.cuh"
|
||||
#include "proxy_dmatrix.cuh"
|
||||
#include "proxy_dmatrix.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace data {
|
||||
|
||||
namespace xgboost::data {
|
||||
void DMatrixProxy::FromCudaColumnar(StringView interface_str) {
|
||||
std::shared_ptr<data::CudfAdapter> adapter{new CudfAdapter{interface_str}};
|
||||
auto const& value = adapter->Value();
|
||||
@ -31,5 +30,15 @@ void DMatrixProxy::FromCudaArray(StringView interface_str) {
|
||||
ctx_.gpu_id = dh::CurrentDevice();
|
||||
}
|
||||
}
|
||||
} // namespace data
|
||||
} // namespace xgboost
|
||||
|
||||
namespace cuda_impl {
|
||||
std::shared_ptr<DMatrix> CreateDMatrixFromProxy(Context const* ctx,
|
||||
std::shared_ptr<DMatrixProxy> proxy,
|
||||
float missing) {
|
||||
return Dispatch<false>(proxy.get(), [&](auto const& adapter) {
|
||||
auto p_fmat = std::shared_ptr<DMatrix>{DMatrix::Create(adapter.get(), missing, ctx->Threads())};
|
||||
return p_fmat;
|
||||
});
|
||||
}
|
||||
} // namespace cuda_impl
|
||||
} // namespace xgboost::data
|
||||
|
||||
@ -6,19 +6,34 @@
|
||||
#include "device_adapter.cuh"
|
||||
#include "proxy_dmatrix.h"
|
||||
|
||||
namespace xgboost::data {
|
||||
template <typename Fn>
|
||||
namespace xgboost::data::cuda_impl {
|
||||
template <bool get_value = true, typename Fn>
|
||||
decltype(auto) Dispatch(DMatrixProxy const* proxy, Fn fn) {
|
||||
if (proxy->Adapter().type() == typeid(std::shared_ptr<CupyAdapter>)) {
|
||||
auto value = std::any_cast<std::shared_ptr<CupyAdapter>>(proxy->Adapter())->Value();
|
||||
return fn(value);
|
||||
if constexpr (get_value) {
|
||||
auto value = std::any_cast<std::shared_ptr<CupyAdapter>>(proxy->Adapter())->Value();
|
||||
return fn(value);
|
||||
} else {
|
||||
auto value = std::any_cast<std::shared_ptr<CupyAdapter>>(proxy->Adapter());
|
||||
return fn(value);
|
||||
}
|
||||
} else if (proxy->Adapter().type() == typeid(std::shared_ptr<CudfAdapter>)) {
|
||||
auto value = std::any_cast<std::shared_ptr<CudfAdapter>>(proxy->Adapter())->Value();
|
||||
return fn(value);
|
||||
if constexpr (get_value) {
|
||||
auto value = std::any_cast<std::shared_ptr<CudfAdapter>>(proxy->Adapter())->Value();
|
||||
return fn(value);
|
||||
} else {
|
||||
auto value = std::any_cast<std::shared_ptr<CudfAdapter>>(proxy->Adapter());
|
||||
return fn(value);
|
||||
}
|
||||
} else {
|
||||
LOG(FATAL) << "Unknown type: " << proxy->Adapter().type().name();
|
||||
auto value = std::any_cast<std::shared_ptr<CudfAdapter>>(proxy->Adapter())->Value();
|
||||
return fn(value);
|
||||
if constexpr (get_value) {
|
||||
auto value = std::any_cast<std::shared_ptr<CudfAdapter>>(proxy->Adapter())->Value();
|
||||
return fn(value);
|
||||
} else {
|
||||
auto value = std::any_cast<std::shared_ptr<CudfAdapter>>(proxy->Adapter());
|
||||
return fn(value);
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace xgboost::data
|
||||
} // namespace xgboost::data::cuda_impl
|
||||
|
||||
@ -62,7 +62,7 @@ class DMatrixProxy : public DMatrix {
|
||||
#endif // defined(XGBOOST_USE_CUDA)
|
||||
}
|
||||
|
||||
void SetArrayData(char const* c_interface);
|
||||
void SetArrayData(StringView interface_str);
|
||||
void SetCSRData(char const* c_indptr, char const* c_indices, char const* c_values,
|
||||
bst_feature_t n_features, bool on_host);
|
||||
|
||||
@ -114,28 +114,62 @@ inline DMatrixProxy* MakeProxy(DMatrixHandle proxy) {
|
||||
return typed;
|
||||
}
|
||||
|
||||
template <typename Fn>
|
||||
/**
|
||||
* @brief Dispatch function call based on input type.
|
||||
*
|
||||
* @tparam get_value Whether the funciton Fn accept an adapter batch or the adapter itself.
|
||||
* @tparam Fn The type of the function to be dispatched.
|
||||
*
|
||||
* @param proxy The proxy object holding the reference to the input.
|
||||
* @param fn The function to be dispatched.
|
||||
* @param type_error[out] Set to ture if it's not null and the input data is not recognized by
|
||||
* the host.
|
||||
*
|
||||
* @return The return value of the function being dispatched.
|
||||
*/
|
||||
template <bool get_value = true, typename Fn>
|
||||
decltype(auto) HostAdapterDispatch(DMatrixProxy const* proxy, Fn fn, bool* type_error = nullptr) {
|
||||
if (proxy->Adapter().type() == typeid(std::shared_ptr<CSRArrayAdapter>)) {
|
||||
auto value = std::any_cast<std::shared_ptr<CSRArrayAdapter>>(proxy->Adapter())->Value();
|
||||
if constexpr (get_value) {
|
||||
auto value = std::any_cast<std::shared_ptr<CSRArrayAdapter>>(proxy->Adapter())->Value();
|
||||
return fn(value);
|
||||
} else {
|
||||
auto value = std::any_cast<std::shared_ptr<CSRArrayAdapter>>(proxy->Adapter());
|
||||
return fn(value);
|
||||
}
|
||||
if (type_error) {
|
||||
*type_error = false;
|
||||
}
|
||||
return fn(value);
|
||||
} else if (proxy->Adapter().type() == typeid(std::shared_ptr<ArrayAdapter>)) {
|
||||
auto value = std::any_cast<std::shared_ptr<ArrayAdapter>>(proxy->Adapter())->Value();
|
||||
if constexpr (get_value) {
|
||||
auto value = std::any_cast<std::shared_ptr<ArrayAdapter>>(proxy->Adapter())->Value();
|
||||
return fn(value);
|
||||
} else {
|
||||
auto value = std::any_cast<std::shared_ptr<ArrayAdapter>>(proxy->Adapter());
|
||||
return fn(value);
|
||||
}
|
||||
if (type_error) {
|
||||
*type_error = false;
|
||||
}
|
||||
return fn(value);
|
||||
} else {
|
||||
if (type_error) {
|
||||
*type_error = true;
|
||||
} else {
|
||||
LOG(FATAL) << "Unknown type: " << proxy->Adapter().type().name();
|
||||
}
|
||||
return std::result_of_t<Fn(decltype(std::declval<std::shared_ptr<ArrayAdapter>>()->Value()))>();
|
||||
if constexpr (get_value) {
|
||||
return std::result_of_t<Fn(
|
||||
decltype(std::declval<std::shared_ptr<ArrayAdapter>>()->Value()))>();
|
||||
} else {
|
||||
return std::result_of_t<Fn(decltype(std::declval<std::shared_ptr<ArrayAdapter>>()))>();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Create a `SimpleDMatrix` instance from a `DMatrixProxy`.
|
||||
*/
|
||||
std::shared_ptr<DMatrix> CreateDMatrixFromProxy(Context const* ctx,
|
||||
std::shared_ptr<DMatrixProxy> proxy, float missing);
|
||||
} // namespace xgboost::data
|
||||
#endif // XGBOOST_DATA_PROXY_DMATRIX_H_
|
||||
|
||||
@ -1,33 +1,31 @@
|
||||
/*!
|
||||
* Copyright 2021 XGBoost contributors
|
||||
/**
|
||||
* Copyright 2021-2023, XGBoost contributors
|
||||
*/
|
||||
#include "../common/device_helpers.cuh" // for CurrentDevice
|
||||
#include "proxy_dmatrix.cuh" // for Dispatch, DMatrixProxy
|
||||
#include "simple_dmatrix.cuh" // for CopyToSparsePage
|
||||
#include "sparse_page_source.h"
|
||||
#include "proxy_dmatrix.cuh"
|
||||
#include "simple_dmatrix.cuh"
|
||||
|
||||
namespace xgboost {
|
||||
namespace data {
|
||||
#include "xgboost/data.h" // for SparsePage
|
||||
|
||||
namespace xgboost::data {
|
||||
namespace detail {
|
||||
std::size_t NSamplesDevice(DMatrixProxy *proxy) {
|
||||
return Dispatch(proxy, [](auto const &value) { return value.NumRows(); });
|
||||
return cuda_impl::Dispatch(proxy, [](auto const &value) { return value.NumRows(); });
|
||||
}
|
||||
|
||||
std::size_t NFeaturesDevice(DMatrixProxy *proxy) {
|
||||
return Dispatch(proxy, [](auto const &value) { return value.NumCols(); });
|
||||
return cuda_impl::Dispatch(proxy, [](auto const &value) { return value.NumCols(); });
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
void DevicePush(DMatrixProxy* proxy, float missing, SparsePage* page) {
|
||||
void DevicePush(DMatrixProxy *proxy, float missing, SparsePage *page) {
|
||||
auto device = proxy->DeviceIdx();
|
||||
if (device < 0) {
|
||||
device = dh::CurrentDevice();
|
||||
}
|
||||
CHECK_GE(device, 0);
|
||||
|
||||
Dispatch(proxy, [&](auto const &value) {
|
||||
CopyToSparsePage(value, device, missing, page);
|
||||
});
|
||||
cuda_impl::Dispatch(proxy,
|
||||
[&](auto const &value) { CopyToSparsePage(value, device, missing, page); });
|
||||
}
|
||||
} // namespace data
|
||||
} // namespace xgboost
|
||||
} // namespace xgboost::data
|
||||
|
||||
@ -172,8 +172,7 @@ class GBLinear : public GradientBooster {
|
||||
}
|
||||
|
||||
void PredictContribution(DMatrix* p_fmat, HostDeviceVector<bst_float>* out_contribs,
|
||||
uint32_t layer_begin, uint32_t /*layer_end*/, bool, int,
|
||||
unsigned) override {
|
||||
bst_layer_t layer_begin, bst_layer_t /*layer_end*/, bool) override {
|
||||
model_.LazyInitModel();
|
||||
LinearCheckLayer(layer_begin);
|
||||
auto base_margin = p_fmat->Info().base_margin_.View(Context::kCpuId);
|
||||
@ -210,8 +209,8 @@ class GBLinear : public GradientBooster {
|
||||
}
|
||||
}
|
||||
|
||||
void PredictInteractionContributions(DMatrix* p_fmat, HostDeviceVector<bst_float>* out_contribs,
|
||||
unsigned layer_begin, unsigned /*layer_end*/,
|
||||
void PredictInteractionContributions(DMatrix* p_fmat, HostDeviceVector<float>* out_contribs,
|
||||
bst_layer_t layer_begin, bst_layer_t /*layer_end*/,
|
||||
bool) override {
|
||||
LinearCheckLayer(layer_begin);
|
||||
std::vector<bst_float>& contribs = out_contribs->HostVector();
|
||||
|
||||
@ -18,9 +18,11 @@
|
||||
#include <vector>
|
||||
|
||||
#include "../common/common.h"
|
||||
#include "../common/error_msg.h" // for UnknownDevice
|
||||
#include "../common/random.h"
|
||||
#include "../common/threading_utils.h"
|
||||
#include "../common/timer.h"
|
||||
#include "../data/proxy_dmatrix.h" // for DMatrixProxy, HostAdapterDispatch
|
||||
#include "gbtree_model.h"
|
||||
#include "xgboost/base.h"
|
||||
#include "xgboost/data.h"
|
||||
@ -58,9 +60,8 @@ void GBTree::Configure(Args const& cfg) {
|
||||
cpu_predictor_->Configure(cfg);
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
auto n_gpus = common::AllVisibleGPUs();
|
||||
if (!gpu_predictor_ && n_gpus != 0) {
|
||||
gpu_predictor_ = std::unique_ptr<Predictor>(
|
||||
Predictor::Create("gpu_predictor", this->ctx_));
|
||||
if (!gpu_predictor_) {
|
||||
gpu_predictor_ = std::unique_ptr<Predictor>(Predictor::Create("gpu_predictor", this->ctx_));
|
||||
}
|
||||
if (n_gpus != 0) {
|
||||
gpu_predictor_->Configure(cfg);
|
||||
@ -374,12 +375,7 @@ void GBTree::LoadConfig(Json const& in) {
|
||||
// This would cause all trees to be pushed to trees_to_update
|
||||
// e.g. updating a model, then saving and loading it would result in an empty model
|
||||
tparam_.process_type = TreeProcessType::kDefault;
|
||||
int32_t const n_gpus = xgboost::common::AllVisibleGPUs();
|
||||
if (n_gpus == 0 && tparam_.predictor == PredictorType::kGPUPredictor) {
|
||||
LOG(WARNING) << "Loading from a raw memory buffer on CPU only machine. "
|
||||
"Changing predictor to auto.";
|
||||
tparam_.UpdateAllowUnknown(Args{{"predictor", "auto"}});
|
||||
}
|
||||
std::int32_t const n_gpus = xgboost::common::AllVisibleGPUs();
|
||||
|
||||
auto msg = StringView{
|
||||
R"(
|
||||
@ -505,8 +501,8 @@ void GBTree::Slice(bst_layer_t begin, bst_layer_t end, bst_layer_t step, Gradien
|
||||
out_model.param.num_parallel_tree = model_.param.num_parallel_tree;
|
||||
}
|
||||
|
||||
void GBTree::PredictBatch(DMatrix* p_fmat, PredictionCacheEntry* out_preds, bool,
|
||||
bst_layer_t layer_begin, bst_layer_t layer_end) {
|
||||
void GBTree::PredictBatchImpl(DMatrix* p_fmat, PredictionCacheEntry* out_preds, bool is_training,
|
||||
bst_layer_t layer_begin, bst_layer_t layer_end) const {
|
||||
CHECK(configured_);
|
||||
if (layer_end == 0) {
|
||||
layer_end = this->BoostedRounds();
|
||||
@ -526,7 +522,7 @@ void GBTree::PredictBatch(DMatrix* p_fmat, PredictionCacheEntry* out_preds, bool
|
||||
CHECK_EQ(out_preds->version, 0);
|
||||
}
|
||||
|
||||
auto const& predictor = GetPredictor(&out_preds->predictions, p_fmat);
|
||||
auto const& predictor = GetPredictor(is_training, &out_preds->predictions, p_fmat);
|
||||
if (out_preds->version == 0) {
|
||||
// out_preds->Size() can be non-zero as it's initialized here before any
|
||||
// tree is built at the 0^th iterator.
|
||||
@ -546,68 +542,69 @@ void GBTree::PredictBatch(DMatrix* p_fmat, PredictionCacheEntry* out_preds, bool
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<Predictor> const &
|
||||
GBTree::GetPredictor(HostDeviceVector<float> const *out_pred,
|
||||
DMatrix *f_dmat) const {
|
||||
void GBTree::PredictBatch(DMatrix* p_fmat, PredictionCacheEntry* out_preds, bool is_training,
|
||||
bst_layer_t layer_begin, bst_layer_t layer_end) {
|
||||
// dispatch to const function.
|
||||
this->PredictBatchImpl(p_fmat, out_preds, is_training, layer_begin, layer_end);
|
||||
}
|
||||
|
||||
void GBTree::InplacePredict(std::shared_ptr<DMatrix> p_m, float missing,
|
||||
PredictionCacheEntry* out_preds, bst_layer_t layer_begin,
|
||||
bst_layer_t layer_end) const {
|
||||
CHECK(configured_);
|
||||
if (tparam_.predictor != PredictorType::kAuto) {
|
||||
if (tparam_.predictor == PredictorType::kGPUPredictor) {
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
CHECK_GE(common::AllVisibleGPUs(), 1) << "No visible GPU is found for XGBoost.";
|
||||
CHECK(gpu_predictor_);
|
||||
return gpu_predictor_;
|
||||
#else
|
||||
common::AssertGPUSupport();
|
||||
#endif // defined(XGBOOST_USE_CUDA)
|
||||
}
|
||||
if (tparam_.predictor == PredictorType::kOneAPIPredictor) {
|
||||
#if defined(XGBOOST_USE_ONEAPI)
|
||||
CHECK(oneapi_predictor_);
|
||||
return oneapi_predictor_;
|
||||
#else
|
||||
common::AssertOneAPISupport();
|
||||
#endif // defined(XGBOOST_USE_ONEAPI)
|
||||
}
|
||||
CHECK(cpu_predictor_);
|
||||
return cpu_predictor_;
|
||||
auto [tree_begin, tree_end] = detail::LayerToTree(model_, layer_begin, layer_end);
|
||||
CHECK_LE(tree_end, model_.trees.size()) << "Invalid number of trees.";
|
||||
if (p_m->Ctx()->Device() != this->ctx_->Device()) {
|
||||
LOG(WARNING) << "Falling back to prediction using DMatrix due to mismatched devices. XGBoost "
|
||||
<< "is running on: " << this->ctx_->DeviceName()
|
||||
<< ", while the input data is on: " << p_m->Ctx()->DeviceName() << ".";
|
||||
CHECK_EQ(out_preds->version, 0);
|
||||
auto proxy = std::dynamic_pointer_cast<data::DMatrixProxy>(p_m);
|
||||
auto any_adapter = proxy->Adapter();
|
||||
auto p_fmat = data::CreateDMatrixFromProxy(ctx_, proxy, missing);
|
||||
this->PredictBatchImpl(p_fmat.get(), out_preds, false, layer_begin, layer_end);
|
||||
return;
|
||||
}
|
||||
|
||||
if (this->ctx_->IsCPU()) {
|
||||
this->cpu_predictor_->InplacePredict(p_m, model_, missing, out_preds, tree_begin, tree_end);
|
||||
} else if (p_m->Ctx()->IsCUDA()) {
|
||||
CHECK(this->gpu_predictor_);
|
||||
this->gpu_predictor_->InplacePredict(p_m, model_, missing, out_preds, tree_begin, tree_end);
|
||||
} else {
|
||||
LOG(FATAL) << error::UnknownDevice();
|
||||
}
|
||||
}
|
||||
|
||||
[[nodiscard]] std::unique_ptr<Predictor> const& GBTree::GetPredictor(
|
||||
bool is_training, HostDeviceVector<float> const* out_pred, DMatrix* f_dmat) const {
|
||||
CHECK(configured_);
|
||||
|
||||
// Data comes from SparsePageDMatrix. Since we are loading data in pages, no need to
|
||||
// prevent data copy.
|
||||
if (f_dmat && !f_dmat->SingleColBlock()) {
|
||||
if (ctx_->IsCPU()) {
|
||||
return cpu_predictor_;
|
||||
} else {
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
CHECK_GE(common::AllVisibleGPUs(), 1) << "No visible GPU is found for XGBoost.";
|
||||
return gpu_predictor_;
|
||||
#else
|
||||
common::AssertGPUSupport();
|
||||
return cpu_predictor_;
|
||||
#endif // defined(XGBOOST_USE_CUDA)
|
||||
CHECK(gpu_predictor_);
|
||||
return gpu_predictor_;
|
||||
}
|
||||
}
|
||||
|
||||
// Data comes from Device DMatrix.
|
||||
auto is_ellpack = f_dmat && f_dmat->PageExists<EllpackPage>() &&
|
||||
!f_dmat->PageExists<SparsePage>();
|
||||
auto is_ellpack =
|
||||
f_dmat && f_dmat->PageExists<EllpackPage>() && !f_dmat->PageExists<SparsePage>();
|
||||
// Data comes from device memory, like CuDF or CuPy.
|
||||
auto is_from_device =
|
||||
f_dmat && f_dmat->PageExists<SparsePage>() &&
|
||||
(*(f_dmat->GetBatches<SparsePage>().begin())).data.DeviceCanRead();
|
||||
auto is_from_device = f_dmat && f_dmat->PageExists<SparsePage>() &&
|
||||
(*(f_dmat->GetBatches<SparsePage>().begin())).data.DeviceCanRead();
|
||||
auto on_device = is_ellpack || is_from_device;
|
||||
|
||||
// Use GPU Predictor if data is already on device and gpu_id is set.
|
||||
if (on_device && ctx_->gpu_id >= 0) {
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
CHECK_GE(common::AllVisibleGPUs(), 1) << "No visible GPU is found for XGBoost.";
|
||||
if (on_device && ctx_->IsCUDA()) {
|
||||
common::AssertGPUSupport();
|
||||
CHECK(gpu_predictor_);
|
||||
return gpu_predictor_;
|
||||
#else
|
||||
LOG(FATAL) << "Data is on CUDA device, but XGBoost is not compiled with "
|
||||
"CUDA support.";
|
||||
return cpu_predictor_;
|
||||
#endif // defined(XGBOOST_USE_CUDA)
|
||||
}
|
||||
|
||||
// GPU_Hist by default has prediction cache calculated from quantile values,
|
||||
@ -619,23 +616,19 @@ GBTree::GetPredictor(HostDeviceVector<float> const *out_pred,
|
||||
if ((out_pred && out_pred->Size() == 0) && (model_.param.num_trees != 0) &&
|
||||
// FIXME(trivialfis): Implement a better method for testing whether data
|
||||
// is on device after DMatrix refactoring is done.
|
||||
!on_device) {
|
||||
!on_device && is_training) {
|
||||
CHECK(cpu_predictor_);
|
||||
return cpu_predictor_;
|
||||
}
|
||||
|
||||
if (tparam_.tree_method == TreeMethod::kGPUHist) {
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
CHECK_GE(common::AllVisibleGPUs(), 1) << "No visible GPU is found for XGBoost.";
|
||||
if (ctx_->IsCPU()) {
|
||||
return cpu_predictor_;
|
||||
} else {
|
||||
common::AssertGPUSupport();
|
||||
CHECK(gpu_predictor_);
|
||||
return gpu_predictor_;
|
||||
#else
|
||||
common::AssertGPUSupport();
|
||||
return cpu_predictor_;
|
||||
#endif // defined(XGBOOST_USE_CUDA)
|
||||
}
|
||||
|
||||
CHECK(cpu_predictor_);
|
||||
return cpu_predictor_;
|
||||
}
|
||||
|
||||
@ -750,7 +743,7 @@ class Dart : public GBTree {
|
||||
bool training, unsigned layer_begin,
|
||||
unsigned layer_end) const {
|
||||
CHECK(!this->model_.learner_model_param->IsVectorLeaf()) << "dart" << MTNotImplemented();
|
||||
auto &predictor = this->GetPredictor(&p_out_preds->predictions, p_fmat);
|
||||
auto& predictor = this->GetPredictor(training, &p_out_preds->predictions, p_fmat);
|
||||
CHECK(predictor);
|
||||
predictor->InitOutPredictions(p_fmat->Info(), &p_out_preds->predictions,
|
||||
model_);
|
||||
@ -814,49 +807,46 @@ class Dart : public GBTree {
|
||||
auto [tree_begin, tree_end] = detail::LayerToTree(model_, layer_begin, layer_end);
|
||||
auto n_groups = model_.learner_model_param->num_output_group;
|
||||
|
||||
std::vector<Predictor const*> predictors {
|
||||
cpu_predictor_.get(),
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
gpu_predictor_.get()
|
||||
#endif // defined(XGBOOST_USE_CUDA)
|
||||
};
|
||||
Predictor const* predictor{nullptr};
|
||||
StringView msg{"Unsupported data type for inplace predict."};
|
||||
if (ctx_->Device() != p_fmat->Ctx()->Device()) {
|
||||
LOG(WARNING) << "Falling back to prediction using DMatrix due to mismatched devices. XGBoost "
|
||||
<< "is running on: " << this->ctx_->DeviceName()
|
||||
<< ", while the input data is on: " << p_fmat->Ctx()->DeviceName() << ".";
|
||||
auto proxy = std::dynamic_pointer_cast<data::DMatrixProxy>(p_fmat);
|
||||
auto any_adapter = proxy->Adapter();
|
||||
auto p_fmat = data::CreateDMatrixFromProxy(ctx_, proxy, missing);
|
||||
this->PredictBatchImpl(p_fmat.get(), p_out_preds, false, layer_begin, layer_end);
|
||||
return;
|
||||
}
|
||||
|
||||
StringView msg{"Unsupported data type for inplace predict."};
|
||||
PredictionCacheEntry predts;
|
||||
if (ctx_->gpu_id != Context::kCpuId) {
|
||||
predts.predictions.SetDevice(ctx_->gpu_id);
|
||||
}
|
||||
predts.predictions.Resize(p_fmat->Info().num_row_ * n_groups, 0);
|
||||
|
||||
auto get_predictor = [&]() -> Predictor const* {
|
||||
if (ctx_->IsCPU()) {
|
||||
return cpu_predictor_.get();
|
||||
} else if (ctx_->IsCUDA()) {
|
||||
CHECK(this->gpu_predictor_);
|
||||
return gpu_predictor_.get();
|
||||
} else {
|
||||
LOG(FATAL) << error::UnknownDevice();
|
||||
return nullptr;
|
||||
}
|
||||
};
|
||||
auto predict_impl = [&](size_t i) {
|
||||
predts.predictions.Fill(0);
|
||||
if (tparam_.predictor == PredictorType::kAuto) {
|
||||
// Try both predictor implementations
|
||||
bool success = false;
|
||||
for (auto const& p : predictors) {
|
||||
if (p && p->InplacePredict(p_fmat, model_, missing, &predts, i, i + 1)) {
|
||||
success = true;
|
||||
predictor = p;
|
||||
break;
|
||||
}
|
||||
}
|
||||
CHECK(success) << msg;
|
||||
} else {
|
||||
predictor = this->GetPredictor().get();
|
||||
bool success = predictor->InplacePredict(p_fmat, model_, missing, &predts, i, i + 1);
|
||||
CHECK(success) << msg << std::endl
|
||||
<< "Current Predictor: "
|
||||
<< (tparam_.predictor == PredictorType::kCPUPredictor ? "cpu_predictor"
|
||||
: "gpu_predictor");
|
||||
}
|
||||
bool success{get_predictor()->InplacePredict(p_fmat, model_, missing, &predts, i, i + 1)};
|
||||
CHECK(success) << msg;
|
||||
};
|
||||
|
||||
// Inplace predict is not used for training, so no need to drop tree.
|
||||
for (bst_tree_t i = tree_begin; i < tree_end; ++i) {
|
||||
predict_impl(i);
|
||||
if (i == tree_begin) {
|
||||
predictor->InitOutPredictions(p_fmat->Info(), &p_out_preds->predictions, model_);
|
||||
get_predictor()->InitOutPredictions(p_fmat->Info(), &p_out_preds->predictions, model_);
|
||||
}
|
||||
// Multiple the tree weight
|
||||
auto w = this->weight_drop_.at(i);
|
||||
@ -886,25 +876,24 @@ class Dart : public GBTree {
|
||||
std::vector<bst_float> *out_preds,
|
||||
unsigned layer_begin, unsigned layer_end) override {
|
||||
DropTrees(false);
|
||||
auto &predictor = this->GetPredictor();
|
||||
auto &predictor = this->GetPredictor(false);
|
||||
uint32_t _, tree_end;
|
||||
std::tie(_, tree_end) = detail::LayerToTree(model_, layer_begin, layer_end);
|
||||
predictor->PredictInstance(inst, out_preds, model_, tree_end);
|
||||
}
|
||||
|
||||
void PredictContribution(DMatrix* p_fmat,
|
||||
HostDeviceVector<bst_float>* out_contribs,
|
||||
unsigned layer_begin, unsigned layer_end, bool approximate, int,
|
||||
unsigned) override {
|
||||
void PredictContribution(DMatrix* p_fmat, HostDeviceVector<bst_float>* out_contribs,
|
||||
bst_layer_t layer_begin, bst_layer_t layer_end,
|
||||
bool approximate) override {
|
||||
CHECK(configured_);
|
||||
auto [tree_begin, tree_end] = detail::LayerToTree(model_, layer_begin, layer_end);
|
||||
cpu_predictor_->PredictContribution(p_fmat, out_contribs, model_, tree_end, &weight_drop_,
|
||||
approximate);
|
||||
}
|
||||
|
||||
void PredictInteractionContributions(
|
||||
DMatrix *p_fmat, HostDeviceVector<bst_float> *out_contribs,
|
||||
unsigned layer_begin, unsigned layer_end, bool approximate) override {
|
||||
void PredictInteractionContributions(DMatrix* p_fmat, HostDeviceVector<float>* out_contribs,
|
||||
bst_layer_t layer_begin, bst_layer_t layer_end,
|
||||
bool approximate) override {
|
||||
CHECK(configured_);
|
||||
auto [tree_begin, tree_end] = detail::LayerToTree(model_, layer_begin, layer_end);
|
||||
cpu_predictor_->PredictInteractionContributions(p_fmat, out_contribs, model_, tree_end,
|
||||
|
||||
@ -1,14 +1,11 @@
|
||||
/*!
|
||||
* Copyright 2021 by Contributors
|
||||
/**
|
||||
* Copyright 2021-2023, XGBoost Contributors
|
||||
*/
|
||||
#include "../common/device_helpers.cuh"
|
||||
#include "xgboost/context.h"
|
||||
#include "xgboost/linalg.h"
|
||||
#include "xgboost/span.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace gbm {
|
||||
|
||||
namespace xgboost::gbm {
|
||||
void GPUCopyGradient(HostDeviceVector<GradientPair> const *in_gpair,
|
||||
bst_group_t n_groups, bst_group_t group_id,
|
||||
HostDeviceVector<GradientPair> *out_gpair) {
|
||||
@ -41,5 +38,4 @@ void GPUDartInplacePredictInc(common::Span<float> out_predts, common::Span<float
|
||||
out_predts[offset] += (predts[offset] - base_score(0)) * tree_w;
|
||||
});
|
||||
}
|
||||
} // namespace gbm
|
||||
} // namespace xgboost
|
||||
} // namespace xgboost::gbm
|
||||
|
||||
102
src/gbm/gbtree.h
102
src/gbm/gbtree.h
@ -43,18 +43,10 @@ enum class TreeProcessType : int {
|
||||
kDefault = 0,
|
||||
kUpdate = 1
|
||||
};
|
||||
|
||||
enum class PredictorType : int {
|
||||
kAuto = 0,
|
||||
kCPUPredictor,
|
||||
kGPUPredictor,
|
||||
kOneAPIPredictor
|
||||
};
|
||||
} // namespace xgboost
|
||||
|
||||
DECLARE_FIELD_ENUM_CLASS(xgboost::TreeMethod);
|
||||
DECLARE_FIELD_ENUM_CLASS(xgboost::TreeProcessType);
|
||||
DECLARE_FIELD_ENUM_CLASS(xgboost::PredictorType);
|
||||
|
||||
namespace xgboost::gbm {
|
||||
/*! \brief training parameters */
|
||||
@ -63,8 +55,6 @@ struct GBTreeTrainParam : public XGBoostParameter<GBTreeTrainParam> {
|
||||
std::string updater_seq;
|
||||
/*! \brief type of boosting process to run */
|
||||
TreeProcessType process_type;
|
||||
// predictor type
|
||||
PredictorType predictor;
|
||||
// tree construction method
|
||||
TreeMethod tree_method;
|
||||
// declare parameters
|
||||
@ -79,13 +69,6 @@ struct GBTreeTrainParam : public XGBoostParameter<GBTreeTrainParam> {
|
||||
.describe("Whether to run the normal boosting process that creates new trees,"\
|
||||
" or to update the trees in an existing model.");
|
||||
DMLC_DECLARE_ALIAS(updater_seq, updater);
|
||||
DMLC_DECLARE_FIELD(predictor)
|
||||
.set_default(PredictorType::kAuto)
|
||||
.add_enum("auto", PredictorType::kAuto)
|
||||
.add_enum("cpu_predictor", PredictorType::kCPUPredictor)
|
||||
.add_enum("gpu_predictor", PredictorType::kGPUPredictor)
|
||||
.add_enum("oneapi_predictor", PredictorType::kOneAPIPredictor)
|
||||
.describe("Predictor algorithm type");
|
||||
DMLC_DECLARE_FIELD(tree_method)
|
||||
.set_default(TreeMethod::kAuto)
|
||||
.add_enum("auto", TreeMethod::kAuto)
|
||||
@ -206,15 +189,9 @@ class GBTree : public GradientBooster {
|
||||
void DoBoost(DMatrix* p_fmat, HostDeviceVector<GradientPair>* in_gpair,
|
||||
PredictionCacheEntry* predt, ObjFunction const* obj) override;
|
||||
|
||||
bool UseGPU() const override {
|
||||
return
|
||||
tparam_.predictor == PredictorType::kGPUPredictor ||
|
||||
tparam_.tree_method == TreeMethod::kGPUHist;
|
||||
}
|
||||
[[nodiscard]] bool UseGPU() const override { return tparam_.tree_method == TreeMethod::kGPUHist; }
|
||||
|
||||
GBTreeTrainParam const& GetTrainParam() const {
|
||||
return tparam_;
|
||||
}
|
||||
[[nodiscard]] GBTreeTrainParam const& GetTrainParam() const { return tparam_; }
|
||||
|
||||
void Load(dmlc::Stream* fi) override { model_.Load(fi); }
|
||||
void Save(dmlc::Stream* fo) const override {
|
||||
@ -236,39 +213,14 @@ class GBTree : public GradientBooster {
|
||||
return !model_.trees.empty() || !model_.trees_to_update.empty();
|
||||
}
|
||||
|
||||
void PredictBatchImpl(DMatrix* p_fmat, PredictionCacheEntry* out_preds, bool is_training,
|
||||
bst_layer_t layer_begin, bst_layer_t layer_end) const;
|
||||
|
||||
void PredictBatch(DMatrix* p_fmat, PredictionCacheEntry* out_preds, bool training,
|
||||
bst_layer_t layer_begin, bst_layer_t layer_end) override;
|
||||
|
||||
void InplacePredict(std::shared_ptr<DMatrix> p_m, float missing, PredictionCacheEntry* out_preds,
|
||||
bst_layer_t layer_begin, bst_layer_t layer_end) const override {
|
||||
CHECK(configured_);
|
||||
auto [tree_begin, tree_end] = detail::LayerToTree(model_, layer_begin, layer_end);
|
||||
CHECK_LE(tree_end, model_.trees.size()) << "Invalid number of trees.";
|
||||
std::vector<Predictor const *> predictors{
|
||||
cpu_predictor_.get(),
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
gpu_predictor_.get()
|
||||
#endif // defined(XGBOOST_USE_CUDA)
|
||||
};
|
||||
StringView msg{"Unsupported data type for inplace predict."};
|
||||
if (tparam_.predictor == PredictorType::kAuto) {
|
||||
// Try both predictor implementations
|
||||
for (auto const &p : predictors) {
|
||||
if (p && p->InplacePredict(p_m, model_, missing, out_preds, tree_begin, tree_end)) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
LOG(FATAL) << msg;
|
||||
} else {
|
||||
bool success = this->GetPredictor()->InplacePredict(p_m, model_, missing, out_preds,
|
||||
tree_begin, tree_end);
|
||||
CHECK(success) << msg << std::endl
|
||||
<< "Current Predictor: "
|
||||
<< (tparam_.predictor == PredictorType::kCPUPredictor
|
||||
? "cpu_predictor"
|
||||
: "gpu_predictor");
|
||||
}
|
||||
}
|
||||
bst_layer_t layer_begin, bst_layer_t layer_end) const override;
|
||||
|
||||
void FeatureScore(std::string const& importance_type, common::Span<int32_t const> trees,
|
||||
std::vector<bst_feature_t>* features,
|
||||
@ -349,32 +301,29 @@ class GBTree : public GradientBooster {
|
||||
auto [tree_begin, tree_end] = detail::LayerToTree(model_, layer_begin, layer_end);
|
||||
CHECK_EQ(tree_begin, 0) << "Predict leaf supports only iteration end: (0, "
|
||||
"n_iteration), use model slicing instead.";
|
||||
this->GetPredictor()->PredictLeaf(p_fmat, out_preds, model_, tree_end);
|
||||
this->GetPredictor(false)->PredictLeaf(p_fmat, out_preds, model_, tree_end);
|
||||
}
|
||||
|
||||
void PredictContribution(DMatrix* p_fmat,
|
||||
HostDeviceVector<bst_float>* out_contribs,
|
||||
uint32_t layer_begin, uint32_t layer_end, bool approximate,
|
||||
int, unsigned) override {
|
||||
void PredictContribution(DMatrix* p_fmat, HostDeviceVector<float>* out_contribs,
|
||||
bst_layer_t layer_begin, bst_layer_t layer_end,
|
||||
bool approximate) override {
|
||||
CHECK(configured_);
|
||||
auto [tree_begin, tree_end] = detail::LayerToTree(model_, layer_begin, layer_end);
|
||||
CHECK_EQ(tree_begin, 0)
|
||||
<< "Predict contribution supports only iteration end: (0, "
|
||||
"n_iteration), using model slicing instead.";
|
||||
this->GetPredictor()->PredictContribution(
|
||||
p_fmat, out_contribs, model_, tree_end, nullptr, approximate);
|
||||
CHECK_EQ(tree_begin, 0) << "Predict contribution supports only iteration end: (0, "
|
||||
"n_iteration), using model slicing instead.";
|
||||
this->GetPredictor(false)->PredictContribution(p_fmat, out_contribs, model_, tree_end, nullptr,
|
||||
approximate);
|
||||
}
|
||||
|
||||
void PredictInteractionContributions(
|
||||
DMatrix *p_fmat, HostDeviceVector<bst_float> *out_contribs,
|
||||
uint32_t layer_begin, uint32_t layer_end, bool approximate) override {
|
||||
void PredictInteractionContributions(DMatrix* p_fmat, HostDeviceVector<float>* out_contribs,
|
||||
bst_layer_t layer_begin, bst_layer_t layer_end,
|
||||
bool approximate) override {
|
||||
CHECK(configured_);
|
||||
auto [tree_begin, tree_end] = detail::LayerToTree(model_, layer_begin, layer_end);
|
||||
CHECK_EQ(tree_begin, 0)
|
||||
<< "Predict interaction contribution supports only iteration end: (0, "
|
||||
"n_iteration), using model slicing instead.";
|
||||
this->GetPredictor()->PredictInteractionContributions(
|
||||
p_fmat, out_contribs, model_, tree_end, nullptr, approximate);
|
||||
CHECK_EQ(tree_begin, 0) << "Predict interaction contribution supports only iteration end: (0, "
|
||||
"n_iteration), using model slicing instead.";
|
||||
this->GetPredictor(false)->PredictInteractionContributions(p_fmat, out_contribs, model_,
|
||||
tree_end, nullptr, approximate);
|
||||
}
|
||||
|
||||
[[nodiscard]] std::vector<std::string> DumpModel(const FeatureMap& fmap, bool with_stats,
|
||||
@ -390,8 +339,9 @@ class GBTree : public GradientBooster {
|
||||
std::vector<HostDeviceVector<bst_node_t>>* out_position,
|
||||
std::vector<std::unique_ptr<RegTree>>* ret);
|
||||
|
||||
std::unique_ptr<Predictor> const& GetPredictor(HostDeviceVector<float> const* out_pred = nullptr,
|
||||
DMatrix* f_dmat = nullptr) const;
|
||||
[[nodiscard]] std::unique_ptr<Predictor> const& GetPredictor(
|
||||
bool is_training, HostDeviceVector<float> const* out_pred = nullptr,
|
||||
DMatrix* f_dmat = nullptr) const;
|
||||
|
||||
// commit new trees all at once
|
||||
virtual void CommitModel(TreesOneIter&& new_trees);
|
||||
@ -410,9 +360,7 @@ class GBTree : public GradientBooster {
|
||||
std::vector<std::unique_ptr<TreeUpdater>> updaters_;
|
||||
// Predictors
|
||||
std::unique_ptr<Predictor> cpu_predictor_;
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
std::unique_ptr<Predictor> gpu_predictor_;
|
||||
#endif // defined(XGBOOST_USE_CUDA)
|
||||
std::unique_ptr<Predictor> gpu_predictor_{nullptr};
|
||||
#if defined(XGBOOST_USE_ONEAPI)
|
||||
std::unique_ptr<Predictor> oneapi_predictor_;
|
||||
#endif // defined(XGBOOST_USE_ONEAPI)
|
||||
|
||||
@ -40,6 +40,7 @@
|
||||
#include "common/api_entry.h" // for XGBAPIThreadLocalEntry
|
||||
#include "common/charconv.h" // for to_chars, to_chars_result, NumericLimits, from_...
|
||||
#include "common/common.h" // for ToString, Split
|
||||
#include "common/error_msg.h" // for MaxFeatureSize
|
||||
#include "common/io.h" // for PeekableInStream, ReadAll, FixedSizeStream, Mem...
|
||||
#include "common/observer.h" // for TrainingObserver
|
||||
#include "common/random.h" // for GlobalRandom
|
||||
@ -763,9 +764,7 @@ class LearnerConfiguration : public Learner {
|
||||
CHECK(matrix.first.ptr);
|
||||
CHECK(!matrix.second.ref.expired());
|
||||
const uint64_t num_col = matrix.first.ptr->Info().num_col_;
|
||||
CHECK_LE(num_col, static_cast<uint64_t>(std::numeric_limits<unsigned>::max()))
|
||||
<< "Unfortunately, XGBoost does not support data matrices with "
|
||||
<< std::numeric_limits<unsigned>::max() << " features or greater";
|
||||
error::MaxFeatureSize(num_col);
|
||||
num_feature = std::max(num_feature, static_cast<uint32_t>(num_col));
|
||||
}
|
||||
|
||||
@ -1413,6 +1412,8 @@ class LearnerImpl : public LearnerIO {
|
||||
this->CheckModelInitialized();
|
||||
|
||||
auto& out_predictions = this->GetThreadLocal().prediction_entry;
|
||||
out_predictions.version = 0;
|
||||
|
||||
this->gbm_->InplacePredict(p_m, missing, &out_predictions, iteration_begin, iteration_end);
|
||||
if (type == PredictionType::kValue) {
|
||||
obj_->PredTransform(&out_predictions.predictions);
|
||||
|
||||
@ -577,8 +577,8 @@ void LambdaRankUpdatePositionBias(Context const* ctx, linalg::VectorView<double
|
||||
if (lj(0) >= Eps64()) {
|
||||
tj_minus(i) = std::pow(lj(i) / lj(0), regularizer);
|
||||
}
|
||||
assert(!std::isinf(ti_plus(i)));
|
||||
assert(!std::isinf(tj_minus(i)));
|
||||
assert(!isinf(ti_plus(i)));
|
||||
assert(!isinf(tj_minus(i)));
|
||||
});
|
||||
}
|
||||
} // namespace cuda_impl
|
||||
|
||||
@ -883,9 +883,8 @@ class CPUPredictor : public Predictor {
|
||||
for (const auto &batch : p_fmat->GetBatches<SparsePage>()) {
|
||||
auto page = batch.GetView();
|
||||
// parallel over local batch
|
||||
const auto nsize = static_cast<bst_omp_uint>(batch.Size());
|
||||
common::ParallelFor(nsize, n_threads, [&](bst_omp_uint i) {
|
||||
auto row_idx = static_cast<size_t>(batch.base_rowid + i);
|
||||
common::ParallelFor(batch.Size(), n_threads, [&](auto i) {
|
||||
auto row_idx = batch.base_rowid + i;
|
||||
RegTree::FVec &feats = feat_vecs[omp_get_thread_num()];
|
||||
if (feats.Size() == 0) {
|
||||
feats.Init(num_feature);
|
||||
|
||||
@ -226,9 +226,7 @@ struct GPUHistMakerDevice {
|
||||
monitor.Init(std::string("GPUHistMakerDevice") + std::to_string(ctx_->gpu_id));
|
||||
}
|
||||
|
||||
~GPUHistMakerDevice() { // NOLINT
|
||||
dh::safe_cuda(cudaSetDevice(ctx_->gpu_id));
|
||||
}
|
||||
~GPUHistMakerDevice() = default;
|
||||
|
||||
void InitFeatureGroupsOnce() {
|
||||
if (!feature_groups) {
|
||||
|
||||
@ -25,6 +25,9 @@ class LintersPaths:
|
||||
"tests/python/test_tree_regularization.py",
|
||||
"tests/python/test_shap.py",
|
||||
"tests/python-gpu/test_gpu_data_iterator.py",
|
||||
"tests/python-gpu/test_gpu_prediction.py",
|
||||
"tests/python-gpu/load_pickle.py",
|
||||
"tests/python-gpu/test_gpu_pickling.py",
|
||||
"tests/test_distributed/test_with_spark/",
|
||||
"tests/test_distributed/test_gpu_with_spark/",
|
||||
# demo
|
||||
@ -68,6 +71,7 @@ class LintersPaths:
|
||||
"tests/python/test_dt.py",
|
||||
"tests/python/test_data_iterator.py",
|
||||
"tests/python-gpu/test_gpu_data_iterator.py",
|
||||
"tests/python-gpu/load_pickle.py",
|
||||
"tests/test_distributed/test_with_spark/test_data.py",
|
||||
"tests/test_distributed/test_gpu_with_spark/test_data.py",
|
||||
"tests/test_distributed/test_gpu_with_dask/test_gpu_with_dask.py",
|
||||
|
||||
@ -41,7 +41,6 @@ std::string GetModelStr() {
|
||||
"num_class": "0",
|
||||
"num_feature": "10",
|
||||
"objective": "reg:linear",
|
||||
"predictor": "gpu_predictor",
|
||||
"tree_method": "gpu_hist",
|
||||
"updater": "grow_gpu_hist"
|
||||
},
|
||||
|
||||
@ -1,17 +1,20 @@
|
||||
/*!
|
||||
* Copyright 2019-2022 XGBoost contributors
|
||||
/**
|
||||
* Copyright 2019-2023, XGBoost contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
#include <xgboost/context.h>
|
||||
#include <xgboost/host_device_vector.h> // for HostDeviceVector
|
||||
#include <xgboost/learner.h> // for Learner
|
||||
|
||||
#include "../../../src/data/adapter.h"
|
||||
#include "../../../src/data/proxy_dmatrix.h"
|
||||
#include <limits> // for numeric_limits
|
||||
#include <memory> // for shared_ptr
|
||||
#include <string> // for string
|
||||
|
||||
#include "../../../src/data/proxy_dmatrix.h" // for DMatrixProxy
|
||||
#include "../../../src/gbm/gbtree.h"
|
||||
#include "../filesystem.h" // dmlc::TemporaryDirectory
|
||||
#include "../helpers.h"
|
||||
#include "xgboost/base.h"
|
||||
#include "xgboost/host_device_vector.h"
|
||||
#include "xgboost/learner.h"
|
||||
#include "xgboost/predictor.h"
|
||||
|
||||
namespace xgboost {
|
||||
@ -113,12 +116,11 @@ TEST(GBTree, WrongUpdater) {
|
||||
#ifdef XGBOOST_USE_CUDA
|
||||
TEST(GBTree, ChoosePredictor) {
|
||||
// The test ensures data don't get pulled into device.
|
||||
size_t constexpr kRows = 17;
|
||||
size_t constexpr kCols = 15;
|
||||
std::size_t constexpr kRows = 17, kCols = 15;
|
||||
|
||||
auto p_dmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix();
|
||||
|
||||
auto& data = (*(p_dmat->GetBatches<SparsePage>().begin())).data;
|
||||
auto const& data = (*(p_dmat->GetBatches<SparsePage>().begin())).data;
|
||||
p_dmat->Info().labels.Reshape(kRows);
|
||||
|
||||
auto learner = std::unique_ptr<Learner>(Learner::Create({p_dmat}));
|
||||
@ -127,14 +129,13 @@ TEST(GBTree, ChoosePredictor) {
|
||||
learner->UpdateOneIter(i, p_dmat);
|
||||
}
|
||||
ASSERT_TRUE(data.HostCanWrite());
|
||||
|
||||
dmlc::TemporaryDirectory tempdir;
|
||||
const std::string fname = tempdir.path + "/model_param.bst";
|
||||
|
||||
{
|
||||
std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(fname.c_str(), "w"));
|
||||
learner->Save(fo.get());
|
||||
}
|
||||
|
||||
// a new learner
|
||||
learner = std::unique_ptr<Learner>(Learner::Create({p_dmat}));
|
||||
{
|
||||
@ -146,6 +147,8 @@ TEST(GBTree, ChoosePredictor) {
|
||||
learner->UpdateOneIter(i, p_dmat);
|
||||
}
|
||||
ASSERT_TRUE(data.HostCanWrite());
|
||||
ASSERT_FALSE(data.DeviceCanWrite());
|
||||
ASSERT_FALSE(data.DeviceCanRead());
|
||||
|
||||
// pull data into device.
|
||||
data.HostVector();
|
||||
@ -232,14 +235,15 @@ TEST(Dart, JsonIO) {
|
||||
namespace {
|
||||
class Dart : public testing::TestWithParam<char const*> {
|
||||
public:
|
||||
void Run(std::string predictor) {
|
||||
void Run(std::string device) {
|
||||
size_t constexpr kRows = 16, kCols = 10;
|
||||
|
||||
HostDeviceVector<float> data;
|
||||
auto rng = RandomDataGenerator(kRows, kCols, 0);
|
||||
if (predictor == "gpu_predictor") {
|
||||
rng.Device(0);
|
||||
Context ctx;
|
||||
if (device == "GPU") {
|
||||
ctx = MakeCUDACtx(0);
|
||||
}
|
||||
auto rng = RandomDataGenerator(kRows, kCols, 0).Device(ctx.gpu_id);
|
||||
auto array_str = rng.GenerateArrayInterface(&data);
|
||||
auto p_mat = GetDMatrixFromData(data.HostVector(), kRows, kCols);
|
||||
|
||||
@ -258,14 +262,14 @@ class Dart : public testing::TestWithParam<char const*> {
|
||||
learner->UpdateOneIter(i, p_mat);
|
||||
}
|
||||
|
||||
learner->SetParam("predictor", predictor);
|
||||
ConfigLearnerByCtx(&ctx, learner.get());
|
||||
|
||||
HostDeviceVector<float> predts_training;
|
||||
learner->Predict(p_mat, false, &predts_training, 0, 0, true);
|
||||
|
||||
HostDeviceVector<float>* inplace_predts;
|
||||
std::shared_ptr<data::DMatrixProxy> x{new data::DMatrixProxy{}};
|
||||
if (predictor == "gpu_predictor") {
|
||||
if (ctx.IsCUDA()) {
|
||||
x->SetCUDAArray(array_str.c_str());
|
||||
} else {
|
||||
x->SetArrayData(array_str.c_str());
|
||||
@ -295,10 +299,9 @@ class Dart : public testing::TestWithParam<char const*> {
|
||||
TEST_P(Dart, Prediction) { this->Run(GetParam()); }
|
||||
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
INSTANTIATE_TEST_SUITE_P(PredictorTypes, Dart,
|
||||
testing::Values("auto", "cpu_predictor", "gpu_predictor"));
|
||||
INSTANTIATE_TEST_SUITE_P(PredictorTypes, Dart, testing::Values("CPU", "GPU"));
|
||||
#else
|
||||
INSTANTIATE_TEST_SUITE_P(PredictorTypes, Dart, testing::Values("auto", "cpu_predictor"));
|
||||
INSTANTIATE_TEST_SUITE_P(PredictorTypes, Dart, testing::Values("CPU"));
|
||||
#endif // defined(XGBOOST_USE_CUDA)
|
||||
|
||||
|
||||
|
||||
88
tests/cpp/gbm/test_gbtree.cu
Normal file
88
tests/cpp/gbm/test_gbtree.cu
Normal file
@ -0,0 +1,88 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost contributors
|
||||
*/
|
||||
#include <xgboost/context.h> // for Context
|
||||
#include <xgboost/learner.h> // for Learner
|
||||
#include <xgboost/string_view.h> // for StringView
|
||||
|
||||
#include <limits> // for numeric_limits
|
||||
#include <memory> // for shared_ptr
|
||||
#include <string> // for string
|
||||
|
||||
#include "../../../src/data/adapter.h" // for ArrayAdapter
|
||||
#include "../../../src/data/device_adapter.cuh" // for CupyAdapter
|
||||
#include "../../../src/data/proxy_dmatrix.h" // for DMatrixProxy
|
||||
#include "../helpers.h" // for RandomDataGenerator
|
||||
|
||||
namespace xgboost {
|
||||
void TestInplaceFallback(Context const* ctx) {
|
||||
// prepare data
|
||||
bst_row_t n_samples{1024};
|
||||
bst_feature_t n_features{32};
|
||||
HostDeviceVector<float> X_storage;
|
||||
// use a different device than the learner
|
||||
std::int32_t data_ordinal = ctx->IsCPU() ? 0 : -1;
|
||||
auto X = RandomDataGenerator{n_samples, n_features, 0.0}
|
||||
.Device(data_ordinal)
|
||||
.GenerateArrayInterface(&X_storage);
|
||||
HostDeviceVector<float> y_storage;
|
||||
auto y = RandomDataGenerator{n_samples, 1u, 0.0}.GenerateArrayInterface(&y_storage);
|
||||
|
||||
std::shared_ptr<DMatrix> Xy;
|
||||
if (data_ordinal == Context::kCpuId) {
|
||||
auto X_adapter = data::ArrayAdapter{StringView{X}};
|
||||
Xy.reset(DMatrix::Create(&X_adapter, std::numeric_limits<float>::quiet_NaN(), ctx->Threads()));
|
||||
} else {
|
||||
auto X_adapter = data::CupyAdapter{StringView{X}};
|
||||
Xy.reset(DMatrix::Create(&X_adapter, std::numeric_limits<float>::quiet_NaN(), ctx->Threads()));
|
||||
}
|
||||
|
||||
Xy->SetInfo("label", y);
|
||||
|
||||
// learner is configured to the device specified by ctx
|
||||
std::unique_ptr<Learner> learner{Learner::Create({Xy})};
|
||||
ConfigLearnerByCtx(ctx, learner.get());
|
||||
for (std::int32_t i = 0; i < 3; ++i) {
|
||||
learner->UpdateOneIter(i, Xy);
|
||||
}
|
||||
|
||||
std::shared_ptr<DMatrix> p_m{new data::DMatrixProxy};
|
||||
auto proxy = std::dynamic_pointer_cast<data::DMatrixProxy>(p_m);
|
||||
if (data_ordinal == Context::kCpuId) {
|
||||
proxy->SetArrayData(StringView{X});
|
||||
} else {
|
||||
proxy->SetCUDAArray(X.c_str());
|
||||
}
|
||||
|
||||
HostDeviceVector<float>* out_predt{nullptr};
|
||||
ConsoleLogger::Configure(Args{{"verbosity", "1"}});
|
||||
// test whether the warning is raised
|
||||
::testing::internal::CaptureStderr();
|
||||
learner->InplacePredict(p_m, PredictionType::kValue, std::numeric_limits<float>::quiet_NaN(),
|
||||
&out_predt, 0, 0);
|
||||
auto output = testing::internal::GetCapturedStderr();
|
||||
std::cout << "output:" << output << std::endl;
|
||||
ASSERT_NE(output.find("Falling back"), std::string::npos);
|
||||
|
||||
// test when the contexts match
|
||||
Context new_ctx = *proxy->Ctx();
|
||||
ASSERT_NE(new_ctx.gpu_id, ctx->gpu_id);
|
||||
|
||||
ConfigLearnerByCtx(&new_ctx, learner.get());
|
||||
HostDeviceVector<float>* out_predt_1{nullptr};
|
||||
// no warning is raised
|
||||
::testing::internal::CaptureStderr();
|
||||
learner->InplacePredict(p_m, PredictionType::kValue, std::numeric_limits<float>::quiet_NaN(),
|
||||
&out_predt_1, 0, 0);
|
||||
output = testing::internal::GetCapturedStderr();
|
||||
|
||||
ASSERT_TRUE(output.empty());
|
||||
|
||||
ASSERT_EQ(out_predt->ConstHostVector(), out_predt_1->ConstHostVector());
|
||||
}
|
||||
|
||||
TEST(GBTree, InplacePredictFallback) {
|
||||
auto ctx = MakeCUDACtx(0);
|
||||
TestInplaceFallback(&ctx);
|
||||
}
|
||||
} // namespace xgboost
|
||||
@ -395,6 +395,9 @@ std::shared_ptr<DMatrix> RandomDataGenerator::GenerateDMatrix(bool with_label, b
|
||||
for (auto const& page : out->GetBatches<SparsePage>()) {
|
||||
page.data.SetDevice(device_);
|
||||
page.offset.SetDevice(device_);
|
||||
// pull to device
|
||||
page.data.ConstDeviceSpan();
|
||||
page.offset.ConstDeviceSpan();
|
||||
}
|
||||
}
|
||||
if (!ft_.empty()) {
|
||||
|
||||
@ -183,7 +183,7 @@ class SimpleRealUniformDistribution {
|
||||
|
||||
for (size_t k = m; k != 0; --k) {
|
||||
sum_value += static_cast<ResultT>((*rng)() - rng->Min()) * r_k;
|
||||
r_k *= r;
|
||||
r_k *= static_cast<ResultT>(r);
|
||||
}
|
||||
|
||||
ResultT res = sum_value / r_k;
|
||||
@ -322,15 +322,14 @@ inline std::shared_ptr<DMatrix> EmptyDMatrix() {
|
||||
return RandomDataGenerator{0, 0, 0.0}.GenerateDMatrix();
|
||||
}
|
||||
|
||||
inline std::vector<float>
|
||||
GenerateRandomCategoricalSingleColumn(int n, size_t num_categories) {
|
||||
inline std::vector<float> GenerateRandomCategoricalSingleColumn(int n, size_t num_categories) {
|
||||
std::vector<float> x(n);
|
||||
std::mt19937 rng(0);
|
||||
std::uniform_int_distribution<size_t> dist(0, num_categories - 1);
|
||||
std::generate(x.begin(), x.end(), [&]() { return dist(rng); });
|
||||
// Make sure each category is present
|
||||
for(size_t i = 0; i < num_categories; i++) {
|
||||
x[i] = i;
|
||||
for (size_t i = 0; i < num_categories; i++) {
|
||||
x[i] = static_cast<decltype(x)::value_type>(i);
|
||||
}
|
||||
return x;
|
||||
}
|
||||
@ -549,4 +548,15 @@ class DeclareUnifiedDistributedTest(MetricTest) : public ::testing::Test {
|
||||
}
|
||||
};
|
||||
|
||||
// A temporary solution before we move away from gpu_id.
|
||||
inline void ConfigLearnerByCtx(Context const* ctx, Learner* learner) {
|
||||
if (ctx->IsCPU()) {
|
||||
learner->SetParam("tree_method", "hist");
|
||||
} else {
|
||||
learner->SetParam("tree_method", "gpu_hist");
|
||||
}
|
||||
learner->SetParam("gpu_id", std::to_string(ctx->gpu_id));
|
||||
learner->Configure();
|
||||
ASSERT_EQ(learner->Ctx()->gpu_id, ctx->gpu_id);
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
@ -122,11 +122,13 @@ TEST(CpuPredictor, BasicColumnSplit) {
|
||||
}
|
||||
|
||||
TEST(CpuPredictor, IterationRange) {
|
||||
TestIterationRange("cpu_predictor");
|
||||
Context ctx;
|
||||
TestIterationRange(&ctx);
|
||||
}
|
||||
|
||||
TEST(CpuPredictor, IterationRangeColmnSplit) {
|
||||
TestIterationRangeColumnSplit("cpu_predictor");
|
||||
Context ctx;
|
||||
TestIterationRangeColumnSplit(&ctx);
|
||||
}
|
||||
|
||||
TEST(CpuPredictor, ExternalMemory) {
|
||||
@ -139,7 +141,8 @@ TEST(CpuPredictor, ExternalMemory) {
|
||||
TEST(CpuPredictor, InplacePredict) {
|
||||
bst_row_t constexpr kRows{128};
|
||||
bst_feature_t constexpr kCols{64};
|
||||
auto gen = RandomDataGenerator{kRows, kCols, 0.5}.Device(-1);
|
||||
Context ctx;
|
||||
auto gen = RandomDataGenerator{kRows, kCols, 0.5}.Device(ctx.gpu_id);
|
||||
{
|
||||
HostDeviceVector<float> data;
|
||||
gen.GenerateDense(&data);
|
||||
@ -149,7 +152,7 @@ TEST(CpuPredictor, InplacePredict) {
|
||||
std::string arr_str;
|
||||
Json::Dump(array_interface, &arr_str);
|
||||
x->SetArrayData(arr_str.data());
|
||||
TestInplacePrediction(x, "cpu_predictor", kRows, kCols, Context::kCpuId);
|
||||
TestInplacePrediction(&ctx, x, kRows, kCols);
|
||||
}
|
||||
|
||||
{
|
||||
@ -166,50 +169,50 @@ TEST(CpuPredictor, InplacePredict) {
|
||||
Json::Dump(col_interface, &col_str);
|
||||
std::shared_ptr<data::DMatrixProxy> x{new data::DMatrixProxy};
|
||||
x->SetCSRData(rptr_str.data(), col_str.data(), data_str.data(), kCols, true);
|
||||
TestInplacePrediction(x, "cpu_predictor", kRows, kCols, Context::kCpuId);
|
||||
TestInplacePrediction(&ctx, x, kRows, kCols);
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
void TestUpdatePredictionCache(bool use_subsampling) {
|
||||
size_t constexpr kRows = 64, kCols = 16, kClasses = 4;
|
||||
std::size_t constexpr kRows = 64, kCols = 16, kClasses = 4;
|
||||
LearnerModelParam mparam{MakeMP(kCols, .0, kClasses)};
|
||||
Context ctx;
|
||||
|
||||
std::unique_ptr<gbm::GBTree> gbm;
|
||||
gbm.reset(static_cast<gbm::GBTree*>(GradientBooster::Create("gbtree", &ctx, &mparam)));
|
||||
std::map<std::string, std::string> cfg;
|
||||
cfg["tree_method"] = "hist";
|
||||
cfg["predictor"] = "cpu_predictor";
|
||||
Args args{{"tree_method", "hist"}};
|
||||
if (use_subsampling) {
|
||||
cfg["subsample"] = "0.5";
|
||||
args.emplace_back("subsample", "0.5");
|
||||
}
|
||||
Args args = {cfg.cbegin(), cfg.cend()};
|
||||
gbm->Configure(args);
|
||||
|
||||
auto dmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix(true, true, kClasses);
|
||||
|
||||
HostDeviceVector<GradientPair> gpair;
|
||||
auto& h_gpair = gpair.HostVector();
|
||||
h_gpair.resize(kRows*kClasses);
|
||||
for (size_t i = 0; i < kRows*kClasses; ++i) {
|
||||
h_gpair.resize(kRows * kClasses);
|
||||
for (size_t i = 0; i < kRows * kClasses; ++i) {
|
||||
h_gpair[i] = {static_cast<float>(i), 1};
|
||||
}
|
||||
|
||||
PredictionCacheEntry predtion_cache;
|
||||
predtion_cache.predictions.Resize(kRows*kClasses, 0);
|
||||
// after one training iteration predtion_cache is filled with cached in QuantileHistMaker::Builder prediction values
|
||||
predtion_cache.predictions.Resize(kRows * kClasses, 0);
|
||||
// after one training iteration predtion_cache is filled with cached in QuantileHistMaker
|
||||
// prediction values
|
||||
gbm->DoBoost(dmat.get(), &gpair, &predtion_cache, nullptr);
|
||||
|
||||
PredictionCacheEntry out_predictions;
|
||||
// perform fair prediction on the same input data, should be equal to cached result
|
||||
// perform prediction from scratch on the same input data, should be equal to cached result
|
||||
gbm->PredictBatch(dmat.get(), &out_predictions, false, 0, 0);
|
||||
|
||||
std::vector<float> &out_predictions_h = out_predictions.predictions.HostVector();
|
||||
std::vector<float> &predtion_cache_from_train = predtion_cache.predictions.HostVector();
|
||||
std::vector<float>& out_predictions_h = out_predictions.predictions.HostVector();
|
||||
std::vector<float>& predtion_cache_from_train = predtion_cache.predictions.HostVector();
|
||||
for (size_t i = 0; i < out_predictions_h.size(); ++i) {
|
||||
ASSERT_NEAR(out_predictions_h[i], predtion_cache_from_train[i], kRtEps);
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
TEST(CPUPredictor, GHistIndex) {
|
||||
size_t constexpr kRows{128}, kCols{16}, kBins{64};
|
||||
@ -223,19 +226,23 @@ TEST(CPUPredictor, GHistIndex) {
|
||||
}
|
||||
|
||||
TEST(CPUPredictor, CategoricalPrediction) {
|
||||
TestCategoricalPrediction("cpu_predictor");
|
||||
Context ctx;
|
||||
TestCategoricalPrediction(&ctx, false);
|
||||
}
|
||||
|
||||
TEST(CPUPredictor, CategoricalPredictionColumnSplit) {
|
||||
TestCategoricalPredictionColumnSplit("cpu_predictor");
|
||||
Context ctx;
|
||||
TestCategoricalPredictionColumnSplit(&ctx);
|
||||
}
|
||||
|
||||
TEST(CPUPredictor, CategoricalPredictLeaf) {
|
||||
TestCategoricalPredictLeaf(StringView{"cpu_predictor"});
|
||||
Context ctx;
|
||||
TestCategoricalPredictLeaf(&ctx, false);
|
||||
}
|
||||
|
||||
TEST(CPUPredictor, CategoricalPredictLeafColumnSplit) {
|
||||
TestCategoricalPredictLeafColumnSplit(StringView{"cpu_predictor"});
|
||||
Context ctx;
|
||||
TestCategoricalPredictLeafColumnSplit(&ctx);
|
||||
}
|
||||
|
||||
TEST(CpuPredictor, UpdatePredictionCache) {
|
||||
@ -244,21 +251,25 @@ TEST(CpuPredictor, UpdatePredictionCache) {
|
||||
}
|
||||
|
||||
TEST(CpuPredictor, LesserFeatures) {
|
||||
TestPredictionWithLesserFeatures("cpu_predictor");
|
||||
Context ctx;
|
||||
TestPredictionWithLesserFeatures(&ctx);
|
||||
}
|
||||
|
||||
TEST(CpuPredictor, LesserFeaturesColumnSplit) {
|
||||
TestPredictionWithLesserFeaturesColumnSplit("cpu_predictor");
|
||||
Context ctx;
|
||||
TestPredictionWithLesserFeaturesColumnSplit(&ctx);
|
||||
}
|
||||
|
||||
TEST(CpuPredictor, Sparse) {
|
||||
TestSparsePrediction(0.2, "cpu_predictor");
|
||||
TestSparsePrediction(0.8, "cpu_predictor");
|
||||
Context ctx;
|
||||
TestSparsePrediction(&ctx, 0.2);
|
||||
TestSparsePrediction(&ctx, 0.8);
|
||||
}
|
||||
|
||||
TEST(CpuPredictor, SparseColumnSplit) {
|
||||
TestSparsePredictionColumnSplit(0.2, "cpu_predictor");
|
||||
TestSparsePredictionColumnSplit(0.8, "cpu_predictor");
|
||||
Context ctx;
|
||||
TestSparsePredictionColumnSplit(&ctx, 0.2);
|
||||
TestSparsePredictionColumnSplit(&ctx, 0.8);
|
||||
}
|
||||
|
||||
TEST(CpuPredictor, Multi) {
|
||||
@ -266,4 +277,6 @@ TEST(CpuPredictor, Multi) {
|
||||
ctx.nthread = 1;
|
||||
TestVectorLeafPrediction(&ctx);
|
||||
}
|
||||
|
||||
TEST(CpuPredictor, Access) { TestPredictionDeviceAccess(); }
|
||||
} // namespace xgboost
|
||||
|
||||
@ -15,8 +15,7 @@
|
||||
#include "../helpers.h"
|
||||
#include "test_predictor.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace predictor {
|
||||
namespace xgboost::predictor {
|
||||
|
||||
TEST(GPUPredictor, Basic) {
|
||||
auto cpu_lparam = MakeCUDACtx(-1);
|
||||
@ -120,13 +119,14 @@ TEST(GPUPredictor, MGPUBasicColumnSplit) {
|
||||
}
|
||||
|
||||
TEST(GPUPredictor, EllpackBasic) {
|
||||
size_t constexpr kCols {8};
|
||||
size_t constexpr kCols{8};
|
||||
auto ctx = MakeCUDACtx(0);
|
||||
for (size_t bins = 2; bins < 258; bins += 16) {
|
||||
size_t rows = bins * 16;
|
||||
auto p_m = RandomDataGenerator{rows, kCols, 0.0}.Bins(bins).Device(0).GenerateDeviceDMatrix();
|
||||
ASSERT_FALSE(p_m->PageExists<SparsePage>());
|
||||
TestPredictionFromGradientIndex<EllpackPage>("gpu_predictor", rows, kCols, p_m);
|
||||
TestPredictionFromGradientIndex<EllpackPage>("gpu_predictor", bins, kCols, p_m);
|
||||
TestPredictionFromGradientIndex<EllpackPage>(&ctx, rows, kCols, p_m);
|
||||
TestPredictionFromGradientIndex<EllpackPage>(&ctx, bins, kCols, p_m);
|
||||
}
|
||||
}
|
||||
|
||||
@ -181,29 +181,32 @@ TEST(GPUPredictor, ExternalMemoryTest) {
|
||||
}
|
||||
|
||||
TEST(GPUPredictor, InplacePredictCupy) {
|
||||
auto ctx = MakeCUDACtx(0);
|
||||
size_t constexpr kRows{128}, kCols{64};
|
||||
RandomDataGenerator gen(kRows, kCols, 0.5);
|
||||
gen.Device(0);
|
||||
gen.Device(ctx.gpu_id);
|
||||
HostDeviceVector<float> data;
|
||||
std::string interface_str = gen.GenerateArrayInterface(&data);
|
||||
std::shared_ptr<DMatrix> p_fmat{new data::DMatrixProxy};
|
||||
dynamic_cast<data::DMatrixProxy*>(p_fmat.get())->SetCUDAArray(interface_str.c_str());
|
||||
TestInplacePrediction(p_fmat, "gpu_predictor", kRows, kCols, 0);
|
||||
TestInplacePrediction(&ctx, p_fmat, kRows, kCols);
|
||||
}
|
||||
|
||||
TEST(GPUPredictor, InplacePredictCuDF) {
|
||||
auto ctx = MakeCUDACtx(0);
|
||||
size_t constexpr kRows{128}, kCols{64};
|
||||
RandomDataGenerator gen(kRows, kCols, 0.5);
|
||||
gen.Device(0);
|
||||
gen.Device(ctx.gpu_id);
|
||||
std::vector<HostDeviceVector<float>> storage(kCols);
|
||||
auto interface_str = gen.GenerateColumnarArrayInterface(&storage);
|
||||
std::shared_ptr<DMatrix> p_fmat{new data::DMatrixProxy};
|
||||
dynamic_cast<data::DMatrixProxy*>(p_fmat.get())->SetCUDAArray(interface_str.c_str());
|
||||
TestInplacePrediction(p_fmat, "gpu_predictor", kRows, kCols, 0);
|
||||
TestInplacePrediction(&ctx, p_fmat, kRows, kCols);
|
||||
}
|
||||
|
||||
TEST(GpuPredictor, LesserFeatures) {
|
||||
TestPredictionWithLesserFeatures("gpu_predictor");
|
||||
auto ctx = MakeCUDACtx(0);
|
||||
TestPredictionWithLesserFeatures(&ctx);
|
||||
}
|
||||
|
||||
// Very basic test of empty model
|
||||
@ -268,15 +271,18 @@ TEST(GPUPredictor, Shap) {
|
||||
}
|
||||
|
||||
TEST(GPUPredictor, IterationRange) {
|
||||
TestIterationRange("gpu_predictor");
|
||||
auto ctx = MakeCUDACtx(0);
|
||||
TestIterationRange(&ctx);
|
||||
}
|
||||
|
||||
TEST(GPUPredictor, CategoricalPrediction) {
|
||||
TestCategoricalPrediction("gpu_predictor");
|
||||
auto ctx = MakeCUDACtx(0);
|
||||
TestCategoricalPrediction(&ctx, false);
|
||||
}
|
||||
|
||||
TEST(GPUPredictor, CategoricalPredictLeaf) {
|
||||
TestCategoricalPredictLeaf(StringView{"gpu_predictor"});
|
||||
auto ctx = MakeCUDACtx(0);
|
||||
TestCategoricalPredictLeaf(&ctx, false);
|
||||
}
|
||||
|
||||
TEST(GPUPredictor, PredictLeafBasic) {
|
||||
@ -300,8 +306,8 @@ TEST(GPUPredictor, PredictLeafBasic) {
|
||||
}
|
||||
|
||||
TEST(GPUPredictor, Sparse) {
|
||||
TestSparsePrediction(0.2, "gpu_predictor");
|
||||
TestSparsePrediction(0.8, "gpu_predictor");
|
||||
auto ctx = MakeCUDACtx(0);
|
||||
TestSparsePrediction(&ctx, 0.2);
|
||||
TestSparsePrediction(&ctx, 0.8);
|
||||
}
|
||||
} // namespace predictor
|
||||
} // namespace xgboost
|
||||
} // namespace xgboost::predictor
|
||||
|
||||
@ -8,9 +8,11 @@
|
||||
#include <xgboost/data.h> // for DMatrix, BatchIterator, BatchSet, MetaInfo
|
||||
#include <xgboost/host_device_vector.h> // for HostDeviceVector
|
||||
#include <xgboost/predictor.h> // for PredictionCacheEntry, Predictor, Predic...
|
||||
#include <xgboost/string_view.h> // for StringView
|
||||
|
||||
#include <algorithm> // for max
|
||||
#include <limits> // for numeric_limits
|
||||
#include <memory> // for shared_ptr
|
||||
#include <unordered_map> // for unordered_map
|
||||
|
||||
#include "../../../src/common/bitfield.h" // for LBitField32
|
||||
@ -51,7 +53,7 @@ void TestTrainingPrediction(size_t rows, size_t bins,
|
||||
size_t constexpr kIters = 3;
|
||||
|
||||
std::unique_ptr<Learner> learner;
|
||||
auto train = [&](std::string predictor) {
|
||||
auto train = [&](Context const& ctx) {
|
||||
p_hist->Info().labels.Reshape(rows, 1);
|
||||
auto &h_label = p_hist->Info().labels.Data()->HostVector();
|
||||
|
||||
@ -65,7 +67,7 @@ void TestTrainingPrediction(size_t rows, size_t bins,
|
||||
learner->SetParam("num_feature", std::to_string(kCols));
|
||||
learner->SetParam("num_class", std::to_string(kClasses));
|
||||
learner->SetParam("max_bin", std::to_string(bins));
|
||||
learner->SetParam("predictor", predictor);
|
||||
ConfigLearnerByCtx(&ctx, learner.get());
|
||||
learner->Configure();
|
||||
|
||||
for (size_t i = 0; i < kIters; ++i) {
|
||||
@ -77,7 +79,7 @@ void TestTrainingPrediction(size_t rows, size_t bins,
|
||||
|
||||
learner.reset(Learner::Create({}));
|
||||
learner->LoadModel(model);
|
||||
learner->SetParam("predictor", predictor);
|
||||
ConfigLearnerByCtx(&ctx, learner.get());
|
||||
learner->Configure();
|
||||
|
||||
HostDeviceVector<float> from_full;
|
||||
@ -93,16 +95,16 @@ void TestTrainingPrediction(size_t rows, size_t bins,
|
||||
};
|
||||
|
||||
if (tree_method == "gpu_hist") {
|
||||
train("gpu_predictor");
|
||||
train(MakeCUDACtx(0));
|
||||
} else {
|
||||
train("cpu_predictor");
|
||||
train(Context{});
|
||||
}
|
||||
}
|
||||
|
||||
void TestInplacePrediction(std::shared_ptr<DMatrix> x, std::string predictor, bst_row_t rows,
|
||||
bst_feature_t cols, int32_t device) {
|
||||
size_t constexpr kClasses { 4 };
|
||||
auto gen = RandomDataGenerator{rows, cols, 0.5}.Device(device);
|
||||
void TestInplacePrediction(Context const *ctx, std::shared_ptr<DMatrix> x, bst_row_t rows,
|
||||
bst_feature_t cols) {
|
||||
std::size_t constexpr kClasses { 4 };
|
||||
auto gen = RandomDataGenerator{rows, cols, 0.5}.Device(ctx->gpu_id);
|
||||
std::shared_ptr<DMatrix> m = gen.GenerateDMatrix(true, false, kClasses);
|
||||
|
||||
std::unique_ptr<Learner> learner {
|
||||
@ -113,12 +115,14 @@ void TestInplacePrediction(std::shared_ptr<DMatrix> x, std::string predictor, bs
|
||||
learner->SetParam("num_class", std::to_string(kClasses));
|
||||
learner->SetParam("seed", "0");
|
||||
learner->SetParam("subsample", "0.5");
|
||||
learner->SetParam("gpu_id", std::to_string(device));
|
||||
learner->SetParam("predictor", predictor);
|
||||
learner->SetParam("tree_method", "hist");
|
||||
for (int32_t it = 0; it < 4; ++it) {
|
||||
learner->UpdateOneIter(it, m);
|
||||
}
|
||||
|
||||
learner->SetParam("gpu_id", std::to_string(ctx->gpu_id));
|
||||
learner->Configure();
|
||||
|
||||
HostDeviceVector<float> *p_out_predictions_0{nullptr};
|
||||
learner->InplacePredict(x, PredictionType::kMargin, std::numeric_limits<float>::quiet_NaN(),
|
||||
&p_out_predictions_0, 0, 2);
|
||||
@ -154,40 +158,79 @@ void TestInplacePrediction(std::shared_ptr<DMatrix> x, std::string predictor, bs
|
||||
}
|
||||
|
||||
namespace {
|
||||
std::unique_ptr<Learner> LearnerForTest(std::shared_ptr<DMatrix> dmat, size_t iters,
|
||||
size_t forest = 1) {
|
||||
std::unique_ptr<Learner> LearnerForTest(Context const *ctx, std::shared_ptr<DMatrix> dmat,
|
||||
size_t iters, size_t forest = 1) {
|
||||
std::unique_ptr<Learner> learner{Learner::Create({dmat})};
|
||||
learner->SetParams(Args{{"num_parallel_tree", std::to_string(forest)}});
|
||||
for (size_t i = 0; i < iters; ++i) {
|
||||
learner->UpdateOneIter(i, dmat);
|
||||
}
|
||||
|
||||
ConfigLearnerByCtx(ctx, learner.get());
|
||||
return learner;
|
||||
}
|
||||
|
||||
void VerifyPredictionWithLesserFeatures(Learner *learner, std::string const &predictor_name,
|
||||
size_t rows, std::shared_ptr<DMatrix> const &m_test,
|
||||
std::shared_ptr<DMatrix> const &m_invalid) {
|
||||
void VerifyPredictionWithLesserFeatures(Learner *learner, bst_row_t kRows,
|
||||
std::shared_ptr<DMatrix> m_test,
|
||||
std::shared_ptr<DMatrix> m_invalid) {
|
||||
HostDeviceVector<float> prediction;
|
||||
learner->SetParam("predictor", predictor_name);
|
||||
learner->Configure();
|
||||
Json config{Object()};
|
||||
learner->SaveConfig(&config);
|
||||
ASSERT_EQ(get<String>(config["learner"]["gradient_booster"]["gbtree_train_param"]["predictor"]),
|
||||
predictor_name);
|
||||
|
||||
learner->Predict(m_test, false, &prediction, 0, 0);
|
||||
ASSERT_EQ(prediction.Size(), rows);
|
||||
ASSERT_EQ(prediction.Size(), kRows);
|
||||
|
||||
ASSERT_THROW({ learner->Predict(m_invalid, false, &prediction, 0, 0); }, dmlc::Error);
|
||||
}
|
||||
|
||||
void VerifyPredictionWithLesserFeaturesColumnSplit(Learner *learner, size_t rows,
|
||||
std::shared_ptr<DMatrix> m_test,
|
||||
std::shared_ptr<DMatrix> m_invalid) {
|
||||
auto const world_size = collective::GetWorldSize();
|
||||
auto const rank = collective::GetRank();
|
||||
std::shared_ptr<DMatrix> sliced_test{m_test->SliceCol(world_size, rank)};
|
||||
std::shared_ptr<DMatrix> sliced_invalid{m_invalid->SliceCol(world_size, rank)};
|
||||
|
||||
VerifyPredictionWithLesserFeatures(learner, rows, sliced_test, sliced_invalid);
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
void TestPredictionWithLesserFeatures(Context const *ctx) {
|
||||
size_t constexpr kRows = 256, kTrainCols = 256, kTestCols = 4, kIters = 4;
|
||||
auto m_train = RandomDataGenerator(kRows, kTrainCols, 0.5).GenerateDMatrix(true);
|
||||
auto learner = LearnerForTest(ctx, m_train, kIters);
|
||||
auto m_test = RandomDataGenerator(kRows, kTestCols, 0.5).GenerateDMatrix(false);
|
||||
auto m_invalid = RandomDataGenerator(kRows, kTrainCols + 1, 0.5).GenerateDMatrix(false);
|
||||
VerifyPredictionWithLesserFeatures(learner.get(), kRows, m_test, m_invalid);
|
||||
}
|
||||
|
||||
void TestPredictionDeviceAccess() {
|
||||
Context ctx;
|
||||
size_t constexpr kRows = 256, kTrainCols = 256, kTestCols = 4, kIters = 4;
|
||||
auto m_train = RandomDataGenerator(kRows, kTrainCols, 0.5).GenerateDMatrix(true);
|
||||
auto m_test = RandomDataGenerator(kRows, kTestCols, 0.5).GenerateDMatrix(false);
|
||||
auto learner = LearnerForTest(&ctx, m_train, kIters);
|
||||
|
||||
HostDeviceVector<float> from_cpu;
|
||||
{
|
||||
ASSERT_EQ(from_cpu.DeviceIdx(), Context::kCpuId);
|
||||
Context cpu_ctx;
|
||||
ConfigLearnerByCtx(&cpu_ctx, learner.get());
|
||||
learner->Predict(m_test, false, &from_cpu, 0, 0);
|
||||
ASSERT_TRUE(from_cpu.HostCanWrite());
|
||||
ASSERT_FALSE(from_cpu.DeviceCanRead());
|
||||
}
|
||||
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
HostDeviceVector<float> from_cpu;
|
||||
learner->SetParam("predictor", "cpu_predictor");
|
||||
learner->Predict(m_test, false, &from_cpu, 0, 0);
|
||||
|
||||
HostDeviceVector<float> from_cuda;
|
||||
learner->SetParam("predictor", "gpu_predictor");
|
||||
learner->Predict(m_test, false, &from_cuda, 0, 0);
|
||||
{
|
||||
Context cuda_ctx = MakeCUDACtx(0);
|
||||
ConfigLearnerByCtx(&cuda_ctx, learner.get());
|
||||
learner->Predict(m_test, false, &from_cuda, 0, 0);
|
||||
ASSERT_EQ(from_cuda.DeviceIdx(), 0);
|
||||
ASSERT_TRUE(from_cuda.DeviceCanWrite());
|
||||
ASSERT_FALSE(from_cuda.HostCanRead());
|
||||
}
|
||||
|
||||
auto const &h_cpu = from_cpu.ConstHostVector();
|
||||
auto const &h_gpu = from_cuda.ConstHostVector();
|
||||
@ -196,41 +239,17 @@ void VerifyPredictionWithLesserFeatures(Learner *learner, std::string const &pre
|
||||
}
|
||||
#endif // defined(XGBOOST_USE_CUDA)
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
void TestPredictionWithLesserFeatures(std::string predictor_name) {
|
||||
void TestPredictionWithLesserFeaturesColumnSplit(Context const *ctx) {
|
||||
size_t constexpr kRows = 256, kTrainCols = 256, kTestCols = 4, kIters = 4;
|
||||
auto m_train = RandomDataGenerator(kRows, kTrainCols, 0.5).GenerateDMatrix(true);
|
||||
auto learner = LearnerForTest(m_train, kIters);
|
||||
auto m_test = RandomDataGenerator(kRows, kTestCols, 0.5).GenerateDMatrix(false);
|
||||
auto m_invalid = RandomDataGenerator(kRows, kTrainCols + 1, 0.5).GenerateDMatrix(false);
|
||||
VerifyPredictionWithLesserFeatures(learner.get(), predictor_name, kRows, m_test, m_invalid);
|
||||
}
|
||||
|
||||
namespace {
|
||||
void VerifyPredictionWithLesserFeaturesColumnSplit(Learner *learner,
|
||||
std::string const &predictor_name, size_t rows,
|
||||
std::shared_ptr<DMatrix> m_test,
|
||||
std::shared_ptr<DMatrix> m_invalid) {
|
||||
auto const world_size = collective::GetWorldSize();
|
||||
auto const rank = collective::GetRank();
|
||||
std::shared_ptr<DMatrix> sliced_test{m_test->SliceCol(world_size, rank)};
|
||||
std::shared_ptr<DMatrix> sliced_invalid{m_invalid->SliceCol(world_size, rank)};
|
||||
|
||||
VerifyPredictionWithLesserFeatures(learner, predictor_name, rows, sliced_test, sliced_invalid);
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
void TestPredictionWithLesserFeaturesColumnSplit(std::string predictor_name) {
|
||||
size_t constexpr kRows = 256, kTrainCols = 256, kTestCols = 4, kIters = 4;
|
||||
auto m_train = RandomDataGenerator(kRows, kTrainCols, 0.5).GenerateDMatrix(true);
|
||||
auto learner = LearnerForTest(m_train, kIters);
|
||||
auto learner = LearnerForTest(ctx, m_train, kIters);
|
||||
auto m_test = RandomDataGenerator(kRows, kTestCols, 0.5).GenerateDMatrix(false);
|
||||
auto m_invalid = RandomDataGenerator(kRows, kTrainCols + 1, 0.5).GenerateDMatrix(false);
|
||||
|
||||
auto constexpr kWorldSize = 2;
|
||||
RunWithInMemoryCommunicator(kWorldSize, VerifyPredictionWithLesserFeaturesColumnSplit,
|
||||
learner.get(), predictor_name, kRows, m_test, m_invalid);
|
||||
learner.get(), kRows, m_test, m_invalid);
|
||||
}
|
||||
|
||||
void GBTreeModelForTest(gbm::GBTreeModel *model, uint32_t split_ind,
|
||||
@ -252,7 +271,7 @@ void GBTreeModelForTest(gbm::GBTreeModel *model, uint32_t split_ind,
|
||||
model->CommitModelGroup(std::move(trees), 0);
|
||||
}
|
||||
|
||||
void TestCategoricalPrediction(std::string name, bool is_column_split) {
|
||||
void TestCategoricalPrediction(Context const* ctx, bool is_column_split) {
|
||||
size_t constexpr kCols = 10;
|
||||
PredictionCacheEntry out_predictions;
|
||||
|
||||
@ -262,13 +281,10 @@ void TestCategoricalPrediction(std::string name, bool is_column_split) {
|
||||
float left_weight = 1.3f;
|
||||
float right_weight = 1.7f;
|
||||
|
||||
Context ctx;
|
||||
ctx.UpdateAllowUnknown(Args{});
|
||||
gbm::GBTreeModel model(&mparam, &ctx);
|
||||
gbm::GBTreeModel model(&mparam, ctx);
|
||||
GBTreeModelForTest(&model, split_ind, split_cat, left_weight, right_weight);
|
||||
|
||||
ctx.UpdateAllowUnknown(Args{{"gpu_id", "0"}});
|
||||
std::unique_ptr<Predictor> predictor{Predictor::Create(name.c_str(), &ctx)};
|
||||
std::unique_ptr<Predictor> predictor{CreatePredictorForTest(ctx)};
|
||||
|
||||
std::vector<float> row(kCols);
|
||||
row[split_ind] = split_cat;
|
||||
@ -298,12 +314,12 @@ void TestCategoricalPrediction(std::string name, bool is_column_split) {
|
||||
ASSERT_EQ(out_predictions.predictions.HostVector()[0], left_weight + score);
|
||||
}
|
||||
|
||||
void TestCategoricalPredictionColumnSplit(std::string name) {
|
||||
void TestCategoricalPredictionColumnSplit(Context const *ctx) {
|
||||
auto constexpr kWorldSize = 2;
|
||||
RunWithInMemoryCommunicator(kWorldSize, TestCategoricalPrediction, name, true);
|
||||
RunWithInMemoryCommunicator(kWorldSize, TestCategoricalPrediction, ctx, true);
|
||||
}
|
||||
|
||||
void TestCategoricalPredictLeaf(StringView name, bool is_column_split) {
|
||||
void TestCategoricalPredictLeaf(Context const *ctx, bool is_column_split) {
|
||||
size_t constexpr kCols = 10;
|
||||
PredictionCacheEntry out_predictions;
|
||||
|
||||
@ -314,14 +330,10 @@ void TestCategoricalPredictLeaf(StringView name, bool is_column_split) {
|
||||
float left_weight = 1.3f;
|
||||
float right_weight = 1.7f;
|
||||
|
||||
Context ctx;
|
||||
ctx.UpdateAllowUnknown(Args{});
|
||||
|
||||
gbm::GBTreeModel model(&mparam, &ctx);
|
||||
gbm::GBTreeModel model(&mparam, ctx);
|
||||
GBTreeModelForTest(&model, split_ind, split_cat, left_weight, right_weight);
|
||||
|
||||
ctx.gpu_id = 0;
|
||||
std::unique_ptr<Predictor> predictor{Predictor::Create(name.c_str(), &ctx)};
|
||||
std::unique_ptr<Predictor> predictor{CreatePredictorForTest(ctx)};
|
||||
|
||||
std::vector<float> row(kCols);
|
||||
row[split_ind] = split_cat;
|
||||
@ -346,19 +358,21 @@ void TestCategoricalPredictLeaf(StringView name, bool is_column_split) {
|
||||
ASSERT_EQ(out_predictions.predictions.HostVector()[0], 1);
|
||||
}
|
||||
|
||||
void TestCategoricalPredictLeafColumnSplit(StringView name) {
|
||||
void TestCategoricalPredictLeafColumnSplit(Context const *ctx) {
|
||||
auto constexpr kWorldSize = 2;
|
||||
RunWithInMemoryCommunicator(kWorldSize, TestCategoricalPredictLeaf, name, true);
|
||||
RunWithInMemoryCommunicator(kWorldSize, TestCategoricalPredictLeaf, ctx, true);
|
||||
}
|
||||
|
||||
void TestIterationRange(std::string name) {
|
||||
void TestIterationRange(Context const* ctx) {
|
||||
size_t constexpr kRows = 1000, kCols = 20, kClasses = 4, kForest = 3, kIters = 10;
|
||||
auto dmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix(true, true, kClasses);
|
||||
auto learner = LearnerForTest(dmat, kIters, kForest);
|
||||
learner->SetParams(Args{{"predictor", name}});
|
||||
auto dmat = RandomDataGenerator(kRows, kCols, 0)
|
||||
.Device(ctx->gpu_id)
|
||||
.GenerateDMatrix(true, true, kClasses);
|
||||
auto learner = LearnerForTest(ctx, dmat, kIters, kForest);
|
||||
|
||||
bool bound = false;
|
||||
std::unique_ptr<Learner> sliced {learner->Slice(0, 3, 1, &bound)};
|
||||
bst_layer_t lend{3};
|
||||
std::unique_ptr<Learner> sliced{learner->Slice(0, lend, 1, &bound)};
|
||||
ASSERT_FALSE(bound);
|
||||
|
||||
HostDeviceVector<float> out_predt_sliced;
|
||||
@ -366,11 +380,8 @@ void TestIterationRange(std::string name) {
|
||||
|
||||
// margin
|
||||
{
|
||||
sliced->Predict(dmat, true, &out_predt_sliced, 0, 0, false, false, false,
|
||||
false, false);
|
||||
|
||||
learner->Predict(dmat, true, &out_predt_ranged, 0, 3, false, false, false,
|
||||
false, false);
|
||||
sliced->Predict(dmat, true, &out_predt_sliced, 0, 0, false, false, false, false, false);
|
||||
learner->Predict(dmat, true, &out_predt_ranged, 0, lend, false, false, false, false, false);
|
||||
|
||||
auto const &h_sliced = out_predt_sliced.HostVector();
|
||||
auto const &h_range = out_predt_ranged.HostVector();
|
||||
@ -380,11 +391,8 @@ void TestIterationRange(std::string name) {
|
||||
|
||||
// SHAP
|
||||
{
|
||||
sliced->Predict(dmat, false, &out_predt_sliced, 0, 0, false, false,
|
||||
true, false, false);
|
||||
|
||||
learner->Predict(dmat, false, &out_predt_ranged, 0, 3, false, false, true,
|
||||
false, false);
|
||||
sliced->Predict(dmat, false, &out_predt_sliced, 0, 0, false, false, true, false, false);
|
||||
learner->Predict(dmat, false, &out_predt_ranged, 0, lend, false, false, true, false, false);
|
||||
|
||||
auto const &h_sliced = out_predt_sliced.HostVector();
|
||||
auto const &h_range = out_predt_ranged.HostVector();
|
||||
@ -394,10 +402,8 @@ void TestIterationRange(std::string name) {
|
||||
|
||||
// SHAP interaction
|
||||
{
|
||||
sliced->Predict(dmat, false, &out_predt_sliced, 0, 0, false, false,
|
||||
false, false, true);
|
||||
learner->Predict(dmat, false, &out_predt_ranged, 0, 3, false, false, false,
|
||||
false, true);
|
||||
sliced->Predict(dmat, false, &out_predt_sliced, 0, 0, false, false, false, false, true);
|
||||
learner->Predict(dmat, false, &out_predt_ranged, 0, lend, false, false, false, false, true);
|
||||
auto const &h_sliced = out_predt_sliced.HostVector();
|
||||
auto const &h_range = out_predt_ranged.HostVector();
|
||||
ASSERT_EQ(h_sliced.size(), h_range.size());
|
||||
@ -406,10 +412,8 @@ void TestIterationRange(std::string name) {
|
||||
|
||||
// Leaf
|
||||
{
|
||||
sliced->Predict(dmat, false, &out_predt_sliced, 0, 0, false, true,
|
||||
false, false, false);
|
||||
learner->Predict(dmat, false, &out_predt_ranged, 0, 3, false, true, false,
|
||||
false, false);
|
||||
sliced->Predict(dmat, false, &out_predt_sliced, 0, 0, false, true, false, false, false);
|
||||
learner->Predict(dmat, false, &out_predt_ranged, 0, lend, false, true, false, false, false);
|
||||
auto const &h_sliced = out_predt_sliced.HostVector();
|
||||
auto const &h_range = out_predt_ranged.HostVector();
|
||||
ASSERT_EQ(h_sliced.size(), h_range.size());
|
||||
@ -456,11 +460,16 @@ void VerifyIterationRangeColumnSplit(DMatrix *dmat, Learner *learner, Learner *s
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
void TestIterationRangeColumnSplit(std::string name) {
|
||||
void TestIterationRangeColumnSplit(Context const* ctx) {
|
||||
size_t constexpr kRows = 1000, kCols = 20, kClasses = 4, kForest = 3, kIters = 10;
|
||||
auto dmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix(true, true, kClasses);
|
||||
auto learner = LearnerForTest(dmat, kIters, kForest);
|
||||
learner->SetParams(Args{{"predictor", name}});
|
||||
auto learner = LearnerForTest(ctx, dmat, kIters, kForest);
|
||||
|
||||
if (ctx->IsCPU()) {
|
||||
learner->SetParams(Args{{"gpu_id", std::to_string(-1)}});
|
||||
} else {
|
||||
learner->SetParams(Args{{"gpu_id", std::to_string(0)}});
|
||||
}
|
||||
|
||||
bool bound = false;
|
||||
std::unique_ptr<Learner> sliced{learner->Slice(0, 3, 1, &bound)};
|
||||
@ -488,10 +497,10 @@ void TestIterationRangeColumnSplit(std::string name) {
|
||||
leaf_ranged, leaf_sliced);
|
||||
}
|
||||
|
||||
void TestSparsePrediction(float sparsity, std::string predictor) {
|
||||
void TestSparsePrediction(Context const *ctx, float sparsity) {
|
||||
size_t constexpr kRows = 512, kCols = 128, kIters = 4;
|
||||
auto Xy = RandomDataGenerator(kRows, kCols, sparsity).GenerateDMatrix(true);
|
||||
auto learner = LearnerForTest(Xy, kIters);
|
||||
auto learner = LearnerForTest(ctx, Xy, kIters);
|
||||
|
||||
HostDeviceVector<float> sparse_predt;
|
||||
|
||||
@ -501,11 +510,14 @@ void TestSparsePrediction(float sparsity, std::string predictor) {
|
||||
learner.reset(Learner::Create({Xy}));
|
||||
learner->LoadModel(model);
|
||||
|
||||
learner->SetParam("predictor", predictor);
|
||||
if (ctx->IsCUDA()) {
|
||||
learner->SetParam("tree_method", "gpu_hist");
|
||||
learner->SetParam("gpu_id", std::to_string(ctx->gpu_id));
|
||||
}
|
||||
learner->Predict(Xy, false, &sparse_predt, 0, 0);
|
||||
|
||||
HostDeviceVector<float> with_nan(kRows * kCols, std::numeric_limits<float>::quiet_NaN());
|
||||
auto& h_with_nan = with_nan.HostVector();
|
||||
auto &h_with_nan = with_nan.HostVector();
|
||||
for (auto const &page : Xy->GetBatches<SparsePage>()) {
|
||||
auto batch = page.GetView();
|
||||
for (size_t i = 0; i < batch.Size(); ++i) {
|
||||
@ -516,7 +528,8 @@ void TestSparsePrediction(float sparsity, std::string predictor) {
|
||||
}
|
||||
}
|
||||
|
||||
learner->SetParam("predictor", "cpu_predictor");
|
||||
learner->SetParam("tree_method", "hist");
|
||||
learner->SetParam("gpu_id", "-1");
|
||||
// Xcode_12.4 doesn't compile with `std::make_shared`.
|
||||
auto dense = std::shared_ptr<DMatrix>(new data::DMatrixProxy{});
|
||||
auto array_interface = GetArrayInterface(&with_nan, kRows, kCols);
|
||||
@ -527,8 +540,8 @@ void TestSparsePrediction(float sparsity, std::string predictor) {
|
||||
learner->InplacePredict(dense, PredictionType::kValue, std::numeric_limits<float>::quiet_NaN(),
|
||||
&p_dense_predt, 0, 0);
|
||||
|
||||
auto const& dense_predt = *p_dense_predt;
|
||||
if (predictor == "cpu_predictor") {
|
||||
auto const &dense_predt = *p_dense_predt;
|
||||
if (ctx->IsCPU()) {
|
||||
ASSERT_EQ(dense_predt.HostVector(), sparse_predt.HostVector());
|
||||
} else {
|
||||
auto const &h_dense = dense_predt.HostVector();
|
||||
@ -556,10 +569,10 @@ void VerifySparsePredictionColumnSplit(DMatrix *dmat, Learner *learner,
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
void TestSparsePredictionColumnSplit(float sparsity, std::string predictor) {
|
||||
void TestSparsePredictionColumnSplit(Context const* ctx, float sparsity) {
|
||||
size_t constexpr kRows = 512, kCols = 128, kIters = 4;
|
||||
auto Xy = RandomDataGenerator(kRows, kCols, sparsity).GenerateDMatrix(true);
|
||||
auto learner = LearnerForTest(Xy, kIters);
|
||||
auto learner = LearnerForTest(ctx, Xy, kIters);
|
||||
|
||||
HostDeviceVector<float> sparse_predt;
|
||||
|
||||
@ -569,7 +582,7 @@ void TestSparsePredictionColumnSplit(float sparsity, std::string predictor) {
|
||||
learner.reset(Learner::Create({Xy}));
|
||||
learner->LoadModel(model);
|
||||
|
||||
learner->SetParam("predictor", predictor);
|
||||
ConfigLearnerByCtx(ctx, learner.get());
|
||||
learner->Predict(Xy, false, &sparse_predt, 0, 0);
|
||||
|
||||
auto constexpr kWorldSize = 2;
|
||||
|
||||
@ -31,8 +31,17 @@ inline gbm::GBTreeModel CreateTestModel(LearnerModelParam const* param, Context
|
||||
return model;
|
||||
}
|
||||
|
||||
inline auto CreatePredictorForTest(Context const* ctx) {
|
||||
if (ctx->IsCPU()) {
|
||||
return Predictor::Create("cpu_predictor", ctx);
|
||||
} else {
|
||||
return Predictor::Create("gpu_predictor", ctx);
|
||||
}
|
||||
}
|
||||
|
||||
// fixme: cpu test
|
||||
template <typename Page>
|
||||
void TestPredictionFromGradientIndex(std::string name, size_t rows, size_t cols,
|
||||
void TestPredictionFromGradientIndex(Context const* ctx, size_t rows, size_t cols,
|
||||
std::shared_ptr<DMatrix> p_hist) {
|
||||
constexpr size_t kClasses { 3 };
|
||||
|
||||
@ -40,12 +49,10 @@ void TestPredictionFromGradientIndex(std::string name, size_t rows, size_t cols,
|
||||
auto cuda_ctx = MakeCUDACtx(0);
|
||||
|
||||
std::unique_ptr<Predictor> predictor =
|
||||
std::unique_ptr<Predictor>(Predictor::Create(name, &cuda_ctx));
|
||||
std::unique_ptr<Predictor>(CreatePredictorForTest(&cuda_ctx));
|
||||
predictor->Configure({});
|
||||
|
||||
Context ctx;
|
||||
ctx.UpdateAllowUnknown(Args{});
|
||||
gbm::GBTreeModel model = CreateTestModel(&mparam, &ctx, kClasses);
|
||||
gbm::GBTreeModel model = CreateTestModel(&mparam, ctx, kClasses);
|
||||
|
||||
{
|
||||
auto p_precise = RandomDataGenerator(rows, cols, 0).GenerateDMatrix();
|
||||
@ -81,28 +88,30 @@ void TestTrainingPrediction(size_t rows, size_t bins, std::string tree_method,
|
||||
std::shared_ptr<DMatrix> p_full,
|
||||
std::shared_ptr<DMatrix> p_hist);
|
||||
|
||||
void TestInplacePrediction(std::shared_ptr<DMatrix> x, std::string predictor, bst_row_t rows,
|
||||
bst_feature_t cols, int32_t device = -1);
|
||||
void TestInplacePrediction(Context const* ctx, std::shared_ptr<DMatrix> x, bst_row_t rows,
|
||||
bst_feature_t cols);
|
||||
|
||||
void TestPredictionWithLesserFeatures(std::string preditor_name);
|
||||
void TestPredictionWithLesserFeatures(Context const* ctx);
|
||||
|
||||
void TestPredictionWithLesserFeaturesColumnSplit(std::string preditor_name);
|
||||
void TestPredictionDeviceAccess();
|
||||
|
||||
void TestCategoricalPrediction(std::string name, bool is_column_split = false);
|
||||
void TestCategoricalPrediction(Context const* ctx, bool is_column_split);
|
||||
|
||||
void TestCategoricalPredictionColumnSplit(std::string name);
|
||||
void TestCategoricalPredictionColumnSplit(Context const* ctx);
|
||||
|
||||
void TestCategoricalPredictLeaf(StringView name, bool is_column_split = false);
|
||||
void TestPredictionWithLesserFeaturesColumnSplit(Context const* ctx);
|
||||
|
||||
void TestCategoricalPredictLeafColumnSplit(StringView name);
|
||||
void TestCategoricalPredictLeaf(Context const* ctx, bool is_column_split);
|
||||
|
||||
void TestIterationRange(std::string name);
|
||||
void TestCategoricalPredictLeafColumnSplit(Context const* ctx);
|
||||
|
||||
void TestIterationRangeColumnSplit(std::string name);
|
||||
void TestIterationRange(Context const* ctx);
|
||||
|
||||
void TestSparsePrediction(float sparsity, std::string predictor);
|
||||
void TestIterationRangeColumnSplit(Context const* ctx);
|
||||
|
||||
void TestSparsePredictionColumnSplit(float sparsity, std::string predictor);
|
||||
void TestSparsePrediction(Context const* ctx, float sparsity);
|
||||
|
||||
void TestSparsePredictionColumnSplit(Context const* ctx, float sparsity);
|
||||
|
||||
void TestVectorLeafPrediction(Context const* ctx);
|
||||
} // namespace xgboost
|
||||
|
||||
@ -342,16 +342,6 @@ TEST(Learner, GPUConfiguration) {
|
||||
learner->UpdateOneIter(0, p_dmat);
|
||||
ASSERT_EQ(learner->Ctx()->gpu_id, 0);
|
||||
}
|
||||
{
|
||||
// With CPU algorithm but GPU Predictor, this is to simulate when
|
||||
// XGBoost is only used for prediction, so tree method is not
|
||||
// specified.
|
||||
std::unique_ptr<Learner> learner {Learner::Create(mat)};
|
||||
learner->SetParams({Arg{"tree_method", "hist"},
|
||||
Arg{"predictor", "gpu_predictor"}});
|
||||
learner->UpdateOneIter(0, p_dmat);
|
||||
ASSERT_EQ(learner->Ctx()->gpu_id, 0);
|
||||
}
|
||||
}
|
||||
#endif // defined(XGBOOST_USE_CUDA)
|
||||
|
||||
|
||||
@ -698,10 +698,6 @@ TEST_F(MultiClassesSerializationTest, GpuHist) {
|
||||
{"seed", "0"},
|
||||
{"nthread", "1"},
|
||||
{"max_depth", std::to_string(kClasses)},
|
||||
// Somehow rebuilding the cache can generate slightly
|
||||
// different result (1e-7) with CPU predictor for some
|
||||
// entries.
|
||||
{"predictor", "gpu_predictor"},
|
||||
// Mitigate the difference caused by hardware fused multiply
|
||||
// add to tree weight during update prediction cache.
|
||||
{"learning_rate", "1.0"},
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
'''Loading a pickled model generated by test_pickling.py, only used by
|
||||
`test_gpu_with_dask.py`'''
|
||||
"""Loading a pickled model generated by test_pickling.py, only used by
|
||||
`test_gpu_with_dask.py`"""
|
||||
import json
|
||||
import os
|
||||
|
||||
@ -12,9 +12,9 @@ from xgboost import testing as tm
|
||||
|
||||
|
||||
class TestLoadPickle:
|
||||
def test_load_pkl(self):
|
||||
'''Test whether prediction is correct.'''
|
||||
assert os.environ['CUDA_VISIBLE_DEVICES'] == '-1'
|
||||
def test_load_pkl(self) -> None:
|
||||
"""Test whether prediction is correct."""
|
||||
assert os.environ["CUDA_VISIBLE_DEVICES"] == "-1"
|
||||
bst = load_pickle(model_path)
|
||||
x, y = build_dataset()
|
||||
if isinstance(bst, xgb.Booster):
|
||||
@ -28,46 +28,42 @@ class TestLoadPickle:
|
||||
|
||||
assert len(res) == 10
|
||||
|
||||
def test_predictor_type_is_auto(self):
|
||||
'''Under invalid CUDA_VISIBLE_DEVICES, predictor should be set to
|
||||
auto'''
|
||||
assert os.environ['CUDA_VISIBLE_DEVICES'] == '-1'
|
||||
def test_context_is_removed(self) -> None:
|
||||
"""Under invalid CUDA_VISIBLE_DEVICES, context should reset"""
|
||||
assert os.environ["CUDA_VISIBLE_DEVICES"] == "-1"
|
||||
bst = load_pickle(model_path)
|
||||
config = bst.save_config()
|
||||
config = json.loads(config)
|
||||
assert config['learner']['gradient_booster']['gbtree_train_param'][
|
||||
'predictor'] == 'auto'
|
||||
assert config["learner"]["generic_param"]["gpu_id"] == "-1"
|
||||
|
||||
def test_predictor_type_is_gpu(self):
|
||||
'''When CUDA_VISIBLE_DEVICES is not specified, keep using
|
||||
`gpu_predictor`'''
|
||||
assert 'CUDA_VISIBLE_DEVICES' not in os.environ.keys()
|
||||
def test_context_is_preserved(self) -> None:
|
||||
"""Test the device context is preserved after pickling."""
|
||||
assert "CUDA_VISIBLE_DEVICES" not in os.environ.keys()
|
||||
bst = load_pickle(model_path)
|
||||
config = bst.save_config()
|
||||
config = json.loads(config)
|
||||
assert config['learner']['gradient_booster']['gbtree_train_param'][
|
||||
'predictor'] == 'gpu_predictor'
|
||||
assert config["learner"]["generic_param"]["gpu_id"] == "0"
|
||||
|
||||
def test_wrap_gpu_id(self):
|
||||
assert os.environ['CUDA_VISIBLE_DEVICES'] == '0'
|
||||
def test_wrap_gpu_id(self) -> None:
|
||||
assert os.environ["CUDA_VISIBLE_DEVICES"] == "0"
|
||||
bst = load_pickle(model_path)
|
||||
config = bst.save_config()
|
||||
config = json.loads(config)
|
||||
assert config['learner']['generic_param']['gpu_id'] == '0'
|
||||
assert config["learner"]["generic_param"]["gpu_id"] == "0"
|
||||
|
||||
x, y = build_dataset()
|
||||
test_x = xgb.DMatrix(x)
|
||||
res = bst.predict(test_x)
|
||||
assert len(res) == 10
|
||||
|
||||
def test_training_on_cpu_only_env(self):
|
||||
assert os.environ['CUDA_VISIBLE_DEVICES'] == '-1'
|
||||
def test_training_on_cpu_only_env(self) -> None:
|
||||
assert os.environ["CUDA_VISIBLE_DEVICES"] == "-1"
|
||||
rng = np.random.RandomState(1994)
|
||||
X = rng.randn(10, 10)
|
||||
y = rng.randn(10)
|
||||
with tm.captured_output() as (out, err):
|
||||
# Test no thrust exception is thrown
|
||||
with pytest.raises(xgb.core.XGBoostError):
|
||||
xgb.train({'tree_method': 'gpu_hist'}, xgb.DMatrix(X, y))
|
||||
xgb.train({"tree_method": "gpu_hist"}, xgb.DMatrix(X, y))
|
||||
|
||||
assert out.getvalue().find('No visible GPU is found') != -1
|
||||
assert out.getvalue().find("No visible GPU is found") != -1
|
||||
|
||||
@ -203,7 +203,7 @@ class TestQuantileDMatrix:
|
||||
np.testing.assert_equal(h_ret.indices, d_ret.indices)
|
||||
|
||||
booster = xgb.train(
|
||||
{"tree_method": "gpu_hist", "predictor": "gpu_predictor"}, dtrain=d_m
|
||||
{"tree_method": "gpu_hist", "gpu_id": "0"}, dtrain=d_m
|
||||
)
|
||||
|
||||
np.testing.assert_allclose(
|
||||
|
||||
@ -221,9 +221,10 @@ Arrow specification.'''
|
||||
def test_specified_device(self):
|
||||
import cupy as cp
|
||||
cp.cuda.runtime.setDevice(0)
|
||||
dtrain = dmatrix_from_cupy(
|
||||
np.float32, xgb.QuantileDMatrix, np.nan)
|
||||
with pytest.raises(xgb.core.XGBoostError):
|
||||
dtrain = dmatrix_from_cupy(np.float32, xgb.QuantileDMatrix, np.nan)
|
||||
with pytest.raises(
|
||||
xgb.core.XGBoostError, match="Data is resided on a different device"
|
||||
):
|
||||
xgb.train(
|
||||
{'tree_method': 'gpu_hist', 'gpu_id': 1}, dtrain, num_boost_round=10
|
||||
)
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
'''Test model IO with pickle.'''
|
||||
import json
|
||||
"""Test model IO with pickle."""
|
||||
import os
|
||||
import pickle
|
||||
import subprocess
|
||||
@ -11,49 +10,48 @@ import xgboost as xgb
|
||||
from xgboost import XGBClassifier
|
||||
from xgboost import testing as tm
|
||||
|
||||
model_path = './model.pkl'
|
||||
model_path = "./model.pkl"
|
||||
|
||||
pytestmark = tm.timeout(30)
|
||||
|
||||
|
||||
def build_dataset():
|
||||
N = 10
|
||||
x = np.linspace(0, N*N, N*N)
|
||||
x = np.linspace(0, N * N, N * N)
|
||||
x = x.reshape((N, N))
|
||||
y = np.linspace(0, N, N)
|
||||
return x, y
|
||||
|
||||
|
||||
def save_pickle(bst, path):
|
||||
with open(path, 'wb') as fd:
|
||||
with open(path, "wb") as fd:
|
||||
pickle.dump(bst, fd)
|
||||
|
||||
|
||||
def load_pickle(path):
|
||||
with open(path, 'rb') as fd:
|
||||
with open(path, "rb") as fd:
|
||||
bst = pickle.load(fd)
|
||||
return bst
|
||||
|
||||
|
||||
class TestPickling:
|
||||
args_template = [
|
||||
"pytest",
|
||||
"--verbose",
|
||||
"-s",
|
||||
"--fulltrace"]
|
||||
args_template = ["pytest", "--verbose", "-s", "--fulltrace"]
|
||||
|
||||
def run_pickling(self, bst) -> None:
|
||||
save_pickle(bst, model_path)
|
||||
args = [
|
||||
"pytest", "--verbose", "-s", "--fulltrace",
|
||||
"./tests/python-gpu/load_pickle.py::TestLoadPickle::test_load_pkl"
|
||||
"pytest",
|
||||
"--verbose",
|
||||
"-s",
|
||||
"--fulltrace",
|
||||
"./tests/python-gpu/load_pickle.py::TestLoadPickle::test_load_pkl",
|
||||
]
|
||||
command = ''
|
||||
command = ""
|
||||
for arg in args:
|
||||
command += arg
|
||||
command += ' '
|
||||
command += " "
|
||||
|
||||
cuda_environment = {'CUDA_VISIBLE_DEVICES': '-1'}
|
||||
cuda_environment = {"CUDA_VISIBLE_DEVICES": "-1"}
|
||||
env = os.environ.copy()
|
||||
# Passing new_environment directly to `env' argument results
|
||||
# in failure on Windows:
|
||||
@ -72,7 +70,7 @@ class TestPickling:
|
||||
x, y = build_dataset()
|
||||
train_x = xgb.DMatrix(x, label=y)
|
||||
|
||||
param = {'tree_method': 'gpu_hist', "gpu_id": 0}
|
||||
param = {"tree_method": "gpu_hist", "gpu_id": 0}
|
||||
bst = xgb.train(param, train_x)
|
||||
self.run_pickling(bst)
|
||||
|
||||
@ -91,43 +89,46 @@ class TestPickling:
|
||||
X, y = build_dataset()
|
||||
dtrain = xgb.DMatrix(X, y)
|
||||
|
||||
bst = xgb.train({'tree_method': 'gpu_hist',
|
||||
'gpu_id': 1},
|
||||
dtrain, num_boost_round=6)
|
||||
bst = xgb.train(
|
||||
{"tree_method": "gpu_hist", "gpu_id": 1}, dtrain, num_boost_round=6
|
||||
)
|
||||
|
||||
model_path = 'model.pkl'
|
||||
model_path = "model.pkl"
|
||||
save_pickle(bst, model_path)
|
||||
cuda_environment = {'CUDA_VISIBLE_DEVICES': '0'}
|
||||
cuda_environment = {"CUDA_VISIBLE_DEVICES": "0"}
|
||||
env = os.environ.copy()
|
||||
env.update(cuda_environment)
|
||||
args = self.args_template.copy()
|
||||
args.append(
|
||||
"./tests/python-gpu/"
|
||||
"load_pickle.py::TestLoadPickle::test_wrap_gpu_id"
|
||||
"./tests/python-gpu/" "load_pickle.py::TestLoadPickle::test_wrap_gpu_id"
|
||||
)
|
||||
status = subprocess.call(args, env=env)
|
||||
assert status == 0
|
||||
os.remove(model_path)
|
||||
|
||||
def test_pickled_predictor(self):
|
||||
x, y = build_dataset()
|
||||
def test_pickled_context(self):
|
||||
x, y = tm.make_sparse_regression(10, 10, sparsity=0.8, as_dense=True)
|
||||
train_x = xgb.DMatrix(x, label=y)
|
||||
|
||||
param = {'tree_method': 'gpu_hist',
|
||||
'verbosity': 1, 'predictor': 'gpu_predictor'}
|
||||
param = {"tree_method": "gpu_hist", "verbosity": 1}
|
||||
bst = xgb.train(param, train_x)
|
||||
config = json.loads(bst.save_config())
|
||||
assert config['learner']['gradient_booster']['gbtree_train_param'][
|
||||
'predictor'] == 'gpu_predictor'
|
||||
|
||||
with tm.captured_output() as (out, err):
|
||||
bst.inplace_predict(x)
|
||||
|
||||
# The warning is redirected to Python callback, so it's printed in stdout
|
||||
# instead of stderr.
|
||||
stdout = out.getvalue()
|
||||
assert stdout.find("mismatched devices") != -1
|
||||
|
||||
save_pickle(bst, model_path)
|
||||
|
||||
args = self.args_template.copy()
|
||||
args.append(
|
||||
"./tests/python-gpu/"
|
||||
"load_pickle.py::TestLoadPickle::test_predictor_type_is_auto")
|
||||
root = tm.project_root(__file__)
|
||||
path = os.path.join(root, "tests", "python-gpu", "load_pickle.py")
|
||||
args.append(path + "::TestLoadPickle::test_context_is_removed")
|
||||
|
||||
cuda_environment = {'CUDA_VISIBLE_DEVICES': '-1'}
|
||||
cuda_environment = {"CUDA_VISIBLE_DEVICES": "-1"}
|
||||
env = os.environ.copy()
|
||||
env.update(cuda_environment)
|
||||
|
||||
@ -138,25 +139,29 @@ class TestPickling:
|
||||
args = self.args_template.copy()
|
||||
args.append(
|
||||
"./tests/python-gpu/"
|
||||
"load_pickle.py::TestLoadPickle::test_predictor_type_is_gpu")
|
||||
"load_pickle.py::TestLoadPickle::test_context_is_preserved"
|
||||
)
|
||||
|
||||
# Load in environment that has GPU.
|
||||
env = os.environ.copy()
|
||||
assert 'CUDA_VISIBLE_DEVICES' not in env.keys()
|
||||
assert "CUDA_VISIBLE_DEVICES" not in env.keys()
|
||||
status = subprocess.call(args, env=env)
|
||||
assert status == 0
|
||||
|
||||
os.remove(model_path)
|
||||
|
||||
@pytest.mark.skipif(**tm.no_sklearn())
|
||||
def test_predict_sklearn_pickle(self):
|
||||
def test_predict_sklearn_pickle(self) -> None:
|
||||
from sklearn.datasets import load_digits
|
||||
|
||||
x, y = load_digits(return_X_y=True)
|
||||
|
||||
kwargs = {'tree_method': 'gpu_hist',
|
||||
'predictor': 'gpu_predictor',
|
||||
'objective': 'binary:logistic',
|
||||
'n_estimators': 10}
|
||||
kwargs = {
|
||||
"tree_method": "gpu_hist",
|
||||
"objective": "binary:logistic",
|
||||
"gpu_id": 0,
|
||||
"n_estimators": 10,
|
||||
}
|
||||
|
||||
model = XGBClassifier(**kwargs)
|
||||
model.fit(x, y)
|
||||
@ -165,24 +170,25 @@ class TestPickling:
|
||||
del model
|
||||
|
||||
# load model
|
||||
model: xgb.XGBClassifier = load_pickle("model.pkl")
|
||||
model = load_pickle("model.pkl")
|
||||
os.remove("model.pkl")
|
||||
|
||||
gpu_pred = model.predict(x, output_margin=True)
|
||||
|
||||
# Switch to CPU predictor
|
||||
bst = model.get_booster()
|
||||
bst.set_param({'predictor': 'cpu_predictor'})
|
||||
tm.set_ordinal(-1, bst)
|
||||
cpu_pred = model.predict(x, output_margin=True)
|
||||
np.testing.assert_allclose(cpu_pred, gpu_pred, rtol=1e-5)
|
||||
|
||||
def test_training_on_cpu_only_env(self):
|
||||
cuda_environment = {'CUDA_VISIBLE_DEVICES': '-1'}
|
||||
cuda_environment = {"CUDA_VISIBLE_DEVICES": "-1"}
|
||||
env = os.environ.copy()
|
||||
env.update(cuda_environment)
|
||||
args = self.args_template.copy()
|
||||
args.append(
|
||||
"./tests/python-gpu/"
|
||||
"load_pickle.py::TestLoadPickle::test_training_on_cpu_only_env")
|
||||
"load_pickle.py::TestLoadPickle::test_training_on_cpu_only_env"
|
||||
)
|
||||
status = subprocess.call(args, env=env)
|
||||
assert status == 0
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import sys
|
||||
from copy import copy
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
@ -11,8 +12,10 @@ from xgboost.compat import PANDAS_INSTALLED
|
||||
if PANDAS_INSTALLED:
|
||||
from hypothesis.extra.pandas import column, data_frames, range_indexes
|
||||
else:
|
||||
|
||||
def noop(*args, **kwargs):
|
||||
pass
|
||||
|
||||
column, data_frames, range_indexes = noop, noop, noop
|
||||
|
||||
sys.path.append("tests/python")
|
||||
@ -21,16 +24,20 @@ from test_predict import run_threaded_predict # noqa
|
||||
|
||||
rng = np.random.RandomState(1994)
|
||||
|
||||
shap_parameter_strategy = strategies.fixed_dictionaries({
|
||||
'max_depth': strategies.integers(1, 11),
|
||||
'max_leaves': strategies.integers(0, 256),
|
||||
'num_parallel_tree': strategies.sampled_from([1, 10]),
|
||||
}).filter(lambda x: x['max_depth'] > 0 or x['max_leaves'] > 0)
|
||||
shap_parameter_strategy = strategies.fixed_dictionaries(
|
||||
{
|
||||
"max_depth": strategies.integers(1, 11),
|
||||
"max_leaves": strategies.integers(0, 256),
|
||||
"num_parallel_tree": strategies.sampled_from([1, 10]),
|
||||
}
|
||||
).filter(lambda x: x["max_depth"] > 0 or x["max_leaves"] > 0)
|
||||
|
||||
predict_parameter_strategy = strategies.fixed_dictionaries({
|
||||
'max_depth': strategies.integers(1, 8),
|
||||
'num_parallel_tree': strategies.sampled_from([1, 4]),
|
||||
})
|
||||
predict_parameter_strategy = strategies.fixed_dictionaries(
|
||||
{
|
||||
"max_depth": strategies.integers(1, 8),
|
||||
"num_parallel_tree": strategies.sampled_from([1, 4]),
|
||||
}
|
||||
)
|
||||
|
||||
pytestmark = tm.timeout(20)
|
||||
|
||||
@ -47,43 +54,45 @@ class TestGPUPredict:
|
||||
# with 5000 rows is 0.04.
|
||||
for num_rows in test_num_rows:
|
||||
for num_cols in test_num_cols:
|
||||
dtrain = xgb.DMatrix(np.random.randn(num_rows, num_cols),
|
||||
label=[0, 1] * int(num_rows / 2))
|
||||
dval = xgb.DMatrix(np.random.randn(num_rows, num_cols),
|
||||
label=[0, 1] * int(num_rows / 2))
|
||||
dtest = xgb.DMatrix(np.random.randn(num_rows, num_cols),
|
||||
label=[0, 1] * int(num_rows / 2))
|
||||
watchlist = [(dtrain, 'train'), (dval, 'validation')]
|
||||
dtrain = xgb.DMatrix(
|
||||
np.random.randn(num_rows, num_cols),
|
||||
label=[0, 1] * int(num_rows / 2),
|
||||
)
|
||||
dval = xgb.DMatrix(
|
||||
np.random.randn(num_rows, num_cols),
|
||||
label=[0, 1] * int(num_rows / 2),
|
||||
)
|
||||
dtest = xgb.DMatrix(
|
||||
np.random.randn(num_rows, num_cols),
|
||||
label=[0, 1] * int(num_rows / 2),
|
||||
)
|
||||
watchlist = [(dtrain, "train"), (dval, "validation")]
|
||||
res = {}
|
||||
param = {
|
||||
"objective": "binary:logistic",
|
||||
"predictor": "gpu_predictor",
|
||||
'eval_metric': 'logloss',
|
||||
'tree_method': 'gpu_hist',
|
||||
'max_depth': 1
|
||||
"eval_metric": "logloss",
|
||||
"tree_method": "gpu_hist",
|
||||
"gpu_id": 0,
|
||||
"max_depth": 1,
|
||||
}
|
||||
bst = xgb.train(param, dtrain, iterations, evals=watchlist,
|
||||
evals_result=res)
|
||||
assert self.non_increasing(res["train"]["logloss"])
|
||||
bst = xgb.train(
|
||||
param, dtrain, iterations, evals=watchlist, evals_result=res
|
||||
)
|
||||
assert tm.non_increasing(res["train"]["logloss"], tolerance=0.001)
|
||||
|
||||
gpu_pred_train = bst.predict(dtrain, output_margin=True)
|
||||
gpu_pred_test = bst.predict(dtest, output_margin=True)
|
||||
gpu_pred_val = bst.predict(dval, output_margin=True)
|
||||
|
||||
param["predictor"] = "cpu_predictor"
|
||||
bst_cpu = xgb.train(param, dtrain, iterations, evals=watchlist)
|
||||
bst.set_param({"gpu_id": -1, "tree_method": "hist"})
|
||||
bst_cpu = copy(bst)
|
||||
cpu_pred_train = bst_cpu.predict(dtrain, output_margin=True)
|
||||
cpu_pred_test = bst_cpu.predict(dtest, output_margin=True)
|
||||
cpu_pred_val = bst_cpu.predict(dval, output_margin=True)
|
||||
|
||||
np.testing.assert_allclose(cpu_pred_train, gpu_pred_train,
|
||||
rtol=1e-6)
|
||||
np.testing.assert_allclose(cpu_pred_val, gpu_pred_val,
|
||||
rtol=1e-6)
|
||||
np.testing.assert_allclose(cpu_pred_test, gpu_pred_test,
|
||||
rtol=1e-6)
|
||||
|
||||
def non_increasing(self, L):
|
||||
return all((y - x) < 0.001 for x, y in zip(L, L[1:]))
|
||||
np.testing.assert_allclose(cpu_pred_train, gpu_pred_train, rtol=1e-6)
|
||||
np.testing.assert_allclose(cpu_pred_val, gpu_pred_val, rtol=1e-6)
|
||||
np.testing.assert_allclose(cpu_pred_test, gpu_pred_test, rtol=1e-6)
|
||||
|
||||
# Test case for a bug where multiple batch predictions made on a
|
||||
# test set produce incorrect results
|
||||
@ -94,26 +103,22 @@ class TestGPUPredict:
|
||||
|
||||
n = 1000
|
||||
X, y = make_regression(n, random_state=rng)
|
||||
X_train, X_test, y_train, y_test = train_test_split(X, y,
|
||||
random_state=123)
|
||||
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=123)
|
||||
dtrain = xgb.DMatrix(X_train, label=y_train)
|
||||
dtest = xgb.DMatrix(X_test)
|
||||
|
||||
params = {}
|
||||
params["tree_method"] = "gpu_hist"
|
||||
bst = xgb.train(params, dtrain)
|
||||
|
||||
params['predictor'] = "gpu_predictor"
|
||||
bst_gpu_predict = xgb.train(params, dtrain)
|
||||
tm.set_ordinal(0, bst)
|
||||
# Don't reuse the DMatrix for prediction, otherwise the result is cached.
|
||||
predict_gpu_0 = bst.predict(xgb.DMatrix(X_test))
|
||||
predict_gpu_1 = bst.predict(xgb.DMatrix(X_test))
|
||||
tm.set_ordinal(-1, bst)
|
||||
predict_cpu = bst.predict(xgb.DMatrix(X_test))
|
||||
|
||||
params['predictor'] = "cpu_predictor"
|
||||
bst_cpu_predict = xgb.train(params, dtrain)
|
||||
|
||||
predict0 = bst_gpu_predict.predict(dtest)
|
||||
predict1 = bst_gpu_predict.predict(dtest)
|
||||
cpu_predict = bst_cpu_predict.predict(dtest)
|
||||
|
||||
assert np.allclose(predict0, predict1)
|
||||
assert np.allclose(predict0, cpu_predict)
|
||||
assert np.allclose(predict_gpu_0, predict_gpu_1)
|
||||
assert np.allclose(predict_gpu_0, predict_cpu)
|
||||
|
||||
@pytest.mark.skipif(**tm.no_sklearn())
|
||||
def test_sklearn(self):
|
||||
@ -121,30 +126,31 @@ class TestGPUPredict:
|
||||
tr_size = 2500
|
||||
X = np.random.rand(m, n)
|
||||
y = 200 * np.matmul(X, np.arange(-3, -3 + n))
|
||||
y = y.reshape(y.size)
|
||||
X_train, y_train = X[:tr_size, :], y[:tr_size]
|
||||
X_test, y_test = X[tr_size:, :], y[tr_size:]
|
||||
|
||||
# First with cpu_predictor
|
||||
params = {'tree_method': 'gpu_hist',
|
||||
'predictor': 'cpu_predictor',
|
||||
'n_jobs': -1,
|
||||
'seed': 123}
|
||||
m = xgb.XGBRegressor(**params).fit(X_train, y_train)
|
||||
cpu_train_score = m.score(X_train, y_train)
|
||||
cpu_test_score = m.score(X_test, y_test)
|
||||
|
||||
# Now with gpu_predictor
|
||||
params['predictor'] = 'gpu_predictor'
|
||||
|
||||
params = {
|
||||
"tree_method": "gpu_hist",
|
||||
"gpu_id": "0",
|
||||
"n_jobs": -1,
|
||||
"seed": 123,
|
||||
}
|
||||
m = xgb.XGBRegressor(**params).fit(X_train, y_train)
|
||||
gpu_train_score = m.score(X_train, y_train)
|
||||
gpu_test_score = m.score(X_test, y_test)
|
||||
|
||||
# Now with cpu
|
||||
m = tm.set_ordinal(-1, m)
|
||||
cpu_train_score = m.score(X_train, y_train)
|
||||
cpu_test_score = m.score(X_test, y_test)
|
||||
|
||||
assert np.allclose(cpu_train_score, gpu_train_score)
|
||||
assert np.allclose(cpu_test_score, gpu_test_score)
|
||||
|
||||
def run_inplace_base_margin(self, booster, dtrain, X, base_margin):
|
||||
import cupy as cp
|
||||
|
||||
dtrain.set_info(base_margin=base_margin)
|
||||
from_inplace = booster.inplace_predict(data=X, base_margin=base_margin)
|
||||
from_dmatrix = booster.predict(dtrain)
|
||||
@ -152,10 +158,11 @@ class TestGPUPredict:
|
||||
|
||||
def run_inplace_predict_cupy(self, device: int) -> None:
|
||||
import cupy as cp
|
||||
|
||||
cp.cuda.runtime.setDevice(device)
|
||||
rows = 1000
|
||||
cols = 10
|
||||
missing = 11 # set to integer for testing
|
||||
missing = 11 # set to integer for testing
|
||||
|
||||
cp_rng = cp.random.RandomState(1994)
|
||||
cp.random.set_random_state(cp_rng)
|
||||
@ -168,7 +175,7 @@ class TestGPUPredict:
|
||||
dtrain = xgb.DMatrix(X, y)
|
||||
|
||||
booster = xgb.train(
|
||||
{'tree_method': 'gpu_hist', "gpu_id": device}, dtrain, num_boost_round=10
|
||||
{"tree_method": "gpu_hist", "gpu_id": device}, dtrain, num_boost_round=10
|
||||
)
|
||||
|
||||
test = xgb.DMatrix(X[:10, ...], missing=missing)
|
||||
@ -186,7 +193,7 @@ class TestGPUPredict:
|
||||
# Don't do this on Windows, see issue #5793
|
||||
if sys.platform.startswith("win"):
|
||||
pytest.skip(
|
||||
'Multi-threaded in-place prediction with cuPy is not working on Windows'
|
||||
"Multi-threaded in-place prediction with cuPy is not working on Windows"
|
||||
)
|
||||
for i in range(10):
|
||||
run_threaded_predict(X, rows, predict_dense)
|
||||
@ -205,9 +212,10 @@ class TestGPUPredict:
|
||||
)
|
||||
reg.fit(X, y)
|
||||
|
||||
reg = tm.set_ordinal(device, reg)
|
||||
gpu_predt = reg.predict(X)
|
||||
reg.set_params(predictor="cpu_predictor")
|
||||
cpu_predt = reg.predict(X)
|
||||
reg = tm.set_ordinal(-1, reg)
|
||||
cpu_predt = reg.predict(cp.asnumpy(X))
|
||||
np.testing.assert_allclose(gpu_predt, cpu_predt, atol=1e-6)
|
||||
cp.cuda.runtime.setDevice(0)
|
||||
|
||||
@ -215,11 +223,11 @@ class TestGPUPredict:
|
||||
def test_inplace_predict_cupy(self):
|
||||
self.run_inplace_predict_cupy(0)
|
||||
|
||||
@pytest.mark.xfail
|
||||
@pytest.mark.skipif(**tm.no_cupy())
|
||||
@pytest.mark.mgpu
|
||||
def test_inplace_predict_cupy_specified_device(self):
|
||||
import cupy as cp
|
||||
|
||||
n_devices = cp.cuda.runtime.getDeviceCount()
|
||||
for d in range(n_devices):
|
||||
self.run_inplace_predict_cupy(d)
|
||||
@ -230,6 +238,7 @@ class TestGPUPredict:
|
||||
import cudf
|
||||
import cupy as cp
|
||||
import pandas as pd
|
||||
|
||||
rows = 1000
|
||||
cols = 10
|
||||
rng = np.random.RandomState(1994)
|
||||
@ -241,8 +250,7 @@ class TestGPUPredict:
|
||||
|
||||
dtrain = xgb.DMatrix(X, y)
|
||||
|
||||
booster = xgb.train({'tree_method': 'gpu_hist'},
|
||||
dtrain, num_boost_round=10)
|
||||
booster = xgb.train({"tree_method": "gpu_hist"}, dtrain, num_boost_round=10)
|
||||
test = xgb.DMatrix(X)
|
||||
predt_from_array = booster.inplace_predict(X)
|
||||
predt_from_dmatrix = booster.predict(test)
|
||||
@ -272,11 +280,12 @@ class TestGPUPredict:
|
||||
def test_shap(self, num_rounds, dataset, param):
|
||||
if dataset.name.endswith("-l1"): # not supported by the exact tree method
|
||||
return
|
||||
param.update({"predictor": "gpu_predictor", "gpu_id": 0})
|
||||
param.update({"tree_method": "gpu_hist", "gpu_id": 0})
|
||||
param = dataset.set_params(param)
|
||||
dmat = dataset.get_dmat()
|
||||
bst = xgb.train(param, dmat, num_rounds)
|
||||
test_dmat = xgb.DMatrix(dataset.X, dataset.y, dataset.w, dataset.margin)
|
||||
bst = tm.set_ordinal(0, bst)
|
||||
shap = bst.predict(test_dmat, pred_contribs=True)
|
||||
margin = bst.predict(test_dmat, output_margin=True)
|
||||
assume(len(dataset.y) > 0)
|
||||
@ -289,31 +298,35 @@ class TestGPUPredict:
|
||||
def test_shap_interactions(self, num_rounds, dataset, param):
|
||||
if dataset.name.endswith("-l1"): # not supported by the exact tree method
|
||||
return
|
||||
param.update({"predictor": "gpu_predictor", "gpu_id": 0})
|
||||
param.update({"tree_method": "hist", "gpu_id": 0})
|
||||
param = dataset.set_params(param)
|
||||
dmat = dataset.get_dmat()
|
||||
bst = xgb.train(param, dmat, num_rounds)
|
||||
test_dmat = xgb.DMatrix(dataset.X, dataset.y, dataset.w, dataset.margin)
|
||||
bst = tm.set_ordinal(0, bst)
|
||||
shap = bst.predict(test_dmat, pred_interactions=True)
|
||||
margin = bst.predict(test_dmat, output_margin=True)
|
||||
assume(len(dataset.y) > 0)
|
||||
assert np.allclose(np.sum(shap, axis=(len(shap.shape) - 1, len(shap.shape) - 2)),
|
||||
margin,
|
||||
1e-3, 1e-3)
|
||||
assert np.allclose(
|
||||
np.sum(shap, axis=(len(shap.shape) - 1, len(shap.shape) - 2)),
|
||||
margin,
|
||||
1e-3,
|
||||
1e-3,
|
||||
)
|
||||
|
||||
def test_shap_categorical(self):
|
||||
X, y = tm.make_categorical(100, 20, 7, False)
|
||||
Xy = xgb.DMatrix(X, y, enable_categorical=True)
|
||||
booster = xgb.train({"tree_method": "gpu_hist"}, Xy, num_boost_round=10)
|
||||
|
||||
booster.set_param({"predictor": "gpu_predictor"})
|
||||
booster = tm.set_ordinal(0, booster)
|
||||
shap = booster.predict(Xy, pred_contribs=True)
|
||||
margin = booster.predict(Xy, output_margin=True)
|
||||
np.testing.assert_allclose(
|
||||
np.sum(shap, axis=len(shap.shape) - 1), margin, rtol=1e-3
|
||||
)
|
||||
|
||||
booster.set_param({"predictor": "cpu_predictor"})
|
||||
booster = tm.set_ordinal(-1, booster)
|
||||
shap = booster.predict(Xy, pred_contribs=True)
|
||||
margin = booster.predict(Xy, output_margin=True)
|
||||
np.testing.assert_allclose(
|
||||
@ -321,18 +334,20 @@ class TestGPUPredict:
|
||||
)
|
||||
|
||||
def test_predict_leaf_basic(self):
|
||||
gpu_leaf = run_predict_leaf('gpu_predictor')
|
||||
cpu_leaf = run_predict_leaf('cpu_predictor')
|
||||
gpu_leaf = run_predict_leaf(0)
|
||||
cpu_leaf = run_predict_leaf(-1)
|
||||
np.testing.assert_equal(gpu_leaf, cpu_leaf)
|
||||
|
||||
def run_predict_leaf_booster(self, param, num_rounds, dataset):
|
||||
param = dataset.set_params(param)
|
||||
m = dataset.get_dmat()
|
||||
booster = xgb.train(param, dtrain=dataset.get_dmat(), num_boost_round=num_rounds)
|
||||
booster.set_param({'predictor': 'cpu_predictor'})
|
||||
booster = xgb.train(
|
||||
param, dtrain=dataset.get_dmat(), num_boost_round=num_rounds
|
||||
)
|
||||
booster = tm.set_ordinal(-1, booster)
|
||||
cpu_leaf = booster.predict(m, pred_leaf=True)
|
||||
|
||||
booster.set_param({'predictor': 'gpu_predictor'})
|
||||
booster = tm.set_ordinal(0, booster)
|
||||
gpu_leaf = booster.predict(m, pred_leaf=True)
|
||||
|
||||
np.testing.assert_equal(cpu_leaf, gpu_leaf)
|
||||
@ -344,8 +359,8 @@ class TestGPUPredict:
|
||||
if param.get("num_parallel_tree", 1) > 1 and dataset.name.endswith("-l1"):
|
||||
return
|
||||
|
||||
param['booster'] = 'gbtree'
|
||||
param['tree_method'] = 'gpu_hist'
|
||||
param["booster"] = "gbtree"
|
||||
param["tree_method"] = "gpu_hist"
|
||||
self.run_predict_leaf_booster(param, 10, dataset)
|
||||
|
||||
@given(predict_parameter_strategy, tm.make_dataset_strategy())
|
||||
@ -355,42 +370,61 @@ class TestGPUPredict:
|
||||
if param.get("num_parallel_tree", 1) > 1 and dataset.name.endswith("-l1"):
|
||||
return
|
||||
|
||||
param['booster'] = 'dart'
|
||||
param['tree_method'] = 'gpu_hist'
|
||||
param["booster"] = "dart"
|
||||
param["tree_method"] = "gpu_hist"
|
||||
self.run_predict_leaf_booster(param, 10, dataset)
|
||||
|
||||
@pytest.mark.skipif(**tm.no_sklearn())
|
||||
@pytest.mark.skipif(**tm.no_pandas())
|
||||
@given(df=data_frames([column('x0', elements=strategies.integers(min_value=0, max_value=3)),
|
||||
column('x1', elements=strategies.integers(min_value=0, max_value=5))],
|
||||
index=range_indexes(min_size=20, max_size=50)))
|
||||
@given(
|
||||
df=data_frames(
|
||||
[
|
||||
column("x0", elements=strategies.integers(min_value=0, max_value=3)),
|
||||
column("x1", elements=strategies.integers(min_value=0, max_value=5)),
|
||||
],
|
||||
index=range_indexes(min_size=20, max_size=50),
|
||||
)
|
||||
)
|
||||
@settings(deadline=None, max_examples=20, print_blob=True)
|
||||
def test_predict_categorical_split(self, df):
|
||||
from sklearn.metrics import mean_squared_error
|
||||
|
||||
df = df.astype('category')
|
||||
x0, x1 = df['x0'].to_numpy(), df['x1'].to_numpy()
|
||||
df = df.astype("category")
|
||||
x0, x1 = df["x0"].to_numpy(), df["x1"].to_numpy()
|
||||
y = (x0 * 10 - 20) + (x1 - 2)
|
||||
dtrain = xgb.DMatrix(df, label=y, enable_categorical=True)
|
||||
|
||||
params = {
|
||||
'tree_method': 'gpu_hist', 'predictor': 'gpu_predictor',
|
||||
'max_depth': 3, 'learning_rate': 1.0, 'base_score': 0.0, 'eval_metric': 'rmse'
|
||||
"tree_method": "gpu_hist",
|
||||
"max_depth": 3,
|
||||
"learning_rate": 1.0,
|
||||
"base_score": 0.0,
|
||||
"eval_metric": "rmse",
|
||||
"gpu_id": "0",
|
||||
}
|
||||
|
||||
eval_history = {}
|
||||
bst = xgb.train(params, dtrain, num_boost_round=5, evals=[(dtrain, 'train')],
|
||||
verbose_eval=False, evals_result=eval_history)
|
||||
|
||||
bst = xgb.train(
|
||||
params,
|
||||
dtrain,
|
||||
num_boost_round=5,
|
||||
evals=[(dtrain, "train")],
|
||||
verbose_eval=False,
|
||||
evals_result=eval_history,
|
||||
)
|
||||
bst = tm.set_ordinal(0, bst)
|
||||
pred = bst.predict(dtrain)
|
||||
rmse = mean_squared_error(y_true=y, y_pred=pred, squared=False)
|
||||
np.testing.assert_almost_equal(rmse, eval_history['train']['rmse'][-1], decimal=5)
|
||||
np.testing.assert_almost_equal(
|
||||
rmse, eval_history["train"]["rmse"][-1], decimal=5
|
||||
)
|
||||
|
||||
@pytest.mark.skipif(**tm.no_cupy())
|
||||
@pytest.mark.parametrize("n_classes", [2, 3])
|
||||
def test_predict_dart(self, n_classes):
|
||||
import cupy as cp
|
||||
from sklearn.datasets import make_classification
|
||||
|
||||
n_samples = 1000
|
||||
X_, y_ = make_classification(
|
||||
n_samples=n_samples, n_informative=5, n_classes=n_classes
|
||||
@ -403,7 +437,7 @@ class TestGPUPredict:
|
||||
"tree_method": "gpu_hist",
|
||||
"booster": "dart",
|
||||
"rate_drop": 0.5,
|
||||
"objective": "binary:logistic"
|
||||
"objective": "binary:logistic",
|
||||
}
|
||||
else:
|
||||
params = {
|
||||
@ -411,15 +445,18 @@ class TestGPUPredict:
|
||||
"booster": "dart",
|
||||
"rate_drop": 0.5,
|
||||
"objective": "multi:softprob",
|
||||
"num_class": n_classes
|
||||
"num_class": n_classes,
|
||||
}
|
||||
|
||||
booster = xgb.train(params, Xy, num_boost_round=32)
|
||||
# predictor=auto
|
||||
|
||||
# auto (GPU)
|
||||
inplace = booster.inplace_predict(X)
|
||||
copied = booster.predict(Xy)
|
||||
|
||||
# CPU
|
||||
booster = tm.set_ordinal(-1, booster)
|
||||
cpu_inplace = booster.inplace_predict(X_)
|
||||
booster.set_param({"predictor": "cpu_predictor"})
|
||||
cpu_copied = booster.predict(Xy)
|
||||
|
||||
copied = cp.array(copied)
|
||||
@ -427,7 +464,8 @@ class TestGPUPredict:
|
||||
cp.testing.assert_allclose(cpu_copied, copied, atol=1e-6)
|
||||
cp.testing.assert_allclose(inplace, copied, atol=1e-6)
|
||||
|
||||
booster.set_param({"predictor": "gpu_predictor"})
|
||||
# GPU
|
||||
booster = tm.set_ordinal(0, booster)
|
||||
inplace = booster.inplace_predict(X)
|
||||
copied = booster.predict(Xy)
|
||||
|
||||
@ -437,12 +475,11 @@ class TestGPUPredict:
|
||||
@pytest.mark.skipif(**tm.no_cupy())
|
||||
def test_dtypes(self):
|
||||
import cupy as cp
|
||||
|
||||
rows = 1000
|
||||
cols = 10
|
||||
rng = cp.random.RandomState(1994)
|
||||
orig = rng.randint(low=0, high=127, size=rows * cols).reshape(
|
||||
rows, cols
|
||||
)
|
||||
orig = rng.randint(low=0, high=127, size=rows * cols).reshape(rows, cols)
|
||||
y = rng.randint(low=0, high=127, size=rows)
|
||||
dtrain = xgb.DMatrix(orig, label=y)
|
||||
booster = xgb.train({"tree_method": "gpu_hist"}, dtrain)
|
||||
@ -450,19 +487,16 @@ class TestGPUPredict:
|
||||
predt_orig = booster.inplace_predict(orig)
|
||||
# all primitive types in numpy
|
||||
for dtype in [
|
||||
cp.signedinteger,
|
||||
cp.byte,
|
||||
cp.short,
|
||||
cp.intc,
|
||||
cp.int_,
|
||||
cp.longlong,
|
||||
cp.unsignedinteger,
|
||||
cp.ubyte,
|
||||
cp.ushort,
|
||||
cp.uintc,
|
||||
cp.uint,
|
||||
cp.ulonglong,
|
||||
cp.floating,
|
||||
cp.half,
|
||||
cp.single,
|
||||
cp.double,
|
||||
@ -472,9 +506,7 @@ class TestGPUPredict:
|
||||
cp.testing.assert_allclose(predt, predt_orig)
|
||||
|
||||
# boolean
|
||||
orig = cp.random.binomial(1, 0.5, size=rows * cols).reshape(
|
||||
rows, cols
|
||||
)
|
||||
orig = cp.random.binomial(1, 0.5, size=rows * cols).reshape(rows, cols)
|
||||
predt_orig = booster.inplace_predict(orig)
|
||||
for dtype in [cp.bool8, cp.bool_]:
|
||||
X = cp.array(orig, dtype=dtype)
|
||||
|
||||
@ -29,7 +29,6 @@ def comp_training_with_rank_objective(
|
||||
"booster": "gbtree",
|
||||
"tree_method": "gpu_hist",
|
||||
"gpu_id": 0,
|
||||
"predictor": "gpu_predictor",
|
||||
}
|
||||
|
||||
num_trees = 100
|
||||
@ -54,7 +53,6 @@ def comp_training_with_rank_objective(
|
||||
"booster": "gbtree",
|
||||
"tree_method": "hist",
|
||||
"gpu_id": -1,
|
||||
"predictor": "cpu_predictor",
|
||||
}
|
||||
cpu_params["objective"] = rank_objective
|
||||
cpu_params["eval_metric"] = metric_name
|
||||
|
||||
@ -260,7 +260,6 @@ class TestGPUUpdaters:
|
||||
"seed": 66,
|
||||
"subsample": 0.5,
|
||||
"gamma": 0.2,
|
||||
"predictor": "auto",
|
||||
"eval_metric": "auc",
|
||||
},
|
||||
num_boost_round=150,
|
||||
|
||||
@ -28,7 +28,7 @@ def run_threaded_predict(X, rows, predict_func):
|
||||
assert f.result()
|
||||
|
||||
|
||||
def run_predict_leaf(predictor):
|
||||
def run_predict_leaf(gpu_id: int) -> np.ndarray:
|
||||
rows = 100
|
||||
cols = 4
|
||||
classes = 5
|
||||
@ -42,13 +42,13 @@ def run_predict_leaf(predictor):
|
||||
{
|
||||
"num_parallel_tree": num_parallel_tree,
|
||||
"num_class": classes,
|
||||
"predictor": predictor,
|
||||
"tree_method": "hist",
|
||||
},
|
||||
m,
|
||||
num_boost_round=num_boost_round,
|
||||
)
|
||||
|
||||
booster = tm.set_ordinal(gpu_id, booster)
|
||||
empty = xgb.DMatrix(np.ones(shape=(0, cols)))
|
||||
empty_leaf = booster.predict(empty, pred_leaf=True)
|
||||
assert empty_leaf.shape[0] == 0
|
||||
@ -74,13 +74,14 @@ def run_predict_leaf(predictor):
|
||||
|
||||
# When there's only 1 tree, the output is a 1 dim vector
|
||||
booster = xgb.train({"tree_method": "hist"}, num_boost_round=1, dtrain=m)
|
||||
booster = tm.set_ordinal(gpu_id, booster)
|
||||
assert booster.predict(m, pred_leaf=True).shape == (rows,)
|
||||
|
||||
return leaf
|
||||
|
||||
|
||||
def test_predict_leaf():
|
||||
run_predict_leaf("cpu_predictor")
|
||||
def test_predict_leaf() -> None:
|
||||
run_predict_leaf(-1)
|
||||
|
||||
|
||||
def test_predict_shape():
|
||||
|
||||
@ -274,7 +274,7 @@ class TestTreeMethod:
|
||||
) -> None:
|
||||
parameters: Dict[str, Any] = {"tree_method": tree_method}
|
||||
cat, label = tm.make_categorical(
|
||||
n_samples=rows, n_features=cols, n_categories=cats, onehot=False, sparsity=0.5
|
||||
rows, n_features=cols, n_categories=cats, onehot=False, sparsity=0.5
|
||||
)
|
||||
Xy = xgb.DMatrix(cat, label, enable_categorical=True)
|
||||
|
||||
@ -294,7 +294,9 @@ class TestTreeMethod:
|
||||
y_predt = booster.predict(Xy)
|
||||
|
||||
rmse = tm.root_mean_square(label, y_predt)
|
||||
np.testing.assert_allclose(rmse, evals_result["Train"]["rmse"][-1])
|
||||
np.testing.assert_allclose(
|
||||
rmse, evals_result["Train"]["rmse"][-1], rtol=2e-5
|
||||
)
|
||||
|
||||
# Test with OHE split
|
||||
run(self.USE_ONEHOT)
|
||||
@ -311,10 +313,8 @@ class TestTreeMethod:
|
||||
by_etl_results: Dict[str, Dict[str, List[float]]] = {}
|
||||
by_builtin_results: Dict[str, Dict[str, List[float]]] = {}
|
||||
|
||||
predictor = "gpu_predictor" if tree_method == "gpu_hist" else None
|
||||
parameters: Dict[str, Any] = {
|
||||
"tree_method": tree_method,
|
||||
"predictor": predictor,
|
||||
# Use one-hot exclusively
|
||||
"max_cat_to_onehot": self.USE_ONEHOT
|
||||
}
|
||||
|
||||
@ -1418,23 +1418,6 @@ def test_categorical():
|
||||
np.testing.assert_allclose(predt_cat, predt_enc)
|
||||
|
||||
|
||||
def test_prediction_config():
|
||||
reg = xgb.XGBRegressor()
|
||||
assert reg._can_use_inplace_predict() is True
|
||||
|
||||
reg.set_params(predictor="cpu_predictor")
|
||||
assert reg._can_use_inplace_predict() is False
|
||||
|
||||
reg.set_params(predictor="auto")
|
||||
assert reg._can_use_inplace_predict() is True
|
||||
|
||||
reg.set_params(predictor=None)
|
||||
assert reg._can_use_inplace_predict() is True
|
||||
|
||||
reg.set_params(booster="gblinear")
|
||||
assert reg._can_use_inplace_predict() is False
|
||||
|
||||
|
||||
def test_evaluation_metric():
|
||||
from sklearn.datasets import load_diabetes, load_digits
|
||||
from sklearn.metrics import mean_absolute_error
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user