For histograms, opting into maximum shared memory available per block. (#5491)
This commit is contained in:
@@ -150,21 +150,37 @@ void BuildGradientHistogram(EllpackDeviceAccessor const& matrix,
|
||||
common::Span<GradientPair const> gpair,
|
||||
common::Span<const uint32_t> d_ridx,
|
||||
common::Span<GradientSumT> histogram,
|
||||
GradientSumT rounding, bool shared) {
|
||||
const size_t smem_size =
|
||||
shared
|
||||
? sizeof(GradientSumT) * matrix.NumBins()
|
||||
: 0;
|
||||
auto n_elements = d_ridx.size() * matrix.row_stride;
|
||||
GradientSumT rounding) {
|
||||
// decide whether to use shared memory
|
||||
int device = 0;
|
||||
dh::safe_cuda(cudaGetDevice(&device));
|
||||
int max_shared_memory = dh::MaxSharedMemoryOptin(device);
|
||||
size_t smem_size = sizeof(GradientSumT) * matrix.NumBins();
|
||||
bool shared = smem_size <= max_shared_memory;
|
||||
smem_size = shared ? smem_size : 0;
|
||||
|
||||
uint32_t items_per_thread = 8;
|
||||
uint32_t block_threads = 256;
|
||||
auto grid_size = static_cast<uint32_t>(
|
||||
common::DivRoundUp(n_elements, items_per_thread * block_threads));
|
||||
// opt into maximum shared memory for the kernel if necessary
|
||||
auto kernel = SharedMemHistKernel<GradientSumT>;
|
||||
if (shared) {
|
||||
dh::safe_cuda(cudaFuncSetAttribute
|
||||
(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
max_shared_memory));
|
||||
}
|
||||
|
||||
// determine the launch configuration
|
||||
unsigned block_threads = shared ? 1024 : 256;
|
||||
int n_mps = 0;
|
||||
dh::safe_cuda(cudaDeviceGetAttribute(&n_mps, cudaDevAttrMultiProcessorCount, device));
|
||||
int n_blocks_per_mp = 0;
|
||||
dh::safe_cuda(cudaOccupancyMaxActiveBlocksPerMultiprocessor
|
||||
(&n_blocks_per_mp, kernel, block_threads, smem_size));
|
||||
unsigned grid_size = n_blocks_per_mp * n_mps;
|
||||
|
||||
auto n_elements = d_ridx.size() * matrix.row_stride;
|
||||
dh::LaunchKernel {grid_size, block_threads, smem_size} (
|
||||
SharedMemHistKernel<GradientSumT>,
|
||||
matrix, d_ridx, histogram.data(), gpair.data(), n_elements,
|
||||
kernel, matrix, d_ridx, histogram.data(), gpair.data(), n_elements,
|
||||
rounding, shared);
|
||||
dh::safe_cuda(cudaGetLastError());
|
||||
}
|
||||
|
||||
template void BuildGradientHistogram<GradientPair>(
|
||||
@@ -172,13 +188,14 @@ template void BuildGradientHistogram<GradientPair>(
|
||||
common::Span<GradientPair const> gpair,
|
||||
common::Span<const uint32_t> ridx,
|
||||
common::Span<GradientPair> histogram,
|
||||
GradientPair rounding, bool shared);
|
||||
GradientPair rounding);
|
||||
|
||||
template void BuildGradientHistogram<GradientPairPrecise>(
|
||||
EllpackDeviceAccessor const& matrix,
|
||||
common::Span<GradientPair const> gpair,
|
||||
common::Span<const uint32_t> ridx,
|
||||
common::Span<GradientPairPrecise> histogram,
|
||||
GradientPairPrecise rounding, bool shared);
|
||||
GradientPairPrecise rounding);
|
||||
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
|
||||
@@ -22,7 +22,7 @@ void BuildGradientHistogram(EllpackDeviceAccessor const& matrix,
|
||||
common::Span<GradientPair const> gpair,
|
||||
common::Span<const uint32_t> ridx,
|
||||
common::Span<GradientSumT> histogram,
|
||||
GradientSumT rounding, bool shared);
|
||||
GradientSumT rounding);
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
|
||||
|
||||
Reference in New Issue
Block a user