Implement categorical prediction for CPU and GPU predict leaf. (#7001)

* Categorical prediction with CPU predictor and GPU predict leaf.

* Implement categorical prediction for CPU prediction.
* Implement categorical prediction for GPU predict leaf.
* Refactor the prediction functions to have a unified get next node function.

Co-authored-by: Shvets Kirill <kirill.shvets@intel.com>
This commit is contained in:
Jiaming Yuan
2021-06-11 10:11:45 +08:00
committed by GitHub
parent 72f9daf9b6
commit f79cc4a7a4
10 changed files with 340 additions and 200 deletions

View File

@@ -445,6 +445,10 @@ class RegTree : public Model {
bst_float right_leaf_weight, bst_float loss_change,
float sum_hess, float left_sum, float right_sum);
bool HasCategoricalSplit() const {
return !split_categories_.empty();
}
/*!
* \brief get current depth
* \param nid node id
@@ -537,13 +541,6 @@ class RegTree : public Model {
std::vector<Entry> data_;
bool has_missing_;
};
/*!
* \brief get the leaf index
* \param feat dense feature vector, if the feature is missing the field is set to NaN
* \return the leaf index of the given feature
*/
template <bool has_missing = true>
int GetLeafIndex(const FVec& feat) const;
/*!
* \brief calculate the feature contributions (https://arxiv.org/abs/1706.06060) for the tree
@@ -582,14 +579,6 @@ class RegTree : public Model {
*/
void CalculateContributionsApprox(const RegTree::FVec& feat,
bst_float* out_contribs) const;
/*!
* \brief get next position of the tree given current pid
* \param pid Current node id.
* \param fvalue feature value if not missing.
* \param is_unknown Whether current required feature is missing.
*/
template <bool has_missing = true>
inline int GetNext(int pid, bst_float fvalue, bool is_unknown) const;
/*!
* \brief dump the model in the requested format as a text string
* \param fmap feature map that may help give interpretations of feature
@@ -627,6 +616,20 @@ class RegTree : public Model {
size_t size {0};
};
struct CategoricalSplitMatrix {
common::Span<FeatureType const> split_type;
common::Span<uint32_t const> categories;
common::Span<Segment const> node_ptr;
};
CategoricalSplitMatrix GetCategoriesMatrix() const {
CategoricalSplitMatrix view;
view.split_type = common::Span<FeatureType const>(this->GetSplitTypes());
view.categories = this->GetSplitCategories();
view.node_ptr = common::Span<Segment const>(split_categories_segments_);
return view;
}
private:
void LoadCategoricalSplit(Json const& in);
void SaveCategoricalSplit(Json* p_out) const;
@@ -724,38 +727,5 @@ inline bool RegTree::FVec::IsMissing(size_t i) const {
inline bool RegTree::FVec::HasMissing() const {
return has_missing_;
}
template <bool has_missing>
inline int RegTree::GetLeafIndex(const RegTree::FVec& feat) const {
bst_node_t nid = 0;
while (!(*this)[nid].IsLeaf()) {
unsigned split_index = (*this)[nid].SplitIndex();
nid = this->GetNext<has_missing>(nid, feat.GetFvalue(split_index),
has_missing && feat.IsMissing(split_index));
}
return nid;
}
/*! \brief get next position of the tree given current pid */
template <bool has_missing>
inline int RegTree::GetNext(int pid, bst_float fvalue, bool is_unknown) const {
if (has_missing) {
if (is_unknown) {
return (*this)[pid].DefaultChild();
} else {
if (fvalue < (*this)[pid].SplitCond()) {
return (*this)[pid].LeftChild();
} else {
return (*this)[pid].RightChild();
}
}
} else {
// 35% speed up due to reduced miss branch predictions
// The following expression returns the left child if (fvalue < split_cond);
// the right child otherwise.
return (*this)[pid].LeftChild() + !(fvalue < (*this)[pid].SplitCond());
}
}
} // namespace xgboost
#endif // XGBOOST_TREE_MODEL_H_