Accept other gradient types for split entry. (#5467)
This commit is contained in:
parent
86beb68ce8
commit
939973630d
@ -286,19 +286,6 @@ XGBOOST_DEVICE inline T CalcGain(const TrainingParams &p, StatT stat) {
|
|||||||
return CalcGain(p, stat.GetGrad(), stat.GetHess());
|
return CalcGain(p, stat.GetGrad(), stat.GetHess());
|
||||||
}
|
}
|
||||||
|
|
||||||
// calculate cost of loss function with four statistics
|
|
||||||
template <typename TrainingParams, typename T>
|
|
||||||
XGBOOST_DEVICE inline T CalcGain(const TrainingParams &p, T sum_grad, T sum_hess,
|
|
||||||
T test_grad, T test_hess) {
|
|
||||||
T w = CalcWeight(sum_grad, sum_hess);
|
|
||||||
T ret = CalcGainGivenWeight(p, test_grad, test_hess);
|
|
||||||
if (p.reg_alpha == 0.0f) {
|
|
||||||
return ret;
|
|
||||||
} else {
|
|
||||||
return ret + p.reg_alpha * std::abs(w);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// calculate weight given the statistics
|
// calculate weight given the statistics
|
||||||
template <typename TrainingParams, typename T>
|
template <typename TrainingParams, typename T>
|
||||||
XGBOOST_DEVICE inline T CalcWeight(const TrainingParams &p, T sum_grad,
|
XGBOOST_DEVICE inline T CalcWeight(const TrainingParams &p, T sum_grad,
|
||||||
@ -340,6 +327,11 @@ struct XGBOOST_ALIGNAS(16) GradStats {
|
|||||||
XGBOOST_DEVICE double GetGrad() const { return sum_grad; }
|
XGBOOST_DEVICE double GetGrad() const { return sum_grad; }
|
||||||
XGBOOST_DEVICE double GetHess() const { return sum_hess; }
|
XGBOOST_DEVICE double GetHess() const { return sum_hess; }
|
||||||
|
|
||||||
|
friend std::ostream& operator<<(std::ostream& os, GradStats s) {
|
||||||
|
os << s.GetGrad() << "/" << s.GetHess();
|
||||||
|
return os;
|
||||||
|
}
|
||||||
|
|
||||||
XGBOOST_DEVICE GradStats() : sum_grad{0}, sum_hess{0} {
|
XGBOOST_DEVICE GradStats() : sum_grad{0}, sum_hess{0} {
|
||||||
static_assert(sizeof(GradStats) == 16,
|
static_assert(sizeof(GradStats) == 16,
|
||||||
"Size of GradStats is not 16 bytes.");
|
"Size of GradStats is not 16 bytes.");
|
||||||
@ -383,28 +375,42 @@ struct XGBOOST_ALIGNAS(16) GradStats {
|
|||||||
* \brief statistics that is helpful to store
|
* \brief statistics that is helpful to store
|
||||||
* and represent a split solution for the tree
|
* and represent a split solution for the tree
|
||||||
*/
|
*/
|
||||||
struct SplitEntry {
|
template<typename GradientT>
|
||||||
|
struct SplitEntryContainer {
|
||||||
/*! \brief loss change after split this node */
|
/*! \brief loss change after split this node */
|
||||||
bst_float loss_chg {0.0f};
|
bst_float loss_chg {0.0f};
|
||||||
/*! \brief split index */
|
/*! \brief split index */
|
||||||
unsigned sindex{0};
|
bst_feature_t sindex{0};
|
||||||
bst_float split_value{0.0f};
|
bst_float split_value{0.0f};
|
||||||
GradStats left_sum;
|
|
||||||
GradStats right_sum;
|
|
||||||
|
|
||||||
/*! \brief constructor */
|
GradientT left_sum;
|
||||||
SplitEntry() = default;
|
GradientT right_sum;
|
||||||
|
|
||||||
|
SplitEntryContainer() = default;
|
||||||
|
|
||||||
|
friend std::ostream& operator<<(std::ostream& os, SplitEntryContainer const& s) {
|
||||||
|
os << "loss_chg: " << s.loss_chg << ", "
|
||||||
|
<< "split index: " << s.SplitIndex() << ", "
|
||||||
|
<< "split value: " << s.split_value << ", "
|
||||||
|
<< "left_sum: " << s.left_sum << ", "
|
||||||
|
<< "right_sum: " << s.right_sum;
|
||||||
|
return os;
|
||||||
|
}
|
||||||
|
/*!\return feature index to split on */
|
||||||
|
bst_feature_t SplitIndex() const { return sindex & ((1U << 31) - 1U); }
|
||||||
|
/*!\return whether missing value goes to left branch */
|
||||||
|
bool DefaultLeft() const { return (sindex >> 31) != 0; }
|
||||||
/*!
|
/*!
|
||||||
* \brief decides whether we can replace current entry with the given
|
* \brief decides whether we can replace current entry with the given statistics
|
||||||
* statistics
|
*
|
||||||
* This function gives better priority to lower index when loss_chg ==
|
* This function gives better priority to lower index when loss_chg == new_loss_chg.
|
||||||
* new_loss_chg.
|
|
||||||
* Not the best way, but helps to give consistent result during multi-thread
|
* Not the best way, but helps to give consistent result during multi-thread
|
||||||
* execution.
|
* execution.
|
||||||
|
*
|
||||||
* \param new_loss_chg the loss reduction get through the split
|
* \param new_loss_chg the loss reduction get through the split
|
||||||
* \param split_index the feature index where the split is on
|
* \param split_index the feature index where the split is on
|
||||||
*/
|
*/
|
||||||
inline bool NeedReplace(bst_float new_loss_chg, unsigned split_index) const {
|
bool NeedReplace(bst_float new_loss_chg, unsigned split_index) const {
|
||||||
if (this->SplitIndex() <= split_index) {
|
if (this->SplitIndex() <= split_index) {
|
||||||
return new_loss_chg > this->loss_chg;
|
return new_loss_chg > this->loss_chg;
|
||||||
} else {
|
} else {
|
||||||
@ -416,7 +422,7 @@ struct SplitEntry {
|
|||||||
* \param e candidate split solution
|
* \param e candidate split solution
|
||||||
* \return whether the proposed split is better and can replace current split
|
* \return whether the proposed split is better and can replace current split
|
||||||
*/
|
*/
|
||||||
inline bool Update(const SplitEntry &e) {
|
inline bool Update(const SplitEntryContainer &e) {
|
||||||
if (this->NeedReplace(e.loss_chg, e.SplitIndex())) {
|
if (this->NeedReplace(e.loss_chg, e.SplitIndex())) {
|
||||||
this->loss_chg = e.loss_chg;
|
this->loss_chg = e.loss_chg;
|
||||||
this->sindex = e.sindex;
|
this->sindex = e.sindex;
|
||||||
@ -436,9 +442,10 @@ struct SplitEntry {
|
|||||||
* \param default_left whether the missing value goes to left
|
* \param default_left whether the missing value goes to left
|
||||||
* \return whether the proposed split is better and can replace current split
|
* \return whether the proposed split is better and can replace current split
|
||||||
*/
|
*/
|
||||||
inline bool Update(bst_float new_loss_chg, unsigned split_index,
|
bool Update(bst_float new_loss_chg, unsigned split_index,
|
||||||
bst_float new_split_value, bool default_left,
|
bst_float new_split_value, bool default_left,
|
||||||
const GradStats &left_sum, const GradStats &right_sum) {
|
const GradientT &left_sum,
|
||||||
|
const GradientT &right_sum) {
|
||||||
if (this->NeedReplace(new_loss_chg, split_index)) {
|
if (this->NeedReplace(new_loss_chg, split_index)) {
|
||||||
this->loss_chg = new_loss_chg;
|
this->loss_chg = new_loss_chg;
|
||||||
if (default_left) {
|
if (default_left) {
|
||||||
@ -453,17 +460,16 @@ struct SplitEntry {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/*! \brief same as update, used by AllReduce*/
|
/*! \brief same as update, used by AllReduce*/
|
||||||
inline static void Reduce(SplitEntry &dst, // NOLINT(*)
|
inline static void Reduce(SplitEntryContainer &dst, // NOLINT(*)
|
||||||
const SplitEntry &src) { // NOLINT(*)
|
const SplitEntryContainer &src) { // NOLINT(*)
|
||||||
dst.Update(src);
|
dst.Update(src);
|
||||||
}
|
}
|
||||||
/*!\return feature index to split on */
|
|
||||||
inline unsigned SplitIndex() const { return sindex & ((1U << 31) - 1U); }
|
|
||||||
/*!\return whether missing value goes to left branch */
|
|
||||||
inline bool DefaultLeft() const { return (sindex >> 31) != 0; }
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
using SplitEntry = SplitEntryContainer<GradStats>;
|
||||||
|
|
||||||
} // namespace tree
|
} // namespace tree
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user