Use quantised gradients in gpu_hist histograms (#8246)

This commit is contained in:
Rory Mitchell
2022-09-26 17:35:35 +02:00
committed by GitHub
parent 4056974e37
commit 8f77677193
14 changed files with 394 additions and 336 deletions

View File

@@ -259,10 +259,61 @@ class GradientPairInternal {
using GradientPair = detail::GradientPairInternal<float>;
/*! \brief High precision gradient statistics pair */
using GradientPairPrecise = detail::GradientPairInternal<double>;
/*! \brief Fixed point representation for gradient pair. */
using GradientPairInt32 = detail::GradientPairInternal<int>;
/*! \brief Fixed point representation for high precision gradient pair. */
using GradientPairInt64 = detail::GradientPairInternal<int64_t>;
/*! \brief Fixed point representation for high precision gradient pair. Has a different interface so
* we don't accidentally use it in gain calculations.*/
class GradientPairInt64 {
using T = int64_t;
T grad_;
T hess_;
public:
using ValueT = T;
XGBOOST_DEVICE GradientPairInt64(T grad, T hess) : grad_(grad), hess_(hess) {}
GradientPairInt64() = default;
// Copy constructor if of same value type, marked as default to be trivially_copyable
GradientPairInt64(const GradientPairInt64 &g) = default;
XGBOOST_DEVICE T GetQuantisedGrad() const { return grad_; }
XGBOOST_DEVICE T GetQuantisedHess() const { return hess_; }
XGBOOST_DEVICE GradientPairInt64 &operator+=(const GradientPairInt64 &rhs) {
grad_ += rhs.grad_;
hess_ += rhs.hess_;
return *this;
}
XGBOOST_DEVICE GradientPairInt64 operator+(const GradientPairInt64 &rhs) const {
GradientPairInt64 g;
g.grad_ = grad_ + rhs.grad_;
g.hess_ = hess_ + rhs.hess_;
return g;
}
XGBOOST_DEVICE GradientPairInt64 &operator-=(const GradientPairInt64 &rhs) {
grad_ -= rhs.grad_;
hess_ -= rhs.hess_;
return *this;
}
XGBOOST_DEVICE GradientPairInt64 operator-(const GradientPairInt64 &rhs) const {
GradientPairInt64 g;
g.grad_ = grad_ - rhs.grad_;
g.hess_ = hess_ - rhs.hess_;
return g;
}
XGBOOST_DEVICE bool operator==(const GradientPairInt64 &rhs) const {
return grad_ == rhs.grad_ && hess_ == rhs.hess_;
}
friend std::ostream &operator<<(std::ostream &os,
const GradientPairInt64 &g) {
os << g.GetQuantisedGrad() << "/" << g.GetQuantisedHess();
return os;
}
};
using Args = std::vector<std::pair<std::string, std::string> >;