Integer gradient summation for GPU histogram algorithm. (#2681)
This commit is contained in:
@@ -87,65 +87,116 @@ typedef uint64_t bst_ulong; // NOLINT(*)
|
||||
typedef float bst_float;
|
||||
|
||||
|
||||
/*! \brief Implementation of gradient statistics pair */
|
||||
namespace detail {
|
||||
/*! \brief Implementation of gradient statistics pair. Template specialisation
|
||||
* may be used to overload different gradients types e.g. low precision, high
|
||||
* precision, integer, floating point. */
|
||||
template <typename T>
|
||||
struct bst_gpair_internal {
|
||||
class bst_gpair_internal {
|
||||
/*! \brief gradient statistics */
|
||||
T grad;
|
||||
T grad_;
|
||||
/*! \brief second order gradient statistics */
|
||||
T hess;
|
||||
T hess_;
|
||||
|
||||
XGBOOST_DEVICE bst_gpair_internal() : grad(0), hess(0) {}
|
||||
XGBOOST_DEVICE void SetGrad(float g) { grad_ = g; }
|
||||
XGBOOST_DEVICE void SetHess(float h) { hess_ = h; }
|
||||
|
||||
XGBOOST_DEVICE bst_gpair_internal(T grad, T hess)
|
||||
: grad(grad), hess(hess) {}
|
||||
public:
|
||||
typedef T value_t;
|
||||
|
||||
XGBOOST_DEVICE bst_gpair_internal() : grad_(0), hess_(0) {}
|
||||
|
||||
XGBOOST_DEVICE bst_gpair_internal(float grad, float hess) {
|
||||
SetGrad(grad);
|
||||
SetHess(hess);
|
||||
}
|
||||
|
||||
// Copy constructor if of same value type
|
||||
XGBOOST_DEVICE bst_gpair_internal(const bst_gpair_internal<T> &g)
|
||||
: grad_(g.grad_), hess_(g.hess_) {}
|
||||
|
||||
// Copy constructor if different value type - use getters and setters to
|
||||
// perform conversion
|
||||
template <typename T2>
|
||||
XGBOOST_DEVICE bst_gpair_internal(bst_gpair_internal<T2>&g)
|
||||
: grad(g.grad), hess(g.hess) {}
|
||||
XGBOOST_DEVICE bst_gpair_internal(const bst_gpair_internal<T2> &g) {
|
||||
SetGrad(g.GetGrad());
|
||||
SetHess(g.GetHess());
|
||||
}
|
||||
|
||||
XGBOOST_DEVICE bst_gpair_internal<T> &operator+=(const bst_gpair_internal<T> &rhs) {
|
||||
grad += rhs.grad;
|
||||
hess += rhs.hess;
|
||||
XGBOOST_DEVICE float GetGrad() const { return grad_; }
|
||||
XGBOOST_DEVICE float GetHess() const { return hess_; }
|
||||
|
||||
XGBOOST_DEVICE bst_gpair_internal<T> &operator+=(
|
||||
const bst_gpair_internal<T> &rhs) {
|
||||
grad_ += rhs.grad_;
|
||||
hess_ += rhs.hess_;
|
||||
return *this;
|
||||
}
|
||||
|
||||
XGBOOST_DEVICE bst_gpair_internal<T> operator+(const bst_gpair_internal<T> &rhs) const {
|
||||
XGBOOST_DEVICE bst_gpair_internal<T> operator+(
|
||||
const bst_gpair_internal<T> &rhs) const {
|
||||
bst_gpair_internal<T> g;
|
||||
g.grad = grad + rhs.grad;
|
||||
g.hess = hess + rhs.hess;
|
||||
g.grad_ = grad_ + rhs.grad_;
|
||||
g.hess_ = hess_ + rhs.hess_;
|
||||
return g;
|
||||
}
|
||||
|
||||
XGBOOST_DEVICE bst_gpair_internal<T> &operator-=(const bst_gpair_internal<T> &rhs) {
|
||||
grad -= rhs.grad;
|
||||
hess -= rhs.hess;
|
||||
XGBOOST_DEVICE bst_gpair_internal<T> &operator-=(
|
||||
const bst_gpair_internal<T> &rhs) {
|
||||
grad_ -= rhs.grad_;
|
||||
hess_ -= rhs.hess_;
|
||||
return *this;
|
||||
}
|
||||
|
||||
XGBOOST_DEVICE bst_gpair_internal<T> operator-(const bst_gpair_internal<T> &rhs) const {
|
||||
XGBOOST_DEVICE bst_gpair_internal<T> operator-(
|
||||
const bst_gpair_internal<T> &rhs) const {
|
||||
bst_gpair_internal<T> g;
|
||||
g.grad = grad - rhs.grad;
|
||||
g.hess = hess - rhs.hess;
|
||||
g.grad_ = grad_ - rhs.grad_;
|
||||
g.hess_ = hess_ - rhs.hess_;
|
||||
return g;
|
||||
}
|
||||
|
||||
XGBOOST_DEVICE bst_gpair_internal(int value) {
|
||||
*this = bst_gpair_internal<T>(static_cast<float>(value), static_cast<float>(value));
|
||||
*this = bst_gpair_internal<T>(static_cast<float>(value),
|
||||
static_cast<float>(value));
|
||||
}
|
||||
|
||||
friend std::ostream &operator<<(std::ostream &os,
|
||||
const bst_gpair_internal<T> &g) {
|
||||
os << g.grad << "/" << g.hess;
|
||||
os << g.grad_ << "/" << g.hess_;
|
||||
return os;
|
||||
}
|
||||
};
|
||||
|
||||
template<>
|
||||
inline XGBOOST_DEVICE float bst_gpair_internal<int64_t>::GetGrad() const {
|
||||
return grad_ * 1e-5;
|
||||
}
|
||||
template<>
|
||||
inline XGBOOST_DEVICE float bst_gpair_internal<int64_t>::GetHess() const {
|
||||
return hess_ * 1e-5;
|
||||
}
|
||||
template<>
|
||||
inline XGBOOST_DEVICE void bst_gpair_internal<int64_t>::SetGrad(float g) {
|
||||
grad_ = g * 1e5;
|
||||
}
|
||||
template<>
|
||||
inline XGBOOST_DEVICE void bst_gpair_internal<int64_t>::SetHess(float h) {
|
||||
hess_ = h * 1e5;
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
/*! \brief gradient statistics pair usually needed in gradient boosting */
|
||||
typedef bst_gpair_internal<float> bst_gpair;
|
||||
typedef detail::bst_gpair_internal<float> bst_gpair;
|
||||
|
||||
/*! \brief High precision gradient statistics pair */
|
||||
typedef bst_gpair_internal<double> bst_gpair_precise;
|
||||
typedef detail::bst_gpair_internal<double> bst_gpair_precise;
|
||||
|
||||
/*! \brief High precision gradient statistics pair with integer backed
|
||||
* storage. Operators are associative where floating point versions are not
|
||||
* associative. */
|
||||
typedef detail::bst_gpair_internal<int64_t> bst_gpair_integer;
|
||||
|
||||
/*! \brief small eps gap for minimum split decision. */
|
||||
const bst_float rt_eps = 1e-6f;
|
||||
|
||||
Reference in New Issue
Block a user