diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index b41cb3632..4286a6c1c 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -515,28 +515,38 @@ __global__ void CompressBinEllpackKernel( } template -__global__ void SharedMemHistKernel(ELLPackMatrix matrix, const bst_uint* d_ridx, +__global__ void SharedMemHistKernel(ELLPackMatrix matrix, + const bst_uint* d_ridx, GradientSumT* d_node_hist, const GradientPair* d_gpair, - size_t segment_begin, size_t n_elements) { + size_t segment_begin, size_t n_elements, + bool use_shared_memory_histograms) { extern __shared__ char smem[]; - GradientSumT* smem_arr = reinterpret_cast(smem); // NOLINT - for (auto i : - dh::BlockStrideRange(static_cast(0), matrix.BinCount())) { - smem_arr[i] = GradientSumT(); + GradientSumT* smem_arr = reinterpret_cast(smem); // NOLINT + if (use_shared_memory_histograms) { + dh::BlockFill(smem_arr, matrix.BinCount(), GradientSumT()); + __syncthreads(); } - __syncthreads(); for (auto idx : dh::GridStrideRange(static_cast(0), n_elements)) { int ridx = d_ridx[idx / matrix.row_stride + segment_begin]; - int gidx = matrix.gidx_iter[ridx * matrix.row_stride + idx % matrix.row_stride]; + int gidx = + matrix.gidx_iter[ridx * matrix.row_stride + idx % matrix.row_stride]; if (gidx != matrix.null_gidx_value) { - AtomicAddGpair(smem_arr + gidx, d_gpair[ridx]); + // If we are not using shared memory, accumulate the values directly into + // global memory + GradientSumT* atomic_add_ptr = + use_shared_memory_histograms ? smem_arr : d_node_hist; + AtomicAddGpair(atomic_add_ptr + gidx, d_gpair[ridx]); } } - __syncthreads(); - for (auto i : - dh::BlockStrideRange(static_cast(0), matrix.BinCount())) { - AtomicAddGpair(d_node_hist + i, smem_arr[i]); + + if (use_shared_memory_histograms) { + // Write shared memory back to global memory + __syncthreads(); + for (auto i : + dh::BlockStrideRange(static_cast(0), matrix.BinCount())) { + AtomicAddGpair(d_node_hist + i, smem_arr[i]); + } } } @@ -620,16 +630,6 @@ __forceinline__ __device__ void CountLeft(int64_t* d_count, int val, #endif } -template -struct DeviceShard; - -template -struct GPUHistBuilderBase { - public: - virtual void Build(DeviceShard* shard, int idx) = 0; - virtual ~GPUHistBuilderBase() = default; -}; - // Manage memory for a single GPU template struct DeviceShard { @@ -679,6 +679,7 @@ struct DeviceShard { TrainParam param; bool prediction_cache_initialised; + bool use_shared_memory_histograms {false}; dh::CubMemory temp_memory; dh::PinnedMemory pinned_memory; @@ -689,8 +690,6 @@ struct DeviceShard { std::vector node_value_constraints; common::ColumnSampler column_sampler; - std::unique_ptr> hist_builder; - using ExpandQueue = std::priority_queue, std::function>; @@ -870,7 +869,27 @@ struct DeviceShard { void BuildHist(int nidx) { hist.AllocateHistogram(nidx); - hist_builder->Build(this, nidx); + auto segment = ridx_segments[nidx]; + auto d_node_hist = hist.GetNodeHistogram(nidx); + auto d_ridx = ridx.Current(); + auto d_gpair = gpair.data(); + + auto n_elements = segment.Size() * ellpack_matrix.row_stride; + + const size_t smem_size = + use_shared_memory_histograms + ? sizeof(GradientSumT) * ellpack_matrix.BinCount() + : 0; + const int items_per_thread = 8; + const int block_threads = 256; + const int grid_size = static_cast( + dh::DivRoundUp(n_elements, items_per_thread * block_threads)); + if (grid_size <= 0) { + return; + } + SharedMemHistKernel<<>>( + ellpack_matrix, d_ridx, d_node_hist.data(), d_gpair, segment.begin, + n_elements, use_shared_memory_histograms); } void SubtractionTrick(int nidx_parent, int nidx_histogram, @@ -1084,7 +1103,6 @@ struct DeviceShard { this->AllReduceHist(subtraction_trick_nidx, reducer); } } - void ApplySplit(const ExpandEntry& candidate, RegTree* p_tree) { RegTree& tree = *p_tree; @@ -1209,55 +1227,6 @@ struct DeviceShard { } }; -template -struct SharedMemHistBuilder : public GPUHistBuilderBase { - void Build(DeviceShard* shard, int nidx) override { - auto segment = shard->ridx_segments[nidx]; - auto segment_begin = segment.begin; - auto d_node_hist = shard->hist.GetNodeHistogram(nidx); - auto d_ridx = shard->ridx.Current(); - auto d_gpair = shard->gpair.data(); - - auto n_elements = segment.Size() * shard->ellpack_matrix.row_stride; - - const size_t smem_size = sizeof(GradientSumT) * shard->ellpack_matrix.BinCount(); - const int items_per_thread = 8; - const int block_threads = 256; - const int grid_size = - static_cast(dh::DivRoundUp(n_elements, - items_per_thread * block_threads)); - if (grid_size <= 0) { - return; - } - SharedMemHistKernel<<>>( - shard->ellpack_matrix, d_ridx, d_node_hist.data(), d_gpair, - segment_begin, n_elements); - } -}; - -template -struct GlobalMemHistBuilder : public GPUHistBuilderBase { - void Build(DeviceShard* shard, int nidx) override { - Segment segment = shard->ridx_segments[nidx]; - auto d_node_hist = shard->hist.GetNodeHistogram(nidx).data(); - bst_uint* d_ridx = shard->ridx.Current(); - GradientPair* d_gpair = shard->gpair.data(); - - size_t const n_elements = segment.Size() * shard->ellpack_matrix.row_stride; - auto d_matrix = shard->ellpack_matrix; - - dh::LaunchN(shard->device_id, n_elements, [=] __device__(size_t idx) { - int ridx = d_ridx[(idx / d_matrix.row_stride) + segment.begin]; - // lookup the index (bin) of histogram. - int gidx = d_matrix.gidx_iter[ridx * d_matrix.row_stride + idx % d_matrix.row_stride]; - - if (gidx != d_matrix.null_gidx_value) { - AtomicAddGpair(d_node_hist + gidx, d_gpair[ridx]); - } - }); - } -}; - template inline void DeviceShard::InitCompressedData( const common::HistCutMatrix& hmat, const SparsePage& row_batch, bool is_dense) { @@ -1317,9 +1286,7 @@ inline void DeviceShard::InitCompressedData( auto histogram_size = sizeof(GradientSumT) * hmat.row_ptr.back(); auto max_smem = dh::MaxSharedMemory(device_id); if (histogram_size <= max_smem) { - hist_builder.reset(new SharedMemHistBuilder); - } else { - hist_builder.reset(new GlobalMemHistBuilder); + use_shared_memory_histograms = true; } // Init histogram diff --git a/tests/cpp/tree/test_gpu_hist.cu b/tests/cpp/tree/test_gpu_hist.cu index e46e8f543..2d120ecc5 100644 --- a/tests/cpp/tree/test_gpu_hist.cu +++ b/tests/cpp/tree/test_gpu_hist.cu @@ -162,7 +162,7 @@ std::vector GetHostHistGpair() { } template -void TestBuildHist(GPUHistBuilderBase& builder) { +void TestBuildHist(bool use_shared_memory_histograms) { int const kNRows = 16, kNCols = 8; TrainParam param; @@ -170,7 +170,6 @@ void TestBuildHist(GPUHistBuilderBase& builder) { param.max_leaves = 0; DeviceShard shard(0, 0, 0, kNRows, param, kNCols); - BuildGidx(&shard, kNRows, kNCols); xgboost::SimpleLCG gen; @@ -202,7 +201,8 @@ void TestBuildHist(GPUHistBuilderBase& builder) { thrust::device_pointer_cast(shard.ridx.Current()), thrust::device_pointer_cast(shard.ridx.Current() + shard.ridx.Size())); - builder.Build(&shard, 0); + shard.use_shared_memory_histograms = use_shared_memory_histograms; + shard.BuildHist(0); DeviceHistogram d_hist = shard.hist; auto node_histogram = d_hist.GetNodeHistogram(0); @@ -224,17 +224,13 @@ void TestBuildHist(GPUHistBuilderBase& builder) { } TEST(GpuHist, BuildHistGlobalMem) { - GlobalMemHistBuilder double_builder; - TestBuildHist(double_builder); - GlobalMemHistBuilder float_builder; - TestBuildHist(float_builder); + TestBuildHist(false); + TestBuildHist(false); } TEST(GpuHist, BuildHistSharedMem) { - SharedMemHistBuilder double_builder; - TestBuildHist(double_builder); - SharedMemHistBuilder float_builder; - TestBuildHist(float_builder); + TestBuildHist(true); + TestBuildHist(true); } common::HistCutMatrix GetHostCutMatrix () {