Implement categorical prediction for CPU and GPU predict leaf. (#7001)

* Categorical prediction with CPU predictor and GPU predict leaf.

* Implement categorical prediction for CPU prediction.
* Implement categorical prediction for GPU predict leaf.
* Refactor the prediction functions to have a unified get next node function.

Co-authored-by: Shvets Kirill <kirill.shvets@intel.com>
This commit is contained in:
Jiaming Yuan 2021-06-11 10:11:45 +08:00 committed by GitHub
parent 72f9daf9b6
commit f79cc4a7a4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 340 additions and 200 deletions

View File

@ -445,6 +445,10 @@ class RegTree : public Model {
bst_float right_leaf_weight, bst_float loss_change,
float sum_hess, float left_sum, float right_sum);
bool HasCategoricalSplit() const {
return !split_categories_.empty();
}
/*!
* \brief get current depth
* \param nid node id
@ -537,13 +541,6 @@ class RegTree : public Model {
std::vector<Entry> data_;
bool has_missing_;
};
/*!
* \brief get the leaf index
* \param feat dense feature vector, if the feature is missing the field is set to NaN
* \return the leaf index of the given feature
*/
template <bool has_missing = true>
int GetLeafIndex(const FVec& feat) const;
/*!
* \brief calculate the feature contributions (https://arxiv.org/abs/1706.06060) for the tree
@ -582,14 +579,6 @@ class RegTree : public Model {
*/
void CalculateContributionsApprox(const RegTree::FVec& feat,
bst_float* out_contribs) const;
/*!
* \brief get next position of the tree given current pid
* \param pid Current node id.
* \param fvalue feature value if not missing.
* \param is_unknown Whether current required feature is missing.
*/
template <bool has_missing = true>
inline int GetNext(int pid, bst_float fvalue, bool is_unknown) const;
/*!
* \brief dump the model in the requested format as a text string
* \param fmap feature map that may help give interpretations of feature
@ -627,6 +616,20 @@ class RegTree : public Model {
size_t size {0};
};
struct CategoricalSplitMatrix {
common::Span<FeatureType const> split_type;
common::Span<uint32_t const> categories;
common::Span<Segment const> node_ptr;
};
CategoricalSplitMatrix GetCategoriesMatrix() const {
CategoricalSplitMatrix view;
view.split_type = common::Span<FeatureType const>(this->GetSplitTypes());
view.categories = this->GetSplitCategories();
view.node_ptr = common::Span<Segment const>(split_categories_segments_);
return view;
}
private:
void LoadCategoricalSplit(Json const& in);
void SaveCategoricalSplit(Json* p_out) const;
@ -724,38 +727,5 @@ inline bool RegTree::FVec::IsMissing(size_t i) const {
inline bool RegTree::FVec::HasMissing() const {
return has_missing_;
}
template <bool has_missing>
inline int RegTree::GetLeafIndex(const RegTree::FVec& feat) const {
bst_node_t nid = 0;
while (!(*this)[nid].IsLeaf()) {
unsigned split_index = (*this)[nid].SplitIndex();
nid = this->GetNext<has_missing>(nid, feat.GetFvalue(split_index),
has_missing && feat.IsMissing(split_index));
}
return nid;
}
/*! \brief get next position of the tree given current pid */
template <bool has_missing>
inline int RegTree::GetNext(int pid, bst_float fvalue, bool is_unknown) const {
if (has_missing) {
if (is_unknown) {
return (*this)[pid].DefaultChild();
} else {
if (fvalue < (*this)[pid].SplitCond()) {
return (*this)[pid].LeftChild();
} else {
return (*this)[pid].RightChild();
}
}
} else {
// 35% speed up due to reduced miss branch predictions
// The following expression returns the left child if (fvalue < split_cond);
// the right child otherwise.
return (*this)[pid].LeftChild() + !(fvalue < (*this)[pid].SplitCond());
}
}
} // namespace xgboost
#endif // XGBOOST_TREE_MODEL_H_

View File

@ -16,9 +16,11 @@
#include "xgboost/logging.h"
#include "xgboost/host_device_vector.h"
#include "predict_fn.h"
#include "../data/adapter.h"
#include "../common/math.h"
#include "../common/threading_utils.h"
#include "../common/categorical.h"
#include "../gbm/gbtree_model.h"
namespace xgboost {
@ -26,6 +28,19 @@ namespace predictor {
DMLC_REGISTRY_FILE_TAG(cpu_predictor);
template <bool has_missing, bool has_categorical>
bst_node_t GetLeafIndex(RegTree const &tree, const RegTree::FVec &feat,
RegTree::CategoricalSplitMatrix const& cats) {
bst_node_t nid = 0;
while (!tree[nid].IsLeaf()) {
unsigned split_index = tree[nid].SplitIndex();
auto fvalue = feat.GetFvalue(split_index);
nid = GetNextNode<has_missing, has_categorical>(
tree[nid], nid, fvalue, has_missing && feat.IsMissing(split_index), cats);
}
return nid;
}
bst_float PredValue(const SparsePage::Inst &inst,
const std::vector<std::unique_ptr<RegTree>> &trees,
const std::vector<int> &tree_info, int bst_group,
@ -35,32 +50,59 @@ bst_float PredValue(const SparsePage::Inst &inst,
p_feats->Fill(inst);
for (size_t i = tree_begin; i < tree_end; ++i) {
if (tree_info[i] == bst_group) {
int tid = trees[i]->GetLeafIndex(*p_feats);
psum += (*trees[i])[tid].LeafValue();
auto const &tree = *trees[i];
bool has_categorical = tree.HasCategoricalSplit();
auto categories = common::Span<uint32_t const>{tree.GetSplitCategories()};
auto split_types = tree.GetSplitTypes();
auto categories_ptr =
common::Span<RegTree::Segment const>{tree.GetSplitCategoriesPtr()};
auto cats = tree.GetCategoriesMatrix();
bst_node_t nidx = -1;
if (has_categorical) {
nidx = GetLeafIndex<true, true>(tree, *p_feats, cats);
} else {
nidx = GetLeafIndex<true, false>(tree, *p_feats, cats);
}
psum += (*trees[i])[nidx].LeafValue();
}
}
p_feats->Drop(inst);
return psum;
}
inline bst_float PredValueByOneTree(const RegTree::FVec& p_feats,
const std::unique_ptr<RegTree>& tree) {
const int lid = p_feats.HasMissing() ? tree->GetLeafIndex<true>(p_feats) :
tree->GetLeafIndex<false>(p_feats); // 35% speed up
return (*tree)[lid].LeafValue();
template <bool has_categorical>
bst_float
PredValueByOneTree(const RegTree::FVec &p_feats, RegTree const &tree,
RegTree::CategoricalSplitMatrix const& cats) {
const bst_node_t leaf = p_feats.HasMissing() ?
GetLeafIndex<true, has_categorical>(tree, p_feats, cats) :
GetLeafIndex<false, has_categorical>(tree, p_feats, cats);
return tree[leaf].LeafValue();
}
inline void PredictByAllTrees(gbm::GBTreeModel const &model, const size_t tree_begin,
const size_t tree_end, std::vector<bst_float>* out_preds,
void PredictByAllTrees(gbm::GBTreeModel const &model, const size_t tree_begin,
const size_t tree_end, std::vector<bst_float> *out_preds,
const size_t predict_offset, const size_t num_group,
const std::vector<RegTree::FVec> &thread_temp,
const size_t offset, const size_t block_size) {
std::vector<bst_float> &preds = *out_preds;
for (size_t tree_id = tree_begin; tree_id < tree_end; ++tree_id) {
const size_t gid = model.tree_info[tree_id];
auto const &tree = *model.trees[tree_id];
auto const& cats = tree.GetCategoriesMatrix();
auto has_categorical = tree.HasCategoricalSplit();
if (has_categorical) {
for (size_t i = 0; i < block_size; ++i) {
preds[(predict_offset + i) * num_group + gid] += PredValueByOneTree(thread_temp[offset + i],
model.trees[tree_id]);
preds[(predict_offset + i) * num_group + gid] +=
PredValueByOneTree<true>(thread_temp[offset + i], tree, cats);
}
} else {
for (size_t i = 0; i < block_size; ++i) {
preds[(predict_offset + i) * num_group + gid] +=
PredValueByOneTree<false>(thread_temp[offset + i], tree, cats);
}
}
}
}
@ -77,6 +119,7 @@ void FVecFill(const size_t block_size, const size_t batch_offset, const int num_
feats.Fill(inst);
}
}
template <typename DataView>
void FVecDrop(const size_t block_size, const size_t batch_offset, DataView* batch,
const size_t fvec_offset, std::vector<RegTree::FVec>* p_feats) {
@ -145,11 +188,11 @@ class AdapterView {
};
template <typename DataView, size_t block_of_rows_size>
void PredictBatchByBlockOfRowsKernel(DataView batch, std::vector<bst_float> *out_preds,
gbm::GBTreeModel const &model, int32_t tree_begin,
int32_t tree_end,
void PredictBatchByBlockOfRowsKernel(
DataView batch, std::vector<bst_float> *out_preds,
gbm::GBTreeModel const &model, int32_t tree_begin, int32_t tree_end,
std::vector<RegTree::FVec> *p_thread_temp) {
auto& thread_temp = *p_thread_temp;
auto &thread_temp = *p_thread_temp;
int32_t const num_group = model.learner_model_param->num_output_group;
CHECK_EQ(model.param.size_leaf_vector, 0)
@ -157,16 +200,20 @@ void PredictBatchByBlockOfRowsKernel(DataView batch, std::vector<bst_float> *out
// parallel over local batch
const auto nsize = static_cast<bst_omp_uint>(batch.Size());
const int num_feature = model.learner_model_param->num_feature;
const bst_omp_uint n_row_blocks = (nsize) / block_of_rows_size + !!((nsize) % block_of_rows_size);
common::ParallelFor(n_row_blocks, [&](bst_omp_uint block_id) {
omp_ulong n_blocks = common::DivRoundUp(nsize, block_of_rows_size);
common::ParallelFor(n_blocks, [&](bst_omp_uint block_id) {
const size_t batch_offset = block_id * block_of_rows_size;
const size_t block_size = std::min(nsize - batch_offset, block_of_rows_size);
const size_t block_size =
std::min(nsize - batch_offset, block_of_rows_size);
const size_t fvec_offset = omp_get_thread_num() * block_of_rows_size;
FVecFill(block_size, batch_offset, num_feature, &batch, fvec_offset, p_thread_temp);
FVecFill(block_size, batch_offset, num_feature, &batch, fvec_offset,
p_thread_temp);
// process block of rows through all trees to keep cache locality
PredictByAllTrees(model, tree_begin, tree_end, out_preds, batch_offset + batch.base_rowid,
num_group, thread_temp, fvec_offset, block_size);
PredictByAllTrees(model, tree_begin, tree_end, out_preds,
batch_offset + batch.base_rowid, num_group, thread_temp,
fvec_offset, block_size);
FVecDrop(block_size, batch_offset, &batch, fvec_offset, p_thread_temp);
});
}
@ -344,7 +391,9 @@ class CPUPredictor : public Predictor {
}
feats.Fill(page[i]);
for (unsigned j = 0; j < ntree_limit; ++j) {
int tid = model.trees[j]->GetLeafIndex(feats);
auto const& tree = *model.trees[j];
auto const& cats = tree.GetCategoriesMatrix();
bst_node_t tid = GetLeafIndex<true, true>(tree, feats, cats);
preds[ridx * ntree_limit + j] = static_cast<bst_float>(tid);
}
feats.Drop(page[i]);

View File

@ -14,6 +14,7 @@
#include "xgboost/tree_updater.h"
#include "xgboost/host_device_vector.h"
#include "predict_fn.h"
#include "../gbm/gbtree_model.h"
#include "../data/ellpack_page.cuh"
#include "../data/device_adapter.cuh"
@ -27,6 +28,42 @@ namespace predictor {
DMLC_REGISTRY_FILE_TAG(gpu_predictor);
struct TreeView {
RegTree::CategoricalSplitMatrix cats;
common::Span<RegTree::Node const> d_tree;
XGBOOST_DEVICE
TreeView(size_t tree_begin, size_t tree_idx,
common::Span<const RegTree::Node> d_nodes,
common::Span<size_t const> d_tree_segments,
common::Span<FeatureType const> d_tree_split_types,
common::Span<uint32_t const> d_cat_tree_segments,
common::Span<RegTree::Segment const> d_cat_node_segments,
common::Span<uint32_t const> d_categories) {
auto begin = d_tree_segments[tree_idx - tree_begin];
auto n_nodes = d_tree_segments[tree_idx - tree_begin + 1] -
d_tree_segments[tree_idx - tree_begin];
d_tree = d_nodes.subspan(begin, n_nodes);
auto tree_cat_ptrs = d_cat_node_segments.subspan(begin, n_nodes);
auto tree_split_types = d_tree_split_types.subspan(begin, n_nodes);
auto tree_categories =
d_categories.subspan(d_cat_tree_segments[tree_idx - tree_begin],
d_cat_tree_segments[tree_idx - tree_begin + 1] -
d_cat_tree_segments[tree_idx - tree_begin]);
cats.split_type = tree_split_types;
cats.categories = tree_categories;
cats.node_ptr = tree_cat_ptrs;
}
__device__ bool HasCategoricalSplit() const {
return !cats.categories.empty();
}
};
struct SparsePageView {
common::Span<const Entry> d_data;
common::Span<const bst_row_t> d_row_ptr;
@ -178,68 +215,44 @@ struct DeviceAdapterLoader {
}
};
template <typename Loader>
__device__ float GetLeafWeight(bst_row_t ridx, const RegTree::Node* tree,
common::Span<FeatureType const> split_types,
common::Span<RegTree::Segment const> d_cat_ptrs,
common::Span<uint32_t const> d_categories,
Loader* loader) {
template <bool has_missing, bool has_categorical, typename Loader>
__device__ bst_node_t GetLeafIndex(bst_row_t ridx, TreeView const &tree,
Loader *loader) {
bst_node_t nidx = 0;
RegTree::Node n = tree[nidx];
RegTree::Node n = tree.d_tree[nidx];
while (!n.IsLeaf()) {
float fvalue = loader->GetElement(ridx, n.SplitIndex());
// Missing value
if (common::CheckNAN(fvalue)) {
nidx = n.DefaultChild();
} else {
bool go_left = true;
if (common::IsCat(split_types, nidx)) {
auto categories = d_categories.subspan(d_cat_ptrs[nidx].beg,
d_cat_ptrs[nidx].size);
go_left = Decision(categories, common::AsCat(fvalue));
} else {
go_left = fvalue < n.SplitCond();
}
if (go_left) {
nidx = n.LeftChild();
} else {
nidx = n.RightChild();
}
}
n = tree[nidx];
}
return tree[nidx].LeafValue();
}
template <typename Loader>
__device__ bst_node_t GetLeafIndex(bst_row_t ridx, const RegTree::Node* tree,
Loader const& loader) {
bst_node_t nidx = 0;
RegTree::Node n = tree[nidx];
while (!n.IsLeaf()) {
float fvalue = loader.GetElement(ridx, n.SplitIndex());
// Missing value
if (common::CheckNAN(fvalue)) {
nidx = n.DefaultChild();
n = tree[nidx];
} else {
if (fvalue < n.SplitCond()) {
nidx = n.LeftChild();
n = tree[nidx];
} else {
nidx = n.RightChild();
n = tree[nidx];
}
}
bool is_missing = common::CheckNAN(fvalue);
nidx = GetNextNode<has_missing, has_categorical>(n, nidx, fvalue,
is_missing, tree.cats);
n = tree.d_tree[nidx];
}
return nidx;
}
template <bool has_missing, typename Loader>
__device__ float GetLeafWeight(bst_row_t ridx, TreeView const &tree,
Loader *loader) {
bst_node_t nidx = -1;
if (tree.HasCategoricalSplit()) {
nidx = GetLeafIndex<has_missing, true>(ridx, tree, loader);
} else {
nidx = GetLeafIndex<has_missing, false>(ridx, tree, loader);
}
return tree.d_tree[nidx].LeafValue();
}
template <typename Loader, typename Data>
__global__ void PredictLeafKernel(Data data,
common::Span<const RegTree::Node> d_nodes,
__global__ void
PredictLeafKernel(Data data, common::Span<const RegTree::Node> d_nodes,
common::Span<float> d_out_predictions,
common::Span<size_t const> d_tree_segments,
common::Span<FeatureType const> d_tree_split_types,
common::Span<uint32_t const> d_cat_tree_segments,
common::Span<RegTree::Segment const> d_cat_node_segments,
common::Span<uint32_t const> d_categories,
size_t tree_begin, size_t tree_end, size_t num_features,
size_t num_rows, size_t entry_start, bool use_shared,
float missing) {
@ -248,14 +261,23 @@ __global__ void PredictLeafKernel(Data data,
return;
}
Loader loader(data, use_shared, num_features, num_rows, entry_start, missing);
for (int tree_idx = tree_begin; tree_idx < tree_end; ++tree_idx) {
const RegTree::Node* d_tree = &d_nodes[d_tree_segments[tree_idx - tree_begin]];
auto leaf = GetLeafIndex(ridx, d_tree, loader);
for (size_t tree_idx = tree_begin; tree_idx < tree_end; ++tree_idx) {
TreeView d_tree{
tree_begin, tree_idx, d_nodes,
d_tree_segments, d_tree_split_types, d_cat_tree_segments,
d_cat_node_segments, d_categories};
bst_node_t leaf = -1;
if (d_tree.HasCategoricalSplit()) {
leaf = GetLeafIndex<true, true>(ridx, d_tree, &loader);
} else {
leaf = GetLeafIndex<true, false>(ridx, d_tree, &loader);
}
d_out_predictions[ridx * (tree_end - tree_begin) + tree_idx] = leaf;
}
}
template <typename Loader, typename Data>
template <typename Loader, typename Data, bool has_missing = true>
__global__ void
PredictKernel(Data data, common::Span<const RegTree::Node> d_nodes,
common::Span<float> d_out_predictions,
@ -272,47 +294,25 @@ PredictKernel(Data data, common::Span<const RegTree::Node> d_nodes,
if (global_idx >= num_rows) return;
if (num_group == 1) {
float sum = 0;
for (int tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
const RegTree::Node* d_tree =
&d_nodes[d_tree_segments[tree_idx - tree_begin]];
auto tree_cat_ptrs = d_cat_node_segments.subspan(
d_tree_segments[tree_idx - tree_begin],
d_tree_segments[tree_idx - tree_begin + 1] -
d_tree_segments[tree_idx - tree_begin]);
auto tree_categories =
d_categories.subspan(d_cat_tree_segments[tree_idx - tree_begin],
d_cat_tree_segments[tree_idx - tree_begin + 1] -
d_cat_tree_segments[tree_idx - tree_begin]);
auto tree_split_types =
d_tree_split_types.subspan(d_tree_segments[tree_idx - tree_begin],
d_tree_segments[tree_idx - tree_begin + 1] -
d_tree_segments[tree_idx - tree_begin]);
float leaf = GetLeafWeight(global_idx, d_tree, tree_split_types,
tree_cat_ptrs,
tree_categories,
&loader);
for (size_t tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
TreeView d_tree{
tree_begin, tree_idx, d_nodes,
d_tree_segments, d_tree_split_types, d_cat_tree_segments,
d_cat_node_segments, d_categories};
float leaf = GetLeafWeight<has_missing>(global_idx, d_tree, &loader);
sum += leaf;
}
d_out_predictions[global_idx] += sum;
} else {
for (int tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
for (size_t tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
int tree_group = d_tree_group[tree_idx];
const RegTree::Node* d_tree =
&d_nodes[d_tree_segments[tree_idx - tree_begin]];
TreeView d_tree{
tree_begin, tree_idx, d_nodes,
d_tree_segments, d_tree_split_types, d_cat_tree_segments,
d_cat_node_segments, d_categories};
bst_uint out_prediction_idx = global_idx * num_group + tree_group;
auto tree_cat_ptrs = d_cat_node_segments.subspan(
d_tree_segments[tree_idx - tree_begin],
d_tree_segments[tree_idx - tree_begin + 1] -
d_tree_segments[tree_idx - tree_begin]);
auto tree_categories =
d_categories.subspan(d_cat_tree_segments[tree_idx - tree_begin],
d_cat_tree_segments[tree_idx - tree_begin + 1] -
d_cat_tree_segments[tree_idx - tree_begin]);
d_out_predictions[out_prediction_idx] +=
GetLeafWeight(global_idx, d_tree, d_tree_split_types,
tree_cat_ptrs,
tree_categories,
&loader);
GetLeafWeight<has_missing>(global_idx, d_tree, &loader);
}
}
}
@ -515,7 +515,7 @@ class GPUPredictor : public xgboost::Predictor {
DeviceModel const& model,
size_t num_features,
HostDeviceVector<bst_float>* predictions,
size_t batch_offset) const {
size_t batch_offset, bool is_dense) const {
batch.offset.SetDevice(generic_param_->gpu_id);
batch.data.SetDevice(generic_param_->gpu_id);
const uint32_t BLOCK_THREADS = 128;
@ -529,16 +529,24 @@ class GPUPredictor : public xgboost::Predictor {
size_t entry_start = 0;
SparsePageView data(batch.data.DeviceSpan(), batch.offset.DeviceSpan(),
num_features);
auto const kernel = [&](auto predict_fn) {
dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS, shared_memory_bytes} (
PredictKernel<SparsePageLoader, SparsePageView>, data,
model.nodes.ConstDeviceSpan(),
predict_fn, data, model.nodes.ConstDeviceSpan(),
predictions->DeviceSpan().subspan(batch_offset),
model.tree_segments.ConstDeviceSpan(), model.tree_group.ConstDeviceSpan(),
model.tree_segments.ConstDeviceSpan(),
model.tree_group.ConstDeviceSpan(),
model.split_types.ConstDeviceSpan(),
model.categories_tree_segments.ConstDeviceSpan(),
model.categories_node_segments.ConstDeviceSpan(),
model.categories.ConstDeviceSpan(), model.tree_beg_, model.tree_end_,
num_features, num_rows, entry_start, use_shared, model.num_group, nan(""));
num_features, num_rows, entry_start, use_shared, model.num_group,
nan(""));
};
if (is_dense) {
kernel(PredictKernel<SparsePageLoader, SparsePageView, false>);
} else {
kernel(PredictKernel<SparsePageLoader, SparsePageView, true>);
}
}
void PredictInternal(EllpackDeviceAccessor const& batch,
DeviceModel const& model,
@ -578,7 +586,7 @@ class GPUPredictor : public xgboost::Predictor {
size_t batch_offset = 0;
for (auto &batch : dmat->GetBatches<SparsePage>()) {
this->PredictInternal(batch, d_model, model.learner_model_param->num_feature,
out_preds, batch_offset);
out_preds, batch_offset, dmat->IsDense());
batch_offset += batch.Size() * model.learner_model_param->num_output_group;
}
} else {
@ -846,6 +854,12 @@ class GPUPredictor : public xgboost::Predictor {
d_model.nodes.ConstDeviceSpan(),
predictions->DeviceSpan().subspan(batch_offset),
d_model.tree_segments.ConstDeviceSpan(),
d_model.split_types.ConstDeviceSpan(),
d_model.categories_tree_segments.ConstDeviceSpan(),
d_model.categories_node_segments.ConstDeviceSpan(),
d_model.categories.ConstDeviceSpan(),
d_model.tree_beg_, d_model.tree_end_, num_features, num_rows,
entry_start, use_shared, nan(""));
batch_offset += batch.Size();
@ -862,6 +876,12 @@ class GPUPredictor : public xgboost::Predictor {
d_model.nodes.ConstDeviceSpan(),
predictions->DeviceSpan().subspan(batch_offset),
d_model.tree_segments.ConstDeviceSpan(),
d_model.split_types.ConstDeviceSpan(),
d_model.categories_tree_segments.ConstDeviceSpan(),
d_model.categories_node_segments.ConstDeviceSpan(),
d_model.categories.ConstDeviceSpan(),
d_model.tree_beg_, d_model.tree_end_, num_features, num_rows,
entry_start, use_shared, nan(""));
batch_offset += batch.Size();

View File

@ -0,0 +1,31 @@
/*!
* Copyright 2021 by XGBoost Contributors
*/
#ifndef XGBOOST_PREDICTOR_PREDICT_FN_H_
#define XGBOOST_PREDICTOR_PREDICT_FN_H_
#include "../common/categorical.h"
#include "xgboost/tree_model.h"
namespace xgboost {
namespace predictor {
template <bool has_missing, bool has_categorical>
inline XGBOOST_DEVICE bst_node_t
GetNextNode(const RegTree::Node &node, const bst_node_t nid, float fvalue,
bool is_missing, RegTree::CategoricalSplitMatrix const &cats) {
if (has_missing && is_missing) {
return node.DefaultChild();
} else {
if (has_categorical && common::IsCat(cats.split_type, nid)) {
auto node_categories = cats.categories.subspan(cats.node_ptr[nid].beg,
cats.node_ptr[nid].size);
return Decision(node_categories, common::AsCat(fvalue))
? node.LeftChild()
: node.RightChild();
} else {
return node.LeftChild() + !(fvalue < node.SplitCond());
}
}
}
} // namespace predictor
} // namespace xgboost
#endif // XGBOOST_PREDICTOR_PREDICT_FN_H_

View File

@ -19,6 +19,7 @@
#include "param.h"
#include "../common/common.h"
#include "../common/categorical.h"
#include "../predictor/predict_fn.h"
namespace xgboost {
// register tree parameter
@ -1052,10 +1053,15 @@ void RegTree::CalculateContributionsApprox(const RegTree::FVec &feat,
// nothing to do anymore
return;
}
bst_node_t nid = 0;
auto cats = this->GetCategoriesMatrix();
while (!(*this)[nid].IsLeaf()) {
split_index = (*this)[nid].SplitIndex();
nid = this->GetNext(nid, feat.GetFvalue(split_index), feat.IsMissing(split_index));
nid = predictor::GetNextNode<true, true>((*this)[nid], nid,
feat.GetFvalue(split_index),
feat.IsMissing(split_index), cats);
bst_float new_value = this->node_mean_values_[nid];
// update feature weight
out_contribs[split_index] += new_value - node_value;

View File

@ -14,6 +14,7 @@
#include "./param.h"
#include "../common/io.h"
#include "../common/threading_utils.h"
#include "../predictor/predict_fn.h"
namespace xgboost {
namespace tree {
@ -123,10 +124,13 @@ class TreeRefresher: public TreeUpdater {
// start from groups that belongs to current data
auto pid = 0;
gstats[pid].Add(gpair[ridx]);
auto const& cats = tree.GetCategoriesMatrix();
// traverse tree
while (!tree[pid].IsLeaf()) {
unsigned split_index = tree[pid].SplitIndex();
pid = tree.GetNext(pid, feat.GetFvalue(split_index), feat.IsMissing(split_index));
pid = predictor::GetNextNode<true, true>(
tree[pid], pid, feat.GetFvalue(split_index), feat.IsMissing(split_index),
cats);
gstats[pid].Add(gpair[ridx]);
}
}

View File

@ -229,6 +229,14 @@ void TestUpdatePredictionCache(bool use_subsampling) {
}
}
TEST(CPUPredictor, CategoricalPrediction) {
TestCategoricalPrediction("cpu_predictor");
}
TEST(CPUPredictor, CategoricalPredictLeaf) {
TestCategoricalPredictLeaf(StringView{"cpu_predictor"});
}
TEST(CpuPredictor, UpdatePredictionCache) {
TestUpdatePredictionCache(false);
TestUpdatePredictionCache(true);

View File

@ -228,6 +228,10 @@ TEST(GPUPredictor, CategoricalPrediction) {
TestCategoricalPrediction("gpu_predictor");
}
TEST(GPUPredictor, CategoricalPredictLeaf) {
TestCategoricalPredictLeaf(StringView{"gpu_predictor"});
}
TEST(GPUPredictor, PredictLeafBasic) {
size_t constexpr kRows = 5, kCols = 5;
auto dmat = RandomDataGenerator(kRows, kCols, 0).Device(0).GenerateDMatrix();

View File

@ -180,6 +180,25 @@ void TestPredictionWithLesserFeatures(std::string predictor_name) {
#endif // defined(XGBOOST_USE_CUDA)
}
void GBTreeModelForTest(gbm::GBTreeModel *model, uint32_t split_ind,
bst_cat_t split_cat, float left_weight,
float right_weight) {
PredictionCacheEntry out_predictions;
std::vector<std::unique_ptr<RegTree>> trees;
trees.push_back(std::unique_ptr<RegTree>(new RegTree));
auto& p_tree = trees.front();
std::vector<uint32_t> split_cats(LBitField32::ComputeStorageSize(split_cat));
LBitField32 cats_bits(split_cats);
cats_bits.Set(split_cat);
p_tree->ExpandCategorical(0, split_ind, split_cats, true, 1.5f,
left_weight, right_weight,
3.0f, 2.2f, 7.0f, 9.0f);
model->CommitModel(std::move(trees), 0);
}
void TestCategoricalPrediction(std::string name) {
size_t constexpr kCols = 10;
PredictionCacheEntry out_predictions;
@ -189,25 +208,13 @@ void TestCategoricalPrediction(std::string name) {
param.num_output_group = 1;
param.base_score = 0.5;
gbm::GBTreeModel model(&param);
std::vector<std::unique_ptr<RegTree>> trees;
trees.push_back(std::unique_ptr<RegTree>(new RegTree));
auto& p_tree = trees.front();
uint32_t split_ind = 3;
bst_cat_t split_cat = 4;
float left_weight = 1.3f;
float right_weight = 1.7f;
std::vector<uint32_t> split_cats(LBitField32::ComputeStorageSize(split_cat));
LBitField32 cats_bits(split_cats);
cats_bits.Set(split_cat);
p_tree->ExpandCategorical(0, split_ind, split_cats, true, 1.5f,
left_weight, right_weight,
3.0f, 2.2f, 7.0f, 9.0f);
model.CommitModel(std::move(trees), 0);
gbm::GBTreeModel model(&param);
GBTreeModelForTest(&model, split_ind, split_cat, left_weight, right_weight);
GenericParameter runtime;
runtime.gpu_id = 0;
@ -232,4 +239,43 @@ void TestCategoricalPrediction(std::string name) {
ASSERT_EQ(out_predictions.predictions.HostVector()[0],
left_weight + param.base_score);
}
void TestCategoricalPredictLeaf(StringView name) {
size_t constexpr kCols = 10;
PredictionCacheEntry out_predictions;
LearnerModelParam param;
param.num_feature = kCols;
param.num_output_group = 1;
param.base_score = 0.5;
uint32_t split_ind = 3;
bst_cat_t split_cat = 4;
float left_weight = 1.3f;
float right_weight = 1.7f;
gbm::GBTreeModel model(&param);
GBTreeModelForTest(&model, split_ind, split_cat, left_weight, right_weight);
GenericParameter runtime;
runtime.gpu_id = 0;
std::unique_ptr<Predictor> predictor{
Predictor::Create(name.c_str(), &runtime)};
std::vector<float> row(kCols);
row[split_ind] = split_cat;
auto m = GetDMatrixFromData(row, 1, kCols);
predictor->PredictLeaf(m.get(), &out_predictions.predictions, model);
CHECK_EQ(out_predictions.predictions.Size(), 1);
// go to left if it doesn't match the category, otherwise right.
ASSERT_EQ(out_predictions.predictions.HostVector()[0], 2);
row[split_ind] = split_cat + 1;
m = GetDMatrixFromData(row, 1, kCols);
out_predictions.version = 0;
predictor->InitOutPredictions(m->Info(), &out_predictions.predictions, model);
predictor->PredictLeaf(m.get(), &out_predictions.predictions, model);
ASSERT_EQ(out_predictions.predictions.HostVector()[0], 1);
}
} // namespace xgboost

View File

@ -66,6 +66,8 @@ void TestInplacePrediction(dmlc::any x, std::string predictor,
void TestPredictionWithLesserFeatures(std::string preditor_name);
void TestCategoricalPrediction(std::string name);
void TestCategoricalPredictLeaf(StringView name);
} // namespace xgboost
#endif // XGBOOST_TEST_PREDICTOR_H_