diff --git a/src/common/device_helpers.cuh b/src/common/device_helpers.cuh index 3024b589f..0a77812e3 100644 --- a/src/common/device_helpers.cuh +++ b/src/common/device_helpers.cuh @@ -100,7 +100,7 @@ inline size_t TotalMemory(int device_idx) { } /** - * \fn inline int max_shared_memory(int device_idx) + * \fn inline int MaxSharedMemory(int device_idx) * * \brief Maximum shared memory per block on this device. * @@ -113,6 +113,23 @@ inline size_t MaxSharedMemory(int device_idx) { return prop.sharedMemPerBlock; } +/** + * \fn inline int MaxSharedMemoryOptin(int device_idx) + * + * \brief Maximum dynamic shared memory per thread block on this device + that can be opted into when using cudaFuncSetAttribute(). + * + * \param device_idx Zero-based index of the device. + */ + +inline size_t MaxSharedMemoryOptin(int device_idx) { + int max_shared_memory = 0; + dh::safe_cuda(cudaDeviceGetAttribute + (&max_shared_memory, cudaDevAttrMaxSharedMemoryPerBlockOptin, + device_idx)); + return size_t(max_shared_memory); +} + inline void CheckComputeCapability() { for (int d_idx = 0; d_idx < xgboost::common::AllVisibleGPUs(); ++d_idx) { cudaDeviceProp prop; diff --git a/src/tree/gpu_hist/histogram.cu b/src/tree/gpu_hist/histogram.cu index 0035fb214..edc3046d1 100644 --- a/src/tree/gpu_hist/histogram.cu +++ b/src/tree/gpu_hist/histogram.cu @@ -150,21 +150,37 @@ void BuildGradientHistogram(EllpackDeviceAccessor const& matrix, common::Span gpair, common::Span d_ridx, common::Span histogram, - GradientSumT rounding, bool shared) { - const size_t smem_size = - shared - ? sizeof(GradientSumT) * matrix.NumBins() - : 0; - auto n_elements = d_ridx.size() * matrix.row_stride; + GradientSumT rounding) { + // decide whether to use shared memory + int device = 0; + dh::safe_cuda(cudaGetDevice(&device)); + int max_shared_memory = dh::MaxSharedMemoryOptin(device); + size_t smem_size = sizeof(GradientSumT) * matrix.NumBins(); + bool shared = smem_size <= max_shared_memory; + smem_size = shared ? smem_size : 0; - uint32_t items_per_thread = 8; - uint32_t block_threads = 256; - auto grid_size = static_cast( - common::DivRoundUp(n_elements, items_per_thread * block_threads)); + // opt into maximum shared memory for the kernel if necessary + auto kernel = SharedMemHistKernel; + if (shared) { + dh::safe_cuda(cudaFuncSetAttribute + (kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + max_shared_memory)); + } + + // determine the launch configuration + unsigned block_threads = shared ? 1024 : 256; + int n_mps = 0; + dh::safe_cuda(cudaDeviceGetAttribute(&n_mps, cudaDevAttrMultiProcessorCount, device)); + int n_blocks_per_mp = 0; + dh::safe_cuda(cudaOccupancyMaxActiveBlocksPerMultiprocessor + (&n_blocks_per_mp, kernel, block_threads, smem_size)); + unsigned grid_size = n_blocks_per_mp * n_mps; + + auto n_elements = d_ridx.size() * matrix.row_stride; dh::LaunchKernel {grid_size, block_threads, smem_size} ( - SharedMemHistKernel, - matrix, d_ridx, histogram.data(), gpair.data(), n_elements, + kernel, matrix, d_ridx, histogram.data(), gpair.data(), n_elements, rounding, shared); + dh::safe_cuda(cudaGetLastError()); } template void BuildGradientHistogram( @@ -172,13 +188,14 @@ template void BuildGradientHistogram( common::Span gpair, common::Span ridx, common::Span histogram, - GradientPair rounding, bool shared); + GradientPair rounding); template void BuildGradientHistogram( EllpackDeviceAccessor const& matrix, common::Span gpair, common::Span ridx, common::Span histogram, - GradientPairPrecise rounding, bool shared); + GradientPairPrecise rounding); + } // namespace tree } // namespace xgboost diff --git a/src/tree/gpu_hist/histogram.cuh b/src/tree/gpu_hist/histogram.cuh index a7c923b61..d8673a8a5 100644 --- a/src/tree/gpu_hist/histogram.cuh +++ b/src/tree/gpu_hist/histogram.cuh @@ -22,7 +22,7 @@ void BuildGradientHistogram(EllpackDeviceAccessor const& matrix, common::Span gpair, common::Span ridx, common::Span histogram, - GradientSumT rounding, bool shared); + GradientSumT rounding); } // namespace tree } // namespace xgboost diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 986e43fe5..67848b1e2 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -422,7 +422,6 @@ struct GPUHistMakerDevice { TrainParam param; bool deterministic_histogram; - bool use_shared_memory_histograms {false}; GradientSumT histogram_rounding; @@ -596,7 +595,7 @@ struct GPUHistMakerDevice { auto d_node_hist = hist.GetNodeHistogram(nidx); auto d_ridx = row_partitioner->GetRows(nidx); BuildGradientHistogram(page->GetDeviceAccessor(device_id), gpair, d_ridx, d_node_hist, - histogram_rounding, use_shared_memory_histograms); + histogram_rounding); } void SubtractionTrick(int nidx_parent, int nidx_histogram, @@ -910,15 +909,6 @@ inline void GPUHistMakerDevice::InitHistogram() { host_node_sum_gradients.resize(param.MaxNodes()); node_sum_gradients.resize(param.MaxNodes()); - // check if we can use shared memory for building histograms - // (assuming atleast we need 2 CTAs per SM to maintain decent latency - // hiding) - auto histogram_size = sizeof(GradientSumT) * page->Cuts().TotalBins(); - auto max_smem = dh::MaxSharedMemory(device_id); - if (histogram_size <= max_smem) { - use_shared_memory_histograms = true; - } - // Init histogram hist.Init(device_id, page->Cuts().TotalBins()); } diff --git a/tests/cpp/tree/gpu_hist/test_histogram.cu b/tests/cpp/tree/gpu_hist/test_histogram.cu index ada6fa8a0..23fa5ebe8 100644 --- a/tests/cpp/tree/gpu_hist/test_histogram.cu +++ b/tests/cpp/tree/gpu_hist/test_histogram.cu @@ -27,7 +27,7 @@ void TestDeterminsticHistogram() { auto rounding = CreateRoundingFactor(gpair.DeviceSpan()); BuildGradientHistogram(page->GetDeviceAccessor(0), gpair.DeviceSpan(), ridx, - d_histogram, rounding, true); + d_histogram, rounding); for (size_t i = 0; i < kRounds; ++i) { dh::device_vector new_histogram(kBins * kCols); @@ -35,7 +35,7 @@ void TestDeterminsticHistogram() { auto rounding = CreateRoundingFactor(gpair.DeviceSpan()); BuildGradientHistogram(page->GetDeviceAccessor(0), gpair.DeviceSpan(), ridx, - d_histogram, rounding, true); + d_histogram, rounding); for (size_t j = 0; j < new_histogram.size(); ++j) { ASSERT_EQ(((Gradient)new_histogram[j]).GetGrad(), @@ -50,7 +50,7 @@ void TestDeterminsticHistogram() { gpair.SetDevice(0); dh::device_vector baseline(kBins * kCols); BuildGradientHistogram(page->GetDeviceAccessor(0), gpair.DeviceSpan(), ridx, - dh::ToSpan(baseline), rounding, true); + dh::ToSpan(baseline), rounding); for (size_t i = 0; i < baseline.size(); ++i) { EXPECT_NEAR(((Gradient)baseline[i]).GetGrad(), ((Gradient)histogram[i]).GetGrad(), ((Gradient)baseline[i]).GetGrad() * 1e-3); diff --git a/tests/cpp/tree/test_gpu_hist.cu b/tests/cpp/tree/test_gpu_hist.cu index 32099ca77..2a42b528c 100644 --- a/tests/cpp/tree/test_gpu_hist.cu +++ b/tests/cpp/tree/test_gpu_hist.cu @@ -101,7 +101,6 @@ void TestBuildHist(bool use_shared_memory_histograms) { maker.hist.AllocateHistogram(0); maker.gpair = gpair.DeviceSpan(); - maker.use_shared_memory_histograms = use_shared_memory_histograms; maker.BuildHist(0); DeviceHistogram d_hist = maker.hist;