Require leaf statistics when expanding tree (#4015)

* Cache left and right gradient sums

* Require leaf statistics when expanding tree
This commit is contained in:
Rory Mitchell
2019-01-18 07:12:20 +02:00
committed by Philip Hyunsu Cho
parent 0f8af85f64
commit 1fc37e4749
11 changed files with 143 additions and 85 deletions

View File

@@ -303,14 +303,22 @@ class RegTree {
}
/**
* \brief Expands a leaf node into two additional leaf nodes
* \brief Expands a leaf node into two additional leaf nodes.
*
* \param nid The node index to expand.
* \param split_index Feature index of the split.
* \param split_value The split condition.
* \param default_left True to default left.
* \param nid The node index to expand.
* \param split_index Feature index of the split.
* \param split_value The split condition.
* \param default_left True to default left.
* \param base_weight The base weight, before learning rate.
* \param left_leaf_weight The left leaf weight for prediction, modified by learning rate.
* \param right_leaf_weight The right leaf weight for prediction, modified by learning rate.
* \param loss_change The loss change.
* \param sum_hess The sum hess.
*/
void ExpandNode(int nid, unsigned split_index, bst_float split_value, bool default_left) {
void ExpandNode(int nid, unsigned split_index, bst_float split_value,
bool default_left, bst_float base_weight,
bst_float left_leaf_weight, bst_float right_leaf_weight,
bst_float loss_change, float sum_hess) {
int pleft = this->AllocNode();
int pright = this->AllocNode();
auto &node = nodes_[nid];
@@ -322,8 +330,12 @@ class RegTree {
node.SetSplit(split_index, split_value,
default_left);
// mark right child as 0, to indicate fresh leaf
nodes_[pleft].SetLeaf(0.0f, 0);
nodes_[pright].SetLeaf(0.0f, 0);
nodes_[pleft].SetLeaf(left_leaf_weight, 0);
nodes_[pright].SetLeaf(right_leaf_weight, 0);
this->Stat(nid).loss_chg = loss_change;
this->Stat(nid).base_weight = base_weight;
this->Stat(nid).sum_hess = sum_hess;
}
/*!