Reduce tree expand boilerplate code (#4008)
This commit is contained in:
parent
84c99f86f4
commit
f75a21af25
@ -301,18 +301,31 @@ class RegTree {
|
||||
fo->Write(dmlc::BeginPtr(nodes_), sizeof(Node) * nodes_.size());
|
||||
fo->Write(dmlc::BeginPtr(stats_), sizeof(RTreeNodeStat) * nodes_.size());
|
||||
}
|
||||
/*!
|
||||
* \brief add child nodes to node
|
||||
* \param nid node id to add children to
|
||||
|
||||
/**
|
||||
* \brief Expands a leaf node into two additional leaf nodes
|
||||
*
|
||||
* \param nid The node index to expand.
|
||||
* \param split_index Feature index of the split.
|
||||
* \param split_value The split condition.
|
||||
* \param default_left True to default left.
|
||||
*/
|
||||
void AddChilds(int nid) {
|
||||
void ExpandNode(int nid, unsigned split_index, bst_float split_value, bool default_left) {
|
||||
int pleft = this->AllocNode();
|
||||
int pright = this->AllocNode();
|
||||
nodes_[nid].SetLeftChild(pleft);
|
||||
nodes_[nid].SetRightChild(pright);
|
||||
nodes_[nodes_[nid].LeftChild() ].SetParent(nid, true);
|
||||
nodes_[nodes_[nid].RightChild()].SetParent(nid, false);
|
||||
auto &node = nodes_[nid];
|
||||
CHECK(node.IsLeaf());
|
||||
node.SetLeftChild(pleft);
|
||||
node.SetRightChild(pright);
|
||||
nodes_[node.LeftChild()].SetParent(nid, true);
|
||||
nodes_[node.RightChild()].SetParent(nid, false);
|
||||
node.SetSplit(split_index, split_value,
|
||||
default_left);
|
||||
// mark right child as 0, to indicate fresh leaf
|
||||
nodes_[pleft].SetLeaf(0.0f, 0);
|
||||
nodes_[pright].SetLeaf(0.0f, 0);
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief get current depth
|
||||
* \param nid node id
|
||||
|
||||
@ -637,11 +637,7 @@ class ColMaker: public TreeUpdater {
|
||||
NodeEntry &e = snode_[nid];
|
||||
// now we know the solution in snode[nid], set split
|
||||
if (e.best.loss_chg > kRtEps) {
|
||||
p_tree->AddChilds(nid);
|
||||
(*p_tree)[nid].SetSplit(e.best.SplitIndex(), e.best.split_value, e.best.DefaultLeft());
|
||||
// mark right child as 0, to indicate fresh leaf
|
||||
(*p_tree)[(*p_tree)[nid].LeftChild()].SetLeaf(0.0f, 0);
|
||||
(*p_tree)[(*p_tree)[nid].RightChild()].SetLeaf(0.0f, 0);
|
||||
p_tree->ExpandNode(nid, e.best.SplitIndex(), e.best.split_value, e.best.DefaultLeft());
|
||||
} else {
|
||||
(*p_tree)[nid].SetLeaf(e.weight * param_.learning_rate);
|
||||
}
|
||||
|
||||
@ -296,13 +296,10 @@ inline void Dense2SparseTree(RegTree* p_tree,
|
||||
for (int gpu_nid = 0; gpu_nid < h_nodes.size(); gpu_nid++) {
|
||||
const DeviceNodeStats& n = h_nodes[gpu_nid];
|
||||
if (!n.IsUnused() && !n.IsLeaf()) {
|
||||
tree.AddChilds(nid);
|
||||
tree[nid].SetSplit(n.fidx, n.fvalue, n.dir == kLeftDir);
|
||||
tree.ExpandNode(nid, n.fidx, n.fvalue, n.dir == kLeftDir);
|
||||
tree.Stat(nid).loss_chg = n.root_gain;
|
||||
tree.Stat(nid).base_weight = n.weight;
|
||||
tree.Stat(nid).sum_hess = n.sum_gradients.GetHess();
|
||||
tree[tree[nid].LeftChild()].SetLeaf(0);
|
||||
tree[tree[nid].RightChild()].SetLeaf(0);
|
||||
nid++;
|
||||
} else if (n.IsLeaf()) {
|
||||
tree[nid].SetLeaf(n.weight * param.learning_rate);
|
||||
|
||||
@ -1184,10 +1184,9 @@ class GPUHistMakerSpecialised{
|
||||
void ApplySplit(const ExpandEntry& candidate, RegTree* p_tree) {
|
||||
// Add new leaves
|
||||
RegTree& tree = *p_tree;
|
||||
tree.AddChilds(candidate.nid);
|
||||
auto& parent = tree[candidate.nid];
|
||||
parent.SetSplit(candidate.split.findex, candidate.split.fvalue,
|
||||
tree.ExpandNode(candidate.nid, candidate.split.findex, candidate.split.fvalue,
|
||||
candidate.split.dir == kLeftDir);
|
||||
auto& parent = tree[candidate.nid];
|
||||
tree.Stat(candidate.nid).loss_chg = candidate.split.loss_chg;
|
||||
|
||||
// Set up child constraints
|
||||
|
||||
@ -243,12 +243,8 @@ class HistMaker: public BaseMaker {
|
||||
p_tree->Stat(nid).loss_chg = best.loss_chg;
|
||||
// now we know the solution in snode[nid], set split
|
||||
if (best.loss_chg > kRtEps) {
|
||||
p_tree->AddChilds(nid);
|
||||
(*p_tree)[nid].SetSplit(best.SplitIndex(),
|
||||
best.split_value, best.DefaultLeft());
|
||||
// mark right child as 0, to indicate fresh leaf
|
||||
(*p_tree)[(*p_tree)[nid].LeftChild()].SetLeaf(0.0f, 0);
|
||||
(*p_tree)[(*p_tree)[nid].RightChild()].SetLeaf(0.0f, 0);
|
||||
p_tree->ExpandNode(nid, best.SplitIndex(), best.split_value,
|
||||
best.DefaultLeft());
|
||||
// right side sum
|
||||
TStats right_sum;
|
||||
right_sum.SetSubstract(node_sum, left_sum[wid]);
|
||||
|
||||
@ -429,14 +429,8 @@ void QuantileHistMaker::Builder::ApplySplit(int nid,
|
||||
|
||||
/* 1. Create child nodes */
|
||||
NodeEntry& e = snode_[nid];
|
||||
|
||||
p_tree->AddChilds(nid);
|
||||
(*p_tree)[nid].SetSplit(e.best.SplitIndex(), e.best.split_value, e.best.DefaultLeft());
|
||||
// mark right child as 0, to indicate fresh leaf
|
||||
int cleft = (*p_tree)[nid].LeftChild();
|
||||
int cright = (*p_tree)[nid].RightChild();
|
||||
(*p_tree)[cleft].SetLeaf(0.0f, 0);
|
||||
(*p_tree)[cright].SetLeaf(0.0f, 0);
|
||||
p_tree->ExpandNode(nid, e.best.SplitIndex(), e.best.split_value,
|
||||
e.best.DefaultLeft());
|
||||
|
||||
/* 2. Categorize member rows */
|
||||
const auto nthread = static_cast<bst_omp_uint>(this->nthread_);
|
||||
|
||||
@ -285,12 +285,8 @@ class SketchMaker: public BaseMaker {
|
||||
this->SetStats(nid, node_stats_[nid], p_tree);
|
||||
// now we know the solution in snode[nid], set split
|
||||
if (best.loss_chg > kRtEps) {
|
||||
p_tree->AddChilds(nid);
|
||||
(*p_tree)[nid].SetSplit(best.SplitIndex(),
|
||||
best.split_value, best.DefaultLeft());
|
||||
// mark right child as 0, to indicate fresh leaf
|
||||
(*p_tree)[(*p_tree)[nid].LeftChild()].SetLeaf(0.0f, 0);
|
||||
(*p_tree)[(*p_tree)[nid].RightChild()].SetLeaf(0.0f, 0);
|
||||
p_tree->ExpandNode(nid, best.SplitIndex(), best.split_value,
|
||||
best.DefaultLeft());
|
||||
} else {
|
||||
(*p_tree)[nid].SetLeaf(p_tree->Stat(nid).base_weight * param_.learning_rate);
|
||||
}
|
||||
|
||||
@ -38,7 +38,7 @@ TEST(Updater, Prune) {
|
||||
pruner->Init(cfg);
|
||||
|
||||
// loss_chg < min_split_loss;
|
||||
tree.AddChilds(0);
|
||||
tree.ExpandNode(0, 0, 0, true);
|
||||
int cleft = tree[0].LeftChild();
|
||||
int cright = tree[0].RightChild();
|
||||
tree[cleft].SetLeaf(0.3f, 0);
|
||||
@ -48,7 +48,7 @@ TEST(Updater, Prune) {
|
||||
ASSERT_EQ(tree.NumExtraNodes(), 0);
|
||||
|
||||
// loss_chg > min_split_loss;
|
||||
tree.AddChilds(0);
|
||||
tree.ExpandNode(0, 0, 0, true);
|
||||
cleft = tree[0].LeftChild();
|
||||
cright = tree[0].RightChild();
|
||||
tree[cleft].SetLeaf(0.3f, 0);
|
||||
|
||||
@ -29,7 +29,7 @@ TEST(Updater, Refresh) {
|
||||
std::vector<RegTree*> trees {&tree};
|
||||
std::unique_ptr<TreeUpdater> refresher(TreeUpdater::Create("refresh"));
|
||||
|
||||
tree.AddChilds(0);
|
||||
tree.ExpandNode(0, 0, 0, true);
|
||||
int cleft = tree[0].LeftChild();
|
||||
int cright = tree[0].RightChild();
|
||||
tree[cleft].SetLeaf(0.2f, 0);
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user