Integer gradient summation for GPU histogram algorithm. (#2681)
This commit is contained in:
@@ -5,17 +5,20 @@
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include "param.h"
|
||||
#include "../common/compressed_iterator.h"
|
||||
#include "../common/hist_util.h"
|
||||
#include "updater_gpu_common.cuh"
|
||||
#include "../common/device_helpers.cuh"
|
||||
#include "../common/hist_util.h"
|
||||
#include "param.h"
|
||||
#include "updater_gpu_common.cuh"
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
|
||||
DMLC_REGISTRY_FILE_TAG(updater_gpu_hist);
|
||||
|
||||
typedef bst_gpair_integer gpair_sum_t;
|
||||
static const ncclDataType_t nccl_sum_t = ncclInt64;
|
||||
|
||||
// Helper for explicit template specialisation
|
||||
template <int N>
|
||||
struct Int {};
|
||||
@@ -50,27 +53,29 @@ struct DeviceGMat {
|
||||
};
|
||||
|
||||
struct HistHelper {
|
||||
bst_gpair* d_hist;
|
||||
gpair_sum_t* d_hist;
|
||||
int n_bins;
|
||||
__host__ __device__ HistHelper(bst_gpair* ptr, int n_bins)
|
||||
__host__ __device__ HistHelper(gpair_sum_t* ptr, int n_bins)
|
||||
: d_hist(ptr), n_bins(n_bins) {}
|
||||
|
||||
__device__ void Add(bst_gpair gpair, int gidx, int nidx) const {
|
||||
int hist_idx = nidx * n_bins + gidx;
|
||||
atomicAdd(&(d_hist[hist_idx].grad), gpair.grad); // OPTMARK: This and below
|
||||
// line lead to about 3X
|
||||
// slowdown due to memory
|
||||
// dependency and access
|
||||
// pattern issues.
|
||||
atomicAdd(&(d_hist[hist_idx].hess), gpair.hess);
|
||||
|
||||
auto dst_ptr = reinterpret_cast<unsigned long long int*>(&d_hist[hist_idx]); // NOLINT
|
||||
gpair_sum_t tmp(gpair.GetGrad(), gpair.GetHess());
|
||||
auto src_ptr = reinterpret_cast<gpair_sum_t::value_t*>(&tmp);
|
||||
|
||||
atomicAdd(dst_ptr, static_cast<unsigned long long int>(*src_ptr)); // NOLINT
|
||||
atomicAdd(dst_ptr + 1, static_cast<unsigned long long int>(*(src_ptr + 1))); // NOLINT
|
||||
}
|
||||
__device__ bst_gpair Get(int gidx, int nidx) const {
|
||||
__device__ gpair_sum_t Get(int gidx, int nidx) const {
|
||||
return d_hist[nidx * n_bins + gidx];
|
||||
}
|
||||
};
|
||||
|
||||
struct DeviceHist {
|
||||
int n_bins;
|
||||
dh::dvec<bst_gpair> data;
|
||||
dh::dvec<gpair_sum_t> data;
|
||||
|
||||
void Init(int n_bins_in) {
|
||||
this->n_bins = n_bins_in;
|
||||
@@ -79,12 +84,12 @@ struct DeviceHist {
|
||||
|
||||
void Reset(int device_idx) {
|
||||
cudaSetDevice(device_idx);
|
||||
data.fill(bst_gpair());
|
||||
data.fill(gpair_sum_t());
|
||||
}
|
||||
|
||||
HistHelper GetBuilder() { return HistHelper(data.data(), n_bins); }
|
||||
|
||||
bst_gpair* GetLevelPtr(int depth) {
|
||||
gpair_sum_t* GetLevelPtr(int depth) {
|
||||
return data.data() + n_nodes(depth - 1) * n_bins;
|
||||
}
|
||||
|
||||
@@ -96,18 +101,19 @@ struct SplitCandidate {
|
||||
bool missing_left;
|
||||
float fvalue;
|
||||
int findex;
|
||||
bst_gpair left_sum;
|
||||
bst_gpair right_sum;
|
||||
gpair_sum_t left_sum;
|
||||
gpair_sum_t right_sum;
|
||||
|
||||
__host__ __device__ SplitCandidate()
|
||||
: loss_chg(-FLT_MAX), missing_left(true), fvalue(0), findex(-1) {}
|
||||
|
||||
__device__ void Update(float loss_chg_in, bool missing_left_in,
|
||||
float fvalue_in, int findex_in, bst_gpair left_sum_in,
|
||||
bst_gpair right_sum_in,
|
||||
float fvalue_in, int findex_in,
|
||||
gpair_sum_t left_sum_in, gpair_sum_t right_sum_in,
|
||||
const GPUTrainingParam& param) {
|
||||
if (loss_chg_in > loss_chg && left_sum_in.hess >= param.min_child_weight &&
|
||||
right_sum_in.hess >= param.min_child_weight) {
|
||||
if (loss_chg_in > loss_chg &&
|
||||
left_sum_in.GetHess() >= param.min_child_weight &&
|
||||
right_sum_in.GetHess() >= param.min_child_weight) {
|
||||
loss_chg = loss_chg_in;
|
||||
missing_left = missing_left_in;
|
||||
fvalue = fvalue_in;
|
||||
@@ -121,11 +127,11 @@ struct SplitCandidate {
|
||||
|
||||
struct GpairCallbackOp {
|
||||
// Running prefix
|
||||
bst_gpair running_total;
|
||||
gpair_sum_t running_total;
|
||||
// Constructor
|
||||
__device__ GpairCallbackOp() : running_total(bst_gpair()) {}
|
||||
__device__ GpairCallbackOp() : running_total(gpair_sum_t()) {}
|
||||
__device__ bst_gpair operator()(bst_gpair block_aggregate) {
|
||||
bst_gpair old_prefix = running_total;
|
||||
gpair_sum_t old_prefix = running_total;
|
||||
running_total += block_aggregate;
|
||||
return old_prefix;
|
||||
}
|
||||
@@ -133,17 +139,16 @@ struct GpairCallbackOp {
|
||||
|
||||
template <int BLOCK_THREADS>
|
||||
__global__ void find_split_kernel(
|
||||
const bst_gpair* d_level_hist, int* d_feature_segments, int depth,
|
||||
const gpair_sum_t* d_level_hist, int* d_feature_segments, int depth,
|
||||
int n_features, int n_bins, DeviceDenseNode* d_nodes,
|
||||
int nodes_offset_device, float* d_fidx_min_map, float* d_gidx_fvalue_map,
|
||||
GPUTrainingParam gpu_param, bool* d_left_child_smallest_temp,
|
||||
bool colsample, int* d_feature_flags) {
|
||||
typedef cub::KeyValuePair<int, float> ArgMaxT;
|
||||
typedef cub::BlockScan<bst_gpair, BLOCK_THREADS,
|
||||
cub::BLOCK_SCAN_WARP_SCANS>
|
||||
typedef cub::BlockScan<gpair_sum_t, BLOCK_THREADS, cub::BLOCK_SCAN_WARP_SCANS>
|
||||
BlockScanT;
|
||||
typedef cub::BlockReduce<ArgMaxT, BLOCK_THREADS> MaxReduceT;
|
||||
typedef cub::BlockReduce<bst_gpair, BLOCK_THREADS> SumReduceT;
|
||||
typedef cub::BlockReduce<gpair_sum_t, BLOCK_THREADS> SumReduceT;
|
||||
|
||||
union TempStorage {
|
||||
typename BlockScanT::TempStorage scan;
|
||||
@@ -153,8 +158,8 @@ __global__ void find_split_kernel(
|
||||
|
||||
__shared__ cub::Uninitialized<SplitCandidate> uninitialized_split;
|
||||
SplitCandidate& split = uninitialized_split.Alias();
|
||||
__shared__ cub::Uninitialized<bst_gpair> uninitialized_sum;
|
||||
bst_gpair& shared_sum = uninitialized_sum.Alias();
|
||||
__shared__ cub::Uninitialized<gpair_sum_t> uninitialized_sum;
|
||||
gpair_sum_t& shared_sum = uninitialized_sum.Alias();
|
||||
__shared__ ArgMaxT block_max;
|
||||
__shared__ TempStorage temp_storage;
|
||||
|
||||
@@ -175,14 +180,13 @@ __global__ void find_split_kernel(
|
||||
int begin = d_feature_segments[level_node_idx * n_features + fidx];
|
||||
int end = d_feature_segments[level_node_idx * n_features + fidx + 1];
|
||||
|
||||
bst_gpair feature_sum = bst_gpair();
|
||||
gpair_sum_t feature_sum = gpair_sum_t();
|
||||
for (int reduce_begin = begin; reduce_begin < end;
|
||||
reduce_begin += BLOCK_THREADS) {
|
||||
bool thread_active = reduce_begin + threadIdx.x < end;
|
||||
// Scan histogram
|
||||
bst_gpair bin = thread_active
|
||||
? d_level_hist[reduce_begin + threadIdx.x]
|
||||
: bst_gpair();
|
||||
gpair_sum_t bin = thread_active ? d_level_hist[reduce_begin + threadIdx.x]
|
||||
: gpair_sum_t();
|
||||
|
||||
feature_sum +=
|
||||
SumReduceT(temp_storage.sum_reduce).Reduce(bin, cub::Sum());
|
||||
@@ -197,18 +201,17 @@ __global__ void find_split_kernel(
|
||||
for (int scan_begin = begin; scan_begin < end;
|
||||
scan_begin += BLOCK_THREADS) {
|
||||
bool thread_active = scan_begin + threadIdx.x < end;
|
||||
bst_gpair bin = thread_active
|
||||
? d_level_hist[scan_begin + threadIdx.x]
|
||||
: bst_gpair();
|
||||
gpair_sum_t bin = thread_active ? d_level_hist[scan_begin + threadIdx.x]
|
||||
: gpair_sum_t();
|
||||
|
||||
BlockScanT(temp_storage.scan)
|
||||
.ExclusiveScan(bin, bin, cub::Sum(), prefix_op);
|
||||
|
||||
// Calculate gain
|
||||
bst_gpair parent_sum = d_nodes[node_idx].sum_gradients;
|
||||
gpair_sum_t parent_sum = gpair_sum_t(d_nodes[node_idx].sum_gradients);
|
||||
float parent_gain = d_nodes[node_idx].root_gain;
|
||||
|
||||
bst_gpair missing = parent_sum - shared_sum;
|
||||
gpair_sum_t missing = parent_sum - shared_sum;
|
||||
|
||||
bool missing_left;
|
||||
float gain = thread_active
|
||||
@@ -239,8 +242,8 @@ __global__ void find_split_kernel(
|
||||
fvalue = d_gidx_fvalue_map[gidx - 1];
|
||||
}
|
||||
|
||||
bst_gpair left = missing_left ? bin + missing : bin;
|
||||
bst_gpair right = parent_sum - left;
|
||||
gpair_sum_t left = missing_left ? bin + missing : bin;
|
||||
gpair_sum_t right = parent_sum - left;
|
||||
|
||||
split.Update(gain, missing_left, fvalue, fidx, left, right, gpu_param);
|
||||
}
|
||||
@@ -263,7 +266,7 @@ __global__ void find_split_kernel(
|
||||
DeviceDenseNode(split.right_sum, right_child_nidx(node_idx), gpu_param);
|
||||
|
||||
// Record smallest node
|
||||
if (split.left_sum.hess <= split.right_sum.hess) {
|
||||
if (split.left_sum.GetHess() <= split.right_sum.GetHess()) {
|
||||
left_child_smallest = true;
|
||||
} else {
|
||||
left_child_smallest = false;
|
||||
@@ -595,6 +598,7 @@ class GPUHistMaker : public TreeUpdater {
|
||||
|
||||
initialised = true;
|
||||
}
|
||||
|
||||
void BuildHist(int depth) {
|
||||
for (int d_idx = 0; d_idx < n_devices; d_idx++) {
|
||||
int device_idx = dList[d_idx];
|
||||
@@ -650,9 +654,9 @@ class GPUHistMaker : public TreeUpdater {
|
||||
dh::safe_nccl(ncclAllReduce(
|
||||
reinterpret_cast<const void*>(hist_vec[d_idx].GetLevelPtr(depth)),
|
||||
reinterpret_cast<void*>(hist_vec[d_idx].GetLevelPtr(depth)),
|
||||
hist_vec[d_idx].LevelSize(depth) * sizeof(bst_gpair) /
|
||||
sizeof(float),
|
||||
ncclFloat, ncclSum, comms[d_idx], *(streams[d_idx])));
|
||||
hist_vec[d_idx].LevelSize(depth) * sizeof(gpair_sum_t) /
|
||||
sizeof(gpair_sum_t::value_t),
|
||||
nccl_sum_t, ncclSum, comms[d_idx], *(streams[d_idx])));
|
||||
}
|
||||
|
||||
for (int d_idx = 0; d_idx < n_devices; d_idx++) {
|
||||
@@ -683,11 +687,12 @@ class GPUHistMaker : public TreeUpdater {
|
||||
}
|
||||
|
||||
int gidx = idx % hist_builder.n_bins;
|
||||
bst_gpair parent = hist_builder.Get(gidx, parent_nidx(nidx));
|
||||
gpair_sum_t parent = hist_builder.Get(gidx, parent_nidx(nidx));
|
||||
int other_nidx = left_smallest ? nidx - 1 : nidx + 1;
|
||||
bst_gpair other = hist_builder.Get(gidx, other_nidx);
|
||||
gpair_sum_t other = hist_builder.Get(gidx, other_nidx);
|
||||
gpair_sum_t sub = parent - other;
|
||||
hist_builder.Add(
|
||||
parent - other, gidx,
|
||||
bst_gpair(sub.GetGrad(), sub.GetHess()), gidx,
|
||||
nidx); // OPTMARK: This is slow, could use shared
|
||||
// memory or cache results intead of writing to
|
||||
// global memory every time in atomic way.
|
||||
@@ -737,11 +742,11 @@ class GPUHistMaker : public TreeUpdater {
|
||||
|
||||
int nodes_offset_device = 0;
|
||||
find_split_kernel<BLOCK_THREADS><<<GRID_SIZE, BLOCK_THREADS>>>(
|
||||
(const bst_gpair*)(hist_vec[d_idx].GetLevelPtr(depth)),
|
||||
feature_segments[d_idx].data(), depth, (info->num_col),
|
||||
(hmat_.row_ptr.back()), nodes[d_idx].data(), nodes_offset_device,
|
||||
fidx_min_map[d_idx].data(), gidx_fvalue_map[d_idx].data(),
|
||||
GPUTrainingParam(param), left_child_smallest[d_idx].data(), colsample,
|
||||
hist_vec[d_idx].GetLevelPtr(depth), feature_segments[d_idx].data(),
|
||||
depth, (info->num_col), (hmat_.row_ptr.back()), nodes[d_idx].data(),
|
||||
nodes_offset_device, fidx_min_map[d_idx].data(),
|
||||
gidx_fvalue_map[d_idx].data(), GPUTrainingParam(param),
|
||||
left_child_smallest[d_idx].data(), colsample,
|
||||
feature_flags[d_idx].data());
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user