Require leaf statistics when expanding tree (#4015)
* Cache left and right gradient sums * Require leaf statistics when expanding tree
This commit is contained in:
committed by
Philip Hyunsu Cho
parent
0f8af85f64
commit
1fc37e4749
@@ -1182,42 +1182,35 @@ class GPUHistMakerSpecialised{
|
||||
}
|
||||
|
||||
void ApplySplit(const ExpandEntry& candidate, RegTree* p_tree) {
|
||||
// Add new leaves
|
||||
RegTree& tree = *p_tree;
|
||||
tree.ExpandNode(candidate.nid, candidate.split.findex, candidate.split.fvalue,
|
||||
candidate.split.dir == kLeftDir);
|
||||
auto& parent = tree[candidate.nid];
|
||||
tree.Stat(candidate.nid).loss_chg = candidate.split.loss_chg;
|
||||
|
||||
// Set up child constraints
|
||||
node_value_constraints_.resize(tree.GetNodes().size());
|
||||
GradStats left_stats(param_);
|
||||
left_stats.Add(candidate.split.left_sum);
|
||||
GradStats right_stats(param_);
|
||||
right_stats.Add(candidate.split.right_sum);
|
||||
node_value_constraints_[candidate.nid].SetChild(
|
||||
param_, parent.SplitIndex(), left_stats, right_stats,
|
||||
&node_value_constraints_[parent.LeftChild()],
|
||||
&node_value_constraints_[parent.RightChild()]);
|
||||
|
||||
// Configure left child
|
||||
GradStats parent_sum(param_);
|
||||
parent_sum.Add(left_stats);
|
||||
parent_sum.Add(right_stats);
|
||||
node_value_constraints_.resize(tree.GetNodes().size());
|
||||
auto base_weight = node_value_constraints_[candidate.nid].CalcWeight(param_, parent_sum);
|
||||
auto left_weight =
|
||||
node_value_constraints_[parent.LeftChild()].CalcWeight(param_, left_stats);
|
||||
tree[parent.LeftChild()].SetLeaf(left_weight * param_.learning_rate, 0);
|
||||
tree.Stat(parent.LeftChild()).base_weight = left_weight;
|
||||
tree.Stat(parent.LeftChild()).sum_hess = candidate.split.left_sum.GetHess();
|
||||
|
||||
// Configure right child
|
||||
node_value_constraints_[candidate.nid].CalcWeight(param_, left_stats)*param_.learning_rate;
|
||||
auto right_weight =
|
||||
node_value_constraints_[parent.RightChild()].CalcWeight(param_, right_stats);
|
||||
tree[parent.RightChild()].SetLeaf(right_weight * param_.learning_rate, 0);
|
||||
tree.Stat(parent.RightChild()).base_weight = right_weight;
|
||||
tree.Stat(parent.RightChild()).sum_hess = candidate.split.right_sum.GetHess();
|
||||
node_value_constraints_[candidate.nid].CalcWeight(param_, right_stats)*param_.learning_rate;
|
||||
tree.ExpandNode(candidate.nid, candidate.split.findex,
|
||||
candidate.split.fvalue, candidate.split.dir == kLeftDir,
|
||||
base_weight, left_weight, right_weight,
|
||||
candidate.split.loss_chg, parent_sum.sum_hess);
|
||||
// Set up child constraints
|
||||
node_value_constraints_.resize(tree.GetNodes().size());
|
||||
node_value_constraints_[candidate.nid].SetChild(
|
||||
param_, tree[candidate.nid].SplitIndex(), left_stats, right_stats,
|
||||
&node_value_constraints_[tree[candidate.nid].LeftChild()],
|
||||
&node_value_constraints_[tree[candidate.nid].RightChild()]);
|
||||
|
||||
// Store sum gradients
|
||||
for (auto& shard : shards_) {
|
||||
shard->node_sum_gradients[parent.LeftChild()] = candidate.split.left_sum;
|
||||
shard->node_sum_gradients[parent.RightChild()] = candidate.split.right_sum;
|
||||
shard->node_sum_gradients[tree[candidate.nid].LeftChild()] = candidate.split.left_sum;
|
||||
shard->node_sum_gradients[tree[candidate.nid].RightChild()] = candidate.split.right_sum;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user