Support learning rate for zero-hessian objectives. (#8866)

This commit is contained in:
Jiaming Yuan
2023-03-06 20:33:28 +08:00
committed by GitHub
parent 173096a6a7
commit 228a46e8ad
34 changed files with 464 additions and 434 deletions

View File

@@ -17,13 +17,11 @@
#include "../../common/random.h"
#include "../../data/gradient_index.h"
#include "../constraints.h"
#include "../param.h"
#include "../param.h" // for TrainParam
#include "../split_evaluator.h"
#include "xgboost/context.h"
namespace xgboost {
namespace tree {
namespace xgboost::tree {
template <typename ExpandEntry>
class HistEvaluator {
private:
@@ -36,7 +34,7 @@ class HistEvaluator {
private:
Context const* ctx_;
TrainParam param_;
TrainParam const* param_;
std::shared_ptr<common::ColumnSampler> column_sampler_;
TreeEvaluator tree_evaluator_;
bool is_col_split_{false};
@@ -55,8 +53,9 @@ class HistEvaluator {
}
}
bool IsValid(GradStats const &left, GradStats const &right) const {
return left.GetHess() >= param_.min_child_weight && right.GetHess() >= param_.min_child_weight;
[[nodiscard]] bool IsValid(GradStats const &left, GradStats const &right) const {
return left.GetHess() >= param_->min_child_weight &&
right.GetHess() >= param_->min_child_weight;
}
/**
@@ -95,9 +94,10 @@ class HistEvaluator {
right_sum = GradStats{hist[i]};
left_sum.SetSubstract(parent.stats, right_sum);
if (IsValid(left_sum, right_sum)) {
auto missing_left_chg = static_cast<float>(
evaluator.CalcSplitGain(param_, nidx, fidx, GradStats{left_sum}, GradStats{right_sum}) -
parent.root_gain);
auto missing_left_chg =
static_cast<float>(evaluator.CalcSplitGain(*param_, nidx, fidx, GradStats{left_sum},
GradStats{right_sum}) -
parent.root_gain);
best.Update(missing_left_chg, fidx, split_pt, true, true, left_sum, right_sum);
}
@@ -105,9 +105,10 @@ class HistEvaluator {
right_sum.Add(missing);
left_sum.SetSubstract(parent.stats, right_sum);
if (IsValid(left_sum, right_sum)) {
auto missing_right_chg = static_cast<float>(
evaluator.CalcSplitGain(param_, nidx, fidx, GradStats{left_sum}, GradStats{right_sum}) -
parent.root_gain);
auto missing_right_chg =
static_cast<float>(evaluator.CalcSplitGain(*param_, nidx, fidx, GradStats{left_sum},
GradStats{right_sum}) -
parent.root_gain);
best.Update(missing_right_chg, fidx, split_pt, false, true, left_sum, right_sum);
}
}
@@ -152,7 +153,7 @@ class HistEvaluator {
bst_bin_t f_begin = cut_ptr[fidx];
bst_bin_t f_end = cut_ptr[fidx + 1];
bst_bin_t n_bins_feature{f_end - f_begin};
auto n_bins = std::min(param_.max_cat_threshold, n_bins_feature);
auto n_bins = std::min(param_->max_cat_threshold, n_bins_feature);
// statistics on both sides of split
GradStats left_sum;
@@ -181,9 +182,9 @@ class HistEvaluator {
right_sum.SetSubstract(parent.stats, left_sum); // missing on right
}
if (IsValid(left_sum, right_sum)) {
auto loss_chg =
evaluator.CalcSplitGain(param_, nidx, fidx, GradStats{left_sum}, GradStats{right_sum}) -
parent.root_gain;
auto loss_chg = evaluator.CalcSplitGain(*param_, nidx, fidx, GradStats{left_sum},
GradStats{right_sum}) -
parent.root_gain;
// We don't have a numeric split point, nan here is a dummy split.
if (best.Update(loss_chg, fidx, std::numeric_limits<float>::quiet_NaN(), d_step == 1, true,
left_sum, right_sum)) {
@@ -256,7 +257,7 @@ class HistEvaluator {
if (d_step > 0) {
// forward enumeration: split at right bound of each bin
loss_chg =
static_cast<float>(evaluator.CalcSplitGain(param_, nidx, fidx, GradStats{left_sum},
static_cast<float>(evaluator.CalcSplitGain(*param_, nidx, fidx, GradStats{left_sum},
GradStats{right_sum}) -
parent.root_gain);
split_pt = cut_val[i]; // not used for partition based
@@ -264,7 +265,7 @@ class HistEvaluator {
} else {
// backward enumeration: split at left bound of each bin
loss_chg =
static_cast<float>(evaluator.CalcSplitGain(param_, nidx, fidx, GradStats{right_sum},
static_cast<float>(evaluator.CalcSplitGain(*param_, nidx, fidx, GradStats{right_sum},
GradStats{left_sum}) -
parent.root_gain);
if (i == imin) {
@@ -326,7 +327,7 @@ class HistEvaluator {
}
if (is_cat) {
auto n_bins = cut_ptrs.at(fidx + 1) - cut_ptrs[fidx];
if (common::UseOneHot(n_bins, param_.max_cat_to_onehot)) {
if (common::UseOneHot(n_bins, param_->max_cat_to_onehot)) {
EnumerateOneHot(cut, histogram, fidx, nidx, evaluator, best);
} else {
std::vector<size_t> sorted_idx(n_bins);
@@ -334,8 +335,8 @@ class HistEvaluator {
auto feat_hist = histogram.subspan(cut_ptrs[fidx], n_bins);
// Sort the histogram to get contiguous partitions.
std::stable_sort(sorted_idx.begin(), sorted_idx.end(), [&](size_t l, size_t r) {
auto ret = evaluator.CalcWeightCat(param_, feat_hist[l]) <
evaluator.CalcWeightCat(param_, feat_hist[r]);
auto ret = evaluator.CalcWeightCat(*param_, feat_hist[l]) <
evaluator.CalcWeightCat(*param_, feat_hist[r]);
return ret;
});
EnumeratePart<+1>(cut, sorted_idx, histogram, fidx, nidx, evaluator, best);
@@ -382,24 +383,22 @@ class HistEvaluator {
GradStats parent_sum = candidate.split.left_sum;
parent_sum.Add(candidate.split.right_sum);
auto base_weight =
evaluator.CalcWeight(candidate.nid, param_, GradStats{parent_sum});
auto base_weight = evaluator.CalcWeight(candidate.nid, *param_, GradStats{parent_sum});
auto left_weight =
evaluator.CalcWeight(candidate.nid, param_, GradStats{candidate.split.left_sum});
evaluator.CalcWeight(candidate.nid, *param_, GradStats{candidate.split.left_sum});
auto right_weight =
evaluator.CalcWeight(candidate.nid, param_, GradStats{candidate.split.right_sum});
evaluator.CalcWeight(candidate.nid, *param_, GradStats{candidate.split.right_sum});
if (candidate.split.is_cat) {
tree.ExpandCategorical(
candidate.nid, candidate.split.SplitIndex(), candidate.split.cat_bits,
candidate.split.DefaultLeft(), base_weight, left_weight * param_.learning_rate,
right_weight * param_.learning_rate, candidate.split.loss_chg, parent_sum.GetHess(),
candidate.split.DefaultLeft(), base_weight, left_weight * param_->learning_rate,
right_weight * param_->learning_rate, candidate.split.loss_chg, parent_sum.GetHess(),
candidate.split.left_sum.GetHess(), candidate.split.right_sum.GetHess());
} else {
tree.ExpandNode(candidate.nid, candidate.split.SplitIndex(), candidate.split.split_value,
candidate.split.DefaultLeft(), base_weight,
left_weight * param_.learning_rate, right_weight * param_.learning_rate,
left_weight * param_->learning_rate, right_weight * param_->learning_rate,
candidate.split.loss_chg, parent_sum.GetHess(),
candidate.split.left_sum.GetHess(), candidate.split.right_sum.GetHess());
}
@@ -415,11 +414,11 @@ class HistEvaluator {
max_node = std::max(candidate.nid, max_node);
snode_.resize(tree.GetNodes().size());
snode_.at(left_child).stats = candidate.split.left_sum;
snode_.at(left_child).root_gain = evaluator.CalcGain(
candidate.nid, param_, GradStats{candidate.split.left_sum});
snode_.at(left_child).root_gain =
evaluator.CalcGain(candidate.nid, *param_, GradStats{candidate.split.left_sum});
snode_.at(right_child).stats = candidate.split.right_sum;
snode_.at(right_child).root_gain = evaluator.CalcGain(
candidate.nid, param_, GradStats{candidate.split.right_sum});
snode_.at(right_child).root_gain =
evaluator.CalcGain(candidate.nid, *param_, GradStats{candidate.split.right_sum});
interaction_constraints_.Split(candidate.nid,
tree[candidate.nid].SplitIndex(), left_child,
@@ -429,31 +428,31 @@ class HistEvaluator {
auto Evaluator() const { return tree_evaluator_.GetEvaluator(); }
auto const& Stats() const { return snode_; }
float InitRoot(GradStats const& root_sum) {
float InitRoot(GradStats const &root_sum) {
snode_.resize(1);
auto root_evaluator = tree_evaluator_.GetEvaluator();
snode_[0].stats = GradStats{root_sum.GetGrad(), root_sum.GetHess()};
snode_[0].root_gain = root_evaluator.CalcGain(RegTree::kRoot, param_,
GradStats{snode_[0].stats});
auto weight = root_evaluator.CalcWeight(RegTree::kRoot, param_,
GradStats{snode_[0].stats});
snode_[0].root_gain =
root_evaluator.CalcGain(RegTree::kRoot, *param_, GradStats{snode_[0].stats});
auto weight = root_evaluator.CalcWeight(RegTree::kRoot, *param_, GradStats{snode_[0].stats});
return weight;
}
public:
// The column sampler must be constructed by caller since we need to preserve the rng
// for the entire training session.
explicit HistEvaluator(Context const* ctx, TrainParam const &param, MetaInfo const &info,
explicit HistEvaluator(Context const *ctx, TrainParam const *param, MetaInfo const &info,
std::shared_ptr<common::ColumnSampler> sampler)
: ctx_{ctx}, param_{param},
: ctx_{ctx},
param_{param},
column_sampler_{std::move(sampler)},
tree_evaluator_{param, static_cast<bst_feature_t>(info.num_col_), Context::kCpuId},
tree_evaluator_{*param, static_cast<bst_feature_t>(info.num_col_), Context::kCpuId},
is_col_split_{info.data_split_mode == DataSplitMode::kCol} {
interaction_constraints_.Configure(param, info.num_col_);
interaction_constraints_.Configure(*param, info.num_col_);
column_sampler_->Init(ctx, info.num_col_, info.feature_weights.HostVector(),
param_.colsample_bynode, param_.colsample_bylevel,
param_.colsample_bytree);
param_->colsample_bynode, param_->colsample_bylevel,
param_->colsample_bytree);
}
};
@@ -488,6 +487,5 @@ void UpdatePredictionCacheImpl(Context const *ctx, RegTree const *p_last_tree,
});
}
}
} // namespace tree
} // namespace xgboost
} // namespace xgboost::tree
#endif // XGBOOST_TREE_HIST_EVALUATE_SPLITS_H_