Implement categorical prediction for CPU and GPU predict leaf. (#7001)
* Categorical prediction with CPU predictor and GPU predict leaf. * Implement categorical prediction for CPU prediction. * Implement categorical prediction for GPU predict leaf. * Refactor the prediction functions to have a unified get next node function. Co-authored-by: Shvets Kirill <kirill.shvets@intel.com>
This commit is contained in:
@@ -19,6 +19,7 @@
|
||||
#include "param.h"
|
||||
#include "../common/common.h"
|
||||
#include "../common/categorical.h"
|
||||
#include "../predictor/predict_fn.h"
|
||||
|
||||
namespace xgboost {
|
||||
// register tree parameter
|
||||
@@ -1052,10 +1053,15 @@ void RegTree::CalculateContributionsApprox(const RegTree::FVec &feat,
|
||||
// nothing to do anymore
|
||||
return;
|
||||
}
|
||||
|
||||
bst_node_t nid = 0;
|
||||
auto cats = this->GetCategoriesMatrix();
|
||||
|
||||
while (!(*this)[nid].IsLeaf()) {
|
||||
split_index = (*this)[nid].SplitIndex();
|
||||
nid = this->GetNext(nid, feat.GetFvalue(split_index), feat.IsMissing(split_index));
|
||||
nid = predictor::GetNextNode<true, true>((*this)[nid], nid,
|
||||
feat.GetFvalue(split_index),
|
||||
feat.IsMissing(split_index), cats);
|
||||
bst_float new_value = this->node_mean_values_[nid];
|
||||
// update feature weight
|
||||
out_contribs[split_index] += new_value - node_value;
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
#include "./param.h"
|
||||
#include "../common/io.h"
|
||||
#include "../common/threading_utils.h"
|
||||
#include "../predictor/predict_fn.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
@@ -123,10 +124,13 @@ class TreeRefresher: public TreeUpdater {
|
||||
// start from groups that belongs to current data
|
||||
auto pid = 0;
|
||||
gstats[pid].Add(gpair[ridx]);
|
||||
auto const& cats = tree.GetCategoriesMatrix();
|
||||
// traverse tree
|
||||
while (!tree[pid].IsLeaf()) {
|
||||
unsigned split_index = tree[pid].SplitIndex();
|
||||
pid = tree.GetNext(pid, feat.GetFvalue(split_index), feat.IsMissing(split_index));
|
||||
pid = predictor::GetNextNode<true, true>(
|
||||
tree[pid], pid, feat.GetFvalue(split_index), feat.IsMissing(split_index),
|
||||
cats);
|
||||
gstats[pid].Add(gpair[ridx]);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user