Use quantised gradients in gpu_hist histograms (#8246)
This commit is contained in:
@@ -259,10 +259,61 @@ class GradientPairInternal {
|
||||
using GradientPair = detail::GradientPairInternal<float>;
|
||||
/*! \brief High precision gradient statistics pair */
|
||||
using GradientPairPrecise = detail::GradientPairInternal<double>;
|
||||
/*! \brief Fixed point representation for gradient pair. */
|
||||
using GradientPairInt32 = detail::GradientPairInternal<int>;
|
||||
/*! \brief Fixed point representation for high precision gradient pair. */
|
||||
using GradientPairInt64 = detail::GradientPairInternal<int64_t>;
|
||||
|
||||
/*! \brief Fixed point representation for high precision gradient pair. Has a different interface so
|
||||
* we don't accidentally use it in gain calculations.*/
|
||||
class GradientPairInt64 {
|
||||
using T = int64_t;
|
||||
T grad_;
|
||||
T hess_;
|
||||
|
||||
public:
|
||||
using ValueT = T;
|
||||
|
||||
XGBOOST_DEVICE GradientPairInt64(T grad, T hess) : grad_(grad), hess_(hess) {}
|
||||
GradientPairInt64() = default;
|
||||
|
||||
// Copy constructor if of same value type, marked as default to be trivially_copyable
|
||||
GradientPairInt64(const GradientPairInt64 &g) = default;
|
||||
|
||||
XGBOOST_DEVICE T GetQuantisedGrad() const { return grad_; }
|
||||
XGBOOST_DEVICE T GetQuantisedHess() const { return hess_; }
|
||||
|
||||
XGBOOST_DEVICE GradientPairInt64 &operator+=(const GradientPairInt64 &rhs) {
|
||||
grad_ += rhs.grad_;
|
||||
hess_ += rhs.hess_;
|
||||
return *this;
|
||||
}
|
||||
|
||||
XGBOOST_DEVICE GradientPairInt64 operator+(const GradientPairInt64 &rhs) const {
|
||||
GradientPairInt64 g;
|
||||
g.grad_ = grad_ + rhs.grad_;
|
||||
g.hess_ = hess_ + rhs.hess_;
|
||||
return g;
|
||||
}
|
||||
|
||||
XGBOOST_DEVICE GradientPairInt64 &operator-=(const GradientPairInt64 &rhs) {
|
||||
grad_ -= rhs.grad_;
|
||||
hess_ -= rhs.hess_;
|
||||
return *this;
|
||||
}
|
||||
|
||||
XGBOOST_DEVICE GradientPairInt64 operator-(const GradientPairInt64 &rhs) const {
|
||||
GradientPairInt64 g;
|
||||
g.grad_ = grad_ - rhs.grad_;
|
||||
g.hess_ = hess_ - rhs.hess_;
|
||||
return g;
|
||||
}
|
||||
|
||||
XGBOOST_DEVICE bool operator==(const GradientPairInt64 &rhs) const {
|
||||
return grad_ == rhs.grad_ && hess_ == rhs.hess_;
|
||||
}
|
||||
friend std::ostream &operator<<(std::ostream &os,
|
||||
const GradientPairInt64 &g) {
|
||||
os << g.GetQuantisedGrad() << "/" << g.GetQuantisedHess();
|
||||
return os;
|
||||
}
|
||||
};
|
||||
|
||||
using Args = std::vector<std::pair<std::string, std::string> >;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user