[EM] Support ExtMemQdm in the GPU predictor. (#10694)

This commit is contained in:
Jiaming Yuan 2024-08-13 12:21:11 +08:00 committed by GitHub
parent 43704549a2
commit 2ecc85ffad
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 124 additions and 129 deletions

View File

@ -494,7 +494,7 @@ XGB_DLL int XGDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy
* - missing: Which value to represent missing value * - missing: Which value to represent missing value
* - nthread (optional): Number of threads used for initializing DMatrix. * - nthread (optional): Number of threads used for initializing DMatrix.
* - max_bin (optional): Maximum number of bins for building histogram. * - max_bin (optional): Maximum number of bins for building histogram.
* \param out The created Device Quantile DMatrix * \param out The created Quantile DMatrix.
* *
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */

View File

@ -72,7 +72,7 @@ class MetaInfo {
* if specified, xgboost will start from this init margin * if specified, xgboost will start from this init margin
* can be used to specify initial prediction to boost from. * can be used to specify initial prediction to boost from.
*/ */
linalg::Tensor<float, 2> base_margin_; // NOLINT linalg::Matrix<float> base_margin_; // NOLINT
/*! /*!
* \brief lower bound of the label, to be used for survival analysis (censored regression) * \brief lower bound of the label, to be used for survival analysis (censored regression)
*/ */

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2017-2023 by Contributors * Copyright 2017-2024, XGBoost Contributors
* \file predictor.h * \file predictor.h
* \brief Interface of predictor, * \brief Interface of predictor,
* performs predictions for a gradient booster. * performs predictions for a gradient booster.
@ -15,7 +15,6 @@
#include <functional> // for function #include <functional> // for function
#include <memory> // for shared_ptr #include <memory> // for shared_ptr
#include <string> #include <string>
#include <utility> // for make_pair
#include <vector> #include <vector>
// Forward declarations // Forward declarations

View File

@ -60,20 +60,26 @@ struct EllpackDeviceAccessor {
min_fvalue = cuts->min_vals_.ConstHostSpan(); min_fvalue = cuts->min_vals_.ConstHostSpan();
} }
} }
// Get a matrix element, uses binary search for look up Return NaN if missing /**
// Given a row index and a feature index, returns the corresponding cut value * @brief Given a row index and a feature index, returns the corresponding cut value.
[[nodiscard]] __device__ int32_t GetBinIndex(size_t ridx, size_t fidx) const { *
* Uses binary search for look up. Returns NaN if missing.
*
* @tparam global_ridx Whether the row index is global to all ellpack batches or it's
* local to the current batch.
*/
template <bool global_ridx = true>
[[nodiscard]] __device__ bst_bin_t GetBinIndex(size_t ridx, size_t fidx) const {
if (global_ridx) {
ridx -= base_rowid; ridx -= base_rowid;
}
auto row_begin = row_stride * ridx; auto row_begin = row_stride * ridx;
auto row_end = row_begin + row_stride; auto row_end = row_begin + row_stride;
auto gidx = -1; bst_bin_t gidx = -1;
if (is_dense) { if (is_dense) {
gidx = gidx_iter[row_begin + fidx]; gidx = gidx_iter[row_begin + fidx];
} else { } else {
gidx = common::BinarySearchBin(row_begin, gidx = common::BinarySearchBin(row_begin, row_end, gidx_iter, feature_segments[fidx],
row_end,
gidx_iter,
feature_segments[fidx],
feature_segments[fidx + 1]); feature_segments[fidx + 1]);
} }
return gidx; return gidx;

View File

@ -3,10 +3,8 @@
*/ */
#include <GPUTreeShap/gpu_treeshap.h> #include <GPUTreeShap/gpu_treeshap.h>
#include <thrust/copy.h> #include <thrust/copy.h>
#include <thrust/device_ptr.h>
#include <thrust/device_vector.h> #include <thrust/device_vector.h>
#include <thrust/fill.h> #include <thrust/fill.h>
#include <thrust/host_vector.h>
#include <any> // for any, any_cast #include <any> // for any, any_cast
#include <memory> #include <memory>
@ -102,7 +100,7 @@ struct SparsePageView {
} }
} }
// Value is missing // Value is missing
return nanf(""); return std::numeric_limits<float>::quiet_NaN();
} }
[[nodiscard]] XGBOOST_DEVICE size_t NumRows() const { return d_row_ptr.size() - 1; } [[nodiscard]] XGBOOST_DEVICE size_t NumRows() const { return d_row_ptr.size() - 1; }
[[nodiscard]] XGBOOST_DEVICE size_t NumCols() const { return num_features; } [[nodiscard]] XGBOOST_DEVICE size_t NumCols() const { return num_features; }
@ -114,22 +112,21 @@ struct SparsePageLoader {
float* smem; float* smem;
__device__ SparsePageLoader(SparsePageView data, bool use_shared, bst_feature_t num_features, __device__ SparsePageLoader(SparsePageView data, bool use_shared, bst_feature_t num_features,
bst_idx_t num_rows, size_t entry_start, float) bst_idx_t num_rows, float)
: use_shared(use_shared), : use_shared(use_shared), data(data) {
data(data) {
extern __shared__ float _smem[]; extern __shared__ float _smem[];
smem = _smem; smem = _smem;
// Copy instances // Copy instances
if (use_shared) { if (use_shared) {
bst_uint global_idx = blockDim.x * blockIdx.x + threadIdx.x; bst_uint global_idx = blockDim.x * blockIdx.x + threadIdx.x;
int shared_elements = blockDim.x * data.num_features; int shared_elements = blockDim.x * data.num_features;
dh::BlockFill(smem, shared_elements, nanf("")); dh::BlockFill(smem, shared_elements, std::numeric_limits<float>::quiet_NaN());
__syncthreads(); __syncthreads();
if (global_idx < num_rows) { if (global_idx < num_rows) {
bst_uint elem_begin = data.d_row_ptr[global_idx]; bst_uint elem_begin = data.d_row_ptr[global_idx];
bst_uint elem_end = data.d_row_ptr[global_idx + 1]; bst_uint elem_end = data.d_row_ptr[global_idx + 1];
for (bst_uint elem_idx = elem_begin; elem_idx < elem_end; elem_idx++) { for (bst_uint elem_idx = elem_begin; elem_idx < elem_end; elem_idx++) {
Entry elem = data.d_data[elem_idx - entry_start]; Entry elem = data.d_data[elem_idx];
smem[threadIdx.x * data.num_features + elem.index] = elem.fvalue; smem[threadIdx.x * data.num_features + elem.index] = elem.fvalue;
} }
} }
@ -148,12 +145,12 @@ struct SparsePageLoader {
struct EllpackLoader { struct EllpackLoader {
EllpackDeviceAccessor const& matrix; EllpackDeviceAccessor const& matrix;
XGBOOST_DEVICE EllpackLoader(EllpackDeviceAccessor const& m, bool, bst_feature_t, bst_idx_t, XGBOOST_DEVICE EllpackLoader(EllpackDeviceAccessor const& m, bool, bst_feature_t, bst_idx_t,
size_t, float) float)
: matrix{m} {} : matrix{m} {}
[[nodiscard]] __device__ __forceinline__ float GetElement(size_t ridx, size_t fidx) const { [[nodiscard]] XGBOOST_DEV_INLINE float GetElement(size_t ridx, size_t fidx) const {
auto gidx = matrix.GetBinIndex(ridx, fidx); auto gidx = matrix.GetBinIndex<false>(ridx, fidx);
if (gidx == -1) { if (gidx == -1) {
return nan(""); return std::numeric_limits<float>::quiet_NaN();
} }
if (common::IsCat(matrix.feature_types, fidx)) { if (common::IsCat(matrix.feature_types, fidx)) {
return matrix.gidx_fvalue_map[gidx]; return matrix.gidx_fvalue_map[gidx];
@ -179,14 +176,14 @@ struct DeviceAdapterLoader {
XGBOOST_DEV_INLINE DeviceAdapterLoader(Batch const batch, bool use_shared, XGBOOST_DEV_INLINE DeviceAdapterLoader(Batch const batch, bool use_shared,
bst_feature_t num_features, bst_idx_t num_rows, bst_feature_t num_features, bst_idx_t num_rows,
size_t entry_start, float missing) float missing)
: batch{batch}, columns{num_features}, use_shared{use_shared}, is_valid{missing} { : batch{batch}, columns{num_features}, use_shared{use_shared}, is_valid{missing} {
extern __shared__ float _smem[]; extern __shared__ float _smem[];
smem = _smem; smem = _smem;
if (use_shared) { if (use_shared) {
uint32_t global_idx = blockDim.x * blockIdx.x + threadIdx.x; uint32_t global_idx = blockDim.x * blockIdx.x + threadIdx.x;
size_t shared_elements = blockDim.x * num_features; size_t shared_elements = blockDim.x * num_features;
dh::BlockFill(smem, shared_elements, nanf("")); dh::BlockFill(smem, shared_elements, std::numeric_limits<float>::quiet_NaN());
__syncthreads(); __syncthreads();
if (global_idx < num_rows) { if (global_idx < num_rows) {
auto beg = global_idx * columns; auto beg = global_idx * columns;
@ -210,21 +207,19 @@ struct DeviceAdapterLoader {
if (is_valid(value)) { if (is_valid(value)) {
return value; return value;
} else { } else {
return nan(""); return std::numeric_limits<float>::quiet_NaN();
} }
} }
}; };
template <bool has_missing, bool has_categorical, typename Loader> template <bool has_missing, bool has_categorical, typename Loader>
__device__ bst_node_t GetLeafIndex(bst_idx_t ridx, TreeView const &tree, __device__ bst_node_t GetLeafIndex(bst_idx_t ridx, TreeView const& tree, Loader* loader) {
Loader *loader) {
bst_node_t nidx = 0; bst_node_t nidx = 0;
RegTree::Node n = tree.d_tree[nidx]; RegTree::Node n = tree.d_tree[nidx];
while (!n.IsLeaf()) { while (!n.IsLeaf()) {
float fvalue = loader->GetElement(ridx, n.SplitIndex()); float fvalue = loader->GetElement(ridx, n.SplitIndex());
bool is_missing = common::CheckNAN(fvalue); bool is_missing = common::CheckNAN(fvalue);
nidx = GetNextNode<has_missing, has_categorical>(n, nidx, fvalue, nidx = GetNextNode<has_missing, has_categorical>(n, nidx, fvalue, is_missing, tree.cats);
is_missing, tree.cats);
n = tree.d_tree[nidx]; n = tree.d_tree[nidx];
} }
return nidx; return nidx;
@ -253,14 +248,14 @@ PredictLeafKernel(Data data, common::Span<const RegTree::Node> d_nodes,
common::Span<RegTree::CategoricalSplitMatrix::Segment const> d_cat_node_segments, common::Span<RegTree::CategoricalSplitMatrix::Segment const> d_cat_node_segments,
common::Span<uint32_t const> d_categories, common::Span<uint32_t const> d_categories,
size_t tree_begin, size_t tree_end, size_t num_features, size_t tree_begin, size_t tree_end, bst_feature_t num_features,
size_t num_rows, size_t entry_start, bool use_shared, size_t num_rows, bool use_shared,
float missing) { float missing) {
bst_idx_t ridx = blockDim.x * blockIdx.x + threadIdx.x; bst_idx_t ridx = blockDim.x * blockIdx.x + threadIdx.x;
if (ridx >= num_rows) { if (ridx >= num_rows) {
return; return;
} }
Loader loader(data, use_shared, num_features, num_rows, entry_start, missing); Loader loader{data, use_shared, num_features, num_rows, missing};
for (size_t tree_idx = tree_begin; tree_idx < tree_end; ++tree_idx) { for (size_t tree_idx = tree_begin; tree_idx < tree_end; ++tree_idx) {
TreeView d_tree{ TreeView d_tree{
tree_begin, tree_idx, d_nodes, tree_begin, tree_idx, d_nodes,
@ -288,10 +283,11 @@ PredictKernel(Data data, common::Span<const RegTree::Node> d_nodes,
common::Span<RegTree::CategoricalSplitMatrix::Segment const> d_cat_node_segments, common::Span<RegTree::CategoricalSplitMatrix::Segment const> d_cat_node_segments,
common::Span<uint32_t const> d_categories, size_t tree_begin, common::Span<uint32_t const> d_categories, size_t tree_begin,
size_t tree_end, size_t num_features, size_t num_rows, size_t tree_end, size_t num_features, size_t num_rows,
size_t entry_start, bool use_shared, int num_group, float missing) { bool use_shared, int num_group, float missing) {
bst_uint global_idx = blockDim.x * blockIdx.x + threadIdx.x; bst_uint global_idx = blockDim.x * blockIdx.x + threadIdx.x;
Loader loader(data, use_shared, num_features, num_rows, entry_start, missing); Loader loader(data, use_shared, num_features, num_rows, missing);
if (global_idx >= num_rows) return; if (global_idx >= num_rows) return;
if (num_group == 1) { if (num_group == 1) {
float sum = 0; float sum = 0;
for (size_t tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) { for (size_t tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
@ -627,10 +623,10 @@ __global__ void MaskBitVectorKernel(
common::Span<std::uint32_t const> d_cat_tree_segments, common::Span<std::uint32_t const> d_cat_tree_segments,
common::Span<RegTree::CategoricalSplitMatrix::Segment const> d_cat_node_segments, common::Span<RegTree::CategoricalSplitMatrix::Segment const> d_cat_node_segments,
common::Span<std::uint32_t const> d_categories, BitVector decision_bits, BitVector missing_bits, common::Span<std::uint32_t const> d_categories, BitVector decision_bits, BitVector missing_bits,
std::size_t tree_begin, std::size_t tree_end, std::size_t num_features, std::size_t num_rows, std::size_t tree_begin, std::size_t tree_end, bst_feature_t num_features, std::size_t num_rows,
std::size_t entry_start, std::size_t num_nodes, bool use_shared, float missing) { std::size_t num_nodes, bool use_shared, float missing) {
// This needs to be always instantiated since the data is loaded cooperatively by all threads. // This needs to be always instantiated since the data is loaded cooperatively by all threads.
SparsePageLoader loader(data, use_shared, num_features, num_rows, entry_start, missing); SparsePageLoader loader{data, use_shared, num_features, num_rows, missing};
auto const row_idx = blockIdx.x * blockDim.x + threadIdx.x; auto const row_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (row_idx >= num_rows) { if (row_idx >= num_rows) {
return; return;
@ -789,7 +785,6 @@ class ColumnSplitHelper {
batch.offset.SetDevice(ctx_->Device()); batch.offset.SetDevice(ctx_->Device());
batch.data.SetDevice(ctx_->Device()); batch.data.SetDevice(ctx_->Device());
std::size_t entry_start = 0;
SparsePageView data(batch.data.DeviceSpan(), batch.offset.DeviceSpan(), num_features); SparsePageView data(batch.data.DeviceSpan(), batch.offset.DeviceSpan(), num_features);
auto const grid = static_cast<uint32_t>(common::DivRoundUp(num_rows, kBlockThreads)); auto const grid = static_cast<uint32_t>(common::DivRoundUp(num_rows, kBlockThreads));
@ -799,7 +794,7 @@ class ColumnSplitHelper {
model.split_types.ConstDeviceSpan(), model.categories_tree_segments.ConstDeviceSpan(), model.split_types.ConstDeviceSpan(), model.categories_tree_segments.ConstDeviceSpan(),
model.categories_node_segments.ConstDeviceSpan(), model.categories.ConstDeviceSpan(), model.categories_node_segments.ConstDeviceSpan(), model.categories.ConstDeviceSpan(),
decision_bits, missing_bits, model.tree_beg_, model.tree_end_, num_features, num_rows, decision_bits, missing_bits, model.tree_beg_, model.tree_end_, num_features, num_rows,
entry_start, num_nodes, use_shared, nan("")); num_nodes, use_shared, std::numeric_limits<float>::quiet_NaN());
AllReduceBitVectors(&decision_storage, &missing_storage); AllReduceBitVectors(&decision_storage, &missing_storage);
@ -852,36 +847,30 @@ class ColumnSplitHelper {
class GPUPredictor : public xgboost::Predictor { class GPUPredictor : public xgboost::Predictor {
private: private:
void PredictInternal(const SparsePage& batch, void PredictInternal(const SparsePage& batch, DeviceModel const& model, size_t num_features,
DeviceModel const& model, HostDeviceVector<bst_float>* predictions, size_t batch_offset,
size_t num_features, bool is_dense) const {
HostDeviceVector<bst_float>* predictions,
size_t batch_offset, bool is_dense) const {
batch.offset.SetDevice(ctx_->Device()); batch.offset.SetDevice(ctx_->Device());
batch.data.SetDevice(ctx_->Device()); batch.data.SetDevice(ctx_->Device());
const uint32_t BLOCK_THREADS = 128; const uint32_t BLOCK_THREADS = 128;
size_t num_rows = batch.Size(); bst_idx_t num_rows = batch.Size();
auto GRID_SIZE = static_cast<uint32_t>(common::DivRoundUp(num_rows, BLOCK_THREADS)); auto GRID_SIZE = static_cast<uint32_t>(common::DivRoundUp(num_rows, BLOCK_THREADS));
auto max_shared_memory_bytes = ConfigureDevice(ctx_->Device()); auto max_shared_memory_bytes = ConfigureDevice(ctx_->Device());
size_t shared_memory_bytes = size_t shared_memory_bytes =
SharedMemoryBytes<BLOCK_THREADS>(num_features, max_shared_memory_bytes); SharedMemoryBytes<BLOCK_THREADS>(num_features, max_shared_memory_bytes);
bool use_shared = shared_memory_bytes != 0; bool use_shared = shared_memory_bytes != 0;
size_t entry_start = 0;
SparsePageView data(batch.data.DeviceSpan(), batch.offset.DeviceSpan(), SparsePageView data(batch.data.DeviceSpan(), batch.offset.DeviceSpan(),
num_features); num_features);
auto const kernel = [&](auto predict_fn) { auto const kernel = [&](auto predict_fn) {
dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS, shared_memory_bytes}( dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS, shared_memory_bytes}(
predict_fn, data, model.nodes.ConstDeviceSpan(), predict_fn, data, model.nodes.ConstDeviceSpan(),
predictions->DeviceSpan().subspan(batch_offset), predictions->DeviceSpan().subspan(batch_offset), model.tree_segments.ConstDeviceSpan(),
model.tree_segments.ConstDeviceSpan(), model.tree_group.ConstDeviceSpan(), model.split_types.ConstDeviceSpan(),
model.tree_group.ConstDeviceSpan(),
model.split_types.ConstDeviceSpan(),
model.categories_tree_segments.ConstDeviceSpan(), model.categories_tree_segments.ConstDeviceSpan(),
model.categories_node_segments.ConstDeviceSpan(), model.categories_node_segments.ConstDeviceSpan(), model.categories.ConstDeviceSpan(),
model.categories.ConstDeviceSpan(), model.tree_beg_, model.tree_end_, model.tree_beg_, model.tree_end_, num_features, num_rows, use_shared, model.num_group,
num_features, num_rows, entry_start, use_shared, model.num_group, std::numeric_limits<float>::quiet_NaN());
nan(""));
}; };
if (is_dense) { if (is_dense) {
kernel(PredictKernel<SparsePageLoader, SparsePageView, false>); kernel(PredictKernel<SparsePageLoader, SparsePageView, false>);
@ -889,27 +878,23 @@ class GPUPredictor : public xgboost::Predictor {
kernel(PredictKernel<SparsePageLoader, SparsePageView, true>); kernel(PredictKernel<SparsePageLoader, SparsePageView, true>);
} }
} }
void PredictInternal(EllpackDeviceAccessor const& batch,
DeviceModel const& model, void PredictInternal(EllpackDeviceAccessor const& batch, DeviceModel const& model,
HostDeviceVector<bst_float>* out_preds, HostDeviceVector<bst_float>* out_preds, bst_idx_t batch_offset) const {
size_t batch_offset) const {
const uint32_t BLOCK_THREADS = 256; const uint32_t BLOCK_THREADS = 256;
size_t num_rows = batch.n_rows; size_t num_rows = batch.n_rows;
auto GRID_SIZE = static_cast<uint32_t>(common::DivRoundUp(num_rows, BLOCK_THREADS)); auto GRID_SIZE = static_cast<uint32_t>(common::DivRoundUp(num_rows, BLOCK_THREADS));
DeviceModel d_model; DeviceModel d_model;
bool use_shared = false; bool use_shared = false;
size_t entry_start = 0;
dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS}( dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS}(
PredictKernel<EllpackLoader, EllpackDeviceAccessor>, batch, PredictKernel<EllpackLoader, EllpackDeviceAccessor>, batch, model.nodes.ConstDeviceSpan(),
model.nodes.ConstDeviceSpan(), out_preds->DeviceSpan().subspan(batch_offset), out_preds->DeviceSpan().subspan(batch_offset), model.tree_segments.ConstDeviceSpan(),
model.tree_segments.ConstDeviceSpan(), model.tree_group.ConstDeviceSpan(), model.tree_group.ConstDeviceSpan(), model.split_types.ConstDeviceSpan(),
model.split_types.ConstDeviceSpan(),
model.categories_tree_segments.ConstDeviceSpan(), model.categories_tree_segments.ConstDeviceSpan(),
model.categories_node_segments.ConstDeviceSpan(), model.categories_node_segments.ConstDeviceSpan(), model.categories.ConstDeviceSpan(),
model.categories.ConstDeviceSpan(), model.tree_beg_, model.tree_end_, model.tree_beg_, model.tree_end_, batch.NumFeatures(), num_rows, use_shared,
batch.NumFeatures(), num_rows, entry_start, use_shared, model.num_group, std::numeric_limits<float>::quiet_NaN());
model.num_group, nan(""));
} }
void DevicePredictInternal(DMatrix* dmat, HostDeviceVector<float>* out_preds, void DevicePredictInternal(DMatrix* dmat, HostDeviceVector<float>* out_preds,
@ -928,24 +913,22 @@ class GPUPredictor : public xgboost::Predictor {
return; return;
} }
CHECK_LE(dmat->Info().num_col_, model.learner_model_param->num_feature);
if (dmat->PageExists<SparsePage>()) { if (dmat->PageExists<SparsePage>()) {
size_t batch_offset = 0; bst_idx_t batch_offset = 0;
for (auto& batch : dmat->GetBatches<SparsePage>()) { for (auto& batch : dmat->GetBatches<SparsePage>()) {
this->PredictInternal(batch, d_model, model.learner_model_param->num_feature, this->PredictInternal(batch, d_model, model.learner_model_param->num_feature, out_preds,
out_preds, batch_offset, dmat->IsDense()); batch_offset, dmat->IsDense());
batch_offset += batch.Size() * model.learner_model_param->num_output_group; batch_offset += batch.Size() * model.learner_model_param->OutputLength();
} }
} else { } else {
size_t batch_offset = 0; bst_idx_t batch_offset = 0;
for (auto const& page : dmat->GetBatches<EllpackPage>(ctx_, BatchParam{})) { for (auto const& page : dmat->GetBatches<EllpackPage>(ctx_, BatchParam{})) {
dmat->Info().feature_types.SetDevice(ctx_->Device()); dmat->Info().feature_types.SetDevice(ctx_->Device());
auto feature_types = dmat->Info().feature_types.ConstDeviceSpan(); auto feature_types = dmat->Info().feature_types.ConstDeviceSpan();
this->PredictInternal( this->PredictInternal(page.Impl()->GetDeviceAccessor(ctx_->Device(), feature_types),
page.Impl()->GetDeviceAccessor(ctx_->Device(), feature_types), d_model, out_preds, batch_offset);
d_model, batch_offset += page.Size() * model.learner_model_param->OutputLength();
out_preds,
batch_offset);
batch_offset += page.Impl()->n_rows;
} }
} }
} }
@ -1004,17 +987,14 @@ class GPUPredictor : public xgboost::Predictor {
d_model.Init(model, tree_begin, tree_end, m->Device()); d_model.Init(model, tree_begin, tree_end, m->Device());
bool use_shared = shared_memory_bytes != 0; bool use_shared = shared_memory_bytes != 0;
size_t entry_start = 0;
dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS, shared_memory_bytes}( dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS, shared_memory_bytes}(
PredictKernel<Loader, typename Loader::BatchT>, m->Value(), PredictKernel<Loader, typename Loader::BatchT>, m->Value(), d_model.nodes.ConstDeviceSpan(),
d_model.nodes.ConstDeviceSpan(), out_preds->predictions.DeviceSpan(), out_preds->predictions.DeviceSpan(), d_model.tree_segments.ConstDeviceSpan(),
d_model.tree_segments.ConstDeviceSpan(), d_model.tree_group.ConstDeviceSpan(), d_model.tree_group.ConstDeviceSpan(), d_model.split_types.ConstDeviceSpan(),
d_model.split_types.ConstDeviceSpan(),
d_model.categories_tree_segments.ConstDeviceSpan(), d_model.categories_tree_segments.ConstDeviceSpan(),
d_model.categories_node_segments.ConstDeviceSpan(), d_model.categories_node_segments.ConstDeviceSpan(), d_model.categories.ConstDeviceSpan(),
d_model.categories.ConstDeviceSpan(), tree_begin, tree_end, m->NumColumns(), tree_begin, tree_end, m->NumColumns(), m->NumRows(), use_shared, output_groups, missing);
m->NumRows(), entry_start, use_shared, output_groups, missing);
} }
bool InplacePredict(std::shared_ptr<DMatrix> p_m, const gbm::GBTreeModel& model, float missing, bool InplacePredict(std::shared_ptr<DMatrix> p_m, const gbm::GBTreeModel& model, float missing,
@ -1043,8 +1023,8 @@ class GPUPredictor : public xgboost::Predictor {
std::vector<bst_float> const* tree_weights, std::vector<bst_float> const* tree_weights,
bool approximate, int, bool approximate, int,
unsigned) const override { unsigned) const override {
std::string not_implemented{"contribution is not implemented in GPU " std::string not_implemented{
"predictor, use `cpu_predictor` instead."}; "contribution is not implemented in the GPU predictor, use CPU instead."};
if (approximate) { if (approximate) {
LOG(FATAL) << "Approximated " << not_implemented; LOG(FATAL) << "Approximated " << not_implemented;
} }
@ -1199,7 +1179,6 @@ class GPUPredictor : public xgboost::Predictor {
info.num_col_, max_shared_memory_bytes); info.num_col_, max_shared_memory_bytes);
bool use_shared = shared_memory_bytes != 0; bool use_shared = shared_memory_bytes != 0;
bst_feature_t num_features = info.num_col_; bst_feature_t num_features = info.num_col_;
size_t entry_start = 0;
if (p_fmat->PageExists<SparsePage>()) { if (p_fmat->PageExists<SparsePage>()) {
for (auto const& batch : p_fmat->GetBatches<SparsePage>()) { for (auto const& batch : p_fmat->GetBatches<SparsePage>()) {
@ -1223,7 +1202,7 @@ class GPUPredictor : public xgboost::Predictor {
d_model.categories.ConstDeviceSpan(), d_model.categories.ConstDeviceSpan(),
d_model.tree_beg_, d_model.tree_end_, num_features, num_rows, d_model.tree_beg_, d_model.tree_end_, num_features, num_rows,
entry_start, use_shared, nan("")); use_shared, std::numeric_limits<float>::quiet_NaN());
batch_offset += batch.Size(); batch_offset += batch.Size();
} }
} else { } else {
@ -1245,16 +1224,12 @@ class GPUPredictor : public xgboost::Predictor {
d_model.categories.ConstDeviceSpan(), d_model.categories.ConstDeviceSpan(),
d_model.tree_beg_, d_model.tree_end_, num_features, num_rows, d_model.tree_beg_, d_model.tree_end_, num_features, num_rows,
entry_start, use_shared, nan("")); use_shared, std::numeric_limits<float>::quiet_NaN());
batch_offset += batch.Size(); batch_offset += batch.Size();
} }
} }
} }
void Configure(const std::vector<std::pair<std::string, std::string>>& cfg) override {
Predictor::Configure(cfg);
}
private: private:
/*! \brief Reconfigure the device when GPU is changed. */ /*! \brief Reconfigure the device when GPU is changed. */
static size_t ConfigureDevice(DeviceOrd device) { static size_t ConfigureDevice(DeviceOrd device) {

View File

@ -147,39 +147,54 @@ TEST(GPUPredictor, EllpackTraining) {
TestTrainingPrediction(&ctx, kRows, kBins, p_full, p_ellpack); TestTrainingPrediction(&ctx, kRows, kBins, p_full, p_ellpack);
} }
TEST(GPUPredictor, ExternalMemoryTest) { namespace {
auto lparam = MakeCUDACtx(0); template <typename Create>
void TestDecisionStumpExternalMemory(Context const* ctx, bst_feature_t n_features,
Create create_fn) {
std::int32_t n_classes = 3;
LearnerModelParam mparam{MakeMP(n_features, .5, n_classes, ctx->Device())};
auto model = CreateTestModel(&mparam, ctx, n_classes);
std::unique_ptr<Predictor> gpu_predictor = std::unique_ptr<Predictor> gpu_predictor =
std::unique_ptr<Predictor>(Predictor::Create("gpu_predictor", &lparam)); std::unique_ptr<Predictor>(Predictor::Create("gpu_predictor", ctx));
gpu_predictor->Configure({}); gpu_predictor->Configure({});
const int n_classes = 3; for (auto p_fmat : {create_fn(400), create_fn(800), create_fn(2048)}) {
Context ctx = MakeCUDACtx(0); p_fmat->Info().base_margin_ = linalg::Constant(ctx, 0.5f, p_fmat->Info().num_row_, n_classes);
LearnerModelParam mparam{MakeMP(5, .5, n_classes, ctx.Device())};
gbm::GBTreeModel model = CreateTestModel(&mparam, &ctx, n_classes);
std::vector<std::unique_ptr<DMatrix>> dmats;
dmats.push_back(CreateSparsePageDMatrix(400));
dmats.push_back(CreateSparsePageDMatrix(800));
dmats.push_back(CreateSparsePageDMatrix(8000));
for (const auto& dmat: dmats) {
dmat->Info().base_margin_ = decltype(dmat->Info().base_margin_){
{dmat->Info().num_row_, static_cast<size_t>(n_classes)}, DeviceOrd::CUDA(0)};
dmat->Info().base_margin_.Data()->Fill(0.5);
PredictionCacheEntry out_predictions; PredictionCacheEntry out_predictions;
gpu_predictor->InitOutPredictions(dmat->Info(), &out_predictions.predictions, model); gpu_predictor->InitOutPredictions(p_fmat->Info(), &out_predictions.predictions, model);
gpu_predictor->PredictBatch(dmat.get(), &out_predictions, model, 0); gpu_predictor->PredictBatch(p_fmat.get(), &out_predictions, model, 0);
EXPECT_EQ(out_predictions.predictions.Size(), dmat->Info().num_row_ * n_classes); ASSERT_EQ(out_predictions.predictions.Size(), p_fmat->Info().num_row_ * n_classes);
const std::vector<float> &host_vector = out_predictions.predictions.ConstHostVector(); auto const& h_predt = out_predictions.predictions.ConstHostVector();
for (size_t i = 0; i < host_vector.size() / n_classes; i++) { for (size_t i = 0; i < h_predt.size() / n_classes; i++) {
ASSERT_EQ(host_vector[i * n_classes], 2.0); ASSERT_EQ(h_predt[i * n_classes], 2.0);
ASSERT_EQ(host_vector[i * n_classes + 1], 0.5); ASSERT_EQ(h_predt[i * n_classes + 1], 0.5);
ASSERT_EQ(host_vector[i * n_classes + 2], 0.5); ASSERT_EQ(h_predt[i * n_classes + 2], 0.5);
} }
} }
} }
} // namespace
TEST(GPUPredictor, ExternalMemory) {
auto ctx = MakeCUDACtx(0);
bst_bin_t max_bin = 128;
bst_feature_t n_features = 32;
TestDecisionStumpExternalMemory(&ctx, n_features, [&](bst_idx_t n_samples) {
return RandomDataGenerator{n_samples, n_features, 0.0f}
.Batches(4)
.Device(ctx.Device())
.Bins(max_bin)
.GenerateSparsePageDMatrix("temp", false);
});
TestDecisionStumpExternalMemory(&ctx, n_features, [&](bst_idx_t n_samples) {
return RandomDataGenerator{n_samples, n_features, 0.0f}
.Batches(4)
.Device(ctx.Device())
.Bins(max_bin)
.GenerateExtMemQuantileDMatrix("temp", false);
});
}
TEST(GPUPredictor, InplacePredictCupy) { TEST(GPUPredictor, InplacePredictCupy) {
auto ctx = MakeCUDACtx(0); auto ctx = MakeCUDACtx(0);