Reduce tree expand boilerplate code (#4008)

This commit is contained in:
Rory Mitchell
2018-12-20 15:52:28 +13:00
committed by GitHub
parent 84c99f86f4
commit f75a21af25
9 changed files with 34 additions and 43 deletions

View File

@@ -637,11 +637,7 @@ class ColMaker: public TreeUpdater {
NodeEntry &e = snode_[nid];
// now we know the solution in snode[nid], set split
if (e.best.loss_chg > kRtEps) {
p_tree->AddChilds(nid);
(*p_tree)[nid].SetSplit(e.best.SplitIndex(), e.best.split_value, e.best.DefaultLeft());
// mark right child as 0, to indicate fresh leaf
(*p_tree)[(*p_tree)[nid].LeftChild()].SetLeaf(0.0f, 0);
(*p_tree)[(*p_tree)[nid].RightChild()].SetLeaf(0.0f, 0);
p_tree->ExpandNode(nid, e.best.SplitIndex(), e.best.split_value, e.best.DefaultLeft());
} else {
(*p_tree)[nid].SetLeaf(e.weight * param_.learning_rate);
}

View File

@@ -296,13 +296,10 @@ inline void Dense2SparseTree(RegTree* p_tree,
for (int gpu_nid = 0; gpu_nid < h_nodes.size(); gpu_nid++) {
const DeviceNodeStats& n = h_nodes[gpu_nid];
if (!n.IsUnused() && !n.IsLeaf()) {
tree.AddChilds(nid);
tree[nid].SetSplit(n.fidx, n.fvalue, n.dir == kLeftDir);
tree.ExpandNode(nid, n.fidx, n.fvalue, n.dir == kLeftDir);
tree.Stat(nid).loss_chg = n.root_gain;
tree.Stat(nid).base_weight = n.weight;
tree.Stat(nid).sum_hess = n.sum_gradients.GetHess();
tree[tree[nid].LeftChild()].SetLeaf(0);
tree[tree[nid].RightChild()].SetLeaf(0);
nid++;
} else if (n.IsLeaf()) {
tree[nid].SetLeaf(n.weight * param.learning_rate);

View File

@@ -1184,10 +1184,9 @@ class GPUHistMakerSpecialised{
void ApplySplit(const ExpandEntry& candidate, RegTree* p_tree) {
// Add new leaves
RegTree& tree = *p_tree;
tree.AddChilds(candidate.nid);
auto& parent = tree[candidate.nid];
parent.SetSplit(candidate.split.findex, candidate.split.fvalue,
tree.ExpandNode(candidate.nid, candidate.split.findex, candidate.split.fvalue,
candidate.split.dir == kLeftDir);
auto& parent = tree[candidate.nid];
tree.Stat(candidate.nid).loss_chg = candidate.split.loss_chg;
// Set up child constraints

View File

@@ -243,12 +243,8 @@ class HistMaker: public BaseMaker {
p_tree->Stat(nid).loss_chg = best.loss_chg;
// now we know the solution in snode[nid], set split
if (best.loss_chg > kRtEps) {
p_tree->AddChilds(nid);
(*p_tree)[nid].SetSplit(best.SplitIndex(),
best.split_value, best.DefaultLeft());
// mark right child as 0, to indicate fresh leaf
(*p_tree)[(*p_tree)[nid].LeftChild()].SetLeaf(0.0f, 0);
(*p_tree)[(*p_tree)[nid].RightChild()].SetLeaf(0.0f, 0);
p_tree->ExpandNode(nid, best.SplitIndex(), best.split_value,
best.DefaultLeft());
// right side sum
TStats right_sum;
right_sum.SetSubstract(node_sum, left_sum[wid]);

View File

@@ -429,14 +429,8 @@ void QuantileHistMaker::Builder::ApplySplit(int nid,
/* 1. Create child nodes */
NodeEntry& e = snode_[nid];
p_tree->AddChilds(nid);
(*p_tree)[nid].SetSplit(e.best.SplitIndex(), e.best.split_value, e.best.DefaultLeft());
// mark right child as 0, to indicate fresh leaf
int cleft = (*p_tree)[nid].LeftChild();
int cright = (*p_tree)[nid].RightChild();
(*p_tree)[cleft].SetLeaf(0.0f, 0);
(*p_tree)[cright].SetLeaf(0.0f, 0);
p_tree->ExpandNode(nid, e.best.SplitIndex(), e.best.split_value,
e.best.DefaultLeft());
/* 2. Categorize member rows */
const auto nthread = static_cast<bst_omp_uint>(this->nthread_);

View File

@@ -285,12 +285,8 @@ class SketchMaker: public BaseMaker {
this->SetStats(nid, node_stats_[nid], p_tree);
// now we know the solution in snode[nid], set split
if (best.loss_chg > kRtEps) {
p_tree->AddChilds(nid);
(*p_tree)[nid].SetSplit(best.SplitIndex(),
best.split_value, best.DefaultLeft());
// mark right child as 0, to indicate fresh leaf
(*p_tree)[(*p_tree)[nid].LeftChild()].SetLeaf(0.0f, 0);
(*p_tree)[(*p_tree)[nid].RightChild()].SetLeaf(0.0f, 0);
p_tree->ExpandNode(nid, best.SplitIndex(), best.split_value,
best.DefaultLeft());
} else {
(*p_tree)[nid].SetLeaf(p_tree->Stat(nid).base_weight * param_.learning_rate);
}