Implement categorical prediction for CPU and GPU predict leaf. (#7001)
* Categorical prediction with CPU predictor and GPU predict leaf. * Implement categorical prediction for CPU prediction. * Implement categorical prediction for GPU predict leaf. * Refactor the prediction functions to have a unified get next node function. Co-authored-by: Shvets Kirill <kirill.shvets@intel.com>
This commit is contained in:
parent
72f9daf9b6
commit
f79cc4a7a4
@ -445,6 +445,10 @@ class RegTree : public Model {
|
||||
bst_float right_leaf_weight, bst_float loss_change,
|
||||
float sum_hess, float left_sum, float right_sum);
|
||||
|
||||
bool HasCategoricalSplit() const {
|
||||
return !split_categories_.empty();
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief get current depth
|
||||
* \param nid node id
|
||||
@ -537,13 +541,6 @@ class RegTree : public Model {
|
||||
std::vector<Entry> data_;
|
||||
bool has_missing_;
|
||||
};
|
||||
/*!
|
||||
* \brief get the leaf index
|
||||
* \param feat dense feature vector, if the feature is missing the field is set to NaN
|
||||
* \return the leaf index of the given feature
|
||||
*/
|
||||
template <bool has_missing = true>
|
||||
int GetLeafIndex(const FVec& feat) const;
|
||||
|
||||
/*!
|
||||
* \brief calculate the feature contributions (https://arxiv.org/abs/1706.06060) for the tree
|
||||
@ -582,14 +579,6 @@ class RegTree : public Model {
|
||||
*/
|
||||
void CalculateContributionsApprox(const RegTree::FVec& feat,
|
||||
bst_float* out_contribs) const;
|
||||
/*!
|
||||
* \brief get next position of the tree given current pid
|
||||
* \param pid Current node id.
|
||||
* \param fvalue feature value if not missing.
|
||||
* \param is_unknown Whether current required feature is missing.
|
||||
*/
|
||||
template <bool has_missing = true>
|
||||
inline int GetNext(int pid, bst_float fvalue, bool is_unknown) const;
|
||||
/*!
|
||||
* \brief dump the model in the requested format as a text string
|
||||
* \param fmap feature map that may help give interpretations of feature
|
||||
@ -627,6 +616,20 @@ class RegTree : public Model {
|
||||
size_t size {0};
|
||||
};
|
||||
|
||||
struct CategoricalSplitMatrix {
|
||||
common::Span<FeatureType const> split_type;
|
||||
common::Span<uint32_t const> categories;
|
||||
common::Span<Segment const> node_ptr;
|
||||
};
|
||||
|
||||
CategoricalSplitMatrix GetCategoriesMatrix() const {
|
||||
CategoricalSplitMatrix view;
|
||||
view.split_type = common::Span<FeatureType const>(this->GetSplitTypes());
|
||||
view.categories = this->GetSplitCategories();
|
||||
view.node_ptr = common::Span<Segment const>(split_categories_segments_);
|
||||
return view;
|
||||
}
|
||||
|
||||
private:
|
||||
void LoadCategoricalSplit(Json const& in);
|
||||
void SaveCategoricalSplit(Json* p_out) const;
|
||||
@ -724,38 +727,5 @@ inline bool RegTree::FVec::IsMissing(size_t i) const {
|
||||
inline bool RegTree::FVec::HasMissing() const {
|
||||
return has_missing_;
|
||||
}
|
||||
|
||||
template <bool has_missing>
|
||||
inline int RegTree::GetLeafIndex(const RegTree::FVec& feat) const {
|
||||
bst_node_t nid = 0;
|
||||
while (!(*this)[nid].IsLeaf()) {
|
||||
unsigned split_index = (*this)[nid].SplitIndex();
|
||||
nid = this->GetNext<has_missing>(nid, feat.GetFvalue(split_index),
|
||||
has_missing && feat.IsMissing(split_index));
|
||||
}
|
||||
return nid;
|
||||
}
|
||||
|
||||
/*! \brief get next position of the tree given current pid */
|
||||
template <bool has_missing>
|
||||
inline int RegTree::GetNext(int pid, bst_float fvalue, bool is_unknown) const {
|
||||
if (has_missing) {
|
||||
if (is_unknown) {
|
||||
return (*this)[pid].DefaultChild();
|
||||
} else {
|
||||
if (fvalue < (*this)[pid].SplitCond()) {
|
||||
return (*this)[pid].LeftChild();
|
||||
} else {
|
||||
return (*this)[pid].RightChild();
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// 35% speed up due to reduced miss branch predictions
|
||||
// The following expression returns the left child if (fvalue < split_cond);
|
||||
// the right child otherwise.
|
||||
return (*this)[pid].LeftChild() + !(fvalue < (*this)[pid].SplitCond());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_TREE_MODEL_H_
|
||||
|
||||
@ -16,9 +16,11 @@
|
||||
#include "xgboost/logging.h"
|
||||
#include "xgboost/host_device_vector.h"
|
||||
|
||||
#include "predict_fn.h"
|
||||
#include "../data/adapter.h"
|
||||
#include "../common/math.h"
|
||||
#include "../common/threading_utils.h"
|
||||
#include "../common/categorical.h"
|
||||
#include "../gbm/gbtree_model.h"
|
||||
|
||||
namespace xgboost {
|
||||
@ -26,6 +28,19 @@ namespace predictor {
|
||||
|
||||
DMLC_REGISTRY_FILE_TAG(cpu_predictor);
|
||||
|
||||
template <bool has_missing, bool has_categorical>
|
||||
bst_node_t GetLeafIndex(RegTree const &tree, const RegTree::FVec &feat,
|
||||
RegTree::CategoricalSplitMatrix const& cats) {
|
||||
bst_node_t nid = 0;
|
||||
while (!tree[nid].IsLeaf()) {
|
||||
unsigned split_index = tree[nid].SplitIndex();
|
||||
auto fvalue = feat.GetFvalue(split_index);
|
||||
nid = GetNextNode<has_missing, has_categorical>(
|
||||
tree[nid], nid, fvalue, has_missing && feat.IsMissing(split_index), cats);
|
||||
}
|
||||
return nid;
|
||||
}
|
||||
|
||||
bst_float PredValue(const SparsePage::Inst &inst,
|
||||
const std::vector<std::unique_ptr<RegTree>> &trees,
|
||||
const std::vector<int> &tree_info, int bst_group,
|
||||
@ -35,32 +50,59 @@ bst_float PredValue(const SparsePage::Inst &inst,
|
||||
p_feats->Fill(inst);
|
||||
for (size_t i = tree_begin; i < tree_end; ++i) {
|
||||
if (tree_info[i] == bst_group) {
|
||||
int tid = trees[i]->GetLeafIndex(*p_feats);
|
||||
psum += (*trees[i])[tid].LeafValue();
|
||||
auto const &tree = *trees[i];
|
||||
bool has_categorical = tree.HasCategoricalSplit();
|
||||
|
||||
auto categories = common::Span<uint32_t const>{tree.GetSplitCategories()};
|
||||
auto split_types = tree.GetSplitTypes();
|
||||
auto categories_ptr =
|
||||
common::Span<RegTree::Segment const>{tree.GetSplitCategoriesPtr()};
|
||||
auto cats = tree.GetCategoriesMatrix();
|
||||
bst_node_t nidx = -1;
|
||||
if (has_categorical) {
|
||||
nidx = GetLeafIndex<true, true>(tree, *p_feats, cats);
|
||||
} else {
|
||||
nidx = GetLeafIndex<true, false>(tree, *p_feats, cats);
|
||||
}
|
||||
psum += (*trees[i])[nidx].LeafValue();
|
||||
}
|
||||
}
|
||||
p_feats->Drop(inst);
|
||||
return psum;
|
||||
}
|
||||
|
||||
inline bst_float PredValueByOneTree(const RegTree::FVec& p_feats,
|
||||
const std::unique_ptr<RegTree>& tree) {
|
||||
const int lid = p_feats.HasMissing() ? tree->GetLeafIndex<true>(p_feats) :
|
||||
tree->GetLeafIndex<false>(p_feats); // 35% speed up
|
||||
return (*tree)[lid].LeafValue();
|
||||
template <bool has_categorical>
|
||||
bst_float
|
||||
PredValueByOneTree(const RegTree::FVec &p_feats, RegTree const &tree,
|
||||
RegTree::CategoricalSplitMatrix const& cats) {
|
||||
const bst_node_t leaf = p_feats.HasMissing() ?
|
||||
GetLeafIndex<true, has_categorical>(tree, p_feats, cats) :
|
||||
GetLeafIndex<false, has_categorical>(tree, p_feats, cats);
|
||||
return tree[leaf].LeafValue();
|
||||
}
|
||||
|
||||
inline void PredictByAllTrees(gbm::GBTreeModel const &model, const size_t tree_begin,
|
||||
const size_t tree_end, std::vector<bst_float>* out_preds,
|
||||
void PredictByAllTrees(gbm::GBTreeModel const &model, const size_t tree_begin,
|
||||
const size_t tree_end, std::vector<bst_float> *out_preds,
|
||||
const size_t predict_offset, const size_t num_group,
|
||||
const std::vector<RegTree::FVec> &thread_temp,
|
||||
const size_t offset, const size_t block_size) {
|
||||
std::vector<bst_float> &preds = *out_preds;
|
||||
for (size_t tree_id = tree_begin; tree_id < tree_end; ++tree_id) {
|
||||
const size_t gid = model.tree_info[tree_id];
|
||||
auto const &tree = *model.trees[tree_id];
|
||||
auto const& cats = tree.GetCategoriesMatrix();
|
||||
auto has_categorical = tree.HasCategoricalSplit();
|
||||
|
||||
if (has_categorical) {
|
||||
for (size_t i = 0; i < block_size; ++i) {
|
||||
preds[(predict_offset + i) * num_group + gid] += PredValueByOneTree(thread_temp[offset + i],
|
||||
model.trees[tree_id]);
|
||||
preds[(predict_offset + i) * num_group + gid] +=
|
||||
PredValueByOneTree<true>(thread_temp[offset + i], tree, cats);
|
||||
}
|
||||
} else {
|
||||
for (size_t i = 0; i < block_size; ++i) {
|
||||
preds[(predict_offset + i) * num_group + gid] +=
|
||||
PredValueByOneTree<false>(thread_temp[offset + i], tree, cats);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -77,6 +119,7 @@ void FVecFill(const size_t block_size, const size_t batch_offset, const int num_
|
||||
feats.Fill(inst);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DataView>
|
||||
void FVecDrop(const size_t block_size, const size_t batch_offset, DataView* batch,
|
||||
const size_t fvec_offset, std::vector<RegTree::FVec>* p_feats) {
|
||||
@ -145,11 +188,11 @@ class AdapterView {
|
||||
};
|
||||
|
||||
template <typename DataView, size_t block_of_rows_size>
|
||||
void PredictBatchByBlockOfRowsKernel(DataView batch, std::vector<bst_float> *out_preds,
|
||||
gbm::GBTreeModel const &model, int32_t tree_begin,
|
||||
int32_t tree_end,
|
||||
void PredictBatchByBlockOfRowsKernel(
|
||||
DataView batch, std::vector<bst_float> *out_preds,
|
||||
gbm::GBTreeModel const &model, int32_t tree_begin, int32_t tree_end,
|
||||
std::vector<RegTree::FVec> *p_thread_temp) {
|
||||
auto& thread_temp = *p_thread_temp;
|
||||
auto &thread_temp = *p_thread_temp;
|
||||
int32_t const num_group = model.learner_model_param->num_output_group;
|
||||
|
||||
CHECK_EQ(model.param.size_leaf_vector, 0)
|
||||
@ -157,16 +200,20 @@ void PredictBatchByBlockOfRowsKernel(DataView batch, std::vector<bst_float> *out
|
||||
// parallel over local batch
|
||||
const auto nsize = static_cast<bst_omp_uint>(batch.Size());
|
||||
const int num_feature = model.learner_model_param->num_feature;
|
||||
const bst_omp_uint n_row_blocks = (nsize) / block_of_rows_size + !!((nsize) % block_of_rows_size);
|
||||
common::ParallelFor(n_row_blocks, [&](bst_omp_uint block_id) {
|
||||
omp_ulong n_blocks = common::DivRoundUp(nsize, block_of_rows_size);
|
||||
|
||||
common::ParallelFor(n_blocks, [&](bst_omp_uint block_id) {
|
||||
const size_t batch_offset = block_id * block_of_rows_size;
|
||||
const size_t block_size = std::min(nsize - batch_offset, block_of_rows_size);
|
||||
const size_t block_size =
|
||||
std::min(nsize - batch_offset, block_of_rows_size);
|
||||
const size_t fvec_offset = omp_get_thread_num() * block_of_rows_size;
|
||||
|
||||
FVecFill(block_size, batch_offset, num_feature, &batch, fvec_offset, p_thread_temp);
|
||||
FVecFill(block_size, batch_offset, num_feature, &batch, fvec_offset,
|
||||
p_thread_temp);
|
||||
// process block of rows through all trees to keep cache locality
|
||||
PredictByAllTrees(model, tree_begin, tree_end, out_preds, batch_offset + batch.base_rowid,
|
||||
num_group, thread_temp, fvec_offset, block_size);
|
||||
PredictByAllTrees(model, tree_begin, tree_end, out_preds,
|
||||
batch_offset + batch.base_rowid, num_group, thread_temp,
|
||||
fvec_offset, block_size);
|
||||
FVecDrop(block_size, batch_offset, &batch, fvec_offset, p_thread_temp);
|
||||
});
|
||||
}
|
||||
@ -344,7 +391,9 @@ class CPUPredictor : public Predictor {
|
||||
}
|
||||
feats.Fill(page[i]);
|
||||
for (unsigned j = 0; j < ntree_limit; ++j) {
|
||||
int tid = model.trees[j]->GetLeafIndex(feats);
|
||||
auto const& tree = *model.trees[j];
|
||||
auto const& cats = tree.GetCategoriesMatrix();
|
||||
bst_node_t tid = GetLeafIndex<true, true>(tree, feats, cats);
|
||||
preds[ridx * ntree_limit + j] = static_cast<bst_float>(tid);
|
||||
}
|
||||
feats.Drop(page[i]);
|
||||
|
||||
@ -14,6 +14,7 @@
|
||||
#include "xgboost/tree_updater.h"
|
||||
#include "xgboost/host_device_vector.h"
|
||||
|
||||
#include "predict_fn.h"
|
||||
#include "../gbm/gbtree_model.h"
|
||||
#include "../data/ellpack_page.cuh"
|
||||
#include "../data/device_adapter.cuh"
|
||||
@ -27,6 +28,42 @@ namespace predictor {
|
||||
|
||||
DMLC_REGISTRY_FILE_TAG(gpu_predictor);
|
||||
|
||||
struct TreeView {
|
||||
RegTree::CategoricalSplitMatrix cats;
|
||||
common::Span<RegTree::Node const> d_tree;
|
||||
|
||||
XGBOOST_DEVICE
|
||||
TreeView(size_t tree_begin, size_t tree_idx,
|
||||
common::Span<const RegTree::Node> d_nodes,
|
||||
common::Span<size_t const> d_tree_segments,
|
||||
common::Span<FeatureType const> d_tree_split_types,
|
||||
common::Span<uint32_t const> d_cat_tree_segments,
|
||||
common::Span<RegTree::Segment const> d_cat_node_segments,
|
||||
common::Span<uint32_t const> d_categories) {
|
||||
auto begin = d_tree_segments[tree_idx - tree_begin];
|
||||
auto n_nodes = d_tree_segments[tree_idx - tree_begin + 1] -
|
||||
d_tree_segments[tree_idx - tree_begin];
|
||||
|
||||
d_tree = d_nodes.subspan(begin, n_nodes);
|
||||
|
||||
auto tree_cat_ptrs = d_cat_node_segments.subspan(begin, n_nodes);
|
||||
auto tree_split_types = d_tree_split_types.subspan(begin, n_nodes);
|
||||
|
||||
auto tree_categories =
|
||||
d_categories.subspan(d_cat_tree_segments[tree_idx - tree_begin],
|
||||
d_cat_tree_segments[tree_idx - tree_begin + 1] -
|
||||
d_cat_tree_segments[tree_idx - tree_begin]);
|
||||
|
||||
cats.split_type = tree_split_types;
|
||||
cats.categories = tree_categories;
|
||||
cats.node_ptr = tree_cat_ptrs;
|
||||
}
|
||||
|
||||
__device__ bool HasCategoricalSplit() const {
|
||||
return !cats.categories.empty();
|
||||
}
|
||||
};
|
||||
|
||||
struct SparsePageView {
|
||||
common::Span<const Entry> d_data;
|
||||
common::Span<const bst_row_t> d_row_ptr;
|
||||
@ -178,68 +215,44 @@ struct DeviceAdapterLoader {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Loader>
|
||||
__device__ float GetLeafWeight(bst_row_t ridx, const RegTree::Node* tree,
|
||||
common::Span<FeatureType const> split_types,
|
||||
common::Span<RegTree::Segment const> d_cat_ptrs,
|
||||
common::Span<uint32_t const> d_categories,
|
||||
Loader* loader) {
|
||||
template <bool has_missing, bool has_categorical, typename Loader>
|
||||
__device__ bst_node_t GetLeafIndex(bst_row_t ridx, TreeView const &tree,
|
||||
Loader *loader) {
|
||||
bst_node_t nidx = 0;
|
||||
RegTree::Node n = tree[nidx];
|
||||
RegTree::Node n = tree.d_tree[nidx];
|
||||
while (!n.IsLeaf()) {
|
||||
float fvalue = loader->GetElement(ridx, n.SplitIndex());
|
||||
// Missing value
|
||||
if (common::CheckNAN(fvalue)) {
|
||||
nidx = n.DefaultChild();
|
||||
} else {
|
||||
bool go_left = true;
|
||||
if (common::IsCat(split_types, nidx)) {
|
||||
auto categories = d_categories.subspan(d_cat_ptrs[nidx].beg,
|
||||
d_cat_ptrs[nidx].size);
|
||||
go_left = Decision(categories, common::AsCat(fvalue));
|
||||
} else {
|
||||
go_left = fvalue < n.SplitCond();
|
||||
}
|
||||
if (go_left) {
|
||||
nidx = n.LeftChild();
|
||||
} else {
|
||||
nidx = n.RightChild();
|
||||
}
|
||||
}
|
||||
n = tree[nidx];
|
||||
}
|
||||
return tree[nidx].LeafValue();
|
||||
}
|
||||
|
||||
template <typename Loader>
|
||||
__device__ bst_node_t GetLeafIndex(bst_row_t ridx, const RegTree::Node* tree,
|
||||
Loader const& loader) {
|
||||
bst_node_t nidx = 0;
|
||||
RegTree::Node n = tree[nidx];
|
||||
while (!n.IsLeaf()) {
|
||||
float fvalue = loader.GetElement(ridx, n.SplitIndex());
|
||||
// Missing value
|
||||
if (common::CheckNAN(fvalue)) {
|
||||
nidx = n.DefaultChild();
|
||||
n = tree[nidx];
|
||||
} else {
|
||||
if (fvalue < n.SplitCond()) {
|
||||
nidx = n.LeftChild();
|
||||
n = tree[nidx];
|
||||
} else {
|
||||
nidx = n.RightChild();
|
||||
n = tree[nidx];
|
||||
}
|
||||
}
|
||||
bool is_missing = common::CheckNAN(fvalue);
|
||||
nidx = GetNextNode<has_missing, has_categorical>(n, nidx, fvalue,
|
||||
is_missing, tree.cats);
|
||||
n = tree.d_tree[nidx];
|
||||
}
|
||||
return nidx;
|
||||
}
|
||||
|
||||
template <bool has_missing, typename Loader>
|
||||
__device__ float GetLeafWeight(bst_row_t ridx, TreeView const &tree,
|
||||
Loader *loader) {
|
||||
bst_node_t nidx = -1;
|
||||
if (tree.HasCategoricalSplit()) {
|
||||
nidx = GetLeafIndex<has_missing, true>(ridx, tree, loader);
|
||||
} else {
|
||||
nidx = GetLeafIndex<has_missing, false>(ridx, tree, loader);
|
||||
}
|
||||
return tree.d_tree[nidx].LeafValue();
|
||||
}
|
||||
|
||||
template <typename Loader, typename Data>
|
||||
__global__ void PredictLeafKernel(Data data,
|
||||
common::Span<const RegTree::Node> d_nodes,
|
||||
__global__ void
|
||||
PredictLeafKernel(Data data, common::Span<const RegTree::Node> d_nodes,
|
||||
common::Span<float> d_out_predictions,
|
||||
common::Span<size_t const> d_tree_segments,
|
||||
|
||||
common::Span<FeatureType const> d_tree_split_types,
|
||||
common::Span<uint32_t const> d_cat_tree_segments,
|
||||
common::Span<RegTree::Segment const> d_cat_node_segments,
|
||||
common::Span<uint32_t const> d_categories,
|
||||
|
||||
size_t tree_begin, size_t tree_end, size_t num_features,
|
||||
size_t num_rows, size_t entry_start, bool use_shared,
|
||||
float missing) {
|
||||
@ -248,14 +261,23 @@ __global__ void PredictLeafKernel(Data data,
|
||||
return;
|
||||
}
|
||||
Loader loader(data, use_shared, num_features, num_rows, entry_start, missing);
|
||||
for (int tree_idx = tree_begin; tree_idx < tree_end; ++tree_idx) {
|
||||
const RegTree::Node* d_tree = &d_nodes[d_tree_segments[tree_idx - tree_begin]];
|
||||
auto leaf = GetLeafIndex(ridx, d_tree, loader);
|
||||
for (size_t tree_idx = tree_begin; tree_idx < tree_end; ++tree_idx) {
|
||||
TreeView d_tree{
|
||||
tree_begin, tree_idx, d_nodes,
|
||||
d_tree_segments, d_tree_split_types, d_cat_tree_segments,
|
||||
d_cat_node_segments, d_categories};
|
||||
|
||||
bst_node_t leaf = -1;
|
||||
if (d_tree.HasCategoricalSplit()) {
|
||||
leaf = GetLeafIndex<true, true>(ridx, d_tree, &loader);
|
||||
} else {
|
||||
leaf = GetLeafIndex<true, false>(ridx, d_tree, &loader);
|
||||
}
|
||||
d_out_predictions[ridx * (tree_end - tree_begin) + tree_idx] = leaf;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Loader, typename Data>
|
||||
template <typename Loader, typename Data, bool has_missing = true>
|
||||
__global__ void
|
||||
PredictKernel(Data data, common::Span<const RegTree::Node> d_nodes,
|
||||
common::Span<float> d_out_predictions,
|
||||
@ -272,47 +294,25 @@ PredictKernel(Data data, common::Span<const RegTree::Node> d_nodes,
|
||||
if (global_idx >= num_rows) return;
|
||||
if (num_group == 1) {
|
||||
float sum = 0;
|
||||
for (int tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
|
||||
const RegTree::Node* d_tree =
|
||||
&d_nodes[d_tree_segments[tree_idx - tree_begin]];
|
||||
auto tree_cat_ptrs = d_cat_node_segments.subspan(
|
||||
d_tree_segments[tree_idx - tree_begin],
|
||||
d_tree_segments[tree_idx - tree_begin + 1] -
|
||||
d_tree_segments[tree_idx - tree_begin]);
|
||||
auto tree_categories =
|
||||
d_categories.subspan(d_cat_tree_segments[tree_idx - tree_begin],
|
||||
d_cat_tree_segments[tree_idx - tree_begin + 1] -
|
||||
d_cat_tree_segments[tree_idx - tree_begin]);
|
||||
auto tree_split_types =
|
||||
d_tree_split_types.subspan(d_tree_segments[tree_idx - tree_begin],
|
||||
d_tree_segments[tree_idx - tree_begin + 1] -
|
||||
d_tree_segments[tree_idx - tree_begin]);
|
||||
float leaf = GetLeafWeight(global_idx, d_tree, tree_split_types,
|
||||
tree_cat_ptrs,
|
||||
tree_categories,
|
||||
&loader);
|
||||
for (size_t tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
|
||||
TreeView d_tree{
|
||||
tree_begin, tree_idx, d_nodes,
|
||||
d_tree_segments, d_tree_split_types, d_cat_tree_segments,
|
||||
d_cat_node_segments, d_categories};
|
||||
float leaf = GetLeafWeight<has_missing>(global_idx, d_tree, &loader);
|
||||
sum += leaf;
|
||||
}
|
||||
d_out_predictions[global_idx] += sum;
|
||||
} else {
|
||||
for (int tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
|
||||
for (size_t tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
|
||||
int tree_group = d_tree_group[tree_idx];
|
||||
const RegTree::Node* d_tree =
|
||||
&d_nodes[d_tree_segments[tree_idx - tree_begin]];
|
||||
TreeView d_tree{
|
||||
tree_begin, tree_idx, d_nodes,
|
||||
d_tree_segments, d_tree_split_types, d_cat_tree_segments,
|
||||
d_cat_node_segments, d_categories};
|
||||
bst_uint out_prediction_idx = global_idx * num_group + tree_group;
|
||||
auto tree_cat_ptrs = d_cat_node_segments.subspan(
|
||||
d_tree_segments[tree_idx - tree_begin],
|
||||
d_tree_segments[tree_idx - tree_begin + 1] -
|
||||
d_tree_segments[tree_idx - tree_begin]);
|
||||
auto tree_categories =
|
||||
d_categories.subspan(d_cat_tree_segments[tree_idx - tree_begin],
|
||||
d_cat_tree_segments[tree_idx - tree_begin + 1] -
|
||||
d_cat_tree_segments[tree_idx - tree_begin]);
|
||||
d_out_predictions[out_prediction_idx] +=
|
||||
GetLeafWeight(global_idx, d_tree, d_tree_split_types,
|
||||
tree_cat_ptrs,
|
||||
tree_categories,
|
||||
&loader);
|
||||
GetLeafWeight<has_missing>(global_idx, d_tree, &loader);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -515,7 +515,7 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
DeviceModel const& model,
|
||||
size_t num_features,
|
||||
HostDeviceVector<bst_float>* predictions,
|
||||
size_t batch_offset) const {
|
||||
size_t batch_offset, bool is_dense) const {
|
||||
batch.offset.SetDevice(generic_param_->gpu_id);
|
||||
batch.data.SetDevice(generic_param_->gpu_id);
|
||||
const uint32_t BLOCK_THREADS = 128;
|
||||
@ -529,16 +529,24 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
size_t entry_start = 0;
|
||||
SparsePageView data(batch.data.DeviceSpan(), batch.offset.DeviceSpan(),
|
||||
num_features);
|
||||
auto const kernel = [&](auto predict_fn) {
|
||||
dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS, shared_memory_bytes} (
|
||||
PredictKernel<SparsePageLoader, SparsePageView>, data,
|
||||
model.nodes.ConstDeviceSpan(),
|
||||
predict_fn, data, model.nodes.ConstDeviceSpan(),
|
||||
predictions->DeviceSpan().subspan(batch_offset),
|
||||
model.tree_segments.ConstDeviceSpan(), model.tree_group.ConstDeviceSpan(),
|
||||
model.tree_segments.ConstDeviceSpan(),
|
||||
model.tree_group.ConstDeviceSpan(),
|
||||
model.split_types.ConstDeviceSpan(),
|
||||
model.categories_tree_segments.ConstDeviceSpan(),
|
||||
model.categories_node_segments.ConstDeviceSpan(),
|
||||
model.categories.ConstDeviceSpan(), model.tree_beg_, model.tree_end_,
|
||||
num_features, num_rows, entry_start, use_shared, model.num_group, nan(""));
|
||||
num_features, num_rows, entry_start, use_shared, model.num_group,
|
||||
nan(""));
|
||||
};
|
||||
if (is_dense) {
|
||||
kernel(PredictKernel<SparsePageLoader, SparsePageView, false>);
|
||||
} else {
|
||||
kernel(PredictKernel<SparsePageLoader, SparsePageView, true>);
|
||||
}
|
||||
}
|
||||
void PredictInternal(EllpackDeviceAccessor const& batch,
|
||||
DeviceModel const& model,
|
||||
@ -578,7 +586,7 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
size_t batch_offset = 0;
|
||||
for (auto &batch : dmat->GetBatches<SparsePage>()) {
|
||||
this->PredictInternal(batch, d_model, model.learner_model_param->num_feature,
|
||||
out_preds, batch_offset);
|
||||
out_preds, batch_offset, dmat->IsDense());
|
||||
batch_offset += batch.Size() * model.learner_model_param->num_output_group;
|
||||
}
|
||||
} else {
|
||||
@ -846,6 +854,12 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
d_model.nodes.ConstDeviceSpan(),
|
||||
predictions->DeviceSpan().subspan(batch_offset),
|
||||
d_model.tree_segments.ConstDeviceSpan(),
|
||||
|
||||
d_model.split_types.ConstDeviceSpan(),
|
||||
d_model.categories_tree_segments.ConstDeviceSpan(),
|
||||
d_model.categories_node_segments.ConstDeviceSpan(),
|
||||
d_model.categories.ConstDeviceSpan(),
|
||||
|
||||
d_model.tree_beg_, d_model.tree_end_, num_features, num_rows,
|
||||
entry_start, use_shared, nan(""));
|
||||
batch_offset += batch.Size();
|
||||
@ -862,6 +876,12 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
d_model.nodes.ConstDeviceSpan(),
|
||||
predictions->DeviceSpan().subspan(batch_offset),
|
||||
d_model.tree_segments.ConstDeviceSpan(),
|
||||
|
||||
d_model.split_types.ConstDeviceSpan(),
|
||||
d_model.categories_tree_segments.ConstDeviceSpan(),
|
||||
d_model.categories_node_segments.ConstDeviceSpan(),
|
||||
d_model.categories.ConstDeviceSpan(),
|
||||
|
||||
d_model.tree_beg_, d_model.tree_end_, num_features, num_rows,
|
||||
entry_start, use_shared, nan(""));
|
||||
batch_offset += batch.Size();
|
||||
|
||||
31
src/predictor/predict_fn.h
Normal file
31
src/predictor/predict_fn.h
Normal file
@ -0,0 +1,31 @@
|
||||
/*!
|
||||
* Copyright 2021 by XGBoost Contributors
|
||||
*/
|
||||
#ifndef XGBOOST_PREDICTOR_PREDICT_FN_H_
|
||||
#define XGBOOST_PREDICTOR_PREDICT_FN_H_
|
||||
#include "../common/categorical.h"
|
||||
#include "xgboost/tree_model.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace predictor {
|
||||
template <bool has_missing, bool has_categorical>
|
||||
inline XGBOOST_DEVICE bst_node_t
|
||||
GetNextNode(const RegTree::Node &node, const bst_node_t nid, float fvalue,
|
||||
bool is_missing, RegTree::CategoricalSplitMatrix const &cats) {
|
||||
if (has_missing && is_missing) {
|
||||
return node.DefaultChild();
|
||||
} else {
|
||||
if (has_categorical && common::IsCat(cats.split_type, nid)) {
|
||||
auto node_categories = cats.categories.subspan(cats.node_ptr[nid].beg,
|
||||
cats.node_ptr[nid].size);
|
||||
return Decision(node_categories, common::AsCat(fvalue))
|
||||
? node.LeftChild()
|
||||
: node.RightChild();
|
||||
} else {
|
||||
return node.LeftChild() + !(fvalue < node.SplitCond());
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace predictor
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_PREDICTOR_PREDICT_FN_H_
|
||||
@ -19,6 +19,7 @@
|
||||
#include "param.h"
|
||||
#include "../common/common.h"
|
||||
#include "../common/categorical.h"
|
||||
#include "../predictor/predict_fn.h"
|
||||
|
||||
namespace xgboost {
|
||||
// register tree parameter
|
||||
@ -1052,10 +1053,15 @@ void RegTree::CalculateContributionsApprox(const RegTree::FVec &feat,
|
||||
// nothing to do anymore
|
||||
return;
|
||||
}
|
||||
|
||||
bst_node_t nid = 0;
|
||||
auto cats = this->GetCategoriesMatrix();
|
||||
|
||||
while (!(*this)[nid].IsLeaf()) {
|
||||
split_index = (*this)[nid].SplitIndex();
|
||||
nid = this->GetNext(nid, feat.GetFvalue(split_index), feat.IsMissing(split_index));
|
||||
nid = predictor::GetNextNode<true, true>((*this)[nid], nid,
|
||||
feat.GetFvalue(split_index),
|
||||
feat.IsMissing(split_index), cats);
|
||||
bst_float new_value = this->node_mean_values_[nid];
|
||||
// update feature weight
|
||||
out_contribs[split_index] += new_value - node_value;
|
||||
|
||||
@ -14,6 +14,7 @@
|
||||
#include "./param.h"
|
||||
#include "../common/io.h"
|
||||
#include "../common/threading_utils.h"
|
||||
#include "../predictor/predict_fn.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
@ -123,10 +124,13 @@ class TreeRefresher: public TreeUpdater {
|
||||
// start from groups that belongs to current data
|
||||
auto pid = 0;
|
||||
gstats[pid].Add(gpair[ridx]);
|
||||
auto const& cats = tree.GetCategoriesMatrix();
|
||||
// traverse tree
|
||||
while (!tree[pid].IsLeaf()) {
|
||||
unsigned split_index = tree[pid].SplitIndex();
|
||||
pid = tree.GetNext(pid, feat.GetFvalue(split_index), feat.IsMissing(split_index));
|
||||
pid = predictor::GetNextNode<true, true>(
|
||||
tree[pid], pid, feat.GetFvalue(split_index), feat.IsMissing(split_index),
|
||||
cats);
|
||||
gstats[pid].Add(gpair[ridx]);
|
||||
}
|
||||
}
|
||||
|
||||
@ -229,6 +229,14 @@ void TestUpdatePredictionCache(bool use_subsampling) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(CPUPredictor, CategoricalPrediction) {
|
||||
TestCategoricalPrediction("cpu_predictor");
|
||||
}
|
||||
|
||||
TEST(CPUPredictor, CategoricalPredictLeaf) {
|
||||
TestCategoricalPredictLeaf(StringView{"cpu_predictor"});
|
||||
}
|
||||
|
||||
TEST(CpuPredictor, UpdatePredictionCache) {
|
||||
TestUpdatePredictionCache(false);
|
||||
TestUpdatePredictionCache(true);
|
||||
|
||||
@ -228,6 +228,10 @@ TEST(GPUPredictor, CategoricalPrediction) {
|
||||
TestCategoricalPrediction("gpu_predictor");
|
||||
}
|
||||
|
||||
TEST(GPUPredictor, CategoricalPredictLeaf) {
|
||||
TestCategoricalPredictLeaf(StringView{"gpu_predictor"});
|
||||
}
|
||||
|
||||
TEST(GPUPredictor, PredictLeafBasic) {
|
||||
size_t constexpr kRows = 5, kCols = 5;
|
||||
auto dmat = RandomDataGenerator(kRows, kCols, 0).Device(0).GenerateDMatrix();
|
||||
|
||||
@ -180,6 +180,25 @@ void TestPredictionWithLesserFeatures(std::string predictor_name) {
|
||||
#endif // defined(XGBOOST_USE_CUDA)
|
||||
}
|
||||
|
||||
void GBTreeModelForTest(gbm::GBTreeModel *model, uint32_t split_ind,
|
||||
bst_cat_t split_cat, float left_weight,
|
||||
float right_weight) {
|
||||
PredictionCacheEntry out_predictions;
|
||||
|
||||
std::vector<std::unique_ptr<RegTree>> trees;
|
||||
trees.push_back(std::unique_ptr<RegTree>(new RegTree));
|
||||
auto& p_tree = trees.front();
|
||||
|
||||
std::vector<uint32_t> split_cats(LBitField32::ComputeStorageSize(split_cat));
|
||||
LBitField32 cats_bits(split_cats);
|
||||
cats_bits.Set(split_cat);
|
||||
|
||||
p_tree->ExpandCategorical(0, split_ind, split_cats, true, 1.5f,
|
||||
left_weight, right_weight,
|
||||
3.0f, 2.2f, 7.0f, 9.0f);
|
||||
model->CommitModel(std::move(trees), 0);
|
||||
}
|
||||
|
||||
void TestCategoricalPrediction(std::string name) {
|
||||
size_t constexpr kCols = 10;
|
||||
PredictionCacheEntry out_predictions;
|
||||
@ -189,25 +208,13 @@ void TestCategoricalPrediction(std::string name) {
|
||||
param.num_output_group = 1;
|
||||
param.base_score = 0.5;
|
||||
|
||||
gbm::GBTreeModel model(¶m);
|
||||
|
||||
std::vector<std::unique_ptr<RegTree>> trees;
|
||||
trees.push_back(std::unique_ptr<RegTree>(new RegTree));
|
||||
auto& p_tree = trees.front();
|
||||
|
||||
uint32_t split_ind = 3;
|
||||
bst_cat_t split_cat = 4;
|
||||
float left_weight = 1.3f;
|
||||
float right_weight = 1.7f;
|
||||
|
||||
std::vector<uint32_t> split_cats(LBitField32::ComputeStorageSize(split_cat));
|
||||
LBitField32 cats_bits(split_cats);
|
||||
cats_bits.Set(split_cat);
|
||||
|
||||
p_tree->ExpandCategorical(0, split_ind, split_cats, true, 1.5f,
|
||||
left_weight, right_weight,
|
||||
3.0f, 2.2f, 7.0f, 9.0f);
|
||||
model.CommitModel(std::move(trees), 0);
|
||||
gbm::GBTreeModel model(¶m);
|
||||
GBTreeModelForTest(&model, split_ind, split_cat, left_weight, right_weight);
|
||||
|
||||
GenericParameter runtime;
|
||||
runtime.gpu_id = 0;
|
||||
@ -232,4 +239,43 @@ void TestCategoricalPrediction(std::string name) {
|
||||
ASSERT_EQ(out_predictions.predictions.HostVector()[0],
|
||||
left_weight + param.base_score);
|
||||
}
|
||||
|
||||
void TestCategoricalPredictLeaf(StringView name) {
|
||||
size_t constexpr kCols = 10;
|
||||
PredictionCacheEntry out_predictions;
|
||||
|
||||
LearnerModelParam param;
|
||||
param.num_feature = kCols;
|
||||
param.num_output_group = 1;
|
||||
param.base_score = 0.5;
|
||||
|
||||
uint32_t split_ind = 3;
|
||||
bst_cat_t split_cat = 4;
|
||||
float left_weight = 1.3f;
|
||||
float right_weight = 1.7f;
|
||||
|
||||
gbm::GBTreeModel model(¶m);
|
||||
GBTreeModelForTest(&model, split_ind, split_cat, left_weight, right_weight);
|
||||
|
||||
GenericParameter runtime;
|
||||
runtime.gpu_id = 0;
|
||||
std::unique_ptr<Predictor> predictor{
|
||||
Predictor::Create(name.c_str(), &runtime)};
|
||||
|
||||
std::vector<float> row(kCols);
|
||||
row[split_ind] = split_cat;
|
||||
auto m = GetDMatrixFromData(row, 1, kCols);
|
||||
|
||||
predictor->PredictLeaf(m.get(), &out_predictions.predictions, model);
|
||||
CHECK_EQ(out_predictions.predictions.Size(), 1);
|
||||
// go to left if it doesn't match the category, otherwise right.
|
||||
ASSERT_EQ(out_predictions.predictions.HostVector()[0], 2);
|
||||
|
||||
row[split_ind] = split_cat + 1;
|
||||
m = GetDMatrixFromData(row, 1, kCols);
|
||||
out_predictions.version = 0;
|
||||
predictor->InitOutPredictions(m->Info(), &out_predictions.predictions, model);
|
||||
predictor->PredictLeaf(m.get(), &out_predictions.predictions, model);
|
||||
ASSERT_EQ(out_predictions.predictions.HostVector()[0], 1);
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
@ -66,6 +66,8 @@ void TestInplacePrediction(dmlc::any x, std::string predictor,
|
||||
void TestPredictionWithLesserFeatures(std::string preditor_name);
|
||||
|
||||
void TestCategoricalPrediction(std::string name);
|
||||
|
||||
void TestCategoricalPredictLeaf(StringView name);
|
||||
} // namespace xgboost
|
||||
|
||||
#endif // XGBOOST_TEST_PREDICTOR_H_
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user