Implement categorical prediction for CPU and GPU predict leaf. (#7001)
* Categorical prediction with CPU predictor and GPU predict leaf. * Implement categorical prediction for CPU prediction. * Implement categorical prediction for GPU predict leaf. * Refactor the prediction functions to have a unified get next node function. Co-authored-by: Shvets Kirill <kirill.shvets@intel.com>
This commit is contained in:
parent
72f9daf9b6
commit
f79cc4a7a4
@ -445,6 +445,10 @@ class RegTree : public Model {
|
|||||||
bst_float right_leaf_weight, bst_float loss_change,
|
bst_float right_leaf_weight, bst_float loss_change,
|
||||||
float sum_hess, float left_sum, float right_sum);
|
float sum_hess, float left_sum, float right_sum);
|
||||||
|
|
||||||
|
bool HasCategoricalSplit() const {
|
||||||
|
return !split_categories_.empty();
|
||||||
|
}
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief get current depth
|
* \brief get current depth
|
||||||
* \param nid node id
|
* \param nid node id
|
||||||
@ -537,13 +541,6 @@ class RegTree : public Model {
|
|||||||
std::vector<Entry> data_;
|
std::vector<Entry> data_;
|
||||||
bool has_missing_;
|
bool has_missing_;
|
||||||
};
|
};
|
||||||
/*!
|
|
||||||
* \brief get the leaf index
|
|
||||||
* \param feat dense feature vector, if the feature is missing the field is set to NaN
|
|
||||||
* \return the leaf index of the given feature
|
|
||||||
*/
|
|
||||||
template <bool has_missing = true>
|
|
||||||
int GetLeafIndex(const FVec& feat) const;
|
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
* \brief calculate the feature contributions (https://arxiv.org/abs/1706.06060) for the tree
|
* \brief calculate the feature contributions (https://arxiv.org/abs/1706.06060) for the tree
|
||||||
@ -582,14 +579,6 @@ class RegTree : public Model {
|
|||||||
*/
|
*/
|
||||||
void CalculateContributionsApprox(const RegTree::FVec& feat,
|
void CalculateContributionsApprox(const RegTree::FVec& feat,
|
||||||
bst_float* out_contribs) const;
|
bst_float* out_contribs) const;
|
||||||
/*!
|
|
||||||
* \brief get next position of the tree given current pid
|
|
||||||
* \param pid Current node id.
|
|
||||||
* \param fvalue feature value if not missing.
|
|
||||||
* \param is_unknown Whether current required feature is missing.
|
|
||||||
*/
|
|
||||||
template <bool has_missing = true>
|
|
||||||
inline int GetNext(int pid, bst_float fvalue, bool is_unknown) const;
|
|
||||||
/*!
|
/*!
|
||||||
* \brief dump the model in the requested format as a text string
|
* \brief dump the model in the requested format as a text string
|
||||||
* \param fmap feature map that may help give interpretations of feature
|
* \param fmap feature map that may help give interpretations of feature
|
||||||
@ -627,6 +616,20 @@ class RegTree : public Model {
|
|||||||
size_t size {0};
|
size_t size {0};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct CategoricalSplitMatrix {
|
||||||
|
common::Span<FeatureType const> split_type;
|
||||||
|
common::Span<uint32_t const> categories;
|
||||||
|
common::Span<Segment const> node_ptr;
|
||||||
|
};
|
||||||
|
|
||||||
|
CategoricalSplitMatrix GetCategoriesMatrix() const {
|
||||||
|
CategoricalSplitMatrix view;
|
||||||
|
view.split_type = common::Span<FeatureType const>(this->GetSplitTypes());
|
||||||
|
view.categories = this->GetSplitCategories();
|
||||||
|
view.node_ptr = common::Span<Segment const>(split_categories_segments_);
|
||||||
|
return view;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void LoadCategoricalSplit(Json const& in);
|
void LoadCategoricalSplit(Json const& in);
|
||||||
void SaveCategoricalSplit(Json* p_out) const;
|
void SaveCategoricalSplit(Json* p_out) const;
|
||||||
@ -724,38 +727,5 @@ inline bool RegTree::FVec::IsMissing(size_t i) const {
|
|||||||
inline bool RegTree::FVec::HasMissing() const {
|
inline bool RegTree::FVec::HasMissing() const {
|
||||||
return has_missing_;
|
return has_missing_;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <bool has_missing>
|
|
||||||
inline int RegTree::GetLeafIndex(const RegTree::FVec& feat) const {
|
|
||||||
bst_node_t nid = 0;
|
|
||||||
while (!(*this)[nid].IsLeaf()) {
|
|
||||||
unsigned split_index = (*this)[nid].SplitIndex();
|
|
||||||
nid = this->GetNext<has_missing>(nid, feat.GetFvalue(split_index),
|
|
||||||
has_missing && feat.IsMissing(split_index));
|
|
||||||
}
|
|
||||||
return nid;
|
|
||||||
}
|
|
||||||
|
|
||||||
/*! \brief get next position of the tree given current pid */
|
|
||||||
template <bool has_missing>
|
|
||||||
inline int RegTree::GetNext(int pid, bst_float fvalue, bool is_unknown) const {
|
|
||||||
if (has_missing) {
|
|
||||||
if (is_unknown) {
|
|
||||||
return (*this)[pid].DefaultChild();
|
|
||||||
} else {
|
|
||||||
if (fvalue < (*this)[pid].SplitCond()) {
|
|
||||||
return (*this)[pid].LeftChild();
|
|
||||||
} else {
|
|
||||||
return (*this)[pid].RightChild();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// 35% speed up due to reduced miss branch predictions
|
|
||||||
// The following expression returns the left child if (fvalue < split_cond);
|
|
||||||
// the right child otherwise.
|
|
||||||
return (*this)[pid].LeftChild() + !(fvalue < (*this)[pid].SplitCond());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
#endif // XGBOOST_TREE_MODEL_H_
|
#endif // XGBOOST_TREE_MODEL_H_
|
||||||
|
|||||||
@ -16,9 +16,11 @@
|
|||||||
#include "xgboost/logging.h"
|
#include "xgboost/logging.h"
|
||||||
#include "xgboost/host_device_vector.h"
|
#include "xgboost/host_device_vector.h"
|
||||||
|
|
||||||
|
#include "predict_fn.h"
|
||||||
#include "../data/adapter.h"
|
#include "../data/adapter.h"
|
||||||
#include "../common/math.h"
|
#include "../common/math.h"
|
||||||
#include "../common/threading_utils.h"
|
#include "../common/threading_utils.h"
|
||||||
|
#include "../common/categorical.h"
|
||||||
#include "../gbm/gbtree_model.h"
|
#include "../gbm/gbtree_model.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
@ -26,6 +28,19 @@ namespace predictor {
|
|||||||
|
|
||||||
DMLC_REGISTRY_FILE_TAG(cpu_predictor);
|
DMLC_REGISTRY_FILE_TAG(cpu_predictor);
|
||||||
|
|
||||||
|
template <bool has_missing, bool has_categorical>
|
||||||
|
bst_node_t GetLeafIndex(RegTree const &tree, const RegTree::FVec &feat,
|
||||||
|
RegTree::CategoricalSplitMatrix const& cats) {
|
||||||
|
bst_node_t nid = 0;
|
||||||
|
while (!tree[nid].IsLeaf()) {
|
||||||
|
unsigned split_index = tree[nid].SplitIndex();
|
||||||
|
auto fvalue = feat.GetFvalue(split_index);
|
||||||
|
nid = GetNextNode<has_missing, has_categorical>(
|
||||||
|
tree[nid], nid, fvalue, has_missing && feat.IsMissing(split_index), cats);
|
||||||
|
}
|
||||||
|
return nid;
|
||||||
|
}
|
||||||
|
|
||||||
bst_float PredValue(const SparsePage::Inst &inst,
|
bst_float PredValue(const SparsePage::Inst &inst,
|
||||||
const std::vector<std::unique_ptr<RegTree>> &trees,
|
const std::vector<std::unique_ptr<RegTree>> &trees,
|
||||||
const std::vector<int> &tree_info, int bst_group,
|
const std::vector<int> &tree_info, int bst_group,
|
||||||
@ -35,32 +50,59 @@ bst_float PredValue(const SparsePage::Inst &inst,
|
|||||||
p_feats->Fill(inst);
|
p_feats->Fill(inst);
|
||||||
for (size_t i = tree_begin; i < tree_end; ++i) {
|
for (size_t i = tree_begin; i < tree_end; ++i) {
|
||||||
if (tree_info[i] == bst_group) {
|
if (tree_info[i] == bst_group) {
|
||||||
int tid = trees[i]->GetLeafIndex(*p_feats);
|
auto const &tree = *trees[i];
|
||||||
psum += (*trees[i])[tid].LeafValue();
|
bool has_categorical = tree.HasCategoricalSplit();
|
||||||
|
|
||||||
|
auto categories = common::Span<uint32_t const>{tree.GetSplitCategories()};
|
||||||
|
auto split_types = tree.GetSplitTypes();
|
||||||
|
auto categories_ptr =
|
||||||
|
common::Span<RegTree::Segment const>{tree.GetSplitCategoriesPtr()};
|
||||||
|
auto cats = tree.GetCategoriesMatrix();
|
||||||
|
bst_node_t nidx = -1;
|
||||||
|
if (has_categorical) {
|
||||||
|
nidx = GetLeafIndex<true, true>(tree, *p_feats, cats);
|
||||||
|
} else {
|
||||||
|
nidx = GetLeafIndex<true, false>(tree, *p_feats, cats);
|
||||||
|
}
|
||||||
|
psum += (*trees[i])[nidx].LeafValue();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
p_feats->Drop(inst);
|
p_feats->Drop(inst);
|
||||||
return psum;
|
return psum;
|
||||||
}
|
}
|
||||||
|
|
||||||
inline bst_float PredValueByOneTree(const RegTree::FVec& p_feats,
|
template <bool has_categorical>
|
||||||
const std::unique_ptr<RegTree>& tree) {
|
bst_float
|
||||||
const int lid = p_feats.HasMissing() ? tree->GetLeafIndex<true>(p_feats) :
|
PredValueByOneTree(const RegTree::FVec &p_feats, RegTree const &tree,
|
||||||
tree->GetLeafIndex<false>(p_feats); // 35% speed up
|
RegTree::CategoricalSplitMatrix const& cats) {
|
||||||
return (*tree)[lid].LeafValue();
|
const bst_node_t leaf = p_feats.HasMissing() ?
|
||||||
|
GetLeafIndex<true, has_categorical>(tree, p_feats, cats) :
|
||||||
|
GetLeafIndex<false, has_categorical>(tree, p_feats, cats);
|
||||||
|
return tree[leaf].LeafValue();
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void PredictByAllTrees(gbm::GBTreeModel const &model, const size_t tree_begin,
|
void PredictByAllTrees(gbm::GBTreeModel const &model, const size_t tree_begin,
|
||||||
const size_t tree_end, std::vector<bst_float>* out_preds,
|
const size_t tree_end, std::vector<bst_float> *out_preds,
|
||||||
const size_t predict_offset, const size_t num_group,
|
const size_t predict_offset, const size_t num_group,
|
||||||
const std::vector<RegTree::FVec> &thread_temp,
|
const std::vector<RegTree::FVec> &thread_temp,
|
||||||
const size_t offset, const size_t block_size) {
|
const size_t offset, const size_t block_size) {
|
||||||
std::vector<bst_float> &preds = *out_preds;
|
std::vector<bst_float> &preds = *out_preds;
|
||||||
for (size_t tree_id = tree_begin; tree_id < tree_end; ++tree_id) {
|
for (size_t tree_id = tree_begin; tree_id < tree_end; ++tree_id) {
|
||||||
const size_t gid = model.tree_info[tree_id];
|
const size_t gid = model.tree_info[tree_id];
|
||||||
for (size_t i = 0; i < block_size; ++i) {
|
auto const &tree = *model.trees[tree_id];
|
||||||
preds[(predict_offset + i) * num_group + gid] += PredValueByOneTree(thread_temp[offset + i],
|
auto const& cats = tree.GetCategoriesMatrix();
|
||||||
model.trees[tree_id]);
|
auto has_categorical = tree.HasCategoricalSplit();
|
||||||
|
|
||||||
|
if (has_categorical) {
|
||||||
|
for (size_t i = 0; i < block_size; ++i) {
|
||||||
|
preds[(predict_offset + i) * num_group + gid] +=
|
||||||
|
PredValueByOneTree<true>(thread_temp[offset + i], tree, cats);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (size_t i = 0; i < block_size; ++i) {
|
||||||
|
preds[(predict_offset + i) * num_group + gid] +=
|
||||||
|
PredValueByOneTree<false>(thread_temp[offset + i], tree, cats);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -77,6 +119,7 @@ void FVecFill(const size_t block_size, const size_t batch_offset, const int num_
|
|||||||
feats.Fill(inst);
|
feats.Fill(inst);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename DataView>
|
template <typename DataView>
|
||||||
void FVecDrop(const size_t block_size, const size_t batch_offset, DataView* batch,
|
void FVecDrop(const size_t block_size, const size_t batch_offset, DataView* batch,
|
||||||
const size_t fvec_offset, std::vector<RegTree::FVec>* p_feats) {
|
const size_t fvec_offset, std::vector<RegTree::FVec>* p_feats) {
|
||||||
@ -145,11 +188,11 @@ class AdapterView {
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <typename DataView, size_t block_of_rows_size>
|
template <typename DataView, size_t block_of_rows_size>
|
||||||
void PredictBatchByBlockOfRowsKernel(DataView batch, std::vector<bst_float> *out_preds,
|
void PredictBatchByBlockOfRowsKernel(
|
||||||
gbm::GBTreeModel const &model, int32_t tree_begin,
|
DataView batch, std::vector<bst_float> *out_preds,
|
||||||
int32_t tree_end,
|
gbm::GBTreeModel const &model, int32_t tree_begin, int32_t tree_end,
|
||||||
std::vector<RegTree::FVec> *p_thread_temp) {
|
std::vector<RegTree::FVec> *p_thread_temp) {
|
||||||
auto& thread_temp = *p_thread_temp;
|
auto &thread_temp = *p_thread_temp;
|
||||||
int32_t const num_group = model.learner_model_param->num_output_group;
|
int32_t const num_group = model.learner_model_param->num_output_group;
|
||||||
|
|
||||||
CHECK_EQ(model.param.size_leaf_vector, 0)
|
CHECK_EQ(model.param.size_leaf_vector, 0)
|
||||||
@ -157,16 +200,20 @@ void PredictBatchByBlockOfRowsKernel(DataView batch, std::vector<bst_float> *out
|
|||||||
// parallel over local batch
|
// parallel over local batch
|
||||||
const auto nsize = static_cast<bst_omp_uint>(batch.Size());
|
const auto nsize = static_cast<bst_omp_uint>(batch.Size());
|
||||||
const int num_feature = model.learner_model_param->num_feature;
|
const int num_feature = model.learner_model_param->num_feature;
|
||||||
const bst_omp_uint n_row_blocks = (nsize) / block_of_rows_size + !!((nsize) % block_of_rows_size);
|
omp_ulong n_blocks = common::DivRoundUp(nsize, block_of_rows_size);
|
||||||
common::ParallelFor(n_row_blocks, [&](bst_omp_uint block_id) {
|
|
||||||
|
common::ParallelFor(n_blocks, [&](bst_omp_uint block_id) {
|
||||||
const size_t batch_offset = block_id * block_of_rows_size;
|
const size_t batch_offset = block_id * block_of_rows_size;
|
||||||
const size_t block_size = std::min(nsize - batch_offset, block_of_rows_size);
|
const size_t block_size =
|
||||||
|
std::min(nsize - batch_offset, block_of_rows_size);
|
||||||
const size_t fvec_offset = omp_get_thread_num() * block_of_rows_size;
|
const size_t fvec_offset = omp_get_thread_num() * block_of_rows_size;
|
||||||
|
|
||||||
FVecFill(block_size, batch_offset, num_feature, &batch, fvec_offset, p_thread_temp);
|
FVecFill(block_size, batch_offset, num_feature, &batch, fvec_offset,
|
||||||
|
p_thread_temp);
|
||||||
// process block of rows through all trees to keep cache locality
|
// process block of rows through all trees to keep cache locality
|
||||||
PredictByAllTrees(model, tree_begin, tree_end, out_preds, batch_offset + batch.base_rowid,
|
PredictByAllTrees(model, tree_begin, tree_end, out_preds,
|
||||||
num_group, thread_temp, fvec_offset, block_size);
|
batch_offset + batch.base_rowid, num_group, thread_temp,
|
||||||
|
fvec_offset, block_size);
|
||||||
FVecDrop(block_size, batch_offset, &batch, fvec_offset, p_thread_temp);
|
FVecDrop(block_size, batch_offset, &batch, fvec_offset, p_thread_temp);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@ -344,7 +391,9 @@ class CPUPredictor : public Predictor {
|
|||||||
}
|
}
|
||||||
feats.Fill(page[i]);
|
feats.Fill(page[i]);
|
||||||
for (unsigned j = 0; j < ntree_limit; ++j) {
|
for (unsigned j = 0; j < ntree_limit; ++j) {
|
||||||
int tid = model.trees[j]->GetLeafIndex(feats);
|
auto const& tree = *model.trees[j];
|
||||||
|
auto const& cats = tree.GetCategoriesMatrix();
|
||||||
|
bst_node_t tid = GetLeafIndex<true, true>(tree, feats, cats);
|
||||||
preds[ridx * ntree_limit + j] = static_cast<bst_float>(tid);
|
preds[ridx * ntree_limit + j] = static_cast<bst_float>(tid);
|
||||||
}
|
}
|
||||||
feats.Drop(page[i]);
|
feats.Drop(page[i]);
|
||||||
|
|||||||
@ -14,6 +14,7 @@
|
|||||||
#include "xgboost/tree_updater.h"
|
#include "xgboost/tree_updater.h"
|
||||||
#include "xgboost/host_device_vector.h"
|
#include "xgboost/host_device_vector.h"
|
||||||
|
|
||||||
|
#include "predict_fn.h"
|
||||||
#include "../gbm/gbtree_model.h"
|
#include "../gbm/gbtree_model.h"
|
||||||
#include "../data/ellpack_page.cuh"
|
#include "../data/ellpack_page.cuh"
|
||||||
#include "../data/device_adapter.cuh"
|
#include "../data/device_adapter.cuh"
|
||||||
@ -27,6 +28,42 @@ namespace predictor {
|
|||||||
|
|
||||||
DMLC_REGISTRY_FILE_TAG(gpu_predictor);
|
DMLC_REGISTRY_FILE_TAG(gpu_predictor);
|
||||||
|
|
||||||
|
struct TreeView {
|
||||||
|
RegTree::CategoricalSplitMatrix cats;
|
||||||
|
common::Span<RegTree::Node const> d_tree;
|
||||||
|
|
||||||
|
XGBOOST_DEVICE
|
||||||
|
TreeView(size_t tree_begin, size_t tree_idx,
|
||||||
|
common::Span<const RegTree::Node> d_nodes,
|
||||||
|
common::Span<size_t const> d_tree_segments,
|
||||||
|
common::Span<FeatureType const> d_tree_split_types,
|
||||||
|
common::Span<uint32_t const> d_cat_tree_segments,
|
||||||
|
common::Span<RegTree::Segment const> d_cat_node_segments,
|
||||||
|
common::Span<uint32_t const> d_categories) {
|
||||||
|
auto begin = d_tree_segments[tree_idx - tree_begin];
|
||||||
|
auto n_nodes = d_tree_segments[tree_idx - tree_begin + 1] -
|
||||||
|
d_tree_segments[tree_idx - tree_begin];
|
||||||
|
|
||||||
|
d_tree = d_nodes.subspan(begin, n_nodes);
|
||||||
|
|
||||||
|
auto tree_cat_ptrs = d_cat_node_segments.subspan(begin, n_nodes);
|
||||||
|
auto tree_split_types = d_tree_split_types.subspan(begin, n_nodes);
|
||||||
|
|
||||||
|
auto tree_categories =
|
||||||
|
d_categories.subspan(d_cat_tree_segments[tree_idx - tree_begin],
|
||||||
|
d_cat_tree_segments[tree_idx - tree_begin + 1] -
|
||||||
|
d_cat_tree_segments[tree_idx - tree_begin]);
|
||||||
|
|
||||||
|
cats.split_type = tree_split_types;
|
||||||
|
cats.categories = tree_categories;
|
||||||
|
cats.node_ptr = tree_cat_ptrs;
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ bool HasCategoricalSplit() const {
|
||||||
|
return !cats.categories.empty();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
struct SparsePageView {
|
struct SparsePageView {
|
||||||
common::Span<const Entry> d_data;
|
common::Span<const Entry> d_data;
|
||||||
common::Span<const bst_row_t> d_row_ptr;
|
common::Span<const bst_row_t> d_row_ptr;
|
||||||
@ -178,84 +215,69 @@ struct DeviceAdapterLoader {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename Loader>
|
template <bool has_missing, bool has_categorical, typename Loader>
|
||||||
__device__ float GetLeafWeight(bst_row_t ridx, const RegTree::Node* tree,
|
__device__ bst_node_t GetLeafIndex(bst_row_t ridx, TreeView const &tree,
|
||||||
common::Span<FeatureType const> split_types,
|
Loader *loader) {
|
||||||
common::Span<RegTree::Segment const> d_cat_ptrs,
|
|
||||||
common::Span<uint32_t const> d_categories,
|
|
||||||
Loader* loader) {
|
|
||||||
bst_node_t nidx = 0;
|
bst_node_t nidx = 0;
|
||||||
RegTree::Node n = tree[nidx];
|
RegTree::Node n = tree.d_tree[nidx];
|
||||||
while (!n.IsLeaf()) {
|
while (!n.IsLeaf()) {
|
||||||
float fvalue = loader->GetElement(ridx, n.SplitIndex());
|
float fvalue = loader->GetElement(ridx, n.SplitIndex());
|
||||||
// Missing value
|
bool is_missing = common::CheckNAN(fvalue);
|
||||||
if (common::CheckNAN(fvalue)) {
|
nidx = GetNextNode<has_missing, has_categorical>(n, nidx, fvalue,
|
||||||
nidx = n.DefaultChild();
|
is_missing, tree.cats);
|
||||||
} else {
|
n = tree.d_tree[nidx];
|
||||||
bool go_left = true;
|
|
||||||
if (common::IsCat(split_types, nidx)) {
|
|
||||||
auto categories = d_categories.subspan(d_cat_ptrs[nidx].beg,
|
|
||||||
d_cat_ptrs[nidx].size);
|
|
||||||
go_left = Decision(categories, common::AsCat(fvalue));
|
|
||||||
} else {
|
|
||||||
go_left = fvalue < n.SplitCond();
|
|
||||||
}
|
|
||||||
if (go_left) {
|
|
||||||
nidx = n.LeftChild();
|
|
||||||
} else {
|
|
||||||
nidx = n.RightChild();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
n = tree[nidx];
|
|
||||||
}
|
|
||||||
return tree[nidx].LeafValue();
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename Loader>
|
|
||||||
__device__ bst_node_t GetLeafIndex(bst_row_t ridx, const RegTree::Node* tree,
|
|
||||||
Loader const& loader) {
|
|
||||||
bst_node_t nidx = 0;
|
|
||||||
RegTree::Node n = tree[nidx];
|
|
||||||
while (!n.IsLeaf()) {
|
|
||||||
float fvalue = loader.GetElement(ridx, n.SplitIndex());
|
|
||||||
// Missing value
|
|
||||||
if (common::CheckNAN(fvalue)) {
|
|
||||||
nidx = n.DefaultChild();
|
|
||||||
n = tree[nidx];
|
|
||||||
} else {
|
|
||||||
if (fvalue < n.SplitCond()) {
|
|
||||||
nidx = n.LeftChild();
|
|
||||||
n = tree[nidx];
|
|
||||||
} else {
|
|
||||||
nidx = n.RightChild();
|
|
||||||
n = tree[nidx];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return nidx;
|
return nidx;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <bool has_missing, typename Loader>
|
||||||
|
__device__ float GetLeafWeight(bst_row_t ridx, TreeView const &tree,
|
||||||
|
Loader *loader) {
|
||||||
|
bst_node_t nidx = -1;
|
||||||
|
if (tree.HasCategoricalSplit()) {
|
||||||
|
nidx = GetLeafIndex<has_missing, true>(ridx, tree, loader);
|
||||||
|
} else {
|
||||||
|
nidx = GetLeafIndex<has_missing, false>(ridx, tree, loader);
|
||||||
|
}
|
||||||
|
return tree.d_tree[nidx].LeafValue();
|
||||||
|
}
|
||||||
|
|
||||||
template <typename Loader, typename Data>
|
template <typename Loader, typename Data>
|
||||||
__global__ void PredictLeafKernel(Data data,
|
__global__ void
|
||||||
common::Span<const RegTree::Node> d_nodes,
|
PredictLeafKernel(Data data, common::Span<const RegTree::Node> d_nodes,
|
||||||
common::Span<float> d_out_predictions,
|
common::Span<float> d_out_predictions,
|
||||||
common::Span<size_t const> d_tree_segments,
|
common::Span<size_t const> d_tree_segments,
|
||||||
size_t tree_begin, size_t tree_end, size_t num_features,
|
|
||||||
size_t num_rows, size_t entry_start, bool use_shared,
|
common::Span<FeatureType const> d_tree_split_types,
|
||||||
float missing) {
|
common::Span<uint32_t const> d_cat_tree_segments,
|
||||||
|
common::Span<RegTree::Segment const> d_cat_node_segments,
|
||||||
|
common::Span<uint32_t const> d_categories,
|
||||||
|
|
||||||
|
size_t tree_begin, size_t tree_end, size_t num_features,
|
||||||
|
size_t num_rows, size_t entry_start, bool use_shared,
|
||||||
|
float missing) {
|
||||||
bst_row_t ridx = blockDim.x * blockIdx.x + threadIdx.x;
|
bst_row_t ridx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||||
if (ridx >= num_rows) {
|
if (ridx >= num_rows) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
Loader loader(data, use_shared, num_features, num_rows, entry_start, missing);
|
Loader loader(data, use_shared, num_features, num_rows, entry_start, missing);
|
||||||
for (int tree_idx = tree_begin; tree_idx < tree_end; ++tree_idx) {
|
for (size_t tree_idx = tree_begin; tree_idx < tree_end; ++tree_idx) {
|
||||||
const RegTree::Node* d_tree = &d_nodes[d_tree_segments[tree_idx - tree_begin]];
|
TreeView d_tree{
|
||||||
auto leaf = GetLeafIndex(ridx, d_tree, loader);
|
tree_begin, tree_idx, d_nodes,
|
||||||
|
d_tree_segments, d_tree_split_types, d_cat_tree_segments,
|
||||||
|
d_cat_node_segments, d_categories};
|
||||||
|
|
||||||
|
bst_node_t leaf = -1;
|
||||||
|
if (d_tree.HasCategoricalSplit()) {
|
||||||
|
leaf = GetLeafIndex<true, true>(ridx, d_tree, &loader);
|
||||||
|
} else {
|
||||||
|
leaf = GetLeafIndex<true, false>(ridx, d_tree, &loader);
|
||||||
|
}
|
||||||
d_out_predictions[ridx * (tree_end - tree_begin) + tree_idx] = leaf;
|
d_out_predictions[ridx * (tree_end - tree_begin) + tree_idx] = leaf;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename Loader, typename Data>
|
template <typename Loader, typename Data, bool has_missing = true>
|
||||||
__global__ void
|
__global__ void
|
||||||
PredictKernel(Data data, common::Span<const RegTree::Node> d_nodes,
|
PredictKernel(Data data, common::Span<const RegTree::Node> d_nodes,
|
||||||
common::Span<float> d_out_predictions,
|
common::Span<float> d_out_predictions,
|
||||||
@ -272,47 +294,25 @@ PredictKernel(Data data, common::Span<const RegTree::Node> d_nodes,
|
|||||||
if (global_idx >= num_rows) return;
|
if (global_idx >= num_rows) return;
|
||||||
if (num_group == 1) {
|
if (num_group == 1) {
|
||||||
float sum = 0;
|
float sum = 0;
|
||||||
for (int tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
|
for (size_t tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
|
||||||
const RegTree::Node* d_tree =
|
TreeView d_tree{
|
||||||
&d_nodes[d_tree_segments[tree_idx - tree_begin]];
|
tree_begin, tree_idx, d_nodes,
|
||||||
auto tree_cat_ptrs = d_cat_node_segments.subspan(
|
d_tree_segments, d_tree_split_types, d_cat_tree_segments,
|
||||||
d_tree_segments[tree_idx - tree_begin],
|
d_cat_node_segments, d_categories};
|
||||||
d_tree_segments[tree_idx - tree_begin + 1] -
|
float leaf = GetLeafWeight<has_missing>(global_idx, d_tree, &loader);
|
||||||
d_tree_segments[tree_idx - tree_begin]);
|
|
||||||
auto tree_categories =
|
|
||||||
d_categories.subspan(d_cat_tree_segments[tree_idx - tree_begin],
|
|
||||||
d_cat_tree_segments[tree_idx - tree_begin + 1] -
|
|
||||||
d_cat_tree_segments[tree_idx - tree_begin]);
|
|
||||||
auto tree_split_types =
|
|
||||||
d_tree_split_types.subspan(d_tree_segments[tree_idx - tree_begin],
|
|
||||||
d_tree_segments[tree_idx - tree_begin + 1] -
|
|
||||||
d_tree_segments[tree_idx - tree_begin]);
|
|
||||||
float leaf = GetLeafWeight(global_idx, d_tree, tree_split_types,
|
|
||||||
tree_cat_ptrs,
|
|
||||||
tree_categories,
|
|
||||||
&loader);
|
|
||||||
sum += leaf;
|
sum += leaf;
|
||||||
}
|
}
|
||||||
d_out_predictions[global_idx] += sum;
|
d_out_predictions[global_idx] += sum;
|
||||||
} else {
|
} else {
|
||||||
for (int tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
|
for (size_t tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
|
||||||
int tree_group = d_tree_group[tree_idx];
|
int tree_group = d_tree_group[tree_idx];
|
||||||
const RegTree::Node* d_tree =
|
TreeView d_tree{
|
||||||
&d_nodes[d_tree_segments[tree_idx - tree_begin]];
|
tree_begin, tree_idx, d_nodes,
|
||||||
|
d_tree_segments, d_tree_split_types, d_cat_tree_segments,
|
||||||
|
d_cat_node_segments, d_categories};
|
||||||
bst_uint out_prediction_idx = global_idx * num_group + tree_group;
|
bst_uint out_prediction_idx = global_idx * num_group + tree_group;
|
||||||
auto tree_cat_ptrs = d_cat_node_segments.subspan(
|
|
||||||
d_tree_segments[tree_idx - tree_begin],
|
|
||||||
d_tree_segments[tree_idx - tree_begin + 1] -
|
|
||||||
d_tree_segments[tree_idx - tree_begin]);
|
|
||||||
auto tree_categories =
|
|
||||||
d_categories.subspan(d_cat_tree_segments[tree_idx - tree_begin],
|
|
||||||
d_cat_tree_segments[tree_idx - tree_begin + 1] -
|
|
||||||
d_cat_tree_segments[tree_idx - tree_begin]);
|
|
||||||
d_out_predictions[out_prediction_idx] +=
|
d_out_predictions[out_prediction_idx] +=
|
||||||
GetLeafWeight(global_idx, d_tree, d_tree_split_types,
|
GetLeafWeight<has_missing>(global_idx, d_tree, &loader);
|
||||||
tree_cat_ptrs,
|
|
||||||
tree_categories,
|
|
||||||
&loader);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -515,7 +515,7 @@ class GPUPredictor : public xgboost::Predictor {
|
|||||||
DeviceModel const& model,
|
DeviceModel const& model,
|
||||||
size_t num_features,
|
size_t num_features,
|
||||||
HostDeviceVector<bst_float>* predictions,
|
HostDeviceVector<bst_float>* predictions,
|
||||||
size_t batch_offset) const {
|
size_t batch_offset, bool is_dense) const {
|
||||||
batch.offset.SetDevice(generic_param_->gpu_id);
|
batch.offset.SetDevice(generic_param_->gpu_id);
|
||||||
batch.data.SetDevice(generic_param_->gpu_id);
|
batch.data.SetDevice(generic_param_->gpu_id);
|
||||||
const uint32_t BLOCK_THREADS = 128;
|
const uint32_t BLOCK_THREADS = 128;
|
||||||
@ -529,16 +529,24 @@ class GPUPredictor : public xgboost::Predictor {
|
|||||||
size_t entry_start = 0;
|
size_t entry_start = 0;
|
||||||
SparsePageView data(batch.data.DeviceSpan(), batch.offset.DeviceSpan(),
|
SparsePageView data(batch.data.DeviceSpan(), batch.offset.DeviceSpan(),
|
||||||
num_features);
|
num_features);
|
||||||
dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS, shared_memory_bytes} (
|
auto const kernel = [&](auto predict_fn) {
|
||||||
PredictKernel<SparsePageLoader, SparsePageView>, data,
|
dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS, shared_memory_bytes} (
|
||||||
model.nodes.ConstDeviceSpan(),
|
predict_fn, data, model.nodes.ConstDeviceSpan(),
|
||||||
predictions->DeviceSpan().subspan(batch_offset),
|
predictions->DeviceSpan().subspan(batch_offset),
|
||||||
model.tree_segments.ConstDeviceSpan(), model.tree_group.ConstDeviceSpan(),
|
model.tree_segments.ConstDeviceSpan(),
|
||||||
model.split_types.ConstDeviceSpan(),
|
model.tree_group.ConstDeviceSpan(),
|
||||||
model.categories_tree_segments.ConstDeviceSpan(),
|
model.split_types.ConstDeviceSpan(),
|
||||||
model.categories_node_segments.ConstDeviceSpan(),
|
model.categories_tree_segments.ConstDeviceSpan(),
|
||||||
model.categories.ConstDeviceSpan(), model.tree_beg_, model.tree_end_,
|
model.categories_node_segments.ConstDeviceSpan(),
|
||||||
num_features, num_rows, entry_start, use_shared, model.num_group, nan(""));
|
model.categories.ConstDeviceSpan(), model.tree_beg_, model.tree_end_,
|
||||||
|
num_features, num_rows, entry_start, use_shared, model.num_group,
|
||||||
|
nan(""));
|
||||||
|
};
|
||||||
|
if (is_dense) {
|
||||||
|
kernel(PredictKernel<SparsePageLoader, SparsePageView, false>);
|
||||||
|
} else {
|
||||||
|
kernel(PredictKernel<SparsePageLoader, SparsePageView, true>);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
void PredictInternal(EllpackDeviceAccessor const& batch,
|
void PredictInternal(EllpackDeviceAccessor const& batch,
|
||||||
DeviceModel const& model,
|
DeviceModel const& model,
|
||||||
@ -578,7 +586,7 @@ class GPUPredictor : public xgboost::Predictor {
|
|||||||
size_t batch_offset = 0;
|
size_t batch_offset = 0;
|
||||||
for (auto &batch : dmat->GetBatches<SparsePage>()) {
|
for (auto &batch : dmat->GetBatches<SparsePage>()) {
|
||||||
this->PredictInternal(batch, d_model, model.learner_model_param->num_feature,
|
this->PredictInternal(batch, d_model, model.learner_model_param->num_feature,
|
||||||
out_preds, batch_offset);
|
out_preds, batch_offset, dmat->IsDense());
|
||||||
batch_offset += batch.Size() * model.learner_model_param->num_output_group;
|
batch_offset += batch.Size() * model.learner_model_param->num_output_group;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@ -846,6 +854,12 @@ class GPUPredictor : public xgboost::Predictor {
|
|||||||
d_model.nodes.ConstDeviceSpan(),
|
d_model.nodes.ConstDeviceSpan(),
|
||||||
predictions->DeviceSpan().subspan(batch_offset),
|
predictions->DeviceSpan().subspan(batch_offset),
|
||||||
d_model.tree_segments.ConstDeviceSpan(),
|
d_model.tree_segments.ConstDeviceSpan(),
|
||||||
|
|
||||||
|
d_model.split_types.ConstDeviceSpan(),
|
||||||
|
d_model.categories_tree_segments.ConstDeviceSpan(),
|
||||||
|
d_model.categories_node_segments.ConstDeviceSpan(),
|
||||||
|
d_model.categories.ConstDeviceSpan(),
|
||||||
|
|
||||||
d_model.tree_beg_, d_model.tree_end_, num_features, num_rows,
|
d_model.tree_beg_, d_model.tree_end_, num_features, num_rows,
|
||||||
entry_start, use_shared, nan(""));
|
entry_start, use_shared, nan(""));
|
||||||
batch_offset += batch.Size();
|
batch_offset += batch.Size();
|
||||||
@ -862,6 +876,12 @@ class GPUPredictor : public xgboost::Predictor {
|
|||||||
d_model.nodes.ConstDeviceSpan(),
|
d_model.nodes.ConstDeviceSpan(),
|
||||||
predictions->DeviceSpan().subspan(batch_offset),
|
predictions->DeviceSpan().subspan(batch_offset),
|
||||||
d_model.tree_segments.ConstDeviceSpan(),
|
d_model.tree_segments.ConstDeviceSpan(),
|
||||||
|
|
||||||
|
d_model.split_types.ConstDeviceSpan(),
|
||||||
|
d_model.categories_tree_segments.ConstDeviceSpan(),
|
||||||
|
d_model.categories_node_segments.ConstDeviceSpan(),
|
||||||
|
d_model.categories.ConstDeviceSpan(),
|
||||||
|
|
||||||
d_model.tree_beg_, d_model.tree_end_, num_features, num_rows,
|
d_model.tree_beg_, d_model.tree_end_, num_features, num_rows,
|
||||||
entry_start, use_shared, nan(""));
|
entry_start, use_shared, nan(""));
|
||||||
batch_offset += batch.Size();
|
batch_offset += batch.Size();
|
||||||
|
|||||||
31
src/predictor/predict_fn.h
Normal file
31
src/predictor/predict_fn.h
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
/*!
|
||||||
|
* Copyright 2021 by XGBoost Contributors
|
||||||
|
*/
|
||||||
|
#ifndef XGBOOST_PREDICTOR_PREDICT_FN_H_
|
||||||
|
#define XGBOOST_PREDICTOR_PREDICT_FN_H_
|
||||||
|
#include "../common/categorical.h"
|
||||||
|
#include "xgboost/tree_model.h"
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
namespace predictor {
|
||||||
|
template <bool has_missing, bool has_categorical>
|
||||||
|
inline XGBOOST_DEVICE bst_node_t
|
||||||
|
GetNextNode(const RegTree::Node &node, const bst_node_t nid, float fvalue,
|
||||||
|
bool is_missing, RegTree::CategoricalSplitMatrix const &cats) {
|
||||||
|
if (has_missing && is_missing) {
|
||||||
|
return node.DefaultChild();
|
||||||
|
} else {
|
||||||
|
if (has_categorical && common::IsCat(cats.split_type, nid)) {
|
||||||
|
auto node_categories = cats.categories.subspan(cats.node_ptr[nid].beg,
|
||||||
|
cats.node_ptr[nid].size);
|
||||||
|
return Decision(node_categories, common::AsCat(fvalue))
|
||||||
|
? node.LeftChild()
|
||||||
|
: node.RightChild();
|
||||||
|
} else {
|
||||||
|
return node.LeftChild() + !(fvalue < node.SplitCond());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace predictor
|
||||||
|
} // namespace xgboost
|
||||||
|
#endif // XGBOOST_PREDICTOR_PREDICT_FN_H_
|
||||||
@ -19,6 +19,7 @@
|
|||||||
#include "param.h"
|
#include "param.h"
|
||||||
#include "../common/common.h"
|
#include "../common/common.h"
|
||||||
#include "../common/categorical.h"
|
#include "../common/categorical.h"
|
||||||
|
#include "../predictor/predict_fn.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
// register tree parameter
|
// register tree parameter
|
||||||
@ -1052,10 +1053,15 @@ void RegTree::CalculateContributionsApprox(const RegTree::FVec &feat,
|
|||||||
// nothing to do anymore
|
// nothing to do anymore
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
bst_node_t nid = 0;
|
bst_node_t nid = 0;
|
||||||
|
auto cats = this->GetCategoriesMatrix();
|
||||||
|
|
||||||
while (!(*this)[nid].IsLeaf()) {
|
while (!(*this)[nid].IsLeaf()) {
|
||||||
split_index = (*this)[nid].SplitIndex();
|
split_index = (*this)[nid].SplitIndex();
|
||||||
nid = this->GetNext(nid, feat.GetFvalue(split_index), feat.IsMissing(split_index));
|
nid = predictor::GetNextNode<true, true>((*this)[nid], nid,
|
||||||
|
feat.GetFvalue(split_index),
|
||||||
|
feat.IsMissing(split_index), cats);
|
||||||
bst_float new_value = this->node_mean_values_[nid];
|
bst_float new_value = this->node_mean_values_[nid];
|
||||||
// update feature weight
|
// update feature weight
|
||||||
out_contribs[split_index] += new_value - node_value;
|
out_contribs[split_index] += new_value - node_value;
|
||||||
|
|||||||
@ -14,6 +14,7 @@
|
|||||||
#include "./param.h"
|
#include "./param.h"
|
||||||
#include "../common/io.h"
|
#include "../common/io.h"
|
||||||
#include "../common/threading_utils.h"
|
#include "../common/threading_utils.h"
|
||||||
|
#include "../predictor/predict_fn.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace tree {
|
namespace tree {
|
||||||
@ -123,10 +124,13 @@ class TreeRefresher: public TreeUpdater {
|
|||||||
// start from groups that belongs to current data
|
// start from groups that belongs to current data
|
||||||
auto pid = 0;
|
auto pid = 0;
|
||||||
gstats[pid].Add(gpair[ridx]);
|
gstats[pid].Add(gpair[ridx]);
|
||||||
|
auto const& cats = tree.GetCategoriesMatrix();
|
||||||
// traverse tree
|
// traverse tree
|
||||||
while (!tree[pid].IsLeaf()) {
|
while (!tree[pid].IsLeaf()) {
|
||||||
unsigned split_index = tree[pid].SplitIndex();
|
unsigned split_index = tree[pid].SplitIndex();
|
||||||
pid = tree.GetNext(pid, feat.GetFvalue(split_index), feat.IsMissing(split_index));
|
pid = predictor::GetNextNode<true, true>(
|
||||||
|
tree[pid], pid, feat.GetFvalue(split_index), feat.IsMissing(split_index),
|
||||||
|
cats);
|
||||||
gstats[pid].Add(gpair[ridx]);
|
gstats[pid].Add(gpair[ridx]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -229,9 +229,17 @@ void TestUpdatePredictionCache(bool use_subsampling) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(CPUPredictor, CategoricalPrediction) {
|
||||||
|
TestCategoricalPrediction("cpu_predictor");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CPUPredictor, CategoricalPredictLeaf) {
|
||||||
|
TestCategoricalPredictLeaf(StringView{"cpu_predictor"});
|
||||||
|
}
|
||||||
|
|
||||||
TEST(CpuPredictor, UpdatePredictionCache) {
|
TEST(CpuPredictor, UpdatePredictionCache) {
|
||||||
TestUpdatePredictionCache(false);
|
TestUpdatePredictionCache(false);
|
||||||
TestUpdatePredictionCache(true);
|
TestUpdatePredictionCache(true);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(CpuPredictor, LesserFeatures) {
|
TEST(CpuPredictor, LesserFeatures) {
|
||||||
|
|||||||
@ -228,6 +228,10 @@ TEST(GPUPredictor, CategoricalPrediction) {
|
|||||||
TestCategoricalPrediction("gpu_predictor");
|
TestCategoricalPrediction("gpu_predictor");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(GPUPredictor, CategoricalPredictLeaf) {
|
||||||
|
TestCategoricalPredictLeaf(StringView{"gpu_predictor"});
|
||||||
|
}
|
||||||
|
|
||||||
TEST(GPUPredictor, PredictLeafBasic) {
|
TEST(GPUPredictor, PredictLeafBasic) {
|
||||||
size_t constexpr kRows = 5, kCols = 5;
|
size_t constexpr kRows = 5, kCols = 5;
|
||||||
auto dmat = RandomDataGenerator(kRows, kCols, 0).Device(0).GenerateDMatrix();
|
auto dmat = RandomDataGenerator(kRows, kCols, 0).Device(0).GenerateDMatrix();
|
||||||
|
|||||||
@ -180,6 +180,25 @@ void TestPredictionWithLesserFeatures(std::string predictor_name) {
|
|||||||
#endif // defined(XGBOOST_USE_CUDA)
|
#endif // defined(XGBOOST_USE_CUDA)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void GBTreeModelForTest(gbm::GBTreeModel *model, uint32_t split_ind,
|
||||||
|
bst_cat_t split_cat, float left_weight,
|
||||||
|
float right_weight) {
|
||||||
|
PredictionCacheEntry out_predictions;
|
||||||
|
|
||||||
|
std::vector<std::unique_ptr<RegTree>> trees;
|
||||||
|
trees.push_back(std::unique_ptr<RegTree>(new RegTree));
|
||||||
|
auto& p_tree = trees.front();
|
||||||
|
|
||||||
|
std::vector<uint32_t> split_cats(LBitField32::ComputeStorageSize(split_cat));
|
||||||
|
LBitField32 cats_bits(split_cats);
|
||||||
|
cats_bits.Set(split_cat);
|
||||||
|
|
||||||
|
p_tree->ExpandCategorical(0, split_ind, split_cats, true, 1.5f,
|
||||||
|
left_weight, right_weight,
|
||||||
|
3.0f, 2.2f, 7.0f, 9.0f);
|
||||||
|
model->CommitModel(std::move(trees), 0);
|
||||||
|
}
|
||||||
|
|
||||||
void TestCategoricalPrediction(std::string name) {
|
void TestCategoricalPrediction(std::string name) {
|
||||||
size_t constexpr kCols = 10;
|
size_t constexpr kCols = 10;
|
||||||
PredictionCacheEntry out_predictions;
|
PredictionCacheEntry out_predictions;
|
||||||
@ -189,25 +208,13 @@ void TestCategoricalPrediction(std::string name) {
|
|||||||
param.num_output_group = 1;
|
param.num_output_group = 1;
|
||||||
param.base_score = 0.5;
|
param.base_score = 0.5;
|
||||||
|
|
||||||
gbm::GBTreeModel model(¶m);
|
|
||||||
|
|
||||||
std::vector<std::unique_ptr<RegTree>> trees;
|
|
||||||
trees.push_back(std::unique_ptr<RegTree>(new RegTree));
|
|
||||||
auto& p_tree = trees.front();
|
|
||||||
|
|
||||||
uint32_t split_ind = 3;
|
uint32_t split_ind = 3;
|
||||||
bst_cat_t split_cat = 4;
|
bst_cat_t split_cat = 4;
|
||||||
float left_weight = 1.3f;
|
float left_weight = 1.3f;
|
||||||
float right_weight = 1.7f;
|
float right_weight = 1.7f;
|
||||||
|
|
||||||
std::vector<uint32_t> split_cats(LBitField32::ComputeStorageSize(split_cat));
|
gbm::GBTreeModel model(¶m);
|
||||||
LBitField32 cats_bits(split_cats);
|
GBTreeModelForTest(&model, split_ind, split_cat, left_weight, right_weight);
|
||||||
cats_bits.Set(split_cat);
|
|
||||||
|
|
||||||
p_tree->ExpandCategorical(0, split_ind, split_cats, true, 1.5f,
|
|
||||||
left_weight, right_weight,
|
|
||||||
3.0f, 2.2f, 7.0f, 9.0f);
|
|
||||||
model.CommitModel(std::move(trees), 0);
|
|
||||||
|
|
||||||
GenericParameter runtime;
|
GenericParameter runtime;
|
||||||
runtime.gpu_id = 0;
|
runtime.gpu_id = 0;
|
||||||
@ -232,4 +239,43 @@ void TestCategoricalPrediction(std::string name) {
|
|||||||
ASSERT_EQ(out_predictions.predictions.HostVector()[0],
|
ASSERT_EQ(out_predictions.predictions.HostVector()[0],
|
||||||
left_weight + param.base_score);
|
left_weight + param.base_score);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void TestCategoricalPredictLeaf(StringView name) {
|
||||||
|
size_t constexpr kCols = 10;
|
||||||
|
PredictionCacheEntry out_predictions;
|
||||||
|
|
||||||
|
LearnerModelParam param;
|
||||||
|
param.num_feature = kCols;
|
||||||
|
param.num_output_group = 1;
|
||||||
|
param.base_score = 0.5;
|
||||||
|
|
||||||
|
uint32_t split_ind = 3;
|
||||||
|
bst_cat_t split_cat = 4;
|
||||||
|
float left_weight = 1.3f;
|
||||||
|
float right_weight = 1.7f;
|
||||||
|
|
||||||
|
gbm::GBTreeModel model(¶m);
|
||||||
|
GBTreeModelForTest(&model, split_ind, split_cat, left_weight, right_weight);
|
||||||
|
|
||||||
|
GenericParameter runtime;
|
||||||
|
runtime.gpu_id = 0;
|
||||||
|
std::unique_ptr<Predictor> predictor{
|
||||||
|
Predictor::Create(name.c_str(), &runtime)};
|
||||||
|
|
||||||
|
std::vector<float> row(kCols);
|
||||||
|
row[split_ind] = split_cat;
|
||||||
|
auto m = GetDMatrixFromData(row, 1, kCols);
|
||||||
|
|
||||||
|
predictor->PredictLeaf(m.get(), &out_predictions.predictions, model);
|
||||||
|
CHECK_EQ(out_predictions.predictions.Size(), 1);
|
||||||
|
// go to left if it doesn't match the category, otherwise right.
|
||||||
|
ASSERT_EQ(out_predictions.predictions.HostVector()[0], 2);
|
||||||
|
|
||||||
|
row[split_ind] = split_cat + 1;
|
||||||
|
m = GetDMatrixFromData(row, 1, kCols);
|
||||||
|
out_predictions.version = 0;
|
||||||
|
predictor->InitOutPredictions(m->Info(), &out_predictions.predictions, model);
|
||||||
|
predictor->PredictLeaf(m.get(), &out_predictions.predictions, model);
|
||||||
|
ASSERT_EQ(out_predictions.predictions.HostVector()[0], 1);
|
||||||
|
}
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -66,6 +66,8 @@ void TestInplacePrediction(dmlc::any x, std::string predictor,
|
|||||||
void TestPredictionWithLesserFeatures(std::string preditor_name);
|
void TestPredictionWithLesserFeatures(std::string preditor_name);
|
||||||
|
|
||||||
void TestCategoricalPrediction(std::string name);
|
void TestCategoricalPrediction(std::string name);
|
||||||
|
|
||||||
|
void TestCategoricalPredictLeaf(StringView name);
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|
||||||
#endif // XGBOOST_TEST_PREDICTOR_H_
|
#endif // XGBOOST_TEST_PREDICTOR_H_
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user