Implement categorical prediction for CPU and GPU predict leaf. (#7001)
* Categorical prediction with CPU predictor and GPU predict leaf. * Implement categorical prediction for CPU prediction. * Implement categorical prediction for GPU predict leaf. * Refactor the prediction functions to have a unified get next node function. Co-authored-by: Shvets Kirill <kirill.shvets@intel.com>
This commit is contained in:
@@ -445,6 +445,10 @@ class RegTree : public Model {
|
||||
bst_float right_leaf_weight, bst_float loss_change,
|
||||
float sum_hess, float left_sum, float right_sum);
|
||||
|
||||
bool HasCategoricalSplit() const {
|
||||
return !split_categories_.empty();
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief get current depth
|
||||
* \param nid node id
|
||||
@@ -537,13 +541,6 @@ class RegTree : public Model {
|
||||
std::vector<Entry> data_;
|
||||
bool has_missing_;
|
||||
};
|
||||
/*!
|
||||
* \brief get the leaf index
|
||||
* \param feat dense feature vector, if the feature is missing the field is set to NaN
|
||||
* \return the leaf index of the given feature
|
||||
*/
|
||||
template <bool has_missing = true>
|
||||
int GetLeafIndex(const FVec& feat) const;
|
||||
|
||||
/*!
|
||||
* \brief calculate the feature contributions (https://arxiv.org/abs/1706.06060) for the tree
|
||||
@@ -582,14 +579,6 @@ class RegTree : public Model {
|
||||
*/
|
||||
void CalculateContributionsApprox(const RegTree::FVec& feat,
|
||||
bst_float* out_contribs) const;
|
||||
/*!
|
||||
* \brief get next position of the tree given current pid
|
||||
* \param pid Current node id.
|
||||
* \param fvalue feature value if not missing.
|
||||
* \param is_unknown Whether current required feature is missing.
|
||||
*/
|
||||
template <bool has_missing = true>
|
||||
inline int GetNext(int pid, bst_float fvalue, bool is_unknown) const;
|
||||
/*!
|
||||
* \brief dump the model in the requested format as a text string
|
||||
* \param fmap feature map that may help give interpretations of feature
|
||||
@@ -627,6 +616,20 @@ class RegTree : public Model {
|
||||
size_t size {0};
|
||||
};
|
||||
|
||||
struct CategoricalSplitMatrix {
|
||||
common::Span<FeatureType const> split_type;
|
||||
common::Span<uint32_t const> categories;
|
||||
common::Span<Segment const> node_ptr;
|
||||
};
|
||||
|
||||
CategoricalSplitMatrix GetCategoriesMatrix() const {
|
||||
CategoricalSplitMatrix view;
|
||||
view.split_type = common::Span<FeatureType const>(this->GetSplitTypes());
|
||||
view.categories = this->GetSplitCategories();
|
||||
view.node_ptr = common::Span<Segment const>(split_categories_segments_);
|
||||
return view;
|
||||
}
|
||||
|
||||
private:
|
||||
void LoadCategoricalSplit(Json const& in);
|
||||
void SaveCategoricalSplit(Json* p_out) const;
|
||||
@@ -724,38 +727,5 @@ inline bool RegTree::FVec::IsMissing(size_t i) const {
|
||||
inline bool RegTree::FVec::HasMissing() const {
|
||||
return has_missing_;
|
||||
}
|
||||
|
||||
template <bool has_missing>
|
||||
inline int RegTree::GetLeafIndex(const RegTree::FVec& feat) const {
|
||||
bst_node_t nid = 0;
|
||||
while (!(*this)[nid].IsLeaf()) {
|
||||
unsigned split_index = (*this)[nid].SplitIndex();
|
||||
nid = this->GetNext<has_missing>(nid, feat.GetFvalue(split_index),
|
||||
has_missing && feat.IsMissing(split_index));
|
||||
}
|
||||
return nid;
|
||||
}
|
||||
|
||||
/*! \brief get next position of the tree given current pid */
|
||||
template <bool has_missing>
|
||||
inline int RegTree::GetNext(int pid, bst_float fvalue, bool is_unknown) const {
|
||||
if (has_missing) {
|
||||
if (is_unknown) {
|
||||
return (*this)[pid].DefaultChild();
|
||||
} else {
|
||||
if (fvalue < (*this)[pid].SplitCond()) {
|
||||
return (*this)[pid].LeftChild();
|
||||
} else {
|
||||
return (*this)[pid].RightChild();
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// 35% speed up due to reduced miss branch predictions
|
||||
// The following expression returns the left child if (fvalue < split_cond);
|
||||
// the right child otherwise.
|
||||
return (*this)[pid].LeftChild() + !(fvalue < (*this)[pid].SplitCond());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_TREE_MODEL_H_
|
||||
|
||||
Reference in New Issue
Block a user