Require leaf statistics when expanding tree (#4015)
* Cache left and right gradient sums * Require leaf statistics when expanding tree
This commit is contained in:
committed by
Philip Hyunsu Cho
parent
0f8af85f64
commit
1fc37e4749
@@ -303,14 +303,22 @@ class RegTree {
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Expands a leaf node into two additional leaf nodes
|
||||
* \brief Expands a leaf node into two additional leaf nodes.
|
||||
*
|
||||
* \param nid The node index to expand.
|
||||
* \param split_index Feature index of the split.
|
||||
* \param split_value The split condition.
|
||||
* \param default_left True to default left.
|
||||
* \param nid The node index to expand.
|
||||
* \param split_index Feature index of the split.
|
||||
* \param split_value The split condition.
|
||||
* \param default_left True to default left.
|
||||
* \param base_weight The base weight, before learning rate.
|
||||
* \param left_leaf_weight The left leaf weight for prediction, modified by learning rate.
|
||||
* \param right_leaf_weight The right leaf weight for prediction, modified by learning rate.
|
||||
* \param loss_change The loss change.
|
||||
* \param sum_hess The sum hess.
|
||||
*/
|
||||
void ExpandNode(int nid, unsigned split_index, bst_float split_value, bool default_left) {
|
||||
void ExpandNode(int nid, unsigned split_index, bst_float split_value,
|
||||
bool default_left, bst_float base_weight,
|
||||
bst_float left_leaf_weight, bst_float right_leaf_weight,
|
||||
bst_float loss_change, float sum_hess) {
|
||||
int pleft = this->AllocNode();
|
||||
int pright = this->AllocNode();
|
||||
auto &node = nodes_[nid];
|
||||
@@ -322,8 +330,12 @@ class RegTree {
|
||||
node.SetSplit(split_index, split_value,
|
||||
default_left);
|
||||
// mark right child as 0, to indicate fresh leaf
|
||||
nodes_[pleft].SetLeaf(0.0f, 0);
|
||||
nodes_[pright].SetLeaf(0.0f, 0);
|
||||
nodes_[pleft].SetLeaf(left_leaf_weight, 0);
|
||||
nodes_[pright].SetLeaf(right_leaf_weight, 0);
|
||||
|
||||
this->Stat(nid).loss_chg = loss_change;
|
||||
this->Stat(nid).base_weight = base_weight;
|
||||
this->Stat(nid).sum_hess = sum_hess;
|
||||
}
|
||||
|
||||
/*!
|
||||
|
||||
Reference in New Issue
Block a user