Reduce tree expand boilerplate code (#4008)

This commit is contained in:
Rory Mitchell 2018-12-20 15:52:28 +13:00 committed by GitHub
parent 84c99f86f4
commit f75a21af25
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 34 additions and 43 deletions

View File

@ -301,18 +301,31 @@ class RegTree {
fo->Write(dmlc::BeginPtr(nodes_), sizeof(Node) * nodes_.size());
fo->Write(dmlc::BeginPtr(stats_), sizeof(RTreeNodeStat) * nodes_.size());
}
/*!
* \brief add child nodes to node
* \param nid node id to add children to
/**
* \brief Expands a leaf node into two additional leaf nodes
*
* \param nid The node index to expand.
* \param split_index Feature index of the split.
* \param split_value The split condition.
* \param default_left True to default left.
*/
void AddChilds(int nid) {
void ExpandNode(int nid, unsigned split_index, bst_float split_value, bool default_left) {
int pleft = this->AllocNode();
int pright = this->AllocNode();
nodes_[nid].SetLeftChild(pleft);
nodes_[nid].SetRightChild(pright);
nodes_[nodes_[nid].LeftChild() ].SetParent(nid, true);
nodes_[nodes_[nid].RightChild()].SetParent(nid, false);
auto &node = nodes_[nid];
CHECK(node.IsLeaf());
node.SetLeftChild(pleft);
node.SetRightChild(pright);
nodes_[node.LeftChild()].SetParent(nid, true);
nodes_[node.RightChild()].SetParent(nid, false);
node.SetSplit(split_index, split_value,
default_left);
// mark right child as 0, to indicate fresh leaf
nodes_[pleft].SetLeaf(0.0f, 0);
nodes_[pright].SetLeaf(0.0f, 0);
}
/*!
* \brief get current depth
* \param nid node id

View File

@ -637,11 +637,7 @@ class ColMaker: public TreeUpdater {
NodeEntry &e = snode_[nid];
// now we know the solution in snode[nid], set split
if (e.best.loss_chg > kRtEps) {
p_tree->AddChilds(nid);
(*p_tree)[nid].SetSplit(e.best.SplitIndex(), e.best.split_value, e.best.DefaultLeft());
// mark right child as 0, to indicate fresh leaf
(*p_tree)[(*p_tree)[nid].LeftChild()].SetLeaf(0.0f, 0);
(*p_tree)[(*p_tree)[nid].RightChild()].SetLeaf(0.0f, 0);
p_tree->ExpandNode(nid, e.best.SplitIndex(), e.best.split_value, e.best.DefaultLeft());
} else {
(*p_tree)[nid].SetLeaf(e.weight * param_.learning_rate);
}

View File

@ -296,13 +296,10 @@ inline void Dense2SparseTree(RegTree* p_tree,
for (int gpu_nid = 0; gpu_nid < h_nodes.size(); gpu_nid++) {
const DeviceNodeStats& n = h_nodes[gpu_nid];
if (!n.IsUnused() && !n.IsLeaf()) {
tree.AddChilds(nid);
tree[nid].SetSplit(n.fidx, n.fvalue, n.dir == kLeftDir);
tree.ExpandNode(nid, n.fidx, n.fvalue, n.dir == kLeftDir);
tree.Stat(nid).loss_chg = n.root_gain;
tree.Stat(nid).base_weight = n.weight;
tree.Stat(nid).sum_hess = n.sum_gradients.GetHess();
tree[tree[nid].LeftChild()].SetLeaf(0);
tree[tree[nid].RightChild()].SetLeaf(0);
nid++;
} else if (n.IsLeaf()) {
tree[nid].SetLeaf(n.weight * param.learning_rate);

View File

@ -1184,10 +1184,9 @@ class GPUHistMakerSpecialised{
void ApplySplit(const ExpandEntry& candidate, RegTree* p_tree) {
// Add new leaves
RegTree& tree = *p_tree;
tree.AddChilds(candidate.nid);
auto& parent = tree[candidate.nid];
parent.SetSplit(candidate.split.findex, candidate.split.fvalue,
tree.ExpandNode(candidate.nid, candidate.split.findex, candidate.split.fvalue,
candidate.split.dir == kLeftDir);
auto& parent = tree[candidate.nid];
tree.Stat(candidate.nid).loss_chg = candidate.split.loss_chg;
// Set up child constraints

View File

@ -243,12 +243,8 @@ class HistMaker: public BaseMaker {
p_tree->Stat(nid).loss_chg = best.loss_chg;
// now we know the solution in snode[nid], set split
if (best.loss_chg > kRtEps) {
p_tree->AddChilds(nid);
(*p_tree)[nid].SetSplit(best.SplitIndex(),
best.split_value, best.DefaultLeft());
// mark right child as 0, to indicate fresh leaf
(*p_tree)[(*p_tree)[nid].LeftChild()].SetLeaf(0.0f, 0);
(*p_tree)[(*p_tree)[nid].RightChild()].SetLeaf(0.0f, 0);
p_tree->ExpandNode(nid, best.SplitIndex(), best.split_value,
best.DefaultLeft());
// right side sum
TStats right_sum;
right_sum.SetSubstract(node_sum, left_sum[wid]);

View File

@ -429,14 +429,8 @@ void QuantileHistMaker::Builder::ApplySplit(int nid,
/* 1. Create child nodes */
NodeEntry& e = snode_[nid];
p_tree->AddChilds(nid);
(*p_tree)[nid].SetSplit(e.best.SplitIndex(), e.best.split_value, e.best.DefaultLeft());
// mark right child as 0, to indicate fresh leaf
int cleft = (*p_tree)[nid].LeftChild();
int cright = (*p_tree)[nid].RightChild();
(*p_tree)[cleft].SetLeaf(0.0f, 0);
(*p_tree)[cright].SetLeaf(0.0f, 0);
p_tree->ExpandNode(nid, e.best.SplitIndex(), e.best.split_value,
e.best.DefaultLeft());
/* 2. Categorize member rows */
const auto nthread = static_cast<bst_omp_uint>(this->nthread_);

View File

@ -285,12 +285,8 @@ class SketchMaker: public BaseMaker {
this->SetStats(nid, node_stats_[nid], p_tree);
// now we know the solution in snode[nid], set split
if (best.loss_chg > kRtEps) {
p_tree->AddChilds(nid);
(*p_tree)[nid].SetSplit(best.SplitIndex(),
best.split_value, best.DefaultLeft());
// mark right child as 0, to indicate fresh leaf
(*p_tree)[(*p_tree)[nid].LeftChild()].SetLeaf(0.0f, 0);
(*p_tree)[(*p_tree)[nid].RightChild()].SetLeaf(0.0f, 0);
p_tree->ExpandNode(nid, best.SplitIndex(), best.split_value,
best.DefaultLeft());
} else {
(*p_tree)[nid].SetLeaf(p_tree->Stat(nid).base_weight * param_.learning_rate);
}

View File

@ -38,7 +38,7 @@ TEST(Updater, Prune) {
pruner->Init(cfg);
// loss_chg < min_split_loss;
tree.AddChilds(0);
tree.ExpandNode(0, 0, 0, true);
int cleft = tree[0].LeftChild();
int cright = tree[0].RightChild();
tree[cleft].SetLeaf(0.3f, 0);
@ -48,7 +48,7 @@ TEST(Updater, Prune) {
ASSERT_EQ(tree.NumExtraNodes(), 0);
// loss_chg > min_split_loss;
tree.AddChilds(0);
tree.ExpandNode(0, 0, 0, true);
cleft = tree[0].LeftChild();
cright = tree[0].RightChild();
tree[cleft].SetLeaf(0.3f, 0);

View File

@ -29,7 +29,7 @@ TEST(Updater, Refresh) {
std::vector<RegTree*> trees {&tree};
std::unique_ptr<TreeUpdater> refresher(TreeUpdater::Create("refresh"));
tree.AddChilds(0);
tree.ExpandNode(0, 0, 0, true);
int cleft = tree[0].LeftChild();
int cright = tree[0].RightChild();
tree[cleft].SetLeaf(0.2f, 0);