Predict on Ellpack. (#5327)
* Unify GPU prediction node. * Add `PageExists`. * Dispatch prediction on input data for GPU Predictor.
This commit is contained in:
@@ -792,7 +792,7 @@ void RegTree::CalculateContributionsApprox(const RegTree::FVec &feat,
|
||||
bst_node_t nid = 0;
|
||||
while (!(*this)[nid].IsLeaf()) {
|
||||
split_index = (*this)[nid].SplitIndex();
|
||||
nid = this->GetNext(nid, feat.Fvalue(split_index), feat.IsMissing(split_index));
|
||||
nid = this->GetNext(nid, feat.GetFvalue(split_index), feat.IsMissing(split_index));
|
||||
bst_float new_value = this->node_mean_values_[nid];
|
||||
// update feature weight
|
||||
out_contribs[split_index] += new_value - node_value;
|
||||
@@ -924,7 +924,7 @@ void RegTree::TreeShap(const RegTree::FVec &feat, bst_float *phi,
|
||||
unsigned hot_index = 0;
|
||||
if (feat.IsMissing(split_index)) {
|
||||
hot_index = node.DefaultChild();
|
||||
} else if (feat.Fvalue(split_index) < node.SplitCond()) {
|
||||
} else if (feat.GetFvalue(split_index) < node.SplitCond()) {
|
||||
hot_index = node.LeftChild();
|
||||
} else {
|
||||
hot_index = node.RightChild();
|
||||
|
||||
@@ -688,7 +688,7 @@ struct GPUHistMakerDevice {
|
||||
[=] __device__(bst_uint ridx) {
|
||||
// given a row index, returns the node id it belongs to
|
||||
bst_float cut_value =
|
||||
d_matrix.GetElement(ridx, split_node.SplitIndex());
|
||||
d_matrix.GetFvalue(ridx, split_node.SplitIndex());
|
||||
// Missing value
|
||||
int new_position = 0;
|
||||
if (isnan(cut_value)) {
|
||||
@@ -737,7 +737,7 @@ struct GPUHistMakerDevice {
|
||||
auto node = d_nodes[position];
|
||||
|
||||
while (!node.IsLeaf()) {
|
||||
bst_float element = d_matrix.GetElement(row_id, node.SplitIndex());
|
||||
bst_float element = d_matrix.GetFvalue(row_id, node.SplitIndex());
|
||||
// Missing value
|
||||
if (isnan(element)) {
|
||||
position = node.DefaultChild();
|
||||
|
||||
@@ -119,7 +119,7 @@ class TreeRefresher: public TreeUpdater {
|
||||
// tranverse tree
|
||||
while (!tree[pid].IsLeaf()) {
|
||||
unsigned split_index = tree[pid].SplitIndex();
|
||||
pid = tree.GetNext(pid, feat.Fvalue(split_index), feat.IsMissing(split_index));
|
||||
pid = tree.GetNext(pid, feat.GetFvalue(split_index), feat.IsMissing(split_index));
|
||||
gstats[pid].Add(gpair[ridx]);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user