[EM] Support ExtMemQdm in the GPU predictor. (#10694)
This commit is contained in:
parent
43704549a2
commit
2ecc85ffad
@ -494,7 +494,7 @@ XGB_DLL int XGDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy
|
||||
* - missing: Which value to represent missing value
|
||||
* - nthread (optional): Number of threads used for initializing DMatrix.
|
||||
* - max_bin (optional): Maximum number of bins for building histogram.
|
||||
* \param out The created Device Quantile DMatrix
|
||||
* \param out The created Quantile DMatrix.
|
||||
*
|
||||
* \return 0 when success, -1 when failure happens
|
||||
*/
|
||||
|
||||
@ -72,7 +72,7 @@ class MetaInfo {
|
||||
* if specified, xgboost will start from this init margin
|
||||
* can be used to specify initial prediction to boost from.
|
||||
*/
|
||||
linalg::Tensor<float, 2> base_margin_; // NOLINT
|
||||
linalg::Matrix<float> base_margin_; // NOLINT
|
||||
/*!
|
||||
* \brief lower bound of the label, to be used for survival analysis (censored regression)
|
||||
*/
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/**
|
||||
* Copyright 2017-2023 by Contributors
|
||||
* Copyright 2017-2024, XGBoost Contributors
|
||||
* \file predictor.h
|
||||
* \brief Interface of predictor,
|
||||
* performs predictions for a gradient booster.
|
||||
@ -15,7 +15,6 @@
|
||||
#include <functional> // for function
|
||||
#include <memory> // for shared_ptr
|
||||
#include <string>
|
||||
#include <utility> // for make_pair
|
||||
#include <vector>
|
||||
|
||||
// Forward declarations
|
||||
|
||||
@ -60,20 +60,26 @@ struct EllpackDeviceAccessor {
|
||||
min_fvalue = cuts->min_vals_.ConstHostSpan();
|
||||
}
|
||||
}
|
||||
// Get a matrix element, uses binary search for look up Return NaN if missing
|
||||
// Given a row index and a feature index, returns the corresponding cut value
|
||||
[[nodiscard]] __device__ int32_t GetBinIndex(size_t ridx, size_t fidx) const {
|
||||
/**
|
||||
* @brief Given a row index and a feature index, returns the corresponding cut value.
|
||||
*
|
||||
* Uses binary search for look up. Returns NaN if missing.
|
||||
*
|
||||
* @tparam global_ridx Whether the row index is global to all ellpack batches or it's
|
||||
* local to the current batch.
|
||||
*/
|
||||
template <bool global_ridx = true>
|
||||
[[nodiscard]] __device__ bst_bin_t GetBinIndex(size_t ridx, size_t fidx) const {
|
||||
if (global_ridx) {
|
||||
ridx -= base_rowid;
|
||||
}
|
||||
auto row_begin = row_stride * ridx;
|
||||
auto row_end = row_begin + row_stride;
|
||||
auto gidx = -1;
|
||||
bst_bin_t gidx = -1;
|
||||
if (is_dense) {
|
||||
gidx = gidx_iter[row_begin + fidx];
|
||||
} else {
|
||||
gidx = common::BinarySearchBin(row_begin,
|
||||
row_end,
|
||||
gidx_iter,
|
||||
feature_segments[fidx],
|
||||
gidx = common::BinarySearchBin(row_begin, row_end, gidx_iter, feature_segments[fidx],
|
||||
feature_segments[fidx + 1]);
|
||||
}
|
||||
return gidx;
|
||||
|
||||
@ -3,10 +3,8 @@
|
||||
*/
|
||||
#include <GPUTreeShap/gpu_treeshap.h>
|
||||
#include <thrust/copy.h>
|
||||
#include <thrust/device_ptr.h>
|
||||
#include <thrust/device_vector.h>
|
||||
#include <thrust/fill.h>
|
||||
#include <thrust/host_vector.h>
|
||||
|
||||
#include <any> // for any, any_cast
|
||||
#include <memory>
|
||||
@ -102,7 +100,7 @@ struct SparsePageView {
|
||||
}
|
||||
}
|
||||
// Value is missing
|
||||
return nanf("");
|
||||
return std::numeric_limits<float>::quiet_NaN();
|
||||
}
|
||||
[[nodiscard]] XGBOOST_DEVICE size_t NumRows() const { return d_row_ptr.size() - 1; }
|
||||
[[nodiscard]] XGBOOST_DEVICE size_t NumCols() const { return num_features; }
|
||||
@ -114,22 +112,21 @@ struct SparsePageLoader {
|
||||
float* smem;
|
||||
|
||||
__device__ SparsePageLoader(SparsePageView data, bool use_shared, bst_feature_t num_features,
|
||||
bst_idx_t num_rows, size_t entry_start, float)
|
||||
: use_shared(use_shared),
|
||||
data(data) {
|
||||
bst_idx_t num_rows, float)
|
||||
: use_shared(use_shared), data(data) {
|
||||
extern __shared__ float _smem[];
|
||||
smem = _smem;
|
||||
// Copy instances
|
||||
if (use_shared) {
|
||||
bst_uint global_idx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
int shared_elements = blockDim.x * data.num_features;
|
||||
dh::BlockFill(smem, shared_elements, nanf(""));
|
||||
dh::BlockFill(smem, shared_elements, std::numeric_limits<float>::quiet_NaN());
|
||||
__syncthreads();
|
||||
if (global_idx < num_rows) {
|
||||
bst_uint elem_begin = data.d_row_ptr[global_idx];
|
||||
bst_uint elem_end = data.d_row_ptr[global_idx + 1];
|
||||
for (bst_uint elem_idx = elem_begin; elem_idx < elem_end; elem_idx++) {
|
||||
Entry elem = data.d_data[elem_idx - entry_start];
|
||||
Entry elem = data.d_data[elem_idx];
|
||||
smem[threadIdx.x * data.num_features + elem.index] = elem.fvalue;
|
||||
}
|
||||
}
|
||||
@ -148,12 +145,12 @@ struct SparsePageLoader {
|
||||
struct EllpackLoader {
|
||||
EllpackDeviceAccessor const& matrix;
|
||||
XGBOOST_DEVICE EllpackLoader(EllpackDeviceAccessor const& m, bool, bst_feature_t, bst_idx_t,
|
||||
size_t, float)
|
||||
float)
|
||||
: matrix{m} {}
|
||||
[[nodiscard]] __device__ __forceinline__ float GetElement(size_t ridx, size_t fidx) const {
|
||||
auto gidx = matrix.GetBinIndex(ridx, fidx);
|
||||
[[nodiscard]] XGBOOST_DEV_INLINE float GetElement(size_t ridx, size_t fidx) const {
|
||||
auto gidx = matrix.GetBinIndex<false>(ridx, fidx);
|
||||
if (gidx == -1) {
|
||||
return nan("");
|
||||
return std::numeric_limits<float>::quiet_NaN();
|
||||
}
|
||||
if (common::IsCat(matrix.feature_types, fidx)) {
|
||||
return matrix.gidx_fvalue_map[gidx];
|
||||
@ -179,14 +176,14 @@ struct DeviceAdapterLoader {
|
||||
|
||||
XGBOOST_DEV_INLINE DeviceAdapterLoader(Batch const batch, bool use_shared,
|
||||
bst_feature_t num_features, bst_idx_t num_rows,
|
||||
size_t entry_start, float missing)
|
||||
float missing)
|
||||
: batch{batch}, columns{num_features}, use_shared{use_shared}, is_valid{missing} {
|
||||
extern __shared__ float _smem[];
|
||||
smem = _smem;
|
||||
if (use_shared) {
|
||||
uint32_t global_idx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
size_t shared_elements = blockDim.x * num_features;
|
||||
dh::BlockFill(smem, shared_elements, nanf(""));
|
||||
dh::BlockFill(smem, shared_elements, std::numeric_limits<float>::quiet_NaN());
|
||||
__syncthreads();
|
||||
if (global_idx < num_rows) {
|
||||
auto beg = global_idx * columns;
|
||||
@ -210,21 +207,19 @@ struct DeviceAdapterLoader {
|
||||
if (is_valid(value)) {
|
||||
return value;
|
||||
} else {
|
||||
return nan("");
|
||||
return std::numeric_limits<float>::quiet_NaN();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <bool has_missing, bool has_categorical, typename Loader>
|
||||
__device__ bst_node_t GetLeafIndex(bst_idx_t ridx, TreeView const &tree,
|
||||
Loader *loader) {
|
||||
__device__ bst_node_t GetLeafIndex(bst_idx_t ridx, TreeView const& tree, Loader* loader) {
|
||||
bst_node_t nidx = 0;
|
||||
RegTree::Node n = tree.d_tree[nidx];
|
||||
while (!n.IsLeaf()) {
|
||||
float fvalue = loader->GetElement(ridx, n.SplitIndex());
|
||||
bool is_missing = common::CheckNAN(fvalue);
|
||||
nidx = GetNextNode<has_missing, has_categorical>(n, nidx, fvalue,
|
||||
is_missing, tree.cats);
|
||||
nidx = GetNextNode<has_missing, has_categorical>(n, nidx, fvalue, is_missing, tree.cats);
|
||||
n = tree.d_tree[nidx];
|
||||
}
|
||||
return nidx;
|
||||
@ -253,14 +248,14 @@ PredictLeafKernel(Data data, common::Span<const RegTree::Node> d_nodes,
|
||||
common::Span<RegTree::CategoricalSplitMatrix::Segment const> d_cat_node_segments,
|
||||
common::Span<uint32_t const> d_categories,
|
||||
|
||||
size_t tree_begin, size_t tree_end, size_t num_features,
|
||||
size_t num_rows, size_t entry_start, bool use_shared,
|
||||
size_t tree_begin, size_t tree_end, bst_feature_t num_features,
|
||||
size_t num_rows, bool use_shared,
|
||||
float missing) {
|
||||
bst_idx_t ridx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
if (ridx >= num_rows) {
|
||||
return;
|
||||
}
|
||||
Loader loader(data, use_shared, num_features, num_rows, entry_start, missing);
|
||||
Loader loader{data, use_shared, num_features, num_rows, missing};
|
||||
for (size_t tree_idx = tree_begin; tree_idx < tree_end; ++tree_idx) {
|
||||
TreeView d_tree{
|
||||
tree_begin, tree_idx, d_nodes,
|
||||
@ -288,10 +283,11 @@ PredictKernel(Data data, common::Span<const RegTree::Node> d_nodes,
|
||||
common::Span<RegTree::CategoricalSplitMatrix::Segment const> d_cat_node_segments,
|
||||
common::Span<uint32_t const> d_categories, size_t tree_begin,
|
||||
size_t tree_end, size_t num_features, size_t num_rows,
|
||||
size_t entry_start, bool use_shared, int num_group, float missing) {
|
||||
bool use_shared, int num_group, float missing) {
|
||||
bst_uint global_idx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
Loader loader(data, use_shared, num_features, num_rows, entry_start, missing);
|
||||
Loader loader(data, use_shared, num_features, num_rows, missing);
|
||||
if (global_idx >= num_rows) return;
|
||||
|
||||
if (num_group == 1) {
|
||||
float sum = 0;
|
||||
for (size_t tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
|
||||
@ -627,10 +623,10 @@ __global__ void MaskBitVectorKernel(
|
||||
common::Span<std::uint32_t const> d_cat_tree_segments,
|
||||
common::Span<RegTree::CategoricalSplitMatrix::Segment const> d_cat_node_segments,
|
||||
common::Span<std::uint32_t const> d_categories, BitVector decision_bits, BitVector missing_bits,
|
||||
std::size_t tree_begin, std::size_t tree_end, std::size_t num_features, std::size_t num_rows,
|
||||
std::size_t entry_start, std::size_t num_nodes, bool use_shared, float missing) {
|
||||
std::size_t tree_begin, std::size_t tree_end, bst_feature_t num_features, std::size_t num_rows,
|
||||
std::size_t num_nodes, bool use_shared, float missing) {
|
||||
// This needs to be always instantiated since the data is loaded cooperatively by all threads.
|
||||
SparsePageLoader loader(data, use_shared, num_features, num_rows, entry_start, missing);
|
||||
SparsePageLoader loader{data, use_shared, num_features, num_rows, missing};
|
||||
auto const row_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (row_idx >= num_rows) {
|
||||
return;
|
||||
@ -789,17 +785,16 @@ class ColumnSplitHelper {
|
||||
|
||||
batch.offset.SetDevice(ctx_->Device());
|
||||
batch.data.SetDevice(ctx_->Device());
|
||||
std::size_t entry_start = 0;
|
||||
SparsePageView data(batch.data.DeviceSpan(), batch.offset.DeviceSpan(), num_features);
|
||||
|
||||
auto const grid = static_cast<uint32_t>(common::DivRoundUp(num_rows, kBlockThreads));
|
||||
dh::LaunchKernel {grid, kBlockThreads, shared_memory_bytes, ctx_->CUDACtx()->Stream()} (
|
||||
dh::LaunchKernel {grid, kBlockThreads, shared_memory_bytes, ctx_->CUDACtx()->Stream()}(
|
||||
MaskBitVectorKernel, data, model.nodes.ConstDeviceSpan(),
|
||||
model.tree_segments.ConstDeviceSpan(), model.tree_group.ConstDeviceSpan(),
|
||||
model.split_types.ConstDeviceSpan(), model.categories_tree_segments.ConstDeviceSpan(),
|
||||
model.categories_node_segments.ConstDeviceSpan(), model.categories.ConstDeviceSpan(),
|
||||
decision_bits, missing_bits, model.tree_beg_, model.tree_end_, num_features, num_rows,
|
||||
entry_start, num_nodes, use_shared, nan(""));
|
||||
num_nodes, use_shared, std::numeric_limits<float>::quiet_NaN());
|
||||
|
||||
AllReduceBitVectors(&decision_storage, &missing_storage);
|
||||
|
||||
@ -852,36 +847,30 @@ class ColumnSplitHelper {
|
||||
|
||||
class GPUPredictor : public xgboost::Predictor {
|
||||
private:
|
||||
void PredictInternal(const SparsePage& batch,
|
||||
DeviceModel const& model,
|
||||
size_t num_features,
|
||||
HostDeviceVector<bst_float>* predictions,
|
||||
size_t batch_offset, bool is_dense) const {
|
||||
void PredictInternal(const SparsePage& batch, DeviceModel const& model, size_t num_features,
|
||||
HostDeviceVector<bst_float>* predictions, size_t batch_offset,
|
||||
bool is_dense) const {
|
||||
batch.offset.SetDevice(ctx_->Device());
|
||||
batch.data.SetDevice(ctx_->Device());
|
||||
const uint32_t BLOCK_THREADS = 128;
|
||||
size_t num_rows = batch.Size();
|
||||
bst_idx_t num_rows = batch.Size();
|
||||
auto GRID_SIZE = static_cast<uint32_t>(common::DivRoundUp(num_rows, BLOCK_THREADS));
|
||||
auto max_shared_memory_bytes = ConfigureDevice(ctx_->Device());
|
||||
size_t shared_memory_bytes =
|
||||
SharedMemoryBytes<BLOCK_THREADS>(num_features, max_shared_memory_bytes);
|
||||
bool use_shared = shared_memory_bytes != 0;
|
||||
|
||||
size_t entry_start = 0;
|
||||
SparsePageView data(batch.data.DeviceSpan(), batch.offset.DeviceSpan(),
|
||||
num_features);
|
||||
auto const kernel = [&](auto predict_fn) {
|
||||
dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS, shared_memory_bytes} (
|
||||
dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS, shared_memory_bytes}(
|
||||
predict_fn, data, model.nodes.ConstDeviceSpan(),
|
||||
predictions->DeviceSpan().subspan(batch_offset),
|
||||
model.tree_segments.ConstDeviceSpan(),
|
||||
model.tree_group.ConstDeviceSpan(),
|
||||
model.split_types.ConstDeviceSpan(),
|
||||
predictions->DeviceSpan().subspan(batch_offset), model.tree_segments.ConstDeviceSpan(),
|
||||
model.tree_group.ConstDeviceSpan(), model.split_types.ConstDeviceSpan(),
|
||||
model.categories_tree_segments.ConstDeviceSpan(),
|
||||
model.categories_node_segments.ConstDeviceSpan(),
|
||||
model.categories.ConstDeviceSpan(), model.tree_beg_, model.tree_end_,
|
||||
num_features, num_rows, entry_start, use_shared, model.num_group,
|
||||
nan(""));
|
||||
model.categories_node_segments.ConstDeviceSpan(), model.categories.ConstDeviceSpan(),
|
||||
model.tree_beg_, model.tree_end_, num_features, num_rows, use_shared, model.num_group,
|
||||
std::numeric_limits<float>::quiet_NaN());
|
||||
};
|
||||
if (is_dense) {
|
||||
kernel(PredictKernel<SparsePageLoader, SparsePageView, false>);
|
||||
@ -889,27 +878,23 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
kernel(PredictKernel<SparsePageLoader, SparsePageView, true>);
|
||||
}
|
||||
}
|
||||
void PredictInternal(EllpackDeviceAccessor const& batch,
|
||||
DeviceModel const& model,
|
||||
HostDeviceVector<bst_float>* out_preds,
|
||||
size_t batch_offset) const {
|
||||
|
||||
void PredictInternal(EllpackDeviceAccessor const& batch, DeviceModel const& model,
|
||||
HostDeviceVector<bst_float>* out_preds, bst_idx_t batch_offset) const {
|
||||
const uint32_t BLOCK_THREADS = 256;
|
||||
size_t num_rows = batch.n_rows;
|
||||
auto GRID_SIZE = static_cast<uint32_t>(common::DivRoundUp(num_rows, BLOCK_THREADS));
|
||||
DeviceModel d_model;
|
||||
|
||||
bool use_shared = false;
|
||||
size_t entry_start = 0;
|
||||
dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS} (
|
||||
PredictKernel<EllpackLoader, EllpackDeviceAccessor>, batch,
|
||||
model.nodes.ConstDeviceSpan(), out_preds->DeviceSpan().subspan(batch_offset),
|
||||
model.tree_segments.ConstDeviceSpan(), model.tree_group.ConstDeviceSpan(),
|
||||
model.split_types.ConstDeviceSpan(),
|
||||
dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS}(
|
||||
PredictKernel<EllpackLoader, EllpackDeviceAccessor>, batch, model.nodes.ConstDeviceSpan(),
|
||||
out_preds->DeviceSpan().subspan(batch_offset), model.tree_segments.ConstDeviceSpan(),
|
||||
model.tree_group.ConstDeviceSpan(), model.split_types.ConstDeviceSpan(),
|
||||
model.categories_tree_segments.ConstDeviceSpan(),
|
||||
model.categories_node_segments.ConstDeviceSpan(),
|
||||
model.categories.ConstDeviceSpan(), model.tree_beg_, model.tree_end_,
|
||||
batch.NumFeatures(), num_rows, entry_start, use_shared,
|
||||
model.num_group, nan(""));
|
||||
model.categories_node_segments.ConstDeviceSpan(), model.categories.ConstDeviceSpan(),
|
||||
model.tree_beg_, model.tree_end_, batch.NumFeatures(), num_rows, use_shared,
|
||||
model.num_group, std::numeric_limits<float>::quiet_NaN());
|
||||
}
|
||||
|
||||
void DevicePredictInternal(DMatrix* dmat, HostDeviceVector<float>* out_preds,
|
||||
@ -928,24 +913,22 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
return;
|
||||
}
|
||||
|
||||
CHECK_LE(dmat->Info().num_col_, model.learner_model_param->num_feature);
|
||||
if (dmat->PageExists<SparsePage>()) {
|
||||
size_t batch_offset = 0;
|
||||
for (auto &batch : dmat->GetBatches<SparsePage>()) {
|
||||
this->PredictInternal(batch, d_model, model.learner_model_param->num_feature,
|
||||
out_preds, batch_offset, dmat->IsDense());
|
||||
batch_offset += batch.Size() * model.learner_model_param->num_output_group;
|
||||
bst_idx_t batch_offset = 0;
|
||||
for (auto& batch : dmat->GetBatches<SparsePage>()) {
|
||||
this->PredictInternal(batch, d_model, model.learner_model_param->num_feature, out_preds,
|
||||
batch_offset, dmat->IsDense());
|
||||
batch_offset += batch.Size() * model.learner_model_param->OutputLength();
|
||||
}
|
||||
} else {
|
||||
size_t batch_offset = 0;
|
||||
bst_idx_t batch_offset = 0;
|
||||
for (auto const& page : dmat->GetBatches<EllpackPage>(ctx_, BatchParam{})) {
|
||||
dmat->Info().feature_types.SetDevice(ctx_->Device());
|
||||
auto feature_types = dmat->Info().feature_types.ConstDeviceSpan();
|
||||
this->PredictInternal(
|
||||
page.Impl()->GetDeviceAccessor(ctx_->Device(), feature_types),
|
||||
d_model,
|
||||
out_preds,
|
||||
batch_offset);
|
||||
batch_offset += page.Impl()->n_rows;
|
||||
this->PredictInternal(page.Impl()->GetDeviceAccessor(ctx_->Device(), feature_types),
|
||||
d_model, out_preds, batch_offset);
|
||||
batch_offset += page.Size() * model.learner_model_param->OutputLength();
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1004,17 +987,14 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
d_model.Init(model, tree_begin, tree_end, m->Device());
|
||||
|
||||
bool use_shared = shared_memory_bytes != 0;
|
||||
size_t entry_start = 0;
|
||||
|
||||
dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS, shared_memory_bytes} (
|
||||
PredictKernel<Loader, typename Loader::BatchT>, m->Value(),
|
||||
d_model.nodes.ConstDeviceSpan(), out_preds->predictions.DeviceSpan(),
|
||||
d_model.tree_segments.ConstDeviceSpan(), d_model.tree_group.ConstDeviceSpan(),
|
||||
d_model.split_types.ConstDeviceSpan(),
|
||||
dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS, shared_memory_bytes}(
|
||||
PredictKernel<Loader, typename Loader::BatchT>, m->Value(), d_model.nodes.ConstDeviceSpan(),
|
||||
out_preds->predictions.DeviceSpan(), d_model.tree_segments.ConstDeviceSpan(),
|
||||
d_model.tree_group.ConstDeviceSpan(), d_model.split_types.ConstDeviceSpan(),
|
||||
d_model.categories_tree_segments.ConstDeviceSpan(),
|
||||
d_model.categories_node_segments.ConstDeviceSpan(),
|
||||
d_model.categories.ConstDeviceSpan(), tree_begin, tree_end, m->NumColumns(),
|
||||
m->NumRows(), entry_start, use_shared, output_groups, missing);
|
||||
d_model.categories_node_segments.ConstDeviceSpan(), d_model.categories.ConstDeviceSpan(),
|
||||
tree_begin, tree_end, m->NumColumns(), m->NumRows(), use_shared, output_groups, missing);
|
||||
}
|
||||
|
||||
bool InplacePredict(std::shared_ptr<DMatrix> p_m, const gbm::GBTreeModel& model, float missing,
|
||||
@ -1043,8 +1023,8 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
std::vector<bst_float> const* tree_weights,
|
||||
bool approximate, int,
|
||||
unsigned) const override {
|
||||
std::string not_implemented{"contribution is not implemented in GPU "
|
||||
"predictor, use `cpu_predictor` instead."};
|
||||
std::string not_implemented{
|
||||
"contribution is not implemented in the GPU predictor, use CPU instead."};
|
||||
if (approximate) {
|
||||
LOG(FATAL) << "Approximated " << not_implemented;
|
||||
}
|
||||
@ -1199,7 +1179,6 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
info.num_col_, max_shared_memory_bytes);
|
||||
bool use_shared = shared_memory_bytes != 0;
|
||||
bst_feature_t num_features = info.num_col_;
|
||||
size_t entry_start = 0;
|
||||
|
||||
if (p_fmat->PageExists<SparsePage>()) {
|
||||
for (auto const& batch : p_fmat->GetBatches<SparsePage>()) {
|
||||
@ -1223,7 +1202,7 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
d_model.categories.ConstDeviceSpan(),
|
||||
|
||||
d_model.tree_beg_, d_model.tree_end_, num_features, num_rows,
|
||||
entry_start, use_shared, nan(""));
|
||||
use_shared, std::numeric_limits<float>::quiet_NaN());
|
||||
batch_offset += batch.Size();
|
||||
}
|
||||
} else {
|
||||
@ -1245,16 +1224,12 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
d_model.categories.ConstDeviceSpan(),
|
||||
|
||||
d_model.tree_beg_, d_model.tree_end_, num_features, num_rows,
|
||||
entry_start, use_shared, nan(""));
|
||||
use_shared, std::numeric_limits<float>::quiet_NaN());
|
||||
batch_offset += batch.Size();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Configure(const std::vector<std::pair<std::string, std::string>>& cfg) override {
|
||||
Predictor::Configure(cfg);
|
||||
}
|
||||
|
||||
private:
|
||||
/*! \brief Reconfigure the device when GPU is changed. */
|
||||
static size_t ConfigureDevice(DeviceOrd device) {
|
||||
|
||||
@ -147,39 +147,54 @@ TEST(GPUPredictor, EllpackTraining) {
|
||||
TestTrainingPrediction(&ctx, kRows, kBins, p_full, p_ellpack);
|
||||
}
|
||||
|
||||
TEST(GPUPredictor, ExternalMemoryTest) {
|
||||
auto lparam = MakeCUDACtx(0);
|
||||
namespace {
|
||||
template <typename Create>
|
||||
void TestDecisionStumpExternalMemory(Context const* ctx, bst_feature_t n_features,
|
||||
Create create_fn) {
|
||||
std::int32_t n_classes = 3;
|
||||
LearnerModelParam mparam{MakeMP(n_features, .5, n_classes, ctx->Device())};
|
||||
auto model = CreateTestModel(&mparam, ctx, n_classes);
|
||||
std::unique_ptr<Predictor> gpu_predictor =
|
||||
std::unique_ptr<Predictor>(Predictor::Create("gpu_predictor", &lparam));
|
||||
std::unique_ptr<Predictor>(Predictor::Create("gpu_predictor", ctx));
|
||||
gpu_predictor->Configure({});
|
||||
|
||||
const int n_classes = 3;
|
||||
Context ctx = MakeCUDACtx(0);
|
||||
LearnerModelParam mparam{MakeMP(5, .5, n_classes, ctx.Device())};
|
||||
|
||||
gbm::GBTreeModel model = CreateTestModel(&mparam, &ctx, n_classes);
|
||||
std::vector<std::unique_ptr<DMatrix>> dmats;
|
||||
|
||||
dmats.push_back(CreateSparsePageDMatrix(400));
|
||||
dmats.push_back(CreateSparsePageDMatrix(800));
|
||||
dmats.push_back(CreateSparsePageDMatrix(8000));
|
||||
|
||||
for (const auto& dmat: dmats) {
|
||||
dmat->Info().base_margin_ = decltype(dmat->Info().base_margin_){
|
||||
{dmat->Info().num_row_, static_cast<size_t>(n_classes)}, DeviceOrd::CUDA(0)};
|
||||
dmat->Info().base_margin_.Data()->Fill(0.5);
|
||||
for (auto p_fmat : {create_fn(400), create_fn(800), create_fn(2048)}) {
|
||||
p_fmat->Info().base_margin_ = linalg::Constant(ctx, 0.5f, p_fmat->Info().num_row_, n_classes);
|
||||
PredictionCacheEntry out_predictions;
|
||||
gpu_predictor->InitOutPredictions(dmat->Info(), &out_predictions.predictions, model);
|
||||
gpu_predictor->PredictBatch(dmat.get(), &out_predictions, model, 0);
|
||||
EXPECT_EQ(out_predictions.predictions.Size(), dmat->Info().num_row_ * n_classes);
|
||||
const std::vector<float> &host_vector = out_predictions.predictions.ConstHostVector();
|
||||
for (size_t i = 0; i < host_vector.size() / n_classes; i++) {
|
||||
ASSERT_EQ(host_vector[i * n_classes], 2.0);
|
||||
ASSERT_EQ(host_vector[i * n_classes + 1], 0.5);
|
||||
ASSERT_EQ(host_vector[i * n_classes + 2], 0.5);
|
||||
gpu_predictor->InitOutPredictions(p_fmat->Info(), &out_predictions.predictions, model);
|
||||
gpu_predictor->PredictBatch(p_fmat.get(), &out_predictions, model, 0);
|
||||
ASSERT_EQ(out_predictions.predictions.Size(), p_fmat->Info().num_row_ * n_classes);
|
||||
auto const& h_predt = out_predictions.predictions.ConstHostVector();
|
||||
for (size_t i = 0; i < h_predt.size() / n_classes; i++) {
|
||||
ASSERT_EQ(h_predt[i * n_classes], 2.0);
|
||||
ASSERT_EQ(h_predt[i * n_classes + 1], 0.5);
|
||||
ASSERT_EQ(h_predt[i * n_classes + 2], 0.5);
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
TEST(GPUPredictor, ExternalMemory) {
|
||||
auto ctx = MakeCUDACtx(0);
|
||||
|
||||
bst_bin_t max_bin = 128;
|
||||
bst_feature_t n_features = 32;
|
||||
|
||||
TestDecisionStumpExternalMemory(&ctx, n_features, [&](bst_idx_t n_samples) {
|
||||
return RandomDataGenerator{n_samples, n_features, 0.0f}
|
||||
.Batches(4)
|
||||
.Device(ctx.Device())
|
||||
.Bins(max_bin)
|
||||
.GenerateSparsePageDMatrix("temp", false);
|
||||
});
|
||||
TestDecisionStumpExternalMemory(&ctx, n_features, [&](bst_idx_t n_samples) {
|
||||
return RandomDataGenerator{n_samples, n_features, 0.0f}
|
||||
.Batches(4)
|
||||
.Device(ctx.Device())
|
||||
.Bins(max_bin)
|
||||
.GenerateExtMemQuantileDMatrix("temp", false);
|
||||
});
|
||||
}
|
||||
|
||||
TEST(GPUPredictor, InplacePredictCupy) {
|
||||
auto ctx = MakeCUDACtx(0);
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user