Require leaf statistics when expanding tree (#4015)
* Cache left and right gradient sums * Require leaf statistics when expanding tree
This commit is contained in:
parent
0f8af85f64
commit
1fc37e4749
@ -303,14 +303,22 @@ class RegTree {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* \brief Expands a leaf node into two additional leaf nodes
|
* \brief Expands a leaf node into two additional leaf nodes.
|
||||||
*
|
*
|
||||||
* \param nid The node index to expand.
|
* \param nid The node index to expand.
|
||||||
* \param split_index Feature index of the split.
|
* \param split_index Feature index of the split.
|
||||||
* \param split_value The split condition.
|
* \param split_value The split condition.
|
||||||
* \param default_left True to default left.
|
* \param default_left True to default left.
|
||||||
|
* \param base_weight The base weight, before learning rate.
|
||||||
|
* \param left_leaf_weight The left leaf weight for prediction, modified by learning rate.
|
||||||
|
* \param right_leaf_weight The right leaf weight for prediction, modified by learning rate.
|
||||||
|
* \param loss_change The loss change.
|
||||||
|
* \param sum_hess The sum hess.
|
||||||
*/
|
*/
|
||||||
void ExpandNode(int nid, unsigned split_index, bst_float split_value, bool default_left) {
|
void ExpandNode(int nid, unsigned split_index, bst_float split_value,
|
||||||
|
bool default_left, bst_float base_weight,
|
||||||
|
bst_float left_leaf_weight, bst_float right_leaf_weight,
|
||||||
|
bst_float loss_change, float sum_hess) {
|
||||||
int pleft = this->AllocNode();
|
int pleft = this->AllocNode();
|
||||||
int pright = this->AllocNode();
|
int pright = this->AllocNode();
|
||||||
auto &node = nodes_[nid];
|
auto &node = nodes_[nid];
|
||||||
@ -322,8 +330,12 @@ class RegTree {
|
|||||||
node.SetSplit(split_index, split_value,
|
node.SetSplit(split_index, split_value,
|
||||||
default_left);
|
default_left);
|
||||||
// mark right child as 0, to indicate fresh leaf
|
// mark right child as 0, to indicate fresh leaf
|
||||||
nodes_[pleft].SetLeaf(0.0f, 0);
|
nodes_[pleft].SetLeaf(left_leaf_weight, 0);
|
||||||
nodes_[pright].SetLeaf(0.0f, 0);
|
nodes_[pright].SetLeaf(right_leaf_weight, 0);
|
||||||
|
|
||||||
|
this->Stat(nid).loss_chg = loss_change;
|
||||||
|
this->Stat(nid).base_weight = base_weight;
|
||||||
|
this->Stat(nid).sum_hess = sum_hess;
|
||||||
}
|
}
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
|
|||||||
@ -354,6 +354,8 @@ struct XGBOOST_ALIGNAS(16) GradStats {
|
|||||||
static const int kSimpleStats = 1;
|
static const int kSimpleStats = 1;
|
||||||
/*! \brief constructor, the object must be cleared during construction */
|
/*! \brief constructor, the object must be cleared during construction */
|
||||||
explicit GradStats(const TrainParam& param) { this->Clear(); }
|
explicit GradStats(const TrainParam& param) { this->Clear(); }
|
||||||
|
explicit GradStats(double sum_grad, double sum_hess)
|
||||||
|
: sum_grad(sum_grad), sum_hess(sum_hess) {}
|
||||||
|
|
||||||
template <typename GpairT>
|
template <typename GpairT>
|
||||||
XGBOOST_DEVICE explicit GradStats(const GpairT &sum)
|
XGBOOST_DEVICE explicit GradStats(const GpairT &sum)
|
||||||
@ -490,8 +492,10 @@ struct SplitEntry {
|
|||||||
bst_float loss_chg{0.0f};
|
bst_float loss_chg{0.0f};
|
||||||
/*! \brief split index */
|
/*! \brief split index */
|
||||||
unsigned sindex{0};
|
unsigned sindex{0};
|
||||||
/*! \brief split value */
|
|
||||||
bst_float split_value{0.0f};
|
bst_float split_value{0.0f};
|
||||||
|
GradStats left_sum;
|
||||||
|
GradStats right_sum;
|
||||||
|
|
||||||
/*! \brief constructor */
|
/*! \brief constructor */
|
||||||
SplitEntry() = default;
|
SplitEntry() = default;
|
||||||
/*!
|
/*!
|
||||||
@ -521,6 +525,8 @@ struct SplitEntry {
|
|||||||
this->loss_chg = e.loss_chg;
|
this->loss_chg = e.loss_chg;
|
||||||
this->sindex = e.sindex;
|
this->sindex = e.sindex;
|
||||||
this->split_value = e.split_value;
|
this->split_value = e.split_value;
|
||||||
|
this->left_sum = e.left_sum;
|
||||||
|
this->right_sum = e.right_sum;
|
||||||
return true;
|
return true;
|
||||||
} else {
|
} else {
|
||||||
return false;
|
return false;
|
||||||
@ -535,7 +541,8 @@ struct SplitEntry {
|
|||||||
* \return whether the proposed split is better and can replace current split
|
* \return whether the proposed split is better and can replace current split
|
||||||
*/
|
*/
|
||||||
inline bool Update(bst_float new_loss_chg, unsigned split_index,
|
inline bool Update(bst_float new_loss_chg, unsigned split_index,
|
||||||
bst_float new_split_value, bool default_left) {
|
bst_float new_split_value, bool default_left,
|
||||||
|
const GradStats &left_sum, const GradStats &right_sum) {
|
||||||
if (this->NeedReplace(new_loss_chg, split_index)) {
|
if (this->NeedReplace(new_loss_chg, split_index)) {
|
||||||
this->loss_chg = new_loss_chg;
|
this->loss_chg = new_loss_chg;
|
||||||
if (default_left) {
|
if (default_left) {
|
||||||
@ -543,6 +550,8 @@ struct SplitEntry {
|
|||||||
}
|
}
|
||||||
this->sindex = split_index;
|
this->sindex = split_index;
|
||||||
this->split_value = new_split_value;
|
this->split_value = new_split_value;
|
||||||
|
this->left_sum = left_sum;
|
||||||
|
this->right_sum = right_sum;
|
||||||
return true;
|
return true;
|
||||||
} else {
|
} else {
|
||||||
return false;
|
return false;
|
||||||
|
|||||||
@ -311,7 +311,7 @@ class ColMaker: public TreeUpdater {
|
|||||||
auto loss_chg = static_cast<bst_float>(
|
auto loss_chg = static_cast<bst_float>(
|
||||||
spliteval_->ComputeSplitScore(nid, fid, e.stats, c) -
|
spliteval_->ComputeSplitScore(nid, fid, e.stats, c) -
|
||||||
snode_[nid].root_gain);
|
snode_[nid].root_gain);
|
||||||
e.best.Update(loss_chg, fid, fsplit, false);
|
e.best.Update(loss_chg, fid, fsplit, false, e.stats, c);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (need_backward) {
|
if (need_backward) {
|
||||||
@ -322,7 +322,7 @@ class ColMaker: public TreeUpdater {
|
|||||||
auto loss_chg = static_cast<bst_float>(
|
auto loss_chg = static_cast<bst_float>(
|
||||||
spliteval_->ComputeSplitScore(nid, fid, tmp, c) -
|
spliteval_->ComputeSplitScore(nid, fid, tmp, c) -
|
||||||
snode_[nid].root_gain);
|
snode_[nid].root_gain);
|
||||||
e.best.Update(loss_chg, fid, fsplit, true);
|
e.best.Update(loss_chg, fid, fsplit, true, tmp, c);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -335,7 +335,7 @@ class ColMaker: public TreeUpdater {
|
|||||||
auto loss_chg = static_cast<bst_float>(
|
auto loss_chg = static_cast<bst_float>(
|
||||||
spliteval_->ComputeSplitScore(nid, fid, tmp, c) -
|
spliteval_->ComputeSplitScore(nid, fid, tmp, c) -
|
||||||
snode_[nid].root_gain);
|
snode_[nid].root_gain);
|
||||||
e.best.Update(loss_chg, fid, e.last_fvalue + kRtEps, true);
|
e.best.Update(loss_chg, fid, e.last_fvalue + kRtEps, true, tmp, c);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -368,7 +368,7 @@ class ColMaker: public TreeUpdater {
|
|||||||
spliteval_->ComputeSplitScore(nid, fid, e.stats, c) -
|
spliteval_->ComputeSplitScore(nid, fid, e.stats, c) -
|
||||||
snode_[nid].root_gain);
|
snode_[nid].root_gain);
|
||||||
e.best.Update(loss_chg, fid, (fvalue + e.first_fvalue) * 0.5f,
|
e.best.Update(loss_chg, fid, (fvalue + e.first_fvalue) * 0.5f,
|
||||||
false);
|
false, e.stats, c);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (need_backward) {
|
if (need_backward) {
|
||||||
@ -379,7 +379,7 @@ class ColMaker: public TreeUpdater {
|
|||||||
auto loss_chg = static_cast<bst_float>(
|
auto loss_chg = static_cast<bst_float>(
|
||||||
spliteval_->ComputeSplitScore(nid, fid, c, cright) -
|
spliteval_->ComputeSplitScore(nid, fid, c, cright) -
|
||||||
snode_[nid].root_gain);
|
snode_[nid].root_gain);
|
||||||
e.best.Update(loss_chg, fid, (fvalue + e.first_fvalue) * 0.5f, true);
|
e.best.Update(loss_chg, fid, (fvalue + e.first_fvalue) * 0.5f, true, c, cright);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -410,13 +410,15 @@ class ColMaker: public TreeUpdater {
|
|||||||
loss_chg = static_cast<bst_float>(
|
loss_chg = static_cast<bst_float>(
|
||||||
spliteval_->ComputeSplitScore(nid, fid, c, e.stats) -
|
spliteval_->ComputeSplitScore(nid, fid, c, e.stats) -
|
||||||
snode_[nid].root_gain);
|
snode_[nid].root_gain);
|
||||||
|
e.best.Update(loss_chg, fid, (fvalue + e.last_fvalue) * 0.5f,
|
||||||
|
d_step == -1, c, e.stats);
|
||||||
} else {
|
} else {
|
||||||
loss_chg = static_cast<bst_float>(
|
loss_chg = static_cast<bst_float>(
|
||||||
spliteval_->ComputeSplitScore(nid, fid, e.stats, c) -
|
spliteval_->ComputeSplitScore(nid, fid, e.stats, c) -
|
||||||
snode_[nid].root_gain);
|
snode_[nid].root_gain);
|
||||||
}
|
|
||||||
e.best.Update(loss_chg, fid, (fvalue + e.last_fvalue) * 0.5f,
|
e.best.Update(loss_chg, fid, (fvalue + e.last_fvalue) * 0.5f,
|
||||||
d_step == -1);
|
d_step == -1, e.stats, c);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// update the statistics
|
// update the statistics
|
||||||
@ -486,18 +488,21 @@ class ColMaker: public TreeUpdater {
|
|||||||
if (e.stats.sum_hess >= param_.min_child_weight &&
|
if (e.stats.sum_hess >= param_.min_child_weight &&
|
||||||
c.sum_hess >= param_.min_child_weight) {
|
c.sum_hess >= param_.min_child_weight) {
|
||||||
bst_float loss_chg;
|
bst_float loss_chg;
|
||||||
|
const bst_float gap = std::abs(e.last_fvalue) + kRtEps;
|
||||||
|
const bst_float delta = d_step == +1 ? gap: -gap;
|
||||||
if (d_step == -1) {
|
if (d_step == -1) {
|
||||||
loss_chg = static_cast<bst_float>(
|
loss_chg = static_cast<bst_float>(
|
||||||
spliteval_->ComputeSplitScore(nid, fid, c, e.stats) -
|
spliteval_->ComputeSplitScore(nid, fid, c, e.stats) -
|
||||||
snode_[nid].root_gain);
|
snode_[nid].root_gain);
|
||||||
|
e.best.Update(loss_chg, fid, e.last_fvalue + delta, d_step == -1, c,
|
||||||
|
e.stats);
|
||||||
} else {
|
} else {
|
||||||
loss_chg = static_cast<bst_float>(
|
loss_chg = static_cast<bst_float>(
|
||||||
spliteval_->ComputeSplitScore(nid, fid, e.stats, c) -
|
spliteval_->ComputeSplitScore(nid, fid, e.stats, c) -
|
||||||
snode_[nid].root_gain);
|
snode_[nid].root_gain);
|
||||||
|
e.best.Update(loss_chg, fid, e.last_fvalue + delta, d_step == -1,
|
||||||
|
e.stats, c);
|
||||||
}
|
}
|
||||||
const bst_float gap = std::abs(e.last_fvalue) + kRtEps;
|
|
||||||
const bst_float delta = d_step == +1 ? gap: -gap;
|
|
||||||
e.best.Update(loss_chg, fid, e.last_fvalue + delta, d_step == -1);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -545,12 +550,15 @@ class ColMaker: public TreeUpdater {
|
|||||||
loss_chg = static_cast<bst_float>(
|
loss_chg = static_cast<bst_float>(
|
||||||
spliteval_->ComputeSplitScore(nid, fid, c, e.stats) -
|
spliteval_->ComputeSplitScore(nid, fid, c, e.stats) -
|
||||||
snode_[nid].root_gain);
|
snode_[nid].root_gain);
|
||||||
|
e.best.Update(loss_chg, fid, (fvalue + e.last_fvalue) * 0.5f,
|
||||||
|
d_step == -1, c, e.stats);
|
||||||
} else {
|
} else {
|
||||||
loss_chg = static_cast<bst_float>(
|
loss_chg = static_cast<bst_float>(
|
||||||
spliteval_->ComputeSplitScore(nid, fid, e.stats, c) -
|
spliteval_->ComputeSplitScore(nid, fid, e.stats, c) -
|
||||||
snode_[nid].root_gain);
|
snode_[nid].root_gain);
|
||||||
|
e.best.Update(loss_chg, fid, (fvalue + e.last_fvalue) * 0.5f,
|
||||||
|
d_step == -1, e.stats, c);
|
||||||
}
|
}
|
||||||
e.best.Update(loss_chg, fid, (fvalue + e.last_fvalue) * 0.5f, d_step == -1);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// update the statistics
|
// update the statistics
|
||||||
@ -565,18 +573,21 @@ class ColMaker: public TreeUpdater {
|
|||||||
if (e.stats.sum_hess >= param_.min_child_weight &&
|
if (e.stats.sum_hess >= param_.min_child_weight &&
|
||||||
c.sum_hess >= param_.min_child_weight) {
|
c.sum_hess >= param_.min_child_weight) {
|
||||||
bst_float loss_chg;
|
bst_float loss_chg;
|
||||||
|
GradStats left_sum;
|
||||||
|
GradStats right_sum;
|
||||||
if (d_step == -1) {
|
if (d_step == -1) {
|
||||||
loss_chg = static_cast<bst_float>(
|
left_sum = c;
|
||||||
spliteval_->ComputeSplitScore(nid, fid, c, e.stats) -
|
right_sum = e.stats;
|
||||||
snode_[nid].root_gain);
|
|
||||||
} else {
|
} else {
|
||||||
loss_chg = static_cast<bst_float>(
|
left_sum = e.stats;
|
||||||
spliteval_->ComputeSplitScore(nid, fid, e.stats, c) -
|
right_sum = c;
|
||||||
snode_[nid].root_gain);
|
|
||||||
}
|
}
|
||||||
|
loss_chg = static_cast<bst_float>(
|
||||||
|
spliteval_->ComputeSplitScore(nid, fid, left_sum, right_sum) -
|
||||||
|
snode_[nid].root_gain);
|
||||||
const bst_float gap = std::abs(e.last_fvalue) + kRtEps;
|
const bst_float gap = std::abs(e.last_fvalue) + kRtEps;
|
||||||
const bst_float delta = d_step == +1 ? gap: -gap;
|
const bst_float delta = d_step == +1 ? gap: -gap;
|
||||||
e.best.Update(loss_chg, fid, e.last_fvalue + delta, d_step == -1);
|
e.best.Update(loss_chg, fid, e.last_fvalue + delta, d_step == -1, left_sum, right_sum);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -637,7 +648,16 @@ class ColMaker: public TreeUpdater {
|
|||||||
NodeEntry &e = snode_[nid];
|
NodeEntry &e = snode_[nid];
|
||||||
// now we know the solution in snode[nid], set split
|
// now we know the solution in snode[nid], set split
|
||||||
if (e.best.loss_chg > kRtEps) {
|
if (e.best.loss_chg > kRtEps) {
|
||||||
p_tree->ExpandNode(nid, e.best.SplitIndex(), e.best.split_value, e.best.DefaultLeft());
|
bst_float left_leaf_weight =
|
||||||
|
spliteval_->ComputeWeight(nid, e.best.left_sum) *
|
||||||
|
param_.learning_rate;
|
||||||
|
bst_float right_leaf_weight =
|
||||||
|
spliteval_->ComputeWeight(nid, e.best.right_sum) *
|
||||||
|
param_.learning_rate;
|
||||||
|
p_tree->ExpandNode(nid, e.best.SplitIndex(), e.best.split_value,
|
||||||
|
e.best.DefaultLeft(), e.weight, left_leaf_weight,
|
||||||
|
right_leaf_weight, e.best.loss_chg,
|
||||||
|
e.stats.sum_hess);
|
||||||
} else {
|
} else {
|
||||||
(*p_tree)[nid].SetLeaf(e.weight * param_.learning_rate);
|
(*p_tree)[nid].SetLeaf(e.weight * param_.learning_rate);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -296,7 +296,8 @@ inline void Dense2SparseTree(RegTree* p_tree,
|
|||||||
for (int gpu_nid = 0; gpu_nid < h_nodes.size(); gpu_nid++) {
|
for (int gpu_nid = 0; gpu_nid < h_nodes.size(); gpu_nid++) {
|
||||||
const DeviceNodeStats& n = h_nodes[gpu_nid];
|
const DeviceNodeStats& n = h_nodes[gpu_nid];
|
||||||
if (!n.IsUnused() && !n.IsLeaf()) {
|
if (!n.IsUnused() && !n.IsLeaf()) {
|
||||||
tree.ExpandNode(nid, n.fidx, n.fvalue, n.dir == kLeftDir);
|
tree.ExpandNode(nid, n.fidx, n.fvalue, n.dir == kLeftDir, n.weight, 0.0f,
|
||||||
|
0.0f, n.root_gain, n.sum_gradients.GetHess());
|
||||||
tree.Stat(nid).loss_chg = n.root_gain;
|
tree.Stat(nid).loss_chg = n.root_gain;
|
||||||
tree.Stat(nid).base_weight = n.weight;
|
tree.Stat(nid).base_weight = n.weight;
|
||||||
tree.Stat(nid).sum_hess = n.sum_gradients.GetHess();
|
tree.Stat(nid).sum_hess = n.sum_gradients.GetHess();
|
||||||
|
|||||||
@ -1182,42 +1182,35 @@ class GPUHistMakerSpecialised{
|
|||||||
}
|
}
|
||||||
|
|
||||||
void ApplySplit(const ExpandEntry& candidate, RegTree* p_tree) {
|
void ApplySplit(const ExpandEntry& candidate, RegTree* p_tree) {
|
||||||
// Add new leaves
|
|
||||||
RegTree& tree = *p_tree;
|
RegTree& tree = *p_tree;
|
||||||
tree.ExpandNode(candidate.nid, candidate.split.findex, candidate.split.fvalue,
|
|
||||||
candidate.split.dir == kLeftDir);
|
|
||||||
auto& parent = tree[candidate.nid];
|
|
||||||
tree.Stat(candidate.nid).loss_chg = candidate.split.loss_chg;
|
|
||||||
|
|
||||||
// Set up child constraints
|
|
||||||
node_value_constraints_.resize(tree.GetNodes().size());
|
|
||||||
GradStats left_stats(param_);
|
GradStats left_stats(param_);
|
||||||
left_stats.Add(candidate.split.left_sum);
|
left_stats.Add(candidate.split.left_sum);
|
||||||
GradStats right_stats(param_);
|
GradStats right_stats(param_);
|
||||||
right_stats.Add(candidate.split.right_sum);
|
right_stats.Add(candidate.split.right_sum);
|
||||||
node_value_constraints_[candidate.nid].SetChild(
|
GradStats parent_sum(param_);
|
||||||
param_, parent.SplitIndex(), left_stats, right_stats,
|
parent_sum.Add(left_stats);
|
||||||
&node_value_constraints_[parent.LeftChild()],
|
parent_sum.Add(right_stats);
|
||||||
&node_value_constraints_[parent.RightChild()]);
|
node_value_constraints_.resize(tree.GetNodes().size());
|
||||||
|
auto base_weight = node_value_constraints_[candidate.nid].CalcWeight(param_, parent_sum);
|
||||||
// Configure left child
|
|
||||||
auto left_weight =
|
auto left_weight =
|
||||||
node_value_constraints_[parent.LeftChild()].CalcWeight(param_, left_stats);
|
node_value_constraints_[candidate.nid].CalcWeight(param_, left_stats)*param_.learning_rate;
|
||||||
tree[parent.LeftChild()].SetLeaf(left_weight * param_.learning_rate, 0);
|
|
||||||
tree.Stat(parent.LeftChild()).base_weight = left_weight;
|
|
||||||
tree.Stat(parent.LeftChild()).sum_hess = candidate.split.left_sum.GetHess();
|
|
||||||
|
|
||||||
// Configure right child
|
|
||||||
auto right_weight =
|
auto right_weight =
|
||||||
node_value_constraints_[parent.RightChild()].CalcWeight(param_, right_stats);
|
node_value_constraints_[candidate.nid].CalcWeight(param_, right_stats)*param_.learning_rate;
|
||||||
tree[parent.RightChild()].SetLeaf(right_weight * param_.learning_rate, 0);
|
tree.ExpandNode(candidate.nid, candidate.split.findex,
|
||||||
tree.Stat(parent.RightChild()).base_weight = right_weight;
|
candidate.split.fvalue, candidate.split.dir == kLeftDir,
|
||||||
tree.Stat(parent.RightChild()).sum_hess = candidate.split.right_sum.GetHess();
|
base_weight, left_weight, right_weight,
|
||||||
|
candidate.split.loss_chg, parent_sum.sum_hess);
|
||||||
|
// Set up child constraints
|
||||||
|
node_value_constraints_.resize(tree.GetNodes().size());
|
||||||
|
node_value_constraints_[candidate.nid].SetChild(
|
||||||
|
param_, tree[candidate.nid].SplitIndex(), left_stats, right_stats,
|
||||||
|
&node_value_constraints_[tree[candidate.nid].LeftChild()],
|
||||||
|
&node_value_constraints_[tree[candidate.nid].RightChild()]);
|
||||||
|
|
||||||
// Store sum gradients
|
// Store sum gradients
|
||||||
for (auto& shard : shards_) {
|
for (auto& shard : shards_) {
|
||||||
shard->node_sum_gradients[parent.LeftChild()] = candidate.split.left_sum;
|
shard->node_sum_gradients[tree[candidate.nid].LeftChild()] = candidate.split.left_sum;
|
||||||
shard->node_sum_gradients[parent.RightChild()] = candidate.split.right_sum;
|
shard->node_sum_gradients[tree[candidate.nid].RightChild()] = candidate.split.right_sum;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -192,7 +192,8 @@ class HistMaker: public BaseMaker {
|
|||||||
c.SetSubstract(node_sum, s);
|
c.SetSubstract(node_sum, s);
|
||||||
if (c.sum_hess >= param_.min_child_weight) {
|
if (c.sum_hess >= param_.min_child_weight) {
|
||||||
double loss_chg = s.CalcGain(param_) + c.CalcGain(param_) - root_gain;
|
double loss_chg = s.CalcGain(param_) + c.CalcGain(param_) - root_gain;
|
||||||
if (best->Update(static_cast<bst_float>(loss_chg), fid, hist.cut[i], false)) {
|
if (best->Update(static_cast<bst_float>(loss_chg), fid, hist.cut[i],
|
||||||
|
false, s, c)) {
|
||||||
*left_sum = s;
|
*left_sum = s;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -205,7 +206,7 @@ class HistMaker: public BaseMaker {
|
|||||||
c.SetSubstract(node_sum, s);
|
c.SetSubstract(node_sum, s);
|
||||||
if (c.sum_hess >= param_.min_child_weight) {
|
if (c.sum_hess >= param_.min_child_weight) {
|
||||||
double loss_chg = s.CalcGain(param_) + c.CalcGain(param_) - root_gain;
|
double loss_chg = s.CalcGain(param_) + c.CalcGain(param_) - root_gain;
|
||||||
if (best->Update(static_cast<bst_float>(loss_chg), fid, hist.cut[i-1], true)) {
|
if (best->Update(static_cast<bst_float>(loss_chg), fid, hist.cut[i-1], true, c, s)) {
|
||||||
*left_sum = c;
|
*left_sum = c;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -243,8 +244,18 @@ class HistMaker: public BaseMaker {
|
|||||||
p_tree->Stat(nid).loss_chg = best.loss_chg;
|
p_tree->Stat(nid).loss_chg = best.loss_chg;
|
||||||
// now we know the solution in snode[nid], set split
|
// now we know the solution in snode[nid], set split
|
||||||
if (best.loss_chg > kRtEps) {
|
if (best.loss_chg > kRtEps) {
|
||||||
|
bst_float base_weight = node_sum.CalcWeight(param_);
|
||||||
|
bst_float left_leaf_weight =
|
||||||
|
CalcWeight(param_, best.left_sum.sum_grad, best.left_sum.sum_hess) *
|
||||||
|
param_.learning_rate;
|
||||||
|
bst_float right_leaf_weight =
|
||||||
|
CalcWeight(param_, best.right_sum.sum_grad,
|
||||||
|
best.right_sum.sum_hess) *
|
||||||
|
param_.learning_rate;
|
||||||
p_tree->ExpandNode(nid, best.SplitIndex(), best.split_value,
|
p_tree->ExpandNode(nid, best.SplitIndex(), best.split_value,
|
||||||
best.DefaultLeft());
|
best.DefaultLeft(), base_weight, left_leaf_weight,
|
||||||
|
right_leaf_weight, best.loss_chg,
|
||||||
|
node_sum.sum_hess);
|
||||||
// right side sum
|
// right side sum
|
||||||
TStats right_sum;
|
TStats right_sum;
|
||||||
right_sum.SetSubstract(node_sum, left_sum[wid]);
|
right_sum.SetSubstract(node_sum, left_sum[wid]);
|
||||||
|
|||||||
@ -429,8 +429,13 @@ void QuantileHistMaker::Builder::ApplySplit(int nid,
|
|||||||
|
|
||||||
/* 1. Create child nodes */
|
/* 1. Create child nodes */
|
||||||
NodeEntry& e = snode_[nid];
|
NodeEntry& e = snode_[nid];
|
||||||
|
bst_float left_leaf_weight =
|
||||||
|
spliteval_->ComputeWeight(nid, e.best.left_sum) * param_.learning_rate;
|
||||||
|
bst_float right_leaf_weight =
|
||||||
|
spliteval_->ComputeWeight(nid, e.best.right_sum) * param_.learning_rate;
|
||||||
p_tree->ExpandNode(nid, e.best.SplitIndex(), e.best.split_value,
|
p_tree->ExpandNode(nid, e.best.SplitIndex(), e.best.split_value,
|
||||||
e.best.DefaultLeft());
|
e.best.DefaultLeft(), e.weight, left_leaf_weight,
|
||||||
|
right_leaf_weight, e.best.loss_chg, e.stats.sum_hess);
|
||||||
|
|
||||||
/* 2. Categorize member rows */
|
/* 2. Categorize member rows */
|
||||||
const auto nthread = static_cast<bst_omp_uint>(this->nthread_);
|
const auto nthread = static_cast<bst_omp_uint>(this->nthread_);
|
||||||
@ -698,6 +703,7 @@ void QuantileHistMaker::Builder::EnumerateSplit(int d_step,
|
|||||||
spliteval_->ComputeSplitScore(nodeID, fid, e, c) -
|
spliteval_->ComputeSplitScore(nodeID, fid, e, c) -
|
||||||
snode.root_gain);
|
snode.root_gain);
|
||||||
split_pt = cut_val[i];
|
split_pt = cut_val[i];
|
||||||
|
best.Update(loss_chg, fid, split_pt, d_step == -1, e, c);
|
||||||
} else {
|
} else {
|
||||||
// backward enumeration: split at left bound of each bin
|
// backward enumeration: split at left bound of each bin
|
||||||
loss_chg = static_cast<bst_float>(
|
loss_chg = static_cast<bst_float>(
|
||||||
@ -709,8 +715,8 @@ void QuantileHistMaker::Builder::EnumerateSplit(int d_step,
|
|||||||
} else {
|
} else {
|
||||||
split_pt = cut_val[i - 1];
|
split_pt = cut_val[i - 1];
|
||||||
}
|
}
|
||||||
|
best.Update(loss_chg, fid, split_pt, d_step == -1, c, e);
|
||||||
}
|
}
|
||||||
best.Update(loss_chg, fid, split_pt, d_step == -1);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -281,12 +281,21 @@ class SketchMaker: public BaseMaker {
|
|||||||
const int nid = qexpand_[wid];
|
const int nid = qexpand_[wid];
|
||||||
const SplitEntry &best = sol[wid];
|
const SplitEntry &best = sol[wid];
|
||||||
// set up the values
|
// set up the values
|
||||||
p_tree->Stat(nid).loss_chg = best.loss_chg;
|
|
||||||
this->SetStats(nid, node_stats_[nid], p_tree);
|
this->SetStats(nid, node_stats_[nid], p_tree);
|
||||||
// now we know the solution in snode[nid], set split
|
// now we know the solution in snode[nid], set split
|
||||||
if (best.loss_chg > kRtEps) {
|
if (best.loss_chg > kRtEps) {
|
||||||
|
bst_float base_weight = node_stats_[nid].CalcWeight(param_);
|
||||||
|
bst_float left_leaf_weight =
|
||||||
|
CalcWeight(param_, best.left_sum.sum_grad, best.left_sum.sum_hess) *
|
||||||
|
param_.learning_rate;
|
||||||
|
bst_float right_leaf_weight =
|
||||||
|
CalcWeight(param_, best.right_sum.sum_grad,
|
||||||
|
best.right_sum.sum_hess) *
|
||||||
|
param_.learning_rate;
|
||||||
p_tree->ExpandNode(nid, best.SplitIndex(), best.split_value,
|
p_tree->ExpandNode(nid, best.SplitIndex(), best.split_value,
|
||||||
best.DefaultLeft());
|
best.DefaultLeft(), base_weight, left_leaf_weight,
|
||||||
|
right_leaf_weight, best.loss_chg,
|
||||||
|
node_stats_[nid].sum_hess);
|
||||||
} else {
|
} else {
|
||||||
(*p_tree)[nid].SetLeaf(p_tree->Stat(nid).base_weight * param_.learning_rate);
|
(*p_tree)[nid].SetLeaf(p_tree->Stat(nid).base_weight * param_.learning_rate);
|
||||||
}
|
}
|
||||||
@ -336,7 +345,9 @@ class SketchMaker: public BaseMaker {
|
|||||||
if (s.sum_hess >= param_.min_child_weight &&
|
if (s.sum_hess >= param_.min_child_weight &&
|
||||||
c.sum_hess >= param_.min_child_weight) {
|
c.sum_hess >= param_.min_child_weight) {
|
||||||
double loss_chg = s.CalcGain(param_) + c.CalcGain(param_) - root_gain;
|
double loss_chg = s.CalcGain(param_) + c.CalcGain(param_) - root_gain;
|
||||||
best->Update(static_cast<bst_float>(loss_chg), fid, fsplits[i], false);
|
best->Update(static_cast<bst_float>(loss_chg), fid, fsplits[i], false,
|
||||||
|
GradStats(s.pos_grad - s.neg_grad , s.sum_hess),
|
||||||
|
GradStats(c.pos_grad - c.neg_grad, c.sum_hess));
|
||||||
}
|
}
|
||||||
// backward
|
// backward
|
||||||
c.SetSubstract(feat_sum, s);
|
c.SetSubstract(feat_sum, s);
|
||||||
@ -344,7 +355,9 @@ class SketchMaker: public BaseMaker {
|
|||||||
if (s.sum_hess >= param_.min_child_weight &&
|
if (s.sum_hess >= param_.min_child_weight &&
|
||||||
c.sum_hess >= param_.min_child_weight) {
|
c.sum_hess >= param_.min_child_weight) {
|
||||||
double loss_chg = s.CalcGain(param_) + c.CalcGain(param_) - root_gain;
|
double loss_chg = s.CalcGain(param_) + c.CalcGain(param_) - root_gain;
|
||||||
best->Update(static_cast<bst_float>(loss_chg), fid, fsplits[i], true);
|
best->Update(static_cast<bst_float>(loss_chg), fid, fsplits[i], true,
|
||||||
|
GradStats(s.pos_grad - s.neg_grad, s.sum_hess),
|
||||||
|
GradStats(c.pos_grad - c.neg_grad, c.sum_hess));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
@ -355,8 +368,10 @@ class SketchMaker: public BaseMaker {
|
|||||||
c.sum_hess >= param_.min_child_weight) {
|
c.sum_hess >= param_.min_child_weight) {
|
||||||
bst_float cpt = fsplits.back();
|
bst_float cpt = fsplits.back();
|
||||||
double loss_chg = s.CalcGain(param_) + c.CalcGain(param_) - root_gain;
|
double loss_chg = s.CalcGain(param_) + c.CalcGain(param_) - root_gain;
|
||||||
best->Update(static_cast<bst_float>(loss_chg),
|
best->Update(static_cast<bst_float>(loss_chg), fid,
|
||||||
fid, cpt + std::abs(cpt) + 1.0f, false);
|
cpt + std::abs(cpt) + 1.0f, false,
|
||||||
|
GradStats(s.pos_grad - s.neg_grad, s.sum_hess),
|
||||||
|
GradStats(c.pos_grad - c.neg_grad, c.sum_hess));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -82,12 +82,15 @@ TEST(Param, SplitEntry) {
|
|||||||
|
|
||||||
xgboost::tree::SplitEntry se2;
|
xgboost::tree::SplitEntry se2;
|
||||||
EXPECT_FALSE(se1.Update(se2));
|
EXPECT_FALSE(se1.Update(se2));
|
||||||
EXPECT_FALSE(se2.Update(-1, 100, 0, true));
|
EXPECT_FALSE(se2.Update(-1, 100, 0, true, xgboost::tree::GradStats(),
|
||||||
ASSERT_TRUE(se2.Update(1, 100, 0, true));
|
xgboost::tree::GradStats()));
|
||||||
|
ASSERT_TRUE(se2.Update(1, 100, 0, true, xgboost::tree::GradStats(),
|
||||||
|
xgboost::tree::GradStats()));
|
||||||
ASSERT_TRUE(se1.Update(se2));
|
ASSERT_TRUE(se1.Update(se2));
|
||||||
|
|
||||||
xgboost::tree::SplitEntry se3;
|
xgboost::tree::SplitEntry se3;
|
||||||
se3.Update(2, 101, 0, false);
|
se3.Update(2, 101, 0, false, xgboost::tree::GradStats(),
|
||||||
|
xgboost::tree::GradStats());
|
||||||
xgboost::tree::SplitEntry::Reduce(se2, se3);
|
xgboost::tree::SplitEntry::Reduce(se2, se3);
|
||||||
EXPECT_EQ(se2.SplitIndex(), 101);
|
EXPECT_EQ(se2.SplitIndex(), 101);
|
||||||
EXPECT_FALSE(se2.DefaultLeft());
|
EXPECT_FALSE(se2.DefaultLeft());
|
||||||
|
|||||||
@ -38,22 +38,13 @@ TEST(Updater, Prune) {
|
|||||||
pruner->Init(cfg);
|
pruner->Init(cfg);
|
||||||
|
|
||||||
// loss_chg < min_split_loss;
|
// loss_chg < min_split_loss;
|
||||||
tree.ExpandNode(0, 0, 0, true);
|
tree.ExpandNode(0, 0, 0, true, 0.0f, 0.3f, 0.4f, 0.0f, 0.0f);
|
||||||
int cleft = tree[0].LeftChild();
|
|
||||||
int cright = tree[0].RightChild();
|
|
||||||
tree[cleft].SetLeaf(0.3f, 0);
|
|
||||||
tree[cright].SetLeaf(0.4f, 0);
|
|
||||||
pruner->Update(&gpair, dmat->get(), trees);
|
pruner->Update(&gpair, dmat->get(), trees);
|
||||||
|
|
||||||
ASSERT_EQ(tree.NumExtraNodes(), 0);
|
ASSERT_EQ(tree.NumExtraNodes(), 0);
|
||||||
|
|
||||||
// loss_chg > min_split_loss;
|
// loss_chg > min_split_loss;
|
||||||
tree.ExpandNode(0, 0, 0, true);
|
tree.ExpandNode(0, 0, 0, true, 0.0f, 0.3f, 0.4f, 11.0f, 0.0f);
|
||||||
cleft = tree[0].LeftChild();
|
|
||||||
cright = tree[0].RightChild();
|
|
||||||
tree[cleft].SetLeaf(0.3f, 0);
|
|
||||||
tree[cright].SetLeaf(0.4f, 0);
|
|
||||||
tree.Stat(0).loss_chg = 11;
|
|
||||||
pruner->Update(&gpair, dmat->get(), trees);
|
pruner->Update(&gpair, dmat->get(), trees);
|
||||||
|
|
||||||
ASSERT_EQ(tree.NumExtraNodes(), 2);
|
ASSERT_EQ(tree.NumExtraNodes(), 2);
|
||||||
|
|||||||
@ -29,12 +29,9 @@ TEST(Updater, Refresh) {
|
|||||||
std::vector<RegTree*> trees {&tree};
|
std::vector<RegTree*> trees {&tree};
|
||||||
std::unique_ptr<TreeUpdater> refresher(TreeUpdater::Create("refresh"));
|
std::unique_ptr<TreeUpdater> refresher(TreeUpdater::Create("refresh"));
|
||||||
|
|
||||||
tree.ExpandNode(0, 0, 0, true);
|
tree.ExpandNode(0, 2, 0.2f, false, 0.0, 0.2f, 0.8f, 0.0f, 0.0f);
|
||||||
int cleft = tree[0].LeftChild();
|
int cleft = tree[0].LeftChild();
|
||||||
int cright = tree[0].RightChild();
|
int cright = tree[0].RightChild();
|
||||||
tree[cleft].SetLeaf(0.2f, 0);
|
|
||||||
tree[cright].SetLeaf(0.8f, 0);
|
|
||||||
tree[0].SetSplit(2, 0.2f);
|
|
||||||
|
|
||||||
tree.Stat(cleft).base_weight = 1.2;
|
tree.Stat(cleft).base_weight = 1.2;
|
||||||
tree.Stat(cright).base_weight = 1.3;
|
tree.Stat(cright).base_weight = 1.3;
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user