Predict on Ellpack. (#5327)

* Unify GPU prediction node.
* Add `PageExists`.
* Dispatch prediction on input data for GPU Predictor.
This commit is contained in:
Jiaming Yuan
2020-02-23 06:27:03 +08:00
committed by GitHub
parent 70a91ec3ba
commit 655cf17b60
19 changed files with 320 additions and 134 deletions

View File

@@ -792,7 +792,7 @@ void RegTree::CalculateContributionsApprox(const RegTree::FVec &feat,
bst_node_t nid = 0;
while (!(*this)[nid].IsLeaf()) {
split_index = (*this)[nid].SplitIndex();
nid = this->GetNext(nid, feat.Fvalue(split_index), feat.IsMissing(split_index));
nid = this->GetNext(nid, feat.GetFvalue(split_index), feat.IsMissing(split_index));
bst_float new_value = this->node_mean_values_[nid];
// update feature weight
out_contribs[split_index] += new_value - node_value;
@@ -924,7 +924,7 @@ void RegTree::TreeShap(const RegTree::FVec &feat, bst_float *phi,
unsigned hot_index = 0;
if (feat.IsMissing(split_index)) {
hot_index = node.DefaultChild();
} else if (feat.Fvalue(split_index) < node.SplitCond()) {
} else if (feat.GetFvalue(split_index) < node.SplitCond()) {
hot_index = node.LeftChild();
} else {
hot_index = node.RightChild();

View File

@@ -688,7 +688,7 @@ struct GPUHistMakerDevice {
[=] __device__(bst_uint ridx) {
// given a row index, returns the node id it belongs to
bst_float cut_value =
d_matrix.GetElement(ridx, split_node.SplitIndex());
d_matrix.GetFvalue(ridx, split_node.SplitIndex());
// Missing value
int new_position = 0;
if (isnan(cut_value)) {
@@ -737,7 +737,7 @@ struct GPUHistMakerDevice {
auto node = d_nodes[position];
while (!node.IsLeaf()) {
bst_float element = d_matrix.GetElement(row_id, node.SplitIndex());
bst_float element = d_matrix.GetFvalue(row_id, node.SplitIndex());
// Missing value
if (isnan(element)) {
position = node.DefaultChild();

View File

@@ -119,7 +119,7 @@ class TreeRefresher: public TreeUpdater {
// tranverse tree
while (!tree[pid].IsLeaf()) {
unsigned split_index = tree[pid].SplitIndex();
pid = tree.GetNext(pid, feat.Fvalue(split_index), feat.IsMissing(split_index));
pid = tree.GetNext(pid, feat.GetFvalue(split_index), feat.IsMissing(split_index));
gstats[pid].Add(gpair[ridx]);
}
}