diff --git a/src/tree/updater_gpu_common.cuh b/src/tree/updater_gpu_common.cuh index 9cf490ac7..ab83c0bb6 100644 --- a/src/tree/updater_gpu_common.cuh +++ b/src/tree/updater_gpu_common.cuh @@ -44,6 +44,14 @@ __device__ __forceinline__ void AtomicAddGpair(GradientPairPrecise* dest, atomicAdd(dst_ptr, static_cast(gpair.GetGrad())); atomicAdd(dst_ptr + 1, static_cast(gpair.GetHess())); } +// used by shared-memory atomics code +__device__ __forceinline__ void AtomicAddGpair(GradientPairPrecise* dest, + const GradientPairPrecise& gpair) { + auto dst_ptr = reinterpret_cast(dest); + + atomicAdd(dst_ptr, gpair.GetGrad()); + atomicAdd(dst_ptr + 1, gpair.GetHess()); +} // For integer gradients __device__ __forceinline__ void AtomicAddGpair(GradientPairInteger* dest, diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 0c75e4e14..f1d5249cc 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -274,6 +274,33 @@ __global__ void compress_bin_ellpack_k wr.AtomicWriteSymbol(buffer, bin, (irow + base_row) * row_stride + ifeature); } +__global__ void sharedMemHistKernel(size_t row_stride, + const bst_uint* d_ridx, + common::CompressedIterator d_gidx, + int null_gidx_value, + GradientPairSumT* d_node_hist, + const GradientPair* d_gpair, + size_t segment_begin, + size_t n_elements) { + extern __shared__ char smem[]; + GradientPairSumT* smem_arr = reinterpret_cast(smem); // NOLINT + for (auto i : dh::BlockStrideRange(0, null_gidx_value)) { + smem_arr[i] = GradientPairSumT(); + } + __syncthreads(); + for (auto idx : dh::GridStrideRange(static_cast(0), n_elements)) { + int ridx = d_ridx[idx / row_stride + segment_begin]; + int gidx = d_gidx[ridx * row_stride + idx % row_stride]; + if (gidx != null_gidx_value) { + AtomicAddGpair(smem_arr + gidx, d_gpair[ridx]); + } + } + __syncthreads(); + for (auto i : dh::BlockStrideRange(0, null_gidx_value)) { + AtomicAddGpair(d_node_hist + i, smem_arr[i]); + } +} + // Manage memory for a single GPU struct DeviceShard { struct Segment { @@ -304,7 +331,7 @@ struct DeviceShard { std::vector node_sum_gradients; dh::DVec node_sum_gradients_d; common::CompressedIterator gidx; - int row_stride; + size_t row_stride; bst_uint row_begin_idx; // The row offset for this shard bst_uint row_end_idx; bst_uint n_rows; @@ -313,6 +340,7 @@ struct DeviceShard { DeviceHistogram hist; TrainParam param; bool prediction_cache_initialised; + bool can_use_smem_atomics; int64_t* tmp_pinned; // Small amount of staging memory @@ -330,7 +358,8 @@ struct DeviceShard { n_bins(n_bins), null_gidx_value(n_bins), param(param), - prediction_cache_initialised(false) {} + prediction_cache_initialised(false), + can_use_smem_atomics(false) {} void Init(const common::HistCutMatrix& hmat, const SparsePage& row_batch) { // copy cuts to the GPU @@ -430,6 +459,12 @@ struct DeviceShard { node_sum_gradients.resize(max_nodes); ridx_segments.resize(max_nodes); + // check if we can use shared memory for building histograms + // (assuming atleast we need 2 CTAs per SM to maintain decent latency hiding) + auto histogram_size = sizeof(GradientPairSumT) * null_gidx_value; + auto max_smem = dh::MaxSharedMemory(device_idx); + can_use_smem_atomics = histogram_size <= max_smem; + // Init histogram hist.Init(device_idx, max_nodes, hmat.row_ptr.back(), param.silent); @@ -477,7 +512,7 @@ struct DeviceShard { hist.Reset(); } - void BuildHist(int nidx) { + void BuildHistUsingGlobalMem(int nidx) { auto segment = ridx_segments[nidx]; auto d_node_hist = hist.GetHistPtr(nidx); auto d_gidx = gidx; @@ -496,6 +531,41 @@ struct DeviceShard { } }); } + + void BuildHistUsingSharedMem(int nidx) { + auto segment = ridx_segments[nidx]; + auto segment_begin = segment.begin; + auto d_node_hist = hist.GetHistPtr(nidx); + auto d_gidx = gidx; + auto d_ridx = ridx.Current(); + auto d_gpair = gpair.Data(); + auto row_stride = this->row_stride; + auto null_gidx_value = this->null_gidx_value; + auto n_elements = segment.Size() * row_stride; + + const size_t smem_size = sizeof(GradientPairSumT) * null_gidx_value; + const int items_per_thread = 8; + const int block_threads = 256; + const int grid_size = + static_cast(dh::DivRoundUp(n_elements, + items_per_thread * block_threads)); + if (grid_size <= 0) { + return; + } + dh::safe_cuda(cudaSetDevice(device_idx)); + sharedMemHistKernel<<>> + (row_stride, d_ridx, d_gidx, null_gidx_value, d_node_hist, d_gpair, + segment_begin, n_elements); + } + + void BuildHist(int nidx) { + if (can_use_smem_atomics) { + BuildHistUsingSharedMem(nidx); + } else { + BuildHistUsingGlobalMem(nidx); + } + } + void SubtractionTrick(int nidx_parent, int nidx_histogram, int nidx_subtraction) { auto d_node_hist_parent = hist.GetHistPtr(nidx_parent);