Requires setting leaf stat when expanding tree. (#5501)
* Fix GPU Hist feature importance.
This commit is contained in:
@@ -607,6 +607,8 @@ XGBOOST_REGISTER_TREE_IO(GraphvizGenerator, "dot")
|
||||
return new GraphvizGenerator(fmap, attrs, with_stats);
|
||||
});
|
||||
|
||||
constexpr bst_node_t RegTree::kRoot;
|
||||
|
||||
std::string RegTree::DumpModel(const FeatureMap& fmap,
|
||||
bool with_stats,
|
||||
std::string format) const {
|
||||
@@ -623,26 +625,40 @@ bool RegTree::Equal(const RegTree& b) const {
|
||||
if (NumExtraNodes() != b.NumExtraNodes()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::stack<bst_node_t> nodes;
|
||||
nodes.push(0);
|
||||
auto& self = *this;
|
||||
while (!nodes.empty()) {
|
||||
auto nid = nodes.top();
|
||||
nodes.pop();
|
||||
if (!(self.nodes_.at(nid) == b.nodes_.at(nid))) {
|
||||
auto const& self = *this;
|
||||
bool ret { true };
|
||||
this->WalkTree([&self, &b, &ret](bst_node_t nidx) {
|
||||
if (!(self.nodes_.at(nidx) == b.nodes_.at(nidx))) {
|
||||
ret = false;
|
||||
return false;
|
||||
}
|
||||
auto left = self[nid].LeftChild();
|
||||
auto right = self[nid].RightChild();
|
||||
if (left != RegTree::kInvalidNodeId) {
|
||||
nodes.push(left);
|
||||
}
|
||||
if (right != RegTree::kInvalidNodeId) {
|
||||
nodes.push(right);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
return true;
|
||||
});
|
||||
return ret;
|
||||
}
|
||||
|
||||
bst_node_t RegTree::GetNumLeaves() const {
|
||||
bst_node_t leaves { 0 };
|
||||
auto const& self = *this;
|
||||
this->WalkTree([&leaves, &self](bst_node_t nidx) {
|
||||
if (self[nidx].IsLeaf()) {
|
||||
leaves++;
|
||||
}
|
||||
return true;
|
||||
});
|
||||
return leaves;
|
||||
}
|
||||
|
||||
bst_node_t RegTree::GetNumSplitNodes() const {
|
||||
bst_node_t splits { 0 };
|
||||
auto const& self = *this;
|
||||
this->WalkTree([&splits, &self](bst_node_t nidx) {
|
||||
if (!self[nidx].IsLeaf()) {
|
||||
splits++;
|
||||
}
|
||||
return true;
|
||||
});
|
||||
return splits;
|
||||
}
|
||||
|
||||
void RegTree::Load(dmlc::Stream* fi) {
|
||||
|
||||
@@ -499,7 +499,9 @@ class ColMaker: public TreeUpdater {
|
||||
p_tree->ExpandNode(nid, e.best.SplitIndex(), e.best.split_value,
|
||||
e.best.DefaultLeft(), e.weight, left_leaf_weight,
|
||||
right_leaf_weight, e.best.loss_chg,
|
||||
e.stats.sum_hess, 0);
|
||||
e.stats.sum_hess,
|
||||
e.best.left_sum.GetHess(), e.best.right_sum.GetHess(),
|
||||
0);
|
||||
} else {
|
||||
(*p_tree)[nid].SetLeaf(e.weight * param_.learning_rate);
|
||||
}
|
||||
|
||||
@@ -814,7 +814,8 @@ struct GPUHistMakerDevice {
|
||||
tree.ExpandNode(candidate.nid, candidate.split.findex,
|
||||
candidate.split.fvalue, candidate.split.dir == kLeftDir,
|
||||
base_weight, left_weight, right_weight,
|
||||
candidate.split.loss_chg, parent_sum.sum_hess);
|
||||
candidate.split.loss_chg, parent_sum.sum_hess,
|
||||
left_stats.GetHess(), right_stats.GetHess());
|
||||
// Set up child constraints
|
||||
node_value_constraints.resize(tree.GetNodes().size());
|
||||
node_value_constraints[candidate.nid].SetChild(
|
||||
|
||||
@@ -249,7 +249,8 @@ class HistMaker: public BaseMaker {
|
||||
p_tree->ExpandNode(nid, best.SplitIndex(), best.split_value,
|
||||
best.DefaultLeft(), base_weight, left_leaf_weight,
|
||||
right_leaf_weight, best.loss_chg,
|
||||
node_sum.sum_hess);
|
||||
node_sum.sum_hess,
|
||||
best.left_sum.GetHess(), best.right_sum.GetHess());
|
||||
GradStats right_sum;
|
||||
right_sum.SetSubstract(node_sum, left_sum[wid]);
|
||||
auto left_child = (*p_tree)[nid].LeftChild();
|
||||
|
||||
@@ -263,7 +263,8 @@ void QuantileHistMaker::Builder::AddSplitsToTree(
|
||||
spliteval_->ComputeWeight(nid, e.best.right_sum) * param_.learning_rate;
|
||||
p_tree->ExpandNode(nid, e.best.SplitIndex(), e.best.split_value,
|
||||
e.best.DefaultLeft(), e.weight, left_leaf_weight,
|
||||
right_leaf_weight, e.best.loss_chg, e.stats.sum_hess);
|
||||
right_leaf_weight, e.best.loss_chg, e.stats.sum_hess,
|
||||
e.best.left_sum.GetHess(), e.best.right_sum.GetHess());
|
||||
|
||||
int left_id = (*p_tree)[nid].LeftChild();
|
||||
int right_id = (*p_tree)[nid].RightChild();
|
||||
@@ -410,7 +411,8 @@ void QuantileHistMaker::Builder::ExpandWithLossGuide(
|
||||
spliteval_->ComputeWeight(nid, e.best.right_sum) * param_.learning_rate;
|
||||
p_tree->ExpandNode(nid, e.best.SplitIndex(), e.best.split_value,
|
||||
e.best.DefaultLeft(), e.weight, left_leaf_weight,
|
||||
right_leaf_weight, e.best.loss_chg, e.stats.sum_hess);
|
||||
right_leaf_weight, e.best.loss_chg, e.stats.sum_hess,
|
||||
e.best.left_sum.GetHess(), e.best.right_sum.GetHess());
|
||||
|
||||
this->ApplySplit({candidate}, gmat, column_matrix, hist_, p_tree);
|
||||
|
||||
|
||||
@@ -289,7 +289,8 @@ class SketchMaker: public BaseMaker {
|
||||
p_tree->ExpandNode(nid, best.SplitIndex(), best.split_value,
|
||||
best.DefaultLeft(), base_weight, left_leaf_weight,
|
||||
right_leaf_weight, best.loss_chg,
|
||||
node_stats_[nid].sum_hess);
|
||||
node_stats_[nid].sum_hess,
|
||||
best.left_sum.GetHess(), best.right_sum.GetHess());
|
||||
} else {
|
||||
(*p_tree)[nid].SetLeaf(p_tree->Stat(nid).base_weight * param_.learning_rate);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user