Add categorical data support to GPU predictor. (#6165)
This commit is contained in:
parent
7622b8cdb8
commit
798af22ff4
@ -10,6 +10,7 @@
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include "xgboost/tree_model.h"
|
||||
#include "xgboost/host_device_vector.h"
|
||||
|
||||
namespace xgboost {
|
||||
@ -176,6 +177,7 @@ template class HostDeviceVector<FeatureType>;
|
||||
template class HostDeviceVector<Entry>;
|
||||
template class HostDeviceVector<uint64_t>; // bst_row_t
|
||||
template class HostDeviceVector<uint32_t>; // bst_feature_t
|
||||
template class HostDeviceVector<RegTree::Segment>;
|
||||
|
||||
#if defined(__APPLE__)
|
||||
/*
|
||||
|
||||
@ -404,6 +404,7 @@ template class HostDeviceVector<Entry>;
|
||||
template class HostDeviceVector<uint64_t>; // bst_row_t
|
||||
template class HostDeviceVector<uint32_t>; // bst_feature_t
|
||||
template class HostDeviceVector<RegTree::Node>;
|
||||
template class HostDeviceVector<RegTree::Segment>;
|
||||
template class HostDeviceVector<RTreeNodeStat>;
|
||||
|
||||
#if defined(__APPLE__)
|
||||
|
||||
@ -18,6 +18,8 @@
|
||||
#include "../data/ellpack_page.cuh"
|
||||
#include "../data/device_adapter.cuh"
|
||||
#include "../common/common.h"
|
||||
#include "../common/bitfield.h"
|
||||
#include "../common/categorical.h"
|
||||
#include "../common/device_helpers.cuh"
|
||||
|
||||
namespace xgboost {
|
||||
@ -169,33 +171,49 @@ struct DeviceAdapterLoader {
|
||||
|
||||
template <typename Loader>
|
||||
__device__ float GetLeafWeight(bst_uint ridx, const RegTree::Node* tree,
|
||||
common::Span<FeatureType const> split_types,
|
||||
common::Span<RegTree::Segment const> d_cat_ptrs,
|
||||
common::Span<uint32_t const> d_categories,
|
||||
Loader* loader) {
|
||||
RegTree::Node n = tree[0];
|
||||
bst_node_t nidx = 0;
|
||||
RegTree::Node n = tree[nidx];
|
||||
while (!n.IsLeaf()) {
|
||||
float fvalue = loader->GetElement(ridx, n.SplitIndex());
|
||||
// Missing value
|
||||
if (isnan(fvalue)) {
|
||||
n = tree[n.DefaultChild()];
|
||||
if (common::CheckNAN(fvalue)) {
|
||||
nidx = n.DefaultChild();
|
||||
} else {
|
||||
if (fvalue < n.SplitCond()) {
|
||||
n = tree[n.LeftChild()];
|
||||
bool go_left = true;
|
||||
if (common::IsCat(split_types, nidx)) {
|
||||
auto categories = d_categories.subspan(d_cat_ptrs[nidx].beg,
|
||||
d_cat_ptrs[nidx].size);
|
||||
go_left = Decision(categories, common::AsCat(fvalue));
|
||||
} else {
|
||||
n = tree[n.RightChild()];
|
||||
go_left = fvalue < n.SplitCond();
|
||||
}
|
||||
if (go_left) {
|
||||
nidx = n.LeftChild();
|
||||
} else {
|
||||
nidx = n.RightChild();
|
||||
}
|
||||
}
|
||||
n = tree[nidx];
|
||||
}
|
||||
return n.LeafValue();
|
||||
return tree[nidx].LeafValue();
|
||||
}
|
||||
|
||||
template <typename Loader, typename Data>
|
||||
__global__ void PredictKernel(Data data,
|
||||
common::Span<const RegTree::Node> d_nodes,
|
||||
common::Span<float> d_out_predictions,
|
||||
common::Span<size_t> d_tree_segments,
|
||||
common::Span<int> d_tree_group,
|
||||
size_t tree_begin, size_t tree_end, size_t num_features,
|
||||
size_t num_rows, size_t entry_start,
|
||||
bool use_shared, int num_group) {
|
||||
__global__ void
|
||||
PredictKernel(Data data, common::Span<const RegTree::Node> d_nodes,
|
||||
common::Span<float> d_out_predictions,
|
||||
common::Span<size_t const> d_tree_segments,
|
||||
common::Span<int const> d_tree_group,
|
||||
common::Span<FeatureType const> d_tree_split_types,
|
||||
common::Span<uint32_t const> d_cat_tree_segments,
|
||||
common::Span<RegTree::Segment const> d_cat_node_segments,
|
||||
common::Span<uint32_t const> d_categories, size_t tree_begin,
|
||||
size_t tree_end, size_t num_features, size_t num_rows,
|
||||
size_t entry_start, bool use_shared, int num_group) {
|
||||
bst_uint global_idx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
Loader loader(data, use_shared, num_features, num_rows, entry_start);
|
||||
if (global_idx >= num_rows) return;
|
||||
@ -204,7 +222,18 @@ __global__ void PredictKernel(Data data,
|
||||
for (int tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
|
||||
const RegTree::Node* d_tree =
|
||||
&d_nodes[d_tree_segments[tree_idx - tree_begin]];
|
||||
float leaf = GetLeafWeight(global_idx, d_tree, &loader);
|
||||
auto tree_cat_ptrs = d_cat_node_segments.subspan(
|
||||
d_tree_segments[tree_idx - tree_begin],
|
||||
d_tree_segments[tree_idx - tree_begin + 1] -
|
||||
d_tree_segments[tree_idx - tree_begin]);
|
||||
auto tree_categories =
|
||||
d_categories.subspan(d_cat_tree_segments[tree_idx - tree_begin],
|
||||
d_cat_tree_segments[tree_idx - tree_begin + 1] -
|
||||
d_cat_tree_segments[tree_idx - tree_begin]);
|
||||
float leaf = GetLeafWeight(global_idx, d_tree, d_tree_split_types,
|
||||
tree_cat_ptrs,
|
||||
tree_categories,
|
||||
&loader);
|
||||
sum += leaf;
|
||||
}
|
||||
d_out_predictions[global_idx] += sum;
|
||||
@ -214,8 +243,19 @@ __global__ void PredictKernel(Data data,
|
||||
const RegTree::Node* d_tree =
|
||||
&d_nodes[d_tree_segments[tree_idx - tree_begin]];
|
||||
bst_uint out_prediction_idx = global_idx * num_group + tree_group;
|
||||
auto tree_cat_ptrs = d_cat_node_segments.subspan(
|
||||
d_tree_segments[tree_idx - tree_begin],
|
||||
d_tree_segments[tree_idx - tree_begin + 1] -
|
||||
d_tree_segments[tree_idx - tree_begin]);
|
||||
auto tree_categories =
|
||||
d_categories.subspan(d_cat_tree_segments[tree_idx - tree_begin],
|
||||
d_cat_tree_segments[tree_idx - tree_begin + 1] -
|
||||
d_cat_tree_segments[tree_idx - tree_begin]);
|
||||
d_out_predictions[out_prediction_idx] +=
|
||||
GetLeafWeight(global_idx, d_tree, &loader);
|
||||
GetLeafWeight(global_idx, d_tree, d_tree_split_types,
|
||||
tree_cat_ptrs,
|
||||
tree_categories,
|
||||
&loader);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -223,10 +263,18 @@ __global__ void PredictKernel(Data data,
|
||||
class DeviceModel {
|
||||
public:
|
||||
// Need to lazily construct the vectors because GPU id is only known at runtime
|
||||
HostDeviceVector<RegTree::Node> nodes;
|
||||
HostDeviceVector<RTreeNodeStat> stats;
|
||||
HostDeviceVector<size_t> tree_segments;
|
||||
HostDeviceVector<RegTree::Node> nodes;
|
||||
HostDeviceVector<int> tree_group;
|
||||
HostDeviceVector<FeatureType> split_types;
|
||||
|
||||
// Pointer to each tree, segmenting the node array.
|
||||
HostDeviceVector<uint32_t> categories_tree_segments;
|
||||
// Pointer to each node, segmenting categories array.
|
||||
HostDeviceVector<RegTree::Segment> categories_node_segments;
|
||||
HostDeviceVector<uint32_t> categories;
|
||||
|
||||
size_t tree_beg_; // NOLINT
|
||||
size_t tree_end_; // NOLINT
|
||||
int num_group;
|
||||
@ -264,10 +312,43 @@ class DeviceModel {
|
||||
}
|
||||
|
||||
tree_group = std::move(HostDeviceVector<int>(model.tree_info.size(), 0, gpu_id));
|
||||
auto d_tree_group = tree_group.DevicePointer();
|
||||
dh::safe_cuda(cudaMemcpyAsync(d_tree_group, model.tree_info.data(),
|
||||
sizeof(int) * model.tree_info.size(),
|
||||
cudaMemcpyDefault));
|
||||
auto& h_tree_group = tree_group.HostVector();
|
||||
std::memcpy(h_tree_group.data(), model.tree_info.data(), sizeof(int) * model.tree_info.size());
|
||||
|
||||
// Initialize categorical splits.
|
||||
split_types.SetDevice(gpu_id);
|
||||
std::vector<FeatureType>& h_split_types = split_types.HostVector();
|
||||
h_split_types.resize(h_tree_segments.back());
|
||||
for (auto tree_idx = tree_begin; tree_idx < tree_end; ++tree_idx) {
|
||||
auto const& src_st = model.trees.at(tree_idx)->GetSplitTypes();
|
||||
std::copy(src_st.cbegin(), src_st.cend(),
|
||||
h_split_types.begin() + h_tree_segments[tree_idx - tree_begin]);
|
||||
}
|
||||
|
||||
categories = HostDeviceVector<uint32_t>({}, gpu_id);
|
||||
categories_tree_segments = HostDeviceVector<uint32_t>(1, 0, gpu_id);
|
||||
std::vector<uint32_t> &h_categories = categories.HostVector();
|
||||
std::vector<uint32_t> &h_split_cat_segments = categories_tree_segments.HostVector();
|
||||
for (auto tree_idx = tree_begin; tree_idx < tree_end; ++tree_idx) {
|
||||
auto const& src_cats = model.trees.at(tree_idx)->GetSplitCategories();
|
||||
size_t orig_size = h_categories.size();
|
||||
h_categories.resize(orig_size + src_cats.size());
|
||||
std::copy(src_cats.cbegin(), src_cats.cend(),
|
||||
h_categories.begin() + orig_size);
|
||||
h_split_cat_segments.push_back(h_categories.size());
|
||||
}
|
||||
|
||||
categories_node_segments =
|
||||
HostDeviceVector<RegTree::Segment>(h_tree_segments.back(), {}, gpu_id);
|
||||
std::vector<RegTree::Segment> &h_categories_node_segments =
|
||||
categories_node_segments.HostVector();
|
||||
for (auto tree_idx = tree_begin; tree_idx < tree_end; ++tree_idx) {
|
||||
auto const &src_cats_ptr = model.trees.at(tree_idx)->GetSplitCategoriesPtr();
|
||||
std::copy(src_cats_ptr.cbegin(), src_cats_ptr.cend(),
|
||||
h_categories_node_segments.begin() +
|
||||
h_tree_segments[tree_idx - tree_begin]);
|
||||
}
|
||||
|
||||
this->tree_beg_ = tree_begin;
|
||||
this->tree_end_ = tree_end;
|
||||
this->num_group = model.learner_model_param->num_output_group;
|
||||
@ -360,7 +441,8 @@ void ExtractPaths(dh::device_vector<gpu_treeshap::PathElement>* paths,
|
||||
|
||||
class GPUPredictor : public xgboost::Predictor {
|
||||
private:
|
||||
void PredictInternal(const SparsePage& batch, size_t num_features,
|
||||
void PredictInternal(const SparsePage& batch,
|
||||
size_t num_features,
|
||||
HostDeviceVector<bst_float>* predictions,
|
||||
size_t batch_offset) {
|
||||
batch.offset.SetDevice(generic_param_->gpu_id);
|
||||
@ -380,14 +462,18 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
SparsePageView data(batch.data.DeviceSpan(), batch.offset.DeviceSpan(),
|
||||
num_features);
|
||||
dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS, shared_memory_bytes} (
|
||||
PredictKernel<SparsePageLoader, SparsePageView>,
|
||||
data,
|
||||
model_.nodes.DeviceSpan(), predictions->DeviceSpan().subspan(batch_offset),
|
||||
model_.tree_segments.DeviceSpan(), model_.tree_group.DeviceSpan(),
|
||||
model_.tree_beg_, model_.tree_end_, num_features, num_rows,
|
||||
entry_start, use_shared, model_.num_group);
|
||||
PredictKernel<SparsePageLoader, SparsePageView>, data,
|
||||
model_.nodes.ConstDeviceSpan(),
|
||||
predictions->DeviceSpan().subspan(batch_offset),
|
||||
model_.tree_segments.ConstDeviceSpan(), model_.tree_group.ConstDeviceSpan(),
|
||||
model_.split_types.ConstDeviceSpan(),
|
||||
model_.categories_tree_segments.ConstDeviceSpan(),
|
||||
model_.categories_node_segments.ConstDeviceSpan(),
|
||||
model_.categories.ConstDeviceSpan(), model_.tree_beg_, model_.tree_end_,
|
||||
num_features, num_rows, entry_start, use_shared, model_.num_group);
|
||||
}
|
||||
void PredictInternal(EllpackDeviceAccessor const& batch, HostDeviceVector<bst_float>* out_preds,
|
||||
void PredictInternal(EllpackDeviceAccessor const& batch,
|
||||
HostDeviceVector<bst_float>* out_preds,
|
||||
size_t batch_offset) {
|
||||
const uint32_t BLOCK_THREADS = 256;
|
||||
size_t num_rows = batch.n_rows;
|
||||
@ -396,12 +482,15 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
bool use_shared = false;
|
||||
size_t entry_start = 0;
|
||||
dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS} (
|
||||
PredictKernel<EllpackLoader, EllpackDeviceAccessor>,
|
||||
batch,
|
||||
model_.nodes.DeviceSpan(), out_preds->DeviceSpan().subspan(batch_offset),
|
||||
model_.tree_segments.DeviceSpan(), model_.tree_group.DeviceSpan(),
|
||||
model_.tree_beg_, model_.tree_end_, batch.NumFeatures(), num_rows,
|
||||
entry_start, use_shared, model_.num_group);
|
||||
PredictKernel<EllpackLoader, EllpackDeviceAccessor>, batch,
|
||||
model_.nodes.ConstDeviceSpan(), out_preds->DeviceSpan().subspan(batch_offset),
|
||||
model_.tree_segments.ConstDeviceSpan(), model_.tree_group.ConstDeviceSpan(),
|
||||
model_.split_types.ConstDeviceSpan(),
|
||||
model_.categories_tree_segments.ConstDeviceSpan(),
|
||||
model_.categories_node_segments.ConstDeviceSpan(),
|
||||
model_.categories.ConstDeviceSpan(), model_.tree_beg_, model_.tree_end_,
|
||||
batch.NumFeatures(), num_rows, entry_start, use_shared,
|
||||
model_.num_group);
|
||||
}
|
||||
|
||||
void DevicePredictInternal(DMatrix* dmat, HostDeviceVector<float>* out_preds,
|
||||
@ -413,6 +502,7 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
}
|
||||
model_.Init(model, tree_begin, tree_end, generic_param_->gpu_id);
|
||||
out_preds->SetDevice(generic_param_->gpu_id);
|
||||
auto const& info = dmat->Info();
|
||||
|
||||
if (dmat->PageExists<SparsePage>()) {
|
||||
size_t batch_offset = 0;
|
||||
@ -425,7 +515,8 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
size_t batch_offset = 0;
|
||||
for (auto const& page : dmat->GetBatches<EllpackPage>()) {
|
||||
this->PredictInternal(
|
||||
page.Impl()->GetDeviceAccessor(generic_param_->gpu_id), out_preds,
|
||||
page.Impl()->GetDeviceAccessor(generic_param_->gpu_id),
|
||||
out_preds,
|
||||
batch_offset);
|
||||
batch_offset += page.Impl()->n_rows;
|
||||
}
|
||||
@ -528,12 +619,14 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
size_t entry_start = 0;
|
||||
|
||||
dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS, shared_memory_bytes} (
|
||||
PredictKernel<Loader, typename Loader::BatchT>,
|
||||
m->Value(),
|
||||
d_model.nodes.DeviceSpan(), out_preds->predictions.DeviceSpan(),
|
||||
d_model.tree_segments.DeviceSpan(), d_model.tree_group.DeviceSpan(),
|
||||
tree_begin, tree_end, m->NumColumns(), info.num_row_,
|
||||
entry_start, use_shared, output_groups);
|
||||
PredictKernel<Loader, typename Loader::BatchT>, m->Value(),
|
||||
d_model.nodes.ConstDeviceSpan(), out_preds->predictions.DeviceSpan(),
|
||||
d_model.tree_segments.ConstDeviceSpan(), d_model.tree_group.ConstDeviceSpan(),
|
||||
d_model.split_types.ConstDeviceSpan(),
|
||||
d_model.categories_tree_segments.ConstDeviceSpan(),
|
||||
d_model.categories_node_segments.ConstDeviceSpan(),
|
||||
d_model.categories.ConstDeviceSpan(), tree_begin, tree_end, m->NumColumns(),
|
||||
info.num_row_, entry_start, use_shared, output_groups);
|
||||
}
|
||||
|
||||
void InplacePredict(dmlc::any const &x, const gbm::GBTreeModel &model,
|
||||
|
||||
@ -221,5 +221,8 @@ TEST(GPUPredictor, Shap) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(GPUPredictor, CategoricalPrediction) {
|
||||
TestCategoricalPrediction("gpu_predictor");
|
||||
}
|
||||
} // namespace predictor
|
||||
} // namespace xgboost
|
||||
|
||||
@ -12,6 +12,8 @@
|
||||
|
||||
#include "../helpers.h"
|
||||
#include "../../../src/common/io.h"
|
||||
#include "../../../src/common/categorical.h"
|
||||
#include "../../../src/common/bitfield.h"
|
||||
|
||||
namespace xgboost {
|
||||
TEST(Predictor, PredictionCache) {
|
||||
@ -27,7 +29,7 @@ TEST(Predictor, PredictionCache) {
|
||||
};
|
||||
|
||||
add_cache();
|
||||
ASSERT_EQ(container.Container().size(), 0);
|
||||
ASSERT_EQ(container.Container().size(), 0ul);
|
||||
add_cache();
|
||||
EXPECT_ANY_THROW(container.Entry(m));
|
||||
}
|
||||
@ -174,4 +176,55 @@ void TestPredictionWithLesserFeatures(std::string predictor_name) {
|
||||
}
|
||||
#endif // defined(XGBOOST_USE_CUDA)
|
||||
}
|
||||
|
||||
void TestCategoricalPrediction(std::string name) {
|
||||
size_t constexpr kCols = 10;
|
||||
PredictionCacheEntry out_predictions;
|
||||
|
||||
LearnerModelParam param;
|
||||
param.num_feature = kCols;
|
||||
param.num_output_group = 1;
|
||||
param.base_score = 0.5;
|
||||
|
||||
gbm::GBTreeModel model(¶m);
|
||||
|
||||
std::vector<std::unique_ptr<RegTree>> trees;
|
||||
trees.push_back(std::unique_ptr<RegTree>(new RegTree));
|
||||
auto& p_tree = trees.front();
|
||||
|
||||
uint32_t split_ind = 3;
|
||||
bst_cat_t split_cat = 4;
|
||||
float left_weight = 1.3f;
|
||||
float right_weight = 1.7f;
|
||||
|
||||
std::vector<uint32_t> split_cats(LBitField32::ComputeStorageSize(split_cat));
|
||||
LBitField32 cats_bits(split_cats);
|
||||
cats_bits.Set(split_cat);
|
||||
|
||||
p_tree->ExpandCategorical(0, split_ind, split_cats, true, 1.5f,
|
||||
left_weight, right_weight,
|
||||
3.0f, 2.2f, 7.0f, 9.0f);
|
||||
model.CommitModel(std::move(trees), 0);
|
||||
|
||||
GenericParameter runtime;
|
||||
runtime.gpu_id = 0;
|
||||
std::unique_ptr<Predictor> predictor{
|
||||
Predictor::Create(name.c_str(), &runtime)};
|
||||
|
||||
std::vector<float> row(kCols);
|
||||
row[split_ind] = split_cat;
|
||||
auto m = GetDMatrixFromData(row, 1, kCols);
|
||||
|
||||
predictor->PredictBatch(m.get(), &out_predictions, model, 0);
|
||||
ASSERT_EQ(out_predictions.predictions.Size(), 1ul);
|
||||
ASSERT_EQ(out_predictions.predictions.HostVector()[0],
|
||||
right_weight + param.base_score); // go to right for matching cat
|
||||
|
||||
row[split_ind] = split_cat + 1;
|
||||
m = GetDMatrixFromData(row, 1, kCols);
|
||||
out_predictions.version = 0;
|
||||
predictor->PredictBatch(m.get(), &out_predictions, model, 0);
|
||||
ASSERT_EQ(out_predictions.predictions.HostVector()[0],
|
||||
left_weight + param.base_score);
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
@ -61,6 +61,8 @@ void TestInplacePrediction(dmlc::any x, std::string predictor,
|
||||
int32_t device = -1);
|
||||
|
||||
void TestPredictionWithLesserFeatures(std::string preditor_name);
|
||||
|
||||
void TestCategoricalPrediction(std::string name);
|
||||
} // namespace xgboost
|
||||
|
||||
#endif // XGBOOST_TEST_PREDICTOR_H_
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user