Integer gradient summation for GPU histogram algorithm. (#2681)

This commit is contained in:
Rory Mitchell 2017-09-08 15:07:29 +12:00 committed by GitHub
parent 15267eedf2
commit e6a9063344
15 changed files with 182 additions and 128 deletions

View File

@ -87,65 +87,116 @@ typedef uint64_t bst_ulong; // NOLINT(*)
typedef float bst_float; typedef float bst_float;
/*! \brief Implementation of gradient statistics pair */ namespace detail {
/*! \brief Implementation of gradient statistics pair. Template specialisation
* may be used to overload different gradients types e.g. low precision, high
* precision, integer, floating point. */
template <typename T> template <typename T>
struct bst_gpair_internal { class bst_gpair_internal {
/*! \brief gradient statistics */ /*! \brief gradient statistics */
T grad; T grad_;
/*! \brief second order gradient statistics */ /*! \brief second order gradient statistics */
T hess; T hess_;
XGBOOST_DEVICE bst_gpair_internal() : grad(0), hess(0) {} XGBOOST_DEVICE void SetGrad(float g) { grad_ = g; }
XGBOOST_DEVICE void SetHess(float h) { hess_ = h; }
XGBOOST_DEVICE bst_gpair_internal(T grad, T hess) public:
: grad(grad), hess(hess) {} typedef T value_t;
XGBOOST_DEVICE bst_gpair_internal() : grad_(0), hess_(0) {}
XGBOOST_DEVICE bst_gpair_internal(float grad, float hess) {
SetGrad(grad);
SetHess(hess);
}
// Copy constructor if of same value type
XGBOOST_DEVICE bst_gpair_internal(const bst_gpair_internal<T> &g)
: grad_(g.grad_), hess_(g.hess_) {}
// Copy constructor if different value type - use getters and setters to
// perform conversion
template <typename T2> template <typename T2>
XGBOOST_DEVICE bst_gpair_internal(bst_gpair_internal<T2>&g) XGBOOST_DEVICE bst_gpair_internal(const bst_gpair_internal<T2> &g) {
: grad(g.grad), hess(g.hess) {} SetGrad(g.GetGrad());
SetHess(g.GetHess());
}
XGBOOST_DEVICE bst_gpair_internal<T> &operator+=(const bst_gpair_internal<T> &rhs) { XGBOOST_DEVICE float GetGrad() const { return grad_; }
grad += rhs.grad; XGBOOST_DEVICE float GetHess() const { return hess_; }
hess += rhs.hess;
XGBOOST_DEVICE bst_gpair_internal<T> &operator+=(
const bst_gpair_internal<T> &rhs) {
grad_ += rhs.grad_;
hess_ += rhs.hess_;
return *this; return *this;
} }
XGBOOST_DEVICE bst_gpair_internal<T> operator+(const bst_gpair_internal<T> &rhs) const { XGBOOST_DEVICE bst_gpair_internal<T> operator+(
const bst_gpair_internal<T> &rhs) const {
bst_gpair_internal<T> g; bst_gpair_internal<T> g;
g.grad = grad + rhs.grad; g.grad_ = grad_ + rhs.grad_;
g.hess = hess + rhs.hess; g.hess_ = hess_ + rhs.hess_;
return g; return g;
} }
XGBOOST_DEVICE bst_gpair_internal<T> &operator-=(const bst_gpair_internal<T> &rhs) { XGBOOST_DEVICE bst_gpair_internal<T> &operator-=(
grad -= rhs.grad; const bst_gpair_internal<T> &rhs) {
hess -= rhs.hess; grad_ -= rhs.grad_;
hess_ -= rhs.hess_;
return *this; return *this;
} }
XGBOOST_DEVICE bst_gpair_internal<T> operator-(const bst_gpair_internal<T> &rhs) const { XGBOOST_DEVICE bst_gpair_internal<T> operator-(
const bst_gpair_internal<T> &rhs) const {
bst_gpair_internal<T> g; bst_gpair_internal<T> g;
g.grad = grad - rhs.grad; g.grad_ = grad_ - rhs.grad_;
g.hess = hess - rhs.hess; g.hess_ = hess_ - rhs.hess_;
return g; return g;
} }
XGBOOST_DEVICE bst_gpair_internal(int value) { XGBOOST_DEVICE bst_gpair_internal(int value) {
*this = bst_gpair_internal<T>(static_cast<float>(value), static_cast<float>(value)); *this = bst_gpair_internal<T>(static_cast<float>(value),
static_cast<float>(value));
} }
friend std::ostream &operator<<(std::ostream &os, friend std::ostream &operator<<(std::ostream &os,
const bst_gpair_internal<T> &g) { const bst_gpair_internal<T> &g) {
os << g.grad << "/" << g.hess; os << g.grad_ << "/" << g.hess_;
return os; return os;
} }
}; };
template<>
inline XGBOOST_DEVICE float bst_gpair_internal<int64_t>::GetGrad() const {
return grad_ * 1e-5;
}
template<>
inline XGBOOST_DEVICE float bst_gpair_internal<int64_t>::GetHess() const {
return hess_ * 1e-5;
}
template<>
inline XGBOOST_DEVICE void bst_gpair_internal<int64_t>::SetGrad(float g) {
grad_ = g * 1e5;
}
template<>
inline XGBOOST_DEVICE void bst_gpair_internal<int64_t>::SetHess(float h) {
hess_ = h * 1e5;
}
} // namespace detail
/*! \brief gradient statistics pair usually needed in gradient boosting */ /*! \brief gradient statistics pair usually needed in gradient boosting */
typedef bst_gpair_internal<float> bst_gpair; typedef detail::bst_gpair_internal<float> bst_gpair;
/*! \brief High precision gradient statistics pair */ /*! \brief High precision gradient statistics pair */
typedef bst_gpair_internal<double> bst_gpair_precise; typedef detail::bst_gpair_internal<double> bst_gpair_precise;
/*! \brief High precision gradient statistics pair with integer backed
* storage. Operators are associative where floating point versions are not
* associative. */
typedef detail::bst_gpair_internal<int64_t> bst_gpair_integer;
/*! \brief small eps gap for minimum split decision. */ /*! \brief small eps gap for minimum split decision. */
const bst_float rt_eps = 1e-6f; const bst_float rt_eps = 1e-6f;

View File

@ -33,8 +33,8 @@ struct GHistEntry {
/*! \brief add a bst_gpair to the sum */ /*! \brief add a bst_gpair to the sum */
inline void Add(const bst_gpair& e) { inline void Add(const bst_gpair& e) {
sum_grad += e.grad; sum_grad += e.GetGrad();
sum_hess += e.hess; sum_hess += e.GetHess();
} }
/*! \brief add a GHistEntry to the sum */ /*! \brief add a GHistEntry to the sum */

View File

@ -120,8 +120,9 @@ class GBLinear : public GradientBooster {
#pragma omp parallel for schedule(static) reduction(+: sum_grad, sum_hess) #pragma omp parallel for schedule(static) reduction(+: sum_grad, sum_hess)
for (bst_omp_uint i = 0; i < ndata; ++i) { for (bst_omp_uint i = 0; i < ndata; ++i) {
bst_gpair &p = gpair[rowset[i] * ngroup + gid]; bst_gpair &p = gpair[rowset[i] * ngroup + gid];
if (p.hess >= 0.0f) { if (p.GetHess() >= 0.0f) {
sum_grad += p.grad; sum_hess += p.hess; sum_grad += p.GetGrad();
sum_hess += p.GetHess();
} }
} }
// remove bias effect // remove bias effect
@ -132,8 +133,8 @@ class GBLinear : public GradientBooster {
#pragma omp parallel for schedule(static) #pragma omp parallel for schedule(static)
for (bst_omp_uint i = 0; i < ndata; ++i) { for (bst_omp_uint i = 0; i < ndata; ++i) {
bst_gpair &p = gpair[rowset[i] * ngroup + gid]; bst_gpair &p = gpair[rowset[i] * ngroup + gid];
if (p.hess >= 0.0f) { if (p.GetHess() >= 0.0f) {
p.grad += p.hess * dw; p += bst_gpair(p.GetHess() * dw, 0);
} }
} }
} }
@ -151,9 +152,9 @@ class GBLinear : public GradientBooster {
for (bst_uint j = 0; j < col.length; ++j) { for (bst_uint j = 0; j < col.length; ++j) {
const bst_float v = col[j].fvalue; const bst_float v = col[j].fvalue;
bst_gpair &p = gpair[col[j].index * ngroup + gid]; bst_gpair &p = gpair[col[j].index * ngroup + gid];
if (p.hess < 0.0f) continue; if (p.GetHess() < 0.0f) continue;
sum_grad += p.grad * v; sum_grad += p.GetGrad() * v;
sum_hess += p.hess * v * v; sum_hess += p.GetHess() * v * v;
} }
bst_float &w = model[fid][gid]; bst_float &w = model[fid][gid];
bst_float dw = static_cast<bst_float>(param.learning_rate * bst_float dw = static_cast<bst_float>(param.learning_rate *
@ -162,8 +163,8 @@ class GBLinear : public GradientBooster {
// update grad value // update grad value
for (bst_uint j = 0; j < col.length; ++j) { for (bst_uint j = 0; j < col.length; ++j) {
bst_gpair &p = gpair[col[j].index * ngroup + gid]; bst_gpair &p = gpair[col[j].index * ngroup + gid];
if (p.hess < 0.0f) continue; if (p.GetHess() < 0.0f) continue;
p.grad += p.hess * col[j].fvalue * dw; p += bst_gpair(p.GetHess() * col[j].fvalue * dw, 0);
} }
} }
} }

View File

@ -109,10 +109,8 @@ class LambdaRankObj : public ObjFunction {
bst_float g = p - 1.0f; bst_float g = p - 1.0f;
bst_float h = std::max(p * (1.0f - p), eps); bst_float h = std::max(p * (1.0f - p), eps);
// accumulate gradient and hessian in both pid, and nid // accumulate gradient and hessian in both pid, and nid
gpair[pos.rindex].grad += g * w; gpair[pos.rindex] += bst_gpair(g * w, 2.0f*w*h);
gpair[pos.rindex].hess += 2.0f * w * h; gpair[neg.rindex] += bst_gpair(-g * w, 2.0f*w*h);
gpair[neg.rindex].grad -= g * w;
gpair[neg.rindex].hess += 2.0f * w * h;
} }
} }
} }

View File

@ -313,7 +313,7 @@ struct XGBOOST_ALIGNAS(16) GradStats {
* \brief accumulate statistics * \brief accumulate statistics
* \param p the gradient pair * \param p the gradient pair
*/ */
inline void Add(bst_gpair p) { this->Add(p.grad, p.hess); } inline void Add(bst_gpair p) { this->Add(p.GetGrad(), p.GetHess()); }
/*! /*!
* \brief accumulate statistics, more complicated version * \brief accumulate statistics, more complicated version
* \param gpair the vector storing the gradient statistics * \param gpair the vector storing the gradient statistics
@ -323,7 +323,7 @@ struct XGBOOST_ALIGNAS(16) GradStats {
inline void Add(const std::vector<bst_gpair>& gpair, const MetaInfo& info, inline void Add(const std::vector<bst_gpair>& gpair, const MetaInfo& info,
bst_uint ridx) { bst_uint ridx) {
const bst_gpair& b = gpair[ridx]; const bst_gpair& b = gpair[ridx];
this->Add(b.grad, b.hess); this->Add(b.GetGrad(), b.GetHess());
} }
/*! \brief calculate leaf weight */ /*! \brief calculate leaf weight */
inline double CalcWeight(const TrainParam& param) const { inline double CalcWeight(const TrainParam& param) const {

View File

@ -140,14 +140,14 @@ class BaseMaker: public TreeUpdater {
} }
// mark delete for the deleted datas // mark delete for the deleted datas
for (size_t i = 0; i < position.size(); ++i) { for (size_t i = 0; i < position.size(); ++i) {
if (gpair[i].hess < 0.0f) position[i] = ~position[i]; if (gpair[i].GetHess() < 0.0f) position[i] = ~position[i];
} }
// mark subsample // mark subsample
if (param.subsample < 1.0f) { if (param.subsample < 1.0f) {
std::bernoulli_distribution coin_flip(param.subsample); std::bernoulli_distribution coin_flip(param.subsample);
auto& rnd = common::GlobalRandom(); auto& rnd = common::GlobalRandom();
for (size_t i = 0; i < position.size(); ++i) { for (size_t i = 0; i < position.size(); ++i) {
if (gpair[i].hess < 0.0f) continue; if (gpair[i].GetHess() < 0.0f) continue;
if (!coin_flip(rnd)) position[i] = ~position[i]; if (!coin_flip(rnd)) position[i] = ~position[i];
} }
} }

View File

@ -136,7 +136,7 @@ class ColMaker: public TreeUpdater {
// mark delete for the deleted datas // mark delete for the deleted datas
for (size_t i = 0; i < rowset.size(); ++i) { for (size_t i = 0; i < rowset.size(); ++i) {
const bst_uint ridx = rowset[i]; const bst_uint ridx = rowset[i];
if (gpair[ridx].hess < 0.0f) position[ridx] = ~position[ridx]; if (gpair[ridx].GetHess() < 0.0f) position[ridx] = ~position[ridx];
} }
// mark subsample // mark subsample
if (param.subsample < 1.0f) { if (param.subsample < 1.0f) {
@ -144,7 +144,7 @@ class ColMaker: public TreeUpdater {
auto& rnd = common::GlobalRandom(); auto& rnd = common::GlobalRandom();
for (size_t i = 0; i < rowset.size(); ++i) { for (size_t i = 0; i < rowset.size(); ++i) {
const bst_uint ridx = rowset[i]; const bst_uint ridx = rowset[i];
if (gpair[ridx].hess < 0.0f) continue; if (gpair[ridx].GetHess() < 0.0f) continue;
if (!coin_flip(rnd)) position[ridx] = ~position[ridx]; if (!coin_flip(rnd)) position[ridx] = ~position[ridx];
} }
} }

View File

@ -372,13 +372,13 @@ class FastHistMaker: public TreeUpdater {
std::bernoulli_distribution coin_flip(param.subsample); std::bernoulli_distribution coin_flip(param.subsample);
auto& rnd = common::GlobalRandom(); auto& rnd = common::GlobalRandom();
for (size_t i = 0; i < info.num_row; ++i) { for (size_t i = 0; i < info.num_row; ++i) {
if (gpair[i].hess >= 0.0f && coin_flip(rnd)) { if (gpair[i].GetHess() >= 0.0f && coin_flip(rnd)) {
row_indices.push_back(i); row_indices.push_back(i);
} }
} }
} else { } else {
for (size_t i = 0; i < info.num_row; ++i) { for (size_t i = 0; i < info.num_row; ++i) {
if (gpair[i].hess >= 0.0f) { if (gpair[i].GetHess() >= 0.0f) {
row_indices.push_back(i); row_indices.push_back(i);
} }
} }

View File

@ -82,8 +82,8 @@ struct DeviceDenseNode {
fvalue(0.f), fvalue(0.f),
fidx(UNUSED_NODE), fidx(UNUSED_NODE),
idx(nidx) { idx(nidx) {
this->root_gain = CalcGain(param, sum_gradients.grad, sum_gradients.hess); this->root_gain = CalcGain(param, sum_gradients.GetGrad(), sum_gradients.GetHess());
this->weight = CalcWeight(param, sum_gradients.grad, sum_gradients.hess); this->weight = CalcWeight(param, sum_gradients.GetGrad(), sum_gradients.GetHess());
} }
HOST_DEV_INLINE void SetSplit(float fvalue, int fidx, DefaultDirection dir) { HOST_DEV_INLINE void SetSplit(float fvalue, int fidx, DefaultDirection dir) {
@ -113,8 +113,8 @@ __device__ inline float device_calc_loss_chg(
gpair_t right = parent_sum - left; gpair_t right = parent_sum - left;
float left_gain = CalcGain(param, left.grad, left.hess); float left_gain = CalcGain(param, left.GetGrad(), left.GetHess());
float right_gain = CalcGain(param, right.grad, right.hess); float right_gain = CalcGain(param, right.GetGrad(), right.GetHess());
return left_gain + right_gain - parent_gain; return left_gain + right_gain - parent_gain;
} }
@ -181,13 +181,13 @@ inline void dense2sparse_tree(RegTree* p_tree,
tree[nid].set_split(n.fidx, n.fvalue, n.dir == LeftDir); tree[nid].set_split(n.fidx, n.fvalue, n.dir == LeftDir);
tree.stat(nid).loss_chg = n.root_gain; tree.stat(nid).loss_chg = n.root_gain;
tree.stat(nid).base_weight = n.weight; tree.stat(nid).base_weight = n.weight;
tree.stat(nid).sum_hess = n.sum_gradients.hess; tree.stat(nid).sum_hess = n.sum_gradients.GetHess();
tree[tree[nid].cleft()].set_leaf(0); tree[tree[nid].cleft()].set_leaf(0);
tree[tree[nid].cright()].set_leaf(0); tree[tree[nid].cright()].set_leaf(0);
nid++; nid++;
} else if (n.IsLeaf()) { } else if (n.IsLeaf()) {
tree[nid].set_leaf(n.weight * param.learning_rate); tree[nid].set_leaf(n.weight * param.learning_rate);
tree.stat(nid).sum_hess = n.sum_gradients.hess; tree.stat(nid).sum_hess = n.sum_gradients.GetHess();
nid++; nid++;
} }
} }

View File

@ -5,17 +5,20 @@
#include <memory> #include <memory>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "param.h"
#include "../common/compressed_iterator.h" #include "../common/compressed_iterator.h"
#include "../common/hist_util.h"
#include "updater_gpu_common.cuh"
#include "../common/device_helpers.cuh" #include "../common/device_helpers.cuh"
#include "../common/hist_util.h"
#include "param.h"
#include "updater_gpu_common.cuh"
namespace xgboost { namespace xgboost {
namespace tree { namespace tree {
DMLC_REGISTRY_FILE_TAG(updater_gpu_hist); DMLC_REGISTRY_FILE_TAG(updater_gpu_hist);
typedef bst_gpair_integer gpair_sum_t;
static const ncclDataType_t nccl_sum_t = ncclInt64;
// Helper for explicit template specialisation // Helper for explicit template specialisation
template <int N> template <int N>
struct Int {}; struct Int {};
@ -50,27 +53,29 @@ struct DeviceGMat {
}; };
struct HistHelper { struct HistHelper {
bst_gpair* d_hist; gpair_sum_t* d_hist;
int n_bins; int n_bins;
__host__ __device__ HistHelper(bst_gpair* ptr, int n_bins) __host__ __device__ HistHelper(gpair_sum_t* ptr, int n_bins)
: d_hist(ptr), n_bins(n_bins) {} : d_hist(ptr), n_bins(n_bins) {}
__device__ void Add(bst_gpair gpair, int gidx, int nidx) const { __device__ void Add(bst_gpair gpair, int gidx, int nidx) const {
int hist_idx = nidx * n_bins + gidx; int hist_idx = nidx * n_bins + gidx;
atomicAdd(&(d_hist[hist_idx].grad), gpair.grad); // OPTMARK: This and below
// line lead to about 3X auto dst_ptr = reinterpret_cast<unsigned long long int*>(&d_hist[hist_idx]); // NOLINT
// slowdown due to memory gpair_sum_t tmp(gpair.GetGrad(), gpair.GetHess());
// dependency and access auto src_ptr = reinterpret_cast<gpair_sum_t::value_t*>(&tmp);
// pattern issues.
atomicAdd(&(d_hist[hist_idx].hess), gpair.hess); atomicAdd(dst_ptr, static_cast<unsigned long long int>(*src_ptr)); // NOLINT
atomicAdd(dst_ptr + 1, static_cast<unsigned long long int>(*(src_ptr + 1))); // NOLINT
} }
__device__ bst_gpair Get(int gidx, int nidx) const { __device__ gpair_sum_t Get(int gidx, int nidx) const {
return d_hist[nidx * n_bins + gidx]; return d_hist[nidx * n_bins + gidx];
} }
}; };
struct DeviceHist { struct DeviceHist {
int n_bins; int n_bins;
dh::dvec<bst_gpair> data; dh::dvec<gpair_sum_t> data;
void Init(int n_bins_in) { void Init(int n_bins_in) {
this->n_bins = n_bins_in; this->n_bins = n_bins_in;
@ -79,12 +84,12 @@ struct DeviceHist {
void Reset(int device_idx) { void Reset(int device_idx) {
cudaSetDevice(device_idx); cudaSetDevice(device_idx);
data.fill(bst_gpair()); data.fill(gpair_sum_t());
} }
HistHelper GetBuilder() { return HistHelper(data.data(), n_bins); } HistHelper GetBuilder() { return HistHelper(data.data(), n_bins); }
bst_gpair* GetLevelPtr(int depth) { gpair_sum_t* GetLevelPtr(int depth) {
return data.data() + n_nodes(depth - 1) * n_bins; return data.data() + n_nodes(depth - 1) * n_bins;
} }
@ -96,18 +101,19 @@ struct SplitCandidate {
bool missing_left; bool missing_left;
float fvalue; float fvalue;
int findex; int findex;
bst_gpair left_sum; gpair_sum_t left_sum;
bst_gpair right_sum; gpair_sum_t right_sum;
__host__ __device__ SplitCandidate() __host__ __device__ SplitCandidate()
: loss_chg(-FLT_MAX), missing_left(true), fvalue(0), findex(-1) {} : loss_chg(-FLT_MAX), missing_left(true), fvalue(0), findex(-1) {}
__device__ void Update(float loss_chg_in, bool missing_left_in, __device__ void Update(float loss_chg_in, bool missing_left_in,
float fvalue_in, int findex_in, bst_gpair left_sum_in, float fvalue_in, int findex_in,
bst_gpair right_sum_in, gpair_sum_t left_sum_in, gpair_sum_t right_sum_in,
const GPUTrainingParam& param) { const GPUTrainingParam& param) {
if (loss_chg_in > loss_chg && left_sum_in.hess >= param.min_child_weight && if (loss_chg_in > loss_chg &&
right_sum_in.hess >= param.min_child_weight) { left_sum_in.GetHess() >= param.min_child_weight &&
right_sum_in.GetHess() >= param.min_child_weight) {
loss_chg = loss_chg_in; loss_chg = loss_chg_in;
missing_left = missing_left_in; missing_left = missing_left_in;
fvalue = fvalue_in; fvalue = fvalue_in;
@ -121,11 +127,11 @@ struct SplitCandidate {
struct GpairCallbackOp { struct GpairCallbackOp {
// Running prefix // Running prefix
bst_gpair running_total; gpair_sum_t running_total;
// Constructor // Constructor
__device__ GpairCallbackOp() : running_total(bst_gpair()) {} __device__ GpairCallbackOp() : running_total(gpair_sum_t()) {}
__device__ bst_gpair operator()(bst_gpair block_aggregate) { __device__ bst_gpair operator()(bst_gpair block_aggregate) {
bst_gpair old_prefix = running_total; gpair_sum_t old_prefix = running_total;
running_total += block_aggregate; running_total += block_aggregate;
return old_prefix; return old_prefix;
} }
@ -133,17 +139,16 @@ struct GpairCallbackOp {
template <int BLOCK_THREADS> template <int BLOCK_THREADS>
__global__ void find_split_kernel( __global__ void find_split_kernel(
const bst_gpair* d_level_hist, int* d_feature_segments, int depth, const gpair_sum_t* d_level_hist, int* d_feature_segments, int depth,
int n_features, int n_bins, DeviceDenseNode* d_nodes, int n_features, int n_bins, DeviceDenseNode* d_nodes,
int nodes_offset_device, float* d_fidx_min_map, float* d_gidx_fvalue_map, int nodes_offset_device, float* d_fidx_min_map, float* d_gidx_fvalue_map,
GPUTrainingParam gpu_param, bool* d_left_child_smallest_temp, GPUTrainingParam gpu_param, bool* d_left_child_smallest_temp,
bool colsample, int* d_feature_flags) { bool colsample, int* d_feature_flags) {
typedef cub::KeyValuePair<int, float> ArgMaxT; typedef cub::KeyValuePair<int, float> ArgMaxT;
typedef cub::BlockScan<bst_gpair, BLOCK_THREADS, typedef cub::BlockScan<gpair_sum_t, BLOCK_THREADS, cub::BLOCK_SCAN_WARP_SCANS>
cub::BLOCK_SCAN_WARP_SCANS>
BlockScanT; BlockScanT;
typedef cub::BlockReduce<ArgMaxT, BLOCK_THREADS> MaxReduceT; typedef cub::BlockReduce<ArgMaxT, BLOCK_THREADS> MaxReduceT;
typedef cub::BlockReduce<bst_gpair, BLOCK_THREADS> SumReduceT; typedef cub::BlockReduce<gpair_sum_t, BLOCK_THREADS> SumReduceT;
union TempStorage { union TempStorage {
typename BlockScanT::TempStorage scan; typename BlockScanT::TempStorage scan;
@ -153,8 +158,8 @@ __global__ void find_split_kernel(
__shared__ cub::Uninitialized<SplitCandidate> uninitialized_split; __shared__ cub::Uninitialized<SplitCandidate> uninitialized_split;
SplitCandidate& split = uninitialized_split.Alias(); SplitCandidate& split = uninitialized_split.Alias();
__shared__ cub::Uninitialized<bst_gpair> uninitialized_sum; __shared__ cub::Uninitialized<gpair_sum_t> uninitialized_sum;
bst_gpair& shared_sum = uninitialized_sum.Alias(); gpair_sum_t& shared_sum = uninitialized_sum.Alias();
__shared__ ArgMaxT block_max; __shared__ ArgMaxT block_max;
__shared__ TempStorage temp_storage; __shared__ TempStorage temp_storage;
@ -175,14 +180,13 @@ __global__ void find_split_kernel(
int begin = d_feature_segments[level_node_idx * n_features + fidx]; int begin = d_feature_segments[level_node_idx * n_features + fidx];
int end = d_feature_segments[level_node_idx * n_features + fidx + 1]; int end = d_feature_segments[level_node_idx * n_features + fidx + 1];
bst_gpair feature_sum = bst_gpair(); gpair_sum_t feature_sum = gpair_sum_t();
for (int reduce_begin = begin; reduce_begin < end; for (int reduce_begin = begin; reduce_begin < end;
reduce_begin += BLOCK_THREADS) { reduce_begin += BLOCK_THREADS) {
bool thread_active = reduce_begin + threadIdx.x < end; bool thread_active = reduce_begin + threadIdx.x < end;
// Scan histogram // Scan histogram
bst_gpair bin = thread_active gpair_sum_t bin = thread_active ? d_level_hist[reduce_begin + threadIdx.x]
? d_level_hist[reduce_begin + threadIdx.x] : gpair_sum_t();
: bst_gpair();
feature_sum += feature_sum +=
SumReduceT(temp_storage.sum_reduce).Reduce(bin, cub::Sum()); SumReduceT(temp_storage.sum_reduce).Reduce(bin, cub::Sum());
@ -197,18 +201,17 @@ __global__ void find_split_kernel(
for (int scan_begin = begin; scan_begin < end; for (int scan_begin = begin; scan_begin < end;
scan_begin += BLOCK_THREADS) { scan_begin += BLOCK_THREADS) {
bool thread_active = scan_begin + threadIdx.x < end; bool thread_active = scan_begin + threadIdx.x < end;
bst_gpair bin = thread_active gpair_sum_t bin = thread_active ? d_level_hist[scan_begin + threadIdx.x]
? d_level_hist[scan_begin + threadIdx.x] : gpair_sum_t();
: bst_gpair();
BlockScanT(temp_storage.scan) BlockScanT(temp_storage.scan)
.ExclusiveScan(bin, bin, cub::Sum(), prefix_op); .ExclusiveScan(bin, bin, cub::Sum(), prefix_op);
// Calculate gain // Calculate gain
bst_gpair parent_sum = d_nodes[node_idx].sum_gradients; gpair_sum_t parent_sum = gpair_sum_t(d_nodes[node_idx].sum_gradients);
float parent_gain = d_nodes[node_idx].root_gain; float parent_gain = d_nodes[node_idx].root_gain;
bst_gpair missing = parent_sum - shared_sum; gpair_sum_t missing = parent_sum - shared_sum;
bool missing_left; bool missing_left;
float gain = thread_active float gain = thread_active
@ -239,8 +242,8 @@ __global__ void find_split_kernel(
fvalue = d_gidx_fvalue_map[gidx - 1]; fvalue = d_gidx_fvalue_map[gidx - 1];
} }
bst_gpair left = missing_left ? bin + missing : bin; gpair_sum_t left = missing_left ? bin + missing : bin;
bst_gpair right = parent_sum - left; gpair_sum_t right = parent_sum - left;
split.Update(gain, missing_left, fvalue, fidx, left, right, gpu_param); split.Update(gain, missing_left, fvalue, fidx, left, right, gpu_param);
} }
@ -263,7 +266,7 @@ __global__ void find_split_kernel(
DeviceDenseNode(split.right_sum, right_child_nidx(node_idx), gpu_param); DeviceDenseNode(split.right_sum, right_child_nidx(node_idx), gpu_param);
// Record smallest node // Record smallest node
if (split.left_sum.hess <= split.right_sum.hess) { if (split.left_sum.GetHess() <= split.right_sum.GetHess()) {
left_child_smallest = true; left_child_smallest = true;
} else { } else {
left_child_smallest = false; left_child_smallest = false;
@ -595,6 +598,7 @@ class GPUHistMaker : public TreeUpdater {
initialised = true; initialised = true;
} }
void BuildHist(int depth) { void BuildHist(int depth) {
for (int d_idx = 0; d_idx < n_devices; d_idx++) { for (int d_idx = 0; d_idx < n_devices; d_idx++) {
int device_idx = dList[d_idx]; int device_idx = dList[d_idx];
@ -650,9 +654,9 @@ class GPUHistMaker : public TreeUpdater {
dh::safe_nccl(ncclAllReduce( dh::safe_nccl(ncclAllReduce(
reinterpret_cast<const void*>(hist_vec[d_idx].GetLevelPtr(depth)), reinterpret_cast<const void*>(hist_vec[d_idx].GetLevelPtr(depth)),
reinterpret_cast<void*>(hist_vec[d_idx].GetLevelPtr(depth)), reinterpret_cast<void*>(hist_vec[d_idx].GetLevelPtr(depth)),
hist_vec[d_idx].LevelSize(depth) * sizeof(bst_gpair) / hist_vec[d_idx].LevelSize(depth) * sizeof(gpair_sum_t) /
sizeof(float), sizeof(gpair_sum_t::value_t),
ncclFloat, ncclSum, comms[d_idx], *(streams[d_idx]))); nccl_sum_t, ncclSum, comms[d_idx], *(streams[d_idx])));
} }
for (int d_idx = 0; d_idx < n_devices; d_idx++) { for (int d_idx = 0; d_idx < n_devices; d_idx++) {
@ -683,11 +687,12 @@ class GPUHistMaker : public TreeUpdater {
} }
int gidx = idx % hist_builder.n_bins; int gidx = idx % hist_builder.n_bins;
bst_gpair parent = hist_builder.Get(gidx, parent_nidx(nidx)); gpair_sum_t parent = hist_builder.Get(gidx, parent_nidx(nidx));
int other_nidx = left_smallest ? nidx - 1 : nidx + 1; int other_nidx = left_smallest ? nidx - 1 : nidx + 1;
bst_gpair other = hist_builder.Get(gidx, other_nidx); gpair_sum_t other = hist_builder.Get(gidx, other_nidx);
gpair_sum_t sub = parent - other;
hist_builder.Add( hist_builder.Add(
parent - other, gidx, bst_gpair(sub.GetGrad(), sub.GetHess()), gidx,
nidx); // OPTMARK: This is slow, could use shared nidx); // OPTMARK: This is slow, could use shared
// memory or cache results intead of writing to // memory or cache results intead of writing to
// global memory every time in atomic way. // global memory every time in atomic way.
@ -737,11 +742,11 @@ class GPUHistMaker : public TreeUpdater {
int nodes_offset_device = 0; int nodes_offset_device = 0;
find_split_kernel<BLOCK_THREADS><<<GRID_SIZE, BLOCK_THREADS>>>( find_split_kernel<BLOCK_THREADS><<<GRID_SIZE, BLOCK_THREADS>>>(
(const bst_gpair*)(hist_vec[d_idx].GetLevelPtr(depth)), hist_vec[d_idx].GetLevelPtr(depth), feature_segments[d_idx].data(),
feature_segments[d_idx].data(), depth, (info->num_col), depth, (info->num_col), (hmat_.row_ptr.back()), nodes[d_idx].data(),
(hmat_.row_ptr.back()), nodes[d_idx].data(), nodes_offset_device, nodes_offset_device, fidx_min_map[d_idx].data(),
fidx_min_map[d_idx].data(), gidx_fvalue_map[d_idx].data(), gidx_fvalue_map[d_idx].data(), GPUTrainingParam(param),
GPUTrainingParam(param), left_child_smallest[d_idx].data(), colsample, left_child_smallest[d_idx].data(), colsample,
feature_flags[d_idx].data()); feature_flags[d_idx].data());
} }

View File

@ -568,7 +568,7 @@ class CQHistMaker: public HistMaker<TStats> {
const bst_uint ridx = c[j].index; const bst_uint ridx = c[j].index;
const int nid = this->position[ridx]; const int nid = this->position[ridx];
if (nid >= 0) { if (nid >= 0) {
sbuilder[nid].sum_total += gpair[ridx].hess; sbuilder[nid].sum_total += gpair[ridx].GetHess();
} }
} }
// if only one value, no need to do second pass // if only one value, no need to do second pass
@ -595,7 +595,7 @@ class CQHistMaker: public HistMaker<TStats> {
for (bst_uint i = 0; i < kBuffer; ++i) { for (bst_uint i = 0; i < kBuffer; ++i) {
bst_uint ridx = c[j + i].index; bst_uint ridx = c[j + i].index;
buf_position[i] = this->position[ridx]; buf_position[i] = this->position[ridx];
buf_hess[i] = gpair[ridx].hess; buf_hess[i] = gpair[ridx].GetHess();
} }
for (bst_uint i = 0; i < kBuffer; ++i) { for (bst_uint i = 0; i < kBuffer; ++i) {
const int nid = buf_position[i]; const int nid = buf_position[i];
@ -608,7 +608,7 @@ class CQHistMaker: public HistMaker<TStats> {
const bst_uint ridx = c[j].index; const bst_uint ridx = c[j].index;
const int nid = this->position[ridx]; const int nid = this->position[ridx];
if (nid >= 0) { if (nid >= 0) {
sbuilder[nid].Push(c[j].fvalue, gpair[ridx].hess, max_size); sbuilder[nid].Push(c[j].fvalue, gpair[ridx].GetHess(), max_size);
} }
} }
} else { } else {
@ -616,7 +616,7 @@ class CQHistMaker: public HistMaker<TStats> {
const bst_uint ridx = c[j].index; const bst_uint ridx = c[j].index;
const int nid = this->position[ridx]; const int nid = this->position[ridx];
if (nid >= 0) { if (nid >= 0) {
sbuilder[nid].Push(c[j].fvalue, gpair[ridx].hess, max_size); sbuilder[nid].Push(c[j].fvalue, gpair[ridx].GetHess(), max_size);
} }
} }
} }
@ -818,7 +818,7 @@ class QuantileHistMaker: public HistMaker<TStats> {
for (size_t i = col_ptr[k]; i < col_ptr[k+1]; ++i) { for (size_t i = col_ptr[k]; i < col_ptr[k+1]; ++i) {
const SparseBatch::Entry &e = col_data[i]; const SparseBatch::Entry &e = col_data[i];
const int wid = this->node2workindex[e.index]; const int wid = this->node2workindex[e.index];
sketchs[wid * tree.param.num_feature + k].Push(e.fvalue, gpair[e.index].hess); sketchs[wid * tree.param.num_feature + k].Push(e.fvalue, gpair[e.index].GetHess());
} }
} }
} }

View File

@ -98,12 +98,12 @@ class SketchMaker: public BaseMaker {
const MetaInfo &info, const MetaInfo &info,
bst_uint ridx) { bst_uint ridx) {
const bst_gpair &b = gpair[ridx]; const bst_gpair &b = gpair[ridx];
if (b.grad >= 0.0f) { if (b.GetGrad() >= 0.0f) {
pos_grad += b.grad; pos_grad += b.GetGrad();
} else { } else {
neg_grad -= b.grad; neg_grad -= b.GetGrad();
} }
sum_hess += b.hess; sum_hess += b.GetHess();
} }
/*! \brief calculate gain of the solution */ /*! \brief calculate gain of the solution */
inline double CalcGain(const TrainParam &param) const { inline double CalcGain(const TrainParam &param) const {
@ -199,12 +199,12 @@ class SketchMaker: public BaseMaker {
const int nid = this->position[ridx]; const int nid = this->position[ridx];
if (nid >= 0) { if (nid >= 0) {
const bst_gpair &e = gpair[ridx]; const bst_gpair &e = gpair[ridx];
if (e.grad >= 0.0f) { if (e.GetGrad() >= 0.0f) {
sbuilder[3 * nid + 0].sum_total += e.grad; sbuilder[3 * nid + 0].sum_total += e.GetGrad();
} else { } else {
sbuilder[3 * nid + 1].sum_total -= e.grad; sbuilder[3 * nid + 1].sum_total -= e.GetGrad();
} }
sbuilder[3 * nid + 2].sum_total += e.hess; sbuilder[3 * nid + 2].sum_total += e.GetHess();
} }
} }
} else { } else {
@ -241,12 +241,12 @@ class SketchMaker: public BaseMaker {
const int nid = this->position[ridx]; const int nid = this->position[ridx];
if (nid >= 0) { if (nid >= 0) {
const bst_gpair &e = gpair[ridx]; const bst_gpair &e = gpair[ridx];
if (e.grad >= 0.0f) { if (e.GetGrad() >= 0.0f) {
sbuilder[3 * nid + 0].Push(c[j].fvalue, e.grad, max_size); sbuilder[3 * nid + 0].Push(c[j].fvalue, e.GetGrad(), max_size);
} else { } else {
sbuilder[3 * nid + 1].Push(c[j].fvalue, -e.grad, max_size); sbuilder[3 * nid + 1].Push(c[j].fvalue, -e.GetGrad(), max_size);
} }
sbuilder[3 * nid + 2].Push(c[j].fvalue, e.hess, max_size); sbuilder[3 * nid + 2].Push(c[j].fvalue, e.GetHess(), max_size);
} }
} }
for (size_t i = 0; i < this->qexpand.size(); ++i) { for (size_t i = 0; i < this->qexpand.size(); ++i) {

View File

@ -43,10 +43,10 @@ void CheckObjFunction(xgboost::ObjFunction * obj,
ASSERT_EQ(gpair.size(), preds.size()); ASSERT_EQ(gpair.size(), preds.size());
for (int i = 0; i < static_cast<int>(gpair.size()); ++i) { for (int i = 0; i < static_cast<int>(gpair.size()); ++i) {
EXPECT_NEAR(gpair[i].grad, out_grad[i], 0.01) EXPECT_NEAR(gpair[i].GetGrad(), out_grad[i], 0.01)
<< "Unexpected grad for pred=" << preds[i] << " label=" << labels[i] << "Unexpected grad for pred=" << preds[i] << " label=" << labels[i]
<< " weight=" << weights[i]; << " weight=" << weights[i];
EXPECT_NEAR(gpair[i].hess, out_hess[i], 0.01) EXPECT_NEAR(gpair[i].GetHess(), out_hess[i], 0.01)
<< "Unexpected hess for pred=" << preds[i] << " label=" << labels[i] << "Unexpected hess for pred=" << preds[i] << " label=" << labels[i]
<< " weight=" << weights[i]; << " weight=" << weights[i];
} }

View File

@ -16,7 +16,7 @@ TEST(gpu_predictor, Test) {
std::unique_ptr<Predictor>(Predictor::Create("cpu_predictor")); std::unique_ptr<Predictor>(Predictor::Create("cpu_predictor"));
std::vector<std::unique_ptr<RegTree>> trees; std::vector<std::unique_ptr<RegTree>> trees;
trees.push_back(std::make_unique<RegTree>()); trees.push_back(std::unique_ptr<RegTree>());
trees.back()->InitModel(); trees.back()->InitModel();
(*trees.back())[0].set_leaf(1.5f); (*trees.back())[0].set_leaf(1.5f);
gbm::GBTreeModel model(0.5); gbm::GBTreeModel model(0.5);

View File

@ -14,7 +14,6 @@ UNITTEST_DEPS=lib/libxgboost.a $(DMLC_CORE)/libdmlc.a $(RABIT)/lib/$(LIB_RABIT)
COVER_OBJ=$(patsubst %.o, %.gcda, $(ALL_OBJ)) $(patsubst %.o, %.gcda, $(UNITTEST_OBJ)) COVER_OBJ=$(patsubst %.o, %.gcda, $(ALL_OBJ)) $(patsubst %.o, %.gcda, $(UNITTEST_OBJ))
# the order of the below targets matter!
$(UTEST_OBJ_ROOT)/$(GTEST_PATH)/%.o: $(GTEST_PATH)/%.cc $(UTEST_OBJ_ROOT)/$(GTEST_PATH)/%.o: $(GTEST_PATH)/%.cc
@mkdir -p $(@D) @mkdir -p $(@D)
$(CXX) $(UNITTEST_CFLAGS) -I$(GTEST_INC) -I$(GTEST_PATH) -o $@ -c $< $(CXX) $(UNITTEST_CFLAGS) -I$(GTEST_INC) -I$(GTEST_PATH) -o $@ -c $<