Shared memory atomics while building histogram (#3384)
* Use shared memory atomics for building histograms, whenever possible
This commit is contained in:
parent
2c4359e914
commit
0e78034607
@ -44,6 +44,14 @@ __device__ __forceinline__ void AtomicAddGpair(GradientPairPrecise* dest,
|
|||||||
atomicAdd(dst_ptr, static_cast<double>(gpair.GetGrad()));
|
atomicAdd(dst_ptr, static_cast<double>(gpair.GetGrad()));
|
||||||
atomicAdd(dst_ptr + 1, static_cast<double>(gpair.GetHess()));
|
atomicAdd(dst_ptr + 1, static_cast<double>(gpair.GetHess()));
|
||||||
}
|
}
|
||||||
|
// used by shared-memory atomics code
|
||||||
|
__device__ __forceinline__ void AtomicAddGpair(GradientPairPrecise* dest,
|
||||||
|
const GradientPairPrecise& gpair) {
|
||||||
|
auto dst_ptr = reinterpret_cast<double*>(dest);
|
||||||
|
|
||||||
|
atomicAdd(dst_ptr, gpair.GetGrad());
|
||||||
|
atomicAdd(dst_ptr + 1, gpair.GetHess());
|
||||||
|
}
|
||||||
|
|
||||||
// For integer gradients
|
// For integer gradients
|
||||||
__device__ __forceinline__ void AtomicAddGpair(GradientPairInteger* dest,
|
__device__ __forceinline__ void AtomicAddGpair(GradientPairInteger* dest,
|
||||||
|
|||||||
@ -274,6 +274,33 @@ __global__ void compress_bin_ellpack_k
|
|||||||
wr.AtomicWriteSymbol(buffer, bin, (irow + base_row) * row_stride + ifeature);
|
wr.AtomicWriteSymbol(buffer, bin, (irow + base_row) * row_stride + ifeature);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
__global__ void sharedMemHistKernel(size_t row_stride,
|
||||||
|
const bst_uint* d_ridx,
|
||||||
|
common::CompressedIterator<uint32_t> d_gidx,
|
||||||
|
int null_gidx_value,
|
||||||
|
GradientPairSumT* d_node_hist,
|
||||||
|
const GradientPair* d_gpair,
|
||||||
|
size_t segment_begin,
|
||||||
|
size_t n_elements) {
|
||||||
|
extern __shared__ char smem[];
|
||||||
|
GradientPairSumT* smem_arr = reinterpret_cast<GradientPairSumT*>(smem); // NOLINT
|
||||||
|
for (auto i : dh::BlockStrideRange(0, null_gidx_value)) {
|
||||||
|
smem_arr[i] = GradientPairSumT();
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
for (auto idx : dh::GridStrideRange(static_cast<size_t>(0), n_elements)) {
|
||||||
|
int ridx = d_ridx[idx / row_stride + segment_begin];
|
||||||
|
int gidx = d_gidx[ridx * row_stride + idx % row_stride];
|
||||||
|
if (gidx != null_gidx_value) {
|
||||||
|
AtomicAddGpair(smem_arr + gidx, d_gpair[ridx]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
for (auto i : dh::BlockStrideRange(0, null_gidx_value)) {
|
||||||
|
AtomicAddGpair(d_node_hist + i, smem_arr[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Manage memory for a single GPU
|
// Manage memory for a single GPU
|
||||||
struct DeviceShard {
|
struct DeviceShard {
|
||||||
struct Segment {
|
struct Segment {
|
||||||
@ -304,7 +331,7 @@ struct DeviceShard {
|
|||||||
std::vector<GradientPair> node_sum_gradients;
|
std::vector<GradientPair> node_sum_gradients;
|
||||||
dh::DVec<GradientPair> node_sum_gradients_d;
|
dh::DVec<GradientPair> node_sum_gradients_d;
|
||||||
common::CompressedIterator<uint32_t> gidx;
|
common::CompressedIterator<uint32_t> gidx;
|
||||||
int row_stride;
|
size_t row_stride;
|
||||||
bst_uint row_begin_idx; // The row offset for this shard
|
bst_uint row_begin_idx; // The row offset for this shard
|
||||||
bst_uint row_end_idx;
|
bst_uint row_end_idx;
|
||||||
bst_uint n_rows;
|
bst_uint n_rows;
|
||||||
@ -313,6 +340,7 @@ struct DeviceShard {
|
|||||||
DeviceHistogram hist;
|
DeviceHistogram hist;
|
||||||
TrainParam param;
|
TrainParam param;
|
||||||
bool prediction_cache_initialised;
|
bool prediction_cache_initialised;
|
||||||
|
bool can_use_smem_atomics;
|
||||||
|
|
||||||
int64_t* tmp_pinned; // Small amount of staging memory
|
int64_t* tmp_pinned; // Small amount of staging memory
|
||||||
|
|
||||||
@ -330,7 +358,8 @@ struct DeviceShard {
|
|||||||
n_bins(n_bins),
|
n_bins(n_bins),
|
||||||
null_gidx_value(n_bins),
|
null_gidx_value(n_bins),
|
||||||
param(param),
|
param(param),
|
||||||
prediction_cache_initialised(false) {}
|
prediction_cache_initialised(false),
|
||||||
|
can_use_smem_atomics(false) {}
|
||||||
|
|
||||||
void Init(const common::HistCutMatrix& hmat, const SparsePage& row_batch) {
|
void Init(const common::HistCutMatrix& hmat, const SparsePage& row_batch) {
|
||||||
// copy cuts to the GPU
|
// copy cuts to the GPU
|
||||||
@ -430,6 +459,12 @@ struct DeviceShard {
|
|||||||
node_sum_gradients.resize(max_nodes);
|
node_sum_gradients.resize(max_nodes);
|
||||||
ridx_segments.resize(max_nodes);
|
ridx_segments.resize(max_nodes);
|
||||||
|
|
||||||
|
// check if we can use shared memory for building histograms
|
||||||
|
// (assuming atleast we need 2 CTAs per SM to maintain decent latency hiding)
|
||||||
|
auto histogram_size = sizeof(GradientPairSumT) * null_gidx_value;
|
||||||
|
auto max_smem = dh::MaxSharedMemory(device_idx);
|
||||||
|
can_use_smem_atomics = histogram_size <= max_smem;
|
||||||
|
|
||||||
// Init histogram
|
// Init histogram
|
||||||
hist.Init(device_idx, max_nodes, hmat.row_ptr.back(), param.silent);
|
hist.Init(device_idx, max_nodes, hmat.row_ptr.back(), param.silent);
|
||||||
|
|
||||||
@ -477,7 +512,7 @@ struct DeviceShard {
|
|||||||
hist.Reset();
|
hist.Reset();
|
||||||
}
|
}
|
||||||
|
|
||||||
void BuildHist(int nidx) {
|
void BuildHistUsingGlobalMem(int nidx) {
|
||||||
auto segment = ridx_segments[nidx];
|
auto segment = ridx_segments[nidx];
|
||||||
auto d_node_hist = hist.GetHistPtr(nidx);
|
auto d_node_hist = hist.GetHistPtr(nidx);
|
||||||
auto d_gidx = gidx;
|
auto d_gidx = gidx;
|
||||||
@ -496,6 +531,41 @@ struct DeviceShard {
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void BuildHistUsingSharedMem(int nidx) {
|
||||||
|
auto segment = ridx_segments[nidx];
|
||||||
|
auto segment_begin = segment.begin;
|
||||||
|
auto d_node_hist = hist.GetHistPtr(nidx);
|
||||||
|
auto d_gidx = gidx;
|
||||||
|
auto d_ridx = ridx.Current();
|
||||||
|
auto d_gpair = gpair.Data();
|
||||||
|
auto row_stride = this->row_stride;
|
||||||
|
auto null_gidx_value = this->null_gidx_value;
|
||||||
|
auto n_elements = segment.Size() * row_stride;
|
||||||
|
|
||||||
|
const size_t smem_size = sizeof(GradientPairSumT) * null_gidx_value;
|
||||||
|
const int items_per_thread = 8;
|
||||||
|
const int block_threads = 256;
|
||||||
|
const int grid_size =
|
||||||
|
static_cast<int>(dh::DivRoundUp(n_elements,
|
||||||
|
items_per_thread * block_threads));
|
||||||
|
if (grid_size <= 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
dh::safe_cuda(cudaSetDevice(device_idx));
|
||||||
|
sharedMemHistKernel<<<grid_size, block_threads, smem_size>>>
|
||||||
|
(row_stride, d_ridx, d_gidx, null_gidx_value, d_node_hist, d_gpair,
|
||||||
|
segment_begin, n_elements);
|
||||||
|
}
|
||||||
|
|
||||||
|
void BuildHist(int nidx) {
|
||||||
|
if (can_use_smem_atomics) {
|
||||||
|
BuildHistUsingSharedMem(nidx);
|
||||||
|
} else {
|
||||||
|
BuildHistUsingGlobalMem(nidx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void SubtractionTrick(int nidx_parent, int nidx_histogram,
|
void SubtractionTrick(int nidx_parent, int nidx_histogram,
|
||||||
int nidx_subtraction) {
|
int nidx_subtraction) {
|
||||||
auto d_node_hist_parent = hist.GetHistPtr(nidx_parent);
|
auto d_node_hist_parent = hist.GetHistPtr(nidx_parent);
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user