Integer gradient summation for GPU histogram algorithm. (#2681)

This commit is contained in:
Rory Mitchell
2017-09-08 15:07:29 +12:00
committed by GitHub
parent 15267eedf2
commit e6a9063344
15 changed files with 182 additions and 128 deletions

View File

@@ -87,65 +87,116 @@ typedef uint64_t bst_ulong; // NOLINT(*)
typedef float bst_float;
/*! \brief Implementation of gradient statistics pair */
namespace detail {
/*! \brief Implementation of gradient statistics pair. Template specialisation
* may be used to overload different gradients types e.g. low precision, high
* precision, integer, floating point. */
template <typename T>
struct bst_gpair_internal {
class bst_gpair_internal {
/*! \brief gradient statistics */
T grad;
T grad_;
/*! \brief second order gradient statistics */
T hess;
T hess_;
XGBOOST_DEVICE bst_gpair_internal() : grad(0), hess(0) {}
XGBOOST_DEVICE void SetGrad(float g) { grad_ = g; }
XGBOOST_DEVICE void SetHess(float h) { hess_ = h; }
XGBOOST_DEVICE bst_gpair_internal(T grad, T hess)
: grad(grad), hess(hess) {}
public:
typedef T value_t;
XGBOOST_DEVICE bst_gpair_internal() : grad_(0), hess_(0) {}
XGBOOST_DEVICE bst_gpair_internal(float grad, float hess) {
SetGrad(grad);
SetHess(hess);
}
// Copy constructor if of same value type
XGBOOST_DEVICE bst_gpair_internal(const bst_gpair_internal<T> &g)
: grad_(g.grad_), hess_(g.hess_) {}
// Copy constructor if different value type - use getters and setters to
// perform conversion
template <typename T2>
XGBOOST_DEVICE bst_gpair_internal(bst_gpair_internal<T2>&g)
: grad(g.grad), hess(g.hess) {}
XGBOOST_DEVICE bst_gpair_internal(const bst_gpair_internal<T2> &g) {
SetGrad(g.GetGrad());
SetHess(g.GetHess());
}
XGBOOST_DEVICE bst_gpair_internal<T> &operator+=(const bst_gpair_internal<T> &rhs) {
grad += rhs.grad;
hess += rhs.hess;
XGBOOST_DEVICE float GetGrad() const { return grad_; }
XGBOOST_DEVICE float GetHess() const { return hess_; }
XGBOOST_DEVICE bst_gpair_internal<T> &operator+=(
const bst_gpair_internal<T> &rhs) {
grad_ += rhs.grad_;
hess_ += rhs.hess_;
return *this;
}
XGBOOST_DEVICE bst_gpair_internal<T> operator+(const bst_gpair_internal<T> &rhs) const {
XGBOOST_DEVICE bst_gpair_internal<T> operator+(
const bst_gpair_internal<T> &rhs) const {
bst_gpair_internal<T> g;
g.grad = grad + rhs.grad;
g.hess = hess + rhs.hess;
g.grad_ = grad_ + rhs.grad_;
g.hess_ = hess_ + rhs.hess_;
return g;
}
XGBOOST_DEVICE bst_gpair_internal<T> &operator-=(const bst_gpair_internal<T> &rhs) {
grad -= rhs.grad;
hess -= rhs.hess;
XGBOOST_DEVICE bst_gpair_internal<T> &operator-=(
const bst_gpair_internal<T> &rhs) {
grad_ -= rhs.grad_;
hess_ -= rhs.hess_;
return *this;
}
XGBOOST_DEVICE bst_gpair_internal<T> operator-(const bst_gpair_internal<T> &rhs) const {
XGBOOST_DEVICE bst_gpair_internal<T> operator-(
const bst_gpair_internal<T> &rhs) const {
bst_gpair_internal<T> g;
g.grad = grad - rhs.grad;
g.hess = hess - rhs.hess;
g.grad_ = grad_ - rhs.grad_;
g.hess_ = hess_ - rhs.hess_;
return g;
}
XGBOOST_DEVICE bst_gpair_internal(int value) {
*this = bst_gpair_internal<T>(static_cast<float>(value), static_cast<float>(value));
*this = bst_gpair_internal<T>(static_cast<float>(value),
static_cast<float>(value));
}
friend std::ostream &operator<<(std::ostream &os,
const bst_gpair_internal<T> &g) {
os << g.grad << "/" << g.hess;
os << g.grad_ << "/" << g.hess_;
return os;
}
};
template<>
inline XGBOOST_DEVICE float bst_gpair_internal<int64_t>::GetGrad() const {
return grad_ * 1e-5;
}
template<>
inline XGBOOST_DEVICE float bst_gpair_internal<int64_t>::GetHess() const {
return hess_ * 1e-5;
}
template<>
inline XGBOOST_DEVICE void bst_gpair_internal<int64_t>::SetGrad(float g) {
grad_ = g * 1e5;
}
template<>
inline XGBOOST_DEVICE void bst_gpair_internal<int64_t>::SetHess(float h) {
hess_ = h * 1e5;
}
} // namespace detail
/*! \brief gradient statistics pair usually needed in gradient boosting */
typedef bst_gpair_internal<float> bst_gpair;
typedef detail::bst_gpair_internal<float> bst_gpair;
/*! \brief High precision gradient statistics pair */
typedef bst_gpair_internal<double> bst_gpair_precise;
typedef detail::bst_gpair_internal<double> bst_gpair_precise;
/*! \brief High precision gradient statistics pair with integer backed
* storage. Operators are associative where floating point versions are not
* associative. */
typedef detail::bst_gpair_internal<int64_t> bst_gpair_integer;
/*! \brief small eps gap for minimum split decision. */
const bst_float rt_eps = 1e-6f;