Requires setting leaf stat when expanding tree. (#5501)

* Fix GPU Hist feature importance.
This commit is contained in:
Jiaming Yuan
2020-04-10 12:27:03 +08:00
committed by GitHub
parent dc2950fd90
commit 7d52c0b8c2
11 changed files with 179 additions and 50 deletions

View File

@@ -607,6 +607,8 @@ XGBOOST_REGISTER_TREE_IO(GraphvizGenerator, "dot")
return new GraphvizGenerator(fmap, attrs, with_stats);
});
constexpr bst_node_t RegTree::kRoot;
std::string RegTree::DumpModel(const FeatureMap& fmap,
bool with_stats,
std::string format) const {
@@ -623,26 +625,40 @@ bool RegTree::Equal(const RegTree& b) const {
if (NumExtraNodes() != b.NumExtraNodes()) {
return false;
}
std::stack<bst_node_t> nodes;
nodes.push(0);
auto& self = *this;
while (!nodes.empty()) {
auto nid = nodes.top();
nodes.pop();
if (!(self.nodes_.at(nid) == b.nodes_.at(nid))) {
auto const& self = *this;
bool ret { true };
this->WalkTree([&self, &b, &ret](bst_node_t nidx) {
if (!(self.nodes_.at(nidx) == b.nodes_.at(nidx))) {
ret = false;
return false;
}
auto left = self[nid].LeftChild();
auto right = self[nid].RightChild();
if (left != RegTree::kInvalidNodeId) {
nodes.push(left);
}
if (right != RegTree::kInvalidNodeId) {
nodes.push(right);
}
}
return true;
return true;
});
return ret;
}
bst_node_t RegTree::GetNumLeaves() const {
bst_node_t leaves { 0 };
auto const& self = *this;
this->WalkTree([&leaves, &self](bst_node_t nidx) {
if (self[nidx].IsLeaf()) {
leaves++;
}
return true;
});
return leaves;
}
bst_node_t RegTree::GetNumSplitNodes() const {
bst_node_t splits { 0 };
auto const& self = *this;
this->WalkTree([&splits, &self](bst_node_t nidx) {
if (!self[nidx].IsLeaf()) {
splits++;
}
return true;
});
return splits;
}
void RegTree::Load(dmlc::Stream* fi) {

View File

@@ -499,7 +499,9 @@ class ColMaker: public TreeUpdater {
p_tree->ExpandNode(nid, e.best.SplitIndex(), e.best.split_value,
e.best.DefaultLeft(), e.weight, left_leaf_weight,
right_leaf_weight, e.best.loss_chg,
e.stats.sum_hess, 0);
e.stats.sum_hess,
e.best.left_sum.GetHess(), e.best.right_sum.GetHess(),
0);
} else {
(*p_tree)[nid].SetLeaf(e.weight * param_.learning_rate);
}

View File

@@ -814,7 +814,8 @@ struct GPUHistMakerDevice {
tree.ExpandNode(candidate.nid, candidate.split.findex,
candidate.split.fvalue, candidate.split.dir == kLeftDir,
base_weight, left_weight, right_weight,
candidate.split.loss_chg, parent_sum.sum_hess);
candidate.split.loss_chg, parent_sum.sum_hess,
left_stats.GetHess(), right_stats.GetHess());
// Set up child constraints
node_value_constraints.resize(tree.GetNodes().size());
node_value_constraints[candidate.nid].SetChild(

View File

@@ -249,7 +249,8 @@ class HistMaker: public BaseMaker {
p_tree->ExpandNode(nid, best.SplitIndex(), best.split_value,
best.DefaultLeft(), base_weight, left_leaf_weight,
right_leaf_weight, best.loss_chg,
node_sum.sum_hess);
node_sum.sum_hess,
best.left_sum.GetHess(), best.right_sum.GetHess());
GradStats right_sum;
right_sum.SetSubstract(node_sum, left_sum[wid]);
auto left_child = (*p_tree)[nid].LeftChild();

View File

@@ -263,7 +263,8 @@ void QuantileHistMaker::Builder::AddSplitsToTree(
spliteval_->ComputeWeight(nid, e.best.right_sum) * param_.learning_rate;
p_tree->ExpandNode(nid, e.best.SplitIndex(), e.best.split_value,
e.best.DefaultLeft(), e.weight, left_leaf_weight,
right_leaf_weight, e.best.loss_chg, e.stats.sum_hess);
right_leaf_weight, e.best.loss_chg, e.stats.sum_hess,
e.best.left_sum.GetHess(), e.best.right_sum.GetHess());
int left_id = (*p_tree)[nid].LeftChild();
int right_id = (*p_tree)[nid].RightChild();
@@ -410,7 +411,8 @@ void QuantileHistMaker::Builder::ExpandWithLossGuide(
spliteval_->ComputeWeight(nid, e.best.right_sum) * param_.learning_rate;
p_tree->ExpandNode(nid, e.best.SplitIndex(), e.best.split_value,
e.best.DefaultLeft(), e.weight, left_leaf_weight,
right_leaf_weight, e.best.loss_chg, e.stats.sum_hess);
right_leaf_weight, e.best.loss_chg, e.stats.sum_hess,
e.best.left_sum.GetHess(), e.best.right_sum.GetHess());
this->ApplySplit({candidate}, gmat, column_matrix, hist_, p_tree);

View File

@@ -289,7 +289,8 @@ class SketchMaker: public BaseMaker {
p_tree->ExpandNode(nid, best.SplitIndex(), best.split_value,
best.DefaultLeft(), base_weight, left_leaf_weight,
right_leaf_weight, best.loss_chg,
node_stats_[nid].sum_hess);
node_stats_[nid].sum_hess,
best.left_sum.GetHess(), best.right_sum.GetHess());
} else {
(*p_tree)[nid].SetLeaf(p_tree->Stat(nid).base_weight * param_.learning_rate);
}