Implement categorical prediction for CPU and GPU predict leaf. (#7001)
* Categorical prediction with CPU predictor and GPU predict leaf. * Implement categorical prediction for CPU prediction. * Implement categorical prediction for GPU predict leaf. * Refactor the prediction functions to have a unified get next node function. Co-authored-by: Shvets Kirill <kirill.shvets@intel.com>
This commit is contained in:
@@ -16,9 +16,11 @@
|
||||
#include "xgboost/logging.h"
|
||||
#include "xgboost/host_device_vector.h"
|
||||
|
||||
#include "predict_fn.h"
|
||||
#include "../data/adapter.h"
|
||||
#include "../common/math.h"
|
||||
#include "../common/threading_utils.h"
|
||||
#include "../common/categorical.h"
|
||||
#include "../gbm/gbtree_model.h"
|
||||
|
||||
namespace xgboost {
|
||||
@@ -26,6 +28,19 @@ namespace predictor {
|
||||
|
||||
DMLC_REGISTRY_FILE_TAG(cpu_predictor);
|
||||
|
||||
template <bool has_missing, bool has_categorical>
|
||||
bst_node_t GetLeafIndex(RegTree const &tree, const RegTree::FVec &feat,
|
||||
RegTree::CategoricalSplitMatrix const& cats) {
|
||||
bst_node_t nid = 0;
|
||||
while (!tree[nid].IsLeaf()) {
|
||||
unsigned split_index = tree[nid].SplitIndex();
|
||||
auto fvalue = feat.GetFvalue(split_index);
|
||||
nid = GetNextNode<has_missing, has_categorical>(
|
||||
tree[nid], nid, fvalue, has_missing && feat.IsMissing(split_index), cats);
|
||||
}
|
||||
return nid;
|
||||
}
|
||||
|
||||
bst_float PredValue(const SparsePage::Inst &inst,
|
||||
const std::vector<std::unique_ptr<RegTree>> &trees,
|
||||
const std::vector<int> &tree_info, int bst_group,
|
||||
@@ -35,32 +50,59 @@ bst_float PredValue(const SparsePage::Inst &inst,
|
||||
p_feats->Fill(inst);
|
||||
for (size_t i = tree_begin; i < tree_end; ++i) {
|
||||
if (tree_info[i] == bst_group) {
|
||||
int tid = trees[i]->GetLeafIndex(*p_feats);
|
||||
psum += (*trees[i])[tid].LeafValue();
|
||||
auto const &tree = *trees[i];
|
||||
bool has_categorical = tree.HasCategoricalSplit();
|
||||
|
||||
auto categories = common::Span<uint32_t const>{tree.GetSplitCategories()};
|
||||
auto split_types = tree.GetSplitTypes();
|
||||
auto categories_ptr =
|
||||
common::Span<RegTree::Segment const>{tree.GetSplitCategoriesPtr()};
|
||||
auto cats = tree.GetCategoriesMatrix();
|
||||
bst_node_t nidx = -1;
|
||||
if (has_categorical) {
|
||||
nidx = GetLeafIndex<true, true>(tree, *p_feats, cats);
|
||||
} else {
|
||||
nidx = GetLeafIndex<true, false>(tree, *p_feats, cats);
|
||||
}
|
||||
psum += (*trees[i])[nidx].LeafValue();
|
||||
}
|
||||
}
|
||||
p_feats->Drop(inst);
|
||||
return psum;
|
||||
}
|
||||
|
||||
inline bst_float PredValueByOneTree(const RegTree::FVec& p_feats,
|
||||
const std::unique_ptr<RegTree>& tree) {
|
||||
const int lid = p_feats.HasMissing() ? tree->GetLeafIndex<true>(p_feats) :
|
||||
tree->GetLeafIndex<false>(p_feats); // 35% speed up
|
||||
return (*tree)[lid].LeafValue();
|
||||
template <bool has_categorical>
|
||||
bst_float
|
||||
PredValueByOneTree(const RegTree::FVec &p_feats, RegTree const &tree,
|
||||
RegTree::CategoricalSplitMatrix const& cats) {
|
||||
const bst_node_t leaf = p_feats.HasMissing() ?
|
||||
GetLeafIndex<true, has_categorical>(tree, p_feats, cats) :
|
||||
GetLeafIndex<false, has_categorical>(tree, p_feats, cats);
|
||||
return tree[leaf].LeafValue();
|
||||
}
|
||||
|
||||
inline void PredictByAllTrees(gbm::GBTreeModel const &model, const size_t tree_begin,
|
||||
const size_t tree_end, std::vector<bst_float>* out_preds,
|
||||
const size_t predict_offset, const size_t num_group,
|
||||
const std::vector<RegTree::FVec> &thread_temp,
|
||||
const size_t offset, const size_t block_size) {
|
||||
void PredictByAllTrees(gbm::GBTreeModel const &model, const size_t tree_begin,
|
||||
const size_t tree_end, std::vector<bst_float> *out_preds,
|
||||
const size_t predict_offset, const size_t num_group,
|
||||
const std::vector<RegTree::FVec> &thread_temp,
|
||||
const size_t offset, const size_t block_size) {
|
||||
std::vector<bst_float> &preds = *out_preds;
|
||||
for (size_t tree_id = tree_begin; tree_id < tree_end; ++tree_id) {
|
||||
const size_t gid = model.tree_info[tree_id];
|
||||
for (size_t i = 0; i < block_size; ++i) {
|
||||
preds[(predict_offset + i) * num_group + gid] += PredValueByOneTree(thread_temp[offset + i],
|
||||
model.trees[tree_id]);
|
||||
auto const &tree = *model.trees[tree_id];
|
||||
auto const& cats = tree.GetCategoriesMatrix();
|
||||
auto has_categorical = tree.HasCategoricalSplit();
|
||||
|
||||
if (has_categorical) {
|
||||
for (size_t i = 0; i < block_size; ++i) {
|
||||
preds[(predict_offset + i) * num_group + gid] +=
|
||||
PredValueByOneTree<true>(thread_temp[offset + i], tree, cats);
|
||||
}
|
||||
} else {
|
||||
for (size_t i = 0; i < block_size; ++i) {
|
||||
preds[(predict_offset + i) * num_group + gid] +=
|
||||
PredValueByOneTree<false>(thread_temp[offset + i], tree, cats);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -77,6 +119,7 @@ void FVecFill(const size_t block_size, const size_t batch_offset, const int num_
|
||||
feats.Fill(inst);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DataView>
|
||||
void FVecDrop(const size_t block_size, const size_t batch_offset, DataView* batch,
|
||||
const size_t fvec_offset, std::vector<RegTree::FVec>* p_feats) {
|
||||
@@ -145,11 +188,11 @@ class AdapterView {
|
||||
};
|
||||
|
||||
template <typename DataView, size_t block_of_rows_size>
|
||||
void PredictBatchByBlockOfRowsKernel(DataView batch, std::vector<bst_float> *out_preds,
|
||||
gbm::GBTreeModel const &model, int32_t tree_begin,
|
||||
int32_t tree_end,
|
||||
std::vector<RegTree::FVec> *p_thread_temp) {
|
||||
auto& thread_temp = *p_thread_temp;
|
||||
void PredictBatchByBlockOfRowsKernel(
|
||||
DataView batch, std::vector<bst_float> *out_preds,
|
||||
gbm::GBTreeModel const &model, int32_t tree_begin, int32_t tree_end,
|
||||
std::vector<RegTree::FVec> *p_thread_temp) {
|
||||
auto &thread_temp = *p_thread_temp;
|
||||
int32_t const num_group = model.learner_model_param->num_output_group;
|
||||
|
||||
CHECK_EQ(model.param.size_leaf_vector, 0)
|
||||
@@ -157,16 +200,20 @@ void PredictBatchByBlockOfRowsKernel(DataView batch, std::vector<bst_float> *out
|
||||
// parallel over local batch
|
||||
const auto nsize = static_cast<bst_omp_uint>(batch.Size());
|
||||
const int num_feature = model.learner_model_param->num_feature;
|
||||
const bst_omp_uint n_row_blocks = (nsize) / block_of_rows_size + !!((nsize) % block_of_rows_size);
|
||||
common::ParallelFor(n_row_blocks, [&](bst_omp_uint block_id) {
|
||||
omp_ulong n_blocks = common::DivRoundUp(nsize, block_of_rows_size);
|
||||
|
||||
common::ParallelFor(n_blocks, [&](bst_omp_uint block_id) {
|
||||
const size_t batch_offset = block_id * block_of_rows_size;
|
||||
const size_t block_size = std::min(nsize - batch_offset, block_of_rows_size);
|
||||
const size_t block_size =
|
||||
std::min(nsize - batch_offset, block_of_rows_size);
|
||||
const size_t fvec_offset = omp_get_thread_num() * block_of_rows_size;
|
||||
|
||||
FVecFill(block_size, batch_offset, num_feature, &batch, fvec_offset, p_thread_temp);
|
||||
FVecFill(block_size, batch_offset, num_feature, &batch, fvec_offset,
|
||||
p_thread_temp);
|
||||
// process block of rows through all trees to keep cache locality
|
||||
PredictByAllTrees(model, tree_begin, tree_end, out_preds, batch_offset + batch.base_rowid,
|
||||
num_group, thread_temp, fvec_offset, block_size);
|
||||
PredictByAllTrees(model, tree_begin, tree_end, out_preds,
|
||||
batch_offset + batch.base_rowid, num_group, thread_temp,
|
||||
fvec_offset, block_size);
|
||||
FVecDrop(block_size, batch_offset, &batch, fvec_offset, p_thread_temp);
|
||||
});
|
||||
}
|
||||
@@ -344,7 +391,9 @@ class CPUPredictor : public Predictor {
|
||||
}
|
||||
feats.Fill(page[i]);
|
||||
for (unsigned j = 0; j < ntree_limit; ++j) {
|
||||
int tid = model.trees[j]->GetLeafIndex(feats);
|
||||
auto const& tree = *model.trees[j];
|
||||
auto const& cats = tree.GetCategoriesMatrix();
|
||||
bst_node_t tid = GetLeafIndex<true, true>(tree, feats, cats);
|
||||
preds[ridx * ntree_limit + j] = static_cast<bst_float>(tid);
|
||||
}
|
||||
feats.Drop(page[i]);
|
||||
|
||||
Reference in New Issue
Block a user