Accept other gradient types for split entry. (#5467)
This commit is contained in:
parent
86beb68ce8
commit
939973630d
@ -286,19 +286,6 @@ XGBOOST_DEVICE inline T CalcGain(const TrainingParams &p, StatT stat) {
|
||||
return CalcGain(p, stat.GetGrad(), stat.GetHess());
|
||||
}
|
||||
|
||||
// calculate cost of loss function with four statistics
|
||||
template <typename TrainingParams, typename T>
|
||||
XGBOOST_DEVICE inline T CalcGain(const TrainingParams &p, T sum_grad, T sum_hess,
|
||||
T test_grad, T test_hess) {
|
||||
T w = CalcWeight(sum_grad, sum_hess);
|
||||
T ret = CalcGainGivenWeight(p, test_grad, test_hess);
|
||||
if (p.reg_alpha == 0.0f) {
|
||||
return ret;
|
||||
} else {
|
||||
return ret + p.reg_alpha * std::abs(w);
|
||||
}
|
||||
}
|
||||
|
||||
// calculate weight given the statistics
|
||||
template <typename TrainingParams, typename T>
|
||||
XGBOOST_DEVICE inline T CalcWeight(const TrainingParams &p, T sum_grad,
|
||||
@ -340,6 +327,11 @@ struct XGBOOST_ALIGNAS(16) GradStats {
|
||||
XGBOOST_DEVICE double GetGrad() const { return sum_grad; }
|
||||
XGBOOST_DEVICE double GetHess() const { return sum_hess; }
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& os, GradStats s) {
|
||||
os << s.GetGrad() << "/" << s.GetHess();
|
||||
return os;
|
||||
}
|
||||
|
||||
XGBOOST_DEVICE GradStats() : sum_grad{0}, sum_hess{0} {
|
||||
static_assert(sizeof(GradStats) == 16,
|
||||
"Size of GradStats is not 16 bytes.");
|
||||
@ -383,28 +375,42 @@ struct XGBOOST_ALIGNAS(16) GradStats {
|
||||
* \brief statistics that is helpful to store
|
||||
* and represent a split solution for the tree
|
||||
*/
|
||||
struct SplitEntry {
|
||||
template<typename GradientT>
|
||||
struct SplitEntryContainer {
|
||||
/*! \brief loss change after split this node */
|
||||
bst_float loss_chg {0.0f};
|
||||
/*! \brief split index */
|
||||
unsigned sindex{0};
|
||||
bst_feature_t sindex{0};
|
||||
bst_float split_value{0.0f};
|
||||
GradStats left_sum;
|
||||
GradStats right_sum;
|
||||
|
||||
/*! \brief constructor */
|
||||
SplitEntry() = default;
|
||||
GradientT left_sum;
|
||||
GradientT right_sum;
|
||||
|
||||
SplitEntryContainer() = default;
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& os, SplitEntryContainer const& s) {
|
||||
os << "loss_chg: " << s.loss_chg << ", "
|
||||
<< "split index: " << s.SplitIndex() << ", "
|
||||
<< "split value: " << s.split_value << ", "
|
||||
<< "left_sum: " << s.left_sum << ", "
|
||||
<< "right_sum: " << s.right_sum;
|
||||
return os;
|
||||
}
|
||||
/*!\return feature index to split on */
|
||||
bst_feature_t SplitIndex() const { return sindex & ((1U << 31) - 1U); }
|
||||
/*!\return whether missing value goes to left branch */
|
||||
bool DefaultLeft() const { return (sindex >> 31) != 0; }
|
||||
/*!
|
||||
* \brief decides whether we can replace current entry with the given
|
||||
* statistics
|
||||
* This function gives better priority to lower index when loss_chg ==
|
||||
* new_loss_chg.
|
||||
* \brief decides whether we can replace current entry with the given statistics
|
||||
*
|
||||
* This function gives better priority to lower index when loss_chg == new_loss_chg.
|
||||
* Not the best way, but helps to give consistent result during multi-thread
|
||||
* execution.
|
||||
* execution.
|
||||
*
|
||||
* \param new_loss_chg the loss reduction get through the split
|
||||
* \param split_index the feature index where the split is on
|
||||
*/
|
||||
inline bool NeedReplace(bst_float new_loss_chg, unsigned split_index) const {
|
||||
bool NeedReplace(bst_float new_loss_chg, unsigned split_index) const {
|
||||
if (this->SplitIndex() <= split_index) {
|
||||
return new_loss_chg > this->loss_chg;
|
||||
} else {
|
||||
@ -416,7 +422,7 @@ struct SplitEntry {
|
||||
* \param e candidate split solution
|
||||
* \return whether the proposed split is better and can replace current split
|
||||
*/
|
||||
inline bool Update(const SplitEntry &e) {
|
||||
inline bool Update(const SplitEntryContainer &e) {
|
||||
if (this->NeedReplace(e.loss_chg, e.SplitIndex())) {
|
||||
this->loss_chg = e.loss_chg;
|
||||
this->sindex = e.sindex;
|
||||
@ -436,9 +442,10 @@ struct SplitEntry {
|
||||
* \param default_left whether the missing value goes to left
|
||||
* \return whether the proposed split is better and can replace current split
|
||||
*/
|
||||
inline bool Update(bst_float new_loss_chg, unsigned split_index,
|
||||
bst_float new_split_value, bool default_left,
|
||||
const GradStats &left_sum, const GradStats &right_sum) {
|
||||
bool Update(bst_float new_loss_chg, unsigned split_index,
|
||||
bst_float new_split_value, bool default_left,
|
||||
const GradientT &left_sum,
|
||||
const GradientT &right_sum) {
|
||||
if (this->NeedReplace(new_loss_chg, split_index)) {
|
||||
this->loss_chg = new_loss_chg;
|
||||
if (default_left) {
|
||||
@ -453,17 +460,16 @@ struct SplitEntry {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/*! \brief same as update, used by AllReduce*/
|
||||
inline static void Reduce(SplitEntry &dst, // NOLINT(*)
|
||||
const SplitEntry &src) { // NOLINT(*)
|
||||
inline static void Reduce(SplitEntryContainer &dst, // NOLINT(*)
|
||||
const SplitEntryContainer &src) { // NOLINT(*)
|
||||
dst.Update(src);
|
||||
}
|
||||
/*!\return feature index to split on */
|
||||
inline unsigned SplitIndex() const { return sindex & ((1U << 31) - 1U); }
|
||||
/*!\return whether missing value goes to left branch */
|
||||
inline bool DefaultLeft() const { return (sindex >> 31) != 0; }
|
||||
};
|
||||
|
||||
using SplitEntry = SplitEntryContainer<GradStats>;
|
||||
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user