Accept other gradient types for split entry. (#5467)

This commit is contained in:
Jiaming Yuan 2020-04-03 10:38:44 +08:00 committed by GitHub
parent 86beb68ce8
commit 939973630d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -286,19 +286,6 @@ XGBOOST_DEVICE inline T CalcGain(const TrainingParams &p, StatT stat) {
return CalcGain(p, stat.GetGrad(), stat.GetHess());
}
// calculate cost of loss function with four statistics
template <typename TrainingParams, typename T>
XGBOOST_DEVICE inline T CalcGain(const TrainingParams &p, T sum_grad, T sum_hess,
T test_grad, T test_hess) {
T w = CalcWeight(sum_grad, sum_hess);
T ret = CalcGainGivenWeight(p, test_grad, test_hess);
if (p.reg_alpha == 0.0f) {
return ret;
} else {
return ret + p.reg_alpha * std::abs(w);
}
}
// calculate weight given the statistics
template <typename TrainingParams, typename T>
XGBOOST_DEVICE inline T CalcWeight(const TrainingParams &p, T sum_grad,
@ -340,6 +327,11 @@ struct XGBOOST_ALIGNAS(16) GradStats {
XGBOOST_DEVICE double GetGrad() const { return sum_grad; }
XGBOOST_DEVICE double GetHess() const { return sum_hess; }
friend std::ostream& operator<<(std::ostream& os, GradStats s) {
os << s.GetGrad() << "/" << s.GetHess();
return os;
}
XGBOOST_DEVICE GradStats() : sum_grad{0}, sum_hess{0} {
static_assert(sizeof(GradStats) == 16,
"Size of GradStats is not 16 bytes.");
@ -383,28 +375,42 @@ struct XGBOOST_ALIGNAS(16) GradStats {
* \brief statistics that is helpful to store
* and represent a split solution for the tree
*/
struct SplitEntry {
template<typename GradientT>
struct SplitEntryContainer {
/*! \brief loss change after split this node */
bst_float loss_chg {0.0f};
/*! \brief split index */
unsigned sindex{0};
bst_feature_t sindex{0};
bst_float split_value{0.0f};
GradStats left_sum;
GradStats right_sum;
/*! \brief constructor */
SplitEntry() = default;
GradientT left_sum;
GradientT right_sum;
SplitEntryContainer() = default;
friend std::ostream& operator<<(std::ostream& os, SplitEntryContainer const& s) {
os << "loss_chg: " << s.loss_chg << ", "
<< "split index: " << s.SplitIndex() << ", "
<< "split value: " << s.split_value << ", "
<< "left_sum: " << s.left_sum << ", "
<< "right_sum: " << s.right_sum;
return os;
}
/*!\return feature index to split on */
bst_feature_t SplitIndex() const { return sindex & ((1U << 31) - 1U); }
/*!\return whether missing value goes to left branch */
bool DefaultLeft() const { return (sindex >> 31) != 0; }
/*!
* \brief decides whether we can replace current entry with the given
* statistics
* This function gives better priority to lower index when loss_chg ==
* new_loss_chg.
* \brief decides whether we can replace current entry with the given statistics
*
* This function gives better priority to lower index when loss_chg == new_loss_chg.
* Not the best way, but helps to give consistent result during multi-thread
* execution.
* execution.
*
* \param new_loss_chg the loss reduction get through the split
* \param split_index the feature index where the split is on
*/
inline bool NeedReplace(bst_float new_loss_chg, unsigned split_index) const {
bool NeedReplace(bst_float new_loss_chg, unsigned split_index) const {
if (this->SplitIndex() <= split_index) {
return new_loss_chg > this->loss_chg;
} else {
@ -416,7 +422,7 @@ struct SplitEntry {
* \param e candidate split solution
* \return whether the proposed split is better and can replace current split
*/
inline bool Update(const SplitEntry &e) {
inline bool Update(const SplitEntryContainer &e) {
if (this->NeedReplace(e.loss_chg, e.SplitIndex())) {
this->loss_chg = e.loss_chg;
this->sindex = e.sindex;
@ -436,9 +442,10 @@ struct SplitEntry {
* \param default_left whether the missing value goes to left
* \return whether the proposed split is better and can replace current split
*/
inline bool Update(bst_float new_loss_chg, unsigned split_index,
bst_float new_split_value, bool default_left,
const GradStats &left_sum, const GradStats &right_sum) {
bool Update(bst_float new_loss_chg, unsigned split_index,
bst_float new_split_value, bool default_left,
const GradientT &left_sum,
const GradientT &right_sum) {
if (this->NeedReplace(new_loss_chg, split_index)) {
this->loss_chg = new_loss_chg;
if (default_left) {
@ -453,17 +460,16 @@ struct SplitEntry {
return false;
}
}
/*! \brief same as update, used by AllReduce*/
inline static void Reduce(SplitEntry &dst, // NOLINT(*)
const SplitEntry &src) { // NOLINT(*)
inline static void Reduce(SplitEntryContainer &dst, // NOLINT(*)
const SplitEntryContainer &src) { // NOLINT(*)
dst.Update(src);
}
/*!\return feature index to split on */
inline unsigned SplitIndex() const { return sindex & ((1U << 31) - 1U); }
/*!\return whether missing value goes to left branch */
inline bool DefaultLeft() const { return (sindex >> 31) != 0; }
};
using SplitEntry = SplitEntryContainer<GradStats>;
} // namespace tree
} // namespace xgboost