Implement categorical prediction for CPU and GPU predict leaf. (#7001)
* Categorical prediction with CPU predictor and GPU predict leaf. * Implement categorical prediction for CPU prediction. * Implement categorical prediction for GPU predict leaf. * Refactor the prediction functions to have a unified get next node function. Co-authored-by: Shvets Kirill <kirill.shvets@intel.com>
This commit is contained in:
@@ -14,6 +14,7 @@
|
||||
#include "xgboost/tree_updater.h"
|
||||
#include "xgboost/host_device_vector.h"
|
||||
|
||||
#include "predict_fn.h"
|
||||
#include "../gbm/gbtree_model.h"
|
||||
#include "../data/ellpack_page.cuh"
|
||||
#include "../data/device_adapter.cuh"
|
||||
@@ -27,6 +28,42 @@ namespace predictor {
|
||||
|
||||
DMLC_REGISTRY_FILE_TAG(gpu_predictor);
|
||||
|
||||
struct TreeView {
|
||||
RegTree::CategoricalSplitMatrix cats;
|
||||
common::Span<RegTree::Node const> d_tree;
|
||||
|
||||
XGBOOST_DEVICE
|
||||
TreeView(size_t tree_begin, size_t tree_idx,
|
||||
common::Span<const RegTree::Node> d_nodes,
|
||||
common::Span<size_t const> d_tree_segments,
|
||||
common::Span<FeatureType const> d_tree_split_types,
|
||||
common::Span<uint32_t const> d_cat_tree_segments,
|
||||
common::Span<RegTree::Segment const> d_cat_node_segments,
|
||||
common::Span<uint32_t const> d_categories) {
|
||||
auto begin = d_tree_segments[tree_idx - tree_begin];
|
||||
auto n_nodes = d_tree_segments[tree_idx - tree_begin + 1] -
|
||||
d_tree_segments[tree_idx - tree_begin];
|
||||
|
||||
d_tree = d_nodes.subspan(begin, n_nodes);
|
||||
|
||||
auto tree_cat_ptrs = d_cat_node_segments.subspan(begin, n_nodes);
|
||||
auto tree_split_types = d_tree_split_types.subspan(begin, n_nodes);
|
||||
|
||||
auto tree_categories =
|
||||
d_categories.subspan(d_cat_tree_segments[tree_idx - tree_begin],
|
||||
d_cat_tree_segments[tree_idx - tree_begin + 1] -
|
||||
d_cat_tree_segments[tree_idx - tree_begin]);
|
||||
|
||||
cats.split_type = tree_split_types;
|
||||
cats.categories = tree_categories;
|
||||
cats.node_ptr = tree_cat_ptrs;
|
||||
}
|
||||
|
||||
__device__ bool HasCategoricalSplit() const {
|
||||
return !cats.categories.empty();
|
||||
}
|
||||
};
|
||||
|
||||
struct SparsePageView {
|
||||
common::Span<const Entry> d_data;
|
||||
common::Span<const bst_row_t> d_row_ptr;
|
||||
@@ -178,84 +215,69 @@ struct DeviceAdapterLoader {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Loader>
|
||||
__device__ float GetLeafWeight(bst_row_t ridx, const RegTree::Node* tree,
|
||||
common::Span<FeatureType const> split_types,
|
||||
common::Span<RegTree::Segment const> d_cat_ptrs,
|
||||
common::Span<uint32_t const> d_categories,
|
||||
Loader* loader) {
|
||||
template <bool has_missing, bool has_categorical, typename Loader>
|
||||
__device__ bst_node_t GetLeafIndex(bst_row_t ridx, TreeView const &tree,
|
||||
Loader *loader) {
|
||||
bst_node_t nidx = 0;
|
||||
RegTree::Node n = tree[nidx];
|
||||
RegTree::Node n = tree.d_tree[nidx];
|
||||
while (!n.IsLeaf()) {
|
||||
float fvalue = loader->GetElement(ridx, n.SplitIndex());
|
||||
// Missing value
|
||||
if (common::CheckNAN(fvalue)) {
|
||||
nidx = n.DefaultChild();
|
||||
} else {
|
||||
bool go_left = true;
|
||||
if (common::IsCat(split_types, nidx)) {
|
||||
auto categories = d_categories.subspan(d_cat_ptrs[nidx].beg,
|
||||
d_cat_ptrs[nidx].size);
|
||||
go_left = Decision(categories, common::AsCat(fvalue));
|
||||
} else {
|
||||
go_left = fvalue < n.SplitCond();
|
||||
}
|
||||
if (go_left) {
|
||||
nidx = n.LeftChild();
|
||||
} else {
|
||||
nidx = n.RightChild();
|
||||
}
|
||||
}
|
||||
n = tree[nidx];
|
||||
}
|
||||
return tree[nidx].LeafValue();
|
||||
}
|
||||
|
||||
template <typename Loader>
|
||||
__device__ bst_node_t GetLeafIndex(bst_row_t ridx, const RegTree::Node* tree,
|
||||
Loader const& loader) {
|
||||
bst_node_t nidx = 0;
|
||||
RegTree::Node n = tree[nidx];
|
||||
while (!n.IsLeaf()) {
|
||||
float fvalue = loader.GetElement(ridx, n.SplitIndex());
|
||||
// Missing value
|
||||
if (common::CheckNAN(fvalue)) {
|
||||
nidx = n.DefaultChild();
|
||||
n = tree[nidx];
|
||||
} else {
|
||||
if (fvalue < n.SplitCond()) {
|
||||
nidx = n.LeftChild();
|
||||
n = tree[nidx];
|
||||
} else {
|
||||
nidx = n.RightChild();
|
||||
n = tree[nidx];
|
||||
}
|
||||
}
|
||||
bool is_missing = common::CheckNAN(fvalue);
|
||||
nidx = GetNextNode<has_missing, has_categorical>(n, nidx, fvalue,
|
||||
is_missing, tree.cats);
|
||||
n = tree.d_tree[nidx];
|
||||
}
|
||||
return nidx;
|
||||
}
|
||||
|
||||
template <bool has_missing, typename Loader>
|
||||
__device__ float GetLeafWeight(bst_row_t ridx, TreeView const &tree,
|
||||
Loader *loader) {
|
||||
bst_node_t nidx = -1;
|
||||
if (tree.HasCategoricalSplit()) {
|
||||
nidx = GetLeafIndex<has_missing, true>(ridx, tree, loader);
|
||||
} else {
|
||||
nidx = GetLeafIndex<has_missing, false>(ridx, tree, loader);
|
||||
}
|
||||
return tree.d_tree[nidx].LeafValue();
|
||||
}
|
||||
|
||||
template <typename Loader, typename Data>
|
||||
__global__ void PredictLeafKernel(Data data,
|
||||
common::Span<const RegTree::Node> d_nodes,
|
||||
common::Span<float> d_out_predictions,
|
||||
common::Span<size_t const> d_tree_segments,
|
||||
size_t tree_begin, size_t tree_end, size_t num_features,
|
||||
size_t num_rows, size_t entry_start, bool use_shared,
|
||||
float missing) {
|
||||
__global__ void
|
||||
PredictLeafKernel(Data data, common::Span<const RegTree::Node> d_nodes,
|
||||
common::Span<float> d_out_predictions,
|
||||
common::Span<size_t const> d_tree_segments,
|
||||
|
||||
common::Span<FeatureType const> d_tree_split_types,
|
||||
common::Span<uint32_t const> d_cat_tree_segments,
|
||||
common::Span<RegTree::Segment const> d_cat_node_segments,
|
||||
common::Span<uint32_t const> d_categories,
|
||||
|
||||
size_t tree_begin, size_t tree_end, size_t num_features,
|
||||
size_t num_rows, size_t entry_start, bool use_shared,
|
||||
float missing) {
|
||||
bst_row_t ridx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
if (ridx >= num_rows) {
|
||||
return;
|
||||
}
|
||||
Loader loader(data, use_shared, num_features, num_rows, entry_start, missing);
|
||||
for (int tree_idx = tree_begin; tree_idx < tree_end; ++tree_idx) {
|
||||
const RegTree::Node* d_tree = &d_nodes[d_tree_segments[tree_idx - tree_begin]];
|
||||
auto leaf = GetLeafIndex(ridx, d_tree, loader);
|
||||
for (size_t tree_idx = tree_begin; tree_idx < tree_end; ++tree_idx) {
|
||||
TreeView d_tree{
|
||||
tree_begin, tree_idx, d_nodes,
|
||||
d_tree_segments, d_tree_split_types, d_cat_tree_segments,
|
||||
d_cat_node_segments, d_categories};
|
||||
|
||||
bst_node_t leaf = -1;
|
||||
if (d_tree.HasCategoricalSplit()) {
|
||||
leaf = GetLeafIndex<true, true>(ridx, d_tree, &loader);
|
||||
} else {
|
||||
leaf = GetLeafIndex<true, false>(ridx, d_tree, &loader);
|
||||
}
|
||||
d_out_predictions[ridx * (tree_end - tree_begin) + tree_idx] = leaf;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Loader, typename Data>
|
||||
template <typename Loader, typename Data, bool has_missing = true>
|
||||
__global__ void
|
||||
PredictKernel(Data data, common::Span<const RegTree::Node> d_nodes,
|
||||
common::Span<float> d_out_predictions,
|
||||
@@ -272,47 +294,25 @@ PredictKernel(Data data, common::Span<const RegTree::Node> d_nodes,
|
||||
if (global_idx >= num_rows) return;
|
||||
if (num_group == 1) {
|
||||
float sum = 0;
|
||||
for (int tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
|
||||
const RegTree::Node* d_tree =
|
||||
&d_nodes[d_tree_segments[tree_idx - tree_begin]];
|
||||
auto tree_cat_ptrs = d_cat_node_segments.subspan(
|
||||
d_tree_segments[tree_idx - tree_begin],
|
||||
d_tree_segments[tree_idx - tree_begin + 1] -
|
||||
d_tree_segments[tree_idx - tree_begin]);
|
||||
auto tree_categories =
|
||||
d_categories.subspan(d_cat_tree_segments[tree_idx - tree_begin],
|
||||
d_cat_tree_segments[tree_idx - tree_begin + 1] -
|
||||
d_cat_tree_segments[tree_idx - tree_begin]);
|
||||
auto tree_split_types =
|
||||
d_tree_split_types.subspan(d_tree_segments[tree_idx - tree_begin],
|
||||
d_tree_segments[tree_idx - tree_begin + 1] -
|
||||
d_tree_segments[tree_idx - tree_begin]);
|
||||
float leaf = GetLeafWeight(global_idx, d_tree, tree_split_types,
|
||||
tree_cat_ptrs,
|
||||
tree_categories,
|
||||
&loader);
|
||||
for (size_t tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
|
||||
TreeView d_tree{
|
||||
tree_begin, tree_idx, d_nodes,
|
||||
d_tree_segments, d_tree_split_types, d_cat_tree_segments,
|
||||
d_cat_node_segments, d_categories};
|
||||
float leaf = GetLeafWeight<has_missing>(global_idx, d_tree, &loader);
|
||||
sum += leaf;
|
||||
}
|
||||
d_out_predictions[global_idx] += sum;
|
||||
} else {
|
||||
for (int tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
|
||||
for (size_t tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
|
||||
int tree_group = d_tree_group[tree_idx];
|
||||
const RegTree::Node* d_tree =
|
||||
&d_nodes[d_tree_segments[tree_idx - tree_begin]];
|
||||
TreeView d_tree{
|
||||
tree_begin, tree_idx, d_nodes,
|
||||
d_tree_segments, d_tree_split_types, d_cat_tree_segments,
|
||||
d_cat_node_segments, d_categories};
|
||||
bst_uint out_prediction_idx = global_idx * num_group + tree_group;
|
||||
auto tree_cat_ptrs = d_cat_node_segments.subspan(
|
||||
d_tree_segments[tree_idx - tree_begin],
|
||||
d_tree_segments[tree_idx - tree_begin + 1] -
|
||||
d_tree_segments[tree_idx - tree_begin]);
|
||||
auto tree_categories =
|
||||
d_categories.subspan(d_cat_tree_segments[tree_idx - tree_begin],
|
||||
d_cat_tree_segments[tree_idx - tree_begin + 1] -
|
||||
d_cat_tree_segments[tree_idx - tree_begin]);
|
||||
d_out_predictions[out_prediction_idx] +=
|
||||
GetLeafWeight(global_idx, d_tree, d_tree_split_types,
|
||||
tree_cat_ptrs,
|
||||
tree_categories,
|
||||
&loader);
|
||||
GetLeafWeight<has_missing>(global_idx, d_tree, &loader);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -515,7 +515,7 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
DeviceModel const& model,
|
||||
size_t num_features,
|
||||
HostDeviceVector<bst_float>* predictions,
|
||||
size_t batch_offset) const {
|
||||
size_t batch_offset, bool is_dense) const {
|
||||
batch.offset.SetDevice(generic_param_->gpu_id);
|
||||
batch.data.SetDevice(generic_param_->gpu_id);
|
||||
const uint32_t BLOCK_THREADS = 128;
|
||||
@@ -529,16 +529,24 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
size_t entry_start = 0;
|
||||
SparsePageView data(batch.data.DeviceSpan(), batch.offset.DeviceSpan(),
|
||||
num_features);
|
||||
dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS, shared_memory_bytes} (
|
||||
PredictKernel<SparsePageLoader, SparsePageView>, data,
|
||||
model.nodes.ConstDeviceSpan(),
|
||||
predictions->DeviceSpan().subspan(batch_offset),
|
||||
model.tree_segments.ConstDeviceSpan(), model.tree_group.ConstDeviceSpan(),
|
||||
model.split_types.ConstDeviceSpan(),
|
||||
model.categories_tree_segments.ConstDeviceSpan(),
|
||||
model.categories_node_segments.ConstDeviceSpan(),
|
||||
model.categories.ConstDeviceSpan(), model.tree_beg_, model.tree_end_,
|
||||
num_features, num_rows, entry_start, use_shared, model.num_group, nan(""));
|
||||
auto const kernel = [&](auto predict_fn) {
|
||||
dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS, shared_memory_bytes} (
|
||||
predict_fn, data, model.nodes.ConstDeviceSpan(),
|
||||
predictions->DeviceSpan().subspan(batch_offset),
|
||||
model.tree_segments.ConstDeviceSpan(),
|
||||
model.tree_group.ConstDeviceSpan(),
|
||||
model.split_types.ConstDeviceSpan(),
|
||||
model.categories_tree_segments.ConstDeviceSpan(),
|
||||
model.categories_node_segments.ConstDeviceSpan(),
|
||||
model.categories.ConstDeviceSpan(), model.tree_beg_, model.tree_end_,
|
||||
num_features, num_rows, entry_start, use_shared, model.num_group,
|
||||
nan(""));
|
||||
};
|
||||
if (is_dense) {
|
||||
kernel(PredictKernel<SparsePageLoader, SparsePageView, false>);
|
||||
} else {
|
||||
kernel(PredictKernel<SparsePageLoader, SparsePageView, true>);
|
||||
}
|
||||
}
|
||||
void PredictInternal(EllpackDeviceAccessor const& batch,
|
||||
DeviceModel const& model,
|
||||
@@ -578,7 +586,7 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
size_t batch_offset = 0;
|
||||
for (auto &batch : dmat->GetBatches<SparsePage>()) {
|
||||
this->PredictInternal(batch, d_model, model.learner_model_param->num_feature,
|
||||
out_preds, batch_offset);
|
||||
out_preds, batch_offset, dmat->IsDense());
|
||||
batch_offset += batch.Size() * model.learner_model_param->num_output_group;
|
||||
}
|
||||
} else {
|
||||
@@ -846,6 +854,12 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
d_model.nodes.ConstDeviceSpan(),
|
||||
predictions->DeviceSpan().subspan(batch_offset),
|
||||
d_model.tree_segments.ConstDeviceSpan(),
|
||||
|
||||
d_model.split_types.ConstDeviceSpan(),
|
||||
d_model.categories_tree_segments.ConstDeviceSpan(),
|
||||
d_model.categories_node_segments.ConstDeviceSpan(),
|
||||
d_model.categories.ConstDeviceSpan(),
|
||||
|
||||
d_model.tree_beg_, d_model.tree_end_, num_features, num_rows,
|
||||
entry_start, use_shared, nan(""));
|
||||
batch_offset += batch.Size();
|
||||
@@ -862,6 +876,12 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
d_model.nodes.ConstDeviceSpan(),
|
||||
predictions->DeviceSpan().subspan(batch_offset),
|
||||
d_model.tree_segments.ConstDeviceSpan(),
|
||||
|
||||
d_model.split_types.ConstDeviceSpan(),
|
||||
d_model.categories_tree_segments.ConstDeviceSpan(),
|
||||
d_model.categories_node_segments.ConstDeviceSpan(),
|
||||
d_model.categories.ConstDeviceSpan(),
|
||||
|
||||
d_model.tree_beg_, d_model.tree_end_, num_features, num_rows,
|
||||
entry_start, use_shared, nan(""));
|
||||
batch_offset += batch.Size();
|
||||
|
||||
Reference in New Issue
Block a user