Implement categorical prediction for CPU and GPU predict leaf. (#7001)

* Categorical prediction with CPU predictor and GPU predict leaf.

* Implement categorical prediction for CPU prediction.
* Implement categorical prediction for GPU predict leaf.
* Refactor the prediction functions to have a unified get next node function.

Co-authored-by: Shvets Kirill <kirill.shvets@intel.com>
This commit is contained in:
Jiaming Yuan
2021-06-11 10:11:45 +08:00
committed by GitHub
parent 72f9daf9b6
commit f79cc4a7a4
10 changed files with 340 additions and 200 deletions

View File

@@ -19,6 +19,7 @@
#include "param.h"
#include "../common/common.h"
#include "../common/categorical.h"
#include "../predictor/predict_fn.h"
namespace xgboost {
// register tree parameter
@@ -1052,10 +1053,15 @@ void RegTree::CalculateContributionsApprox(const RegTree::FVec &feat,
// nothing to do anymore
return;
}
bst_node_t nid = 0;
auto cats = this->GetCategoriesMatrix();
while (!(*this)[nid].IsLeaf()) {
split_index = (*this)[nid].SplitIndex();
nid = this->GetNext(nid, feat.GetFvalue(split_index), feat.IsMissing(split_index));
nid = predictor::GetNextNode<true, true>((*this)[nid], nid,
feat.GetFvalue(split_index),
feat.IsMissing(split_index), cats);
bst_float new_value = this->node_mean_values_[nid];
// update feature weight
out_contribs[split_index] += new_value - node_value;

View File

@@ -14,6 +14,7 @@
#include "./param.h"
#include "../common/io.h"
#include "../common/threading_utils.h"
#include "../predictor/predict_fn.h"
namespace xgboost {
namespace tree {
@@ -123,10 +124,13 @@ class TreeRefresher: public TreeUpdater {
// start from groups that belongs to current data
auto pid = 0;
gstats[pid].Add(gpair[ridx]);
auto const& cats = tree.GetCategoriesMatrix();
// traverse tree
while (!tree[pid].IsLeaf()) {
unsigned split_index = tree[pid].SplitIndex();
pid = tree.GetNext(pid, feat.GetFvalue(split_index), feat.IsMissing(split_index));
pid = predictor::GetNextNode<true, true>(
tree[pid], pid, feat.GetFvalue(split_index), feat.IsMissing(split_index),
cats);
gstats[pid].Add(gpair[ridx]);
}
}