Retry switching to per-thread default stream (#9416)

This commit is contained in:
Rong Ou
2023-07-25 16:09:12 -07:00
committed by GitHub
parent 54579da4d7
commit 7579905e18
9 changed files with 37 additions and 36 deletions

View File

@@ -480,7 +480,7 @@ struct XGBCachingDeviceAllocatorImpl : XGBBaseDeviceAllocator<T> {
cub::CachingDeviceAllocator& GetGlobalCachingAllocator() {
// Configure allocator with maximum cached bin size of ~1GB and no limit on
// maximum cached bytes
static cub::CachingDeviceAllocator *allocator = new cub::CachingDeviceAllocator(2, 9, 29);
thread_local cub::CachingDeviceAllocator *allocator = new cub::CachingDeviceAllocator(2, 9, 29);
return *allocator;
}
pointer allocate(size_t n) { // NOLINT
@@ -1176,7 +1176,13 @@ inline void CUDAEvent::Record(CUDAStreamView stream) { // NOLINT
dh::safe_cuda(cudaEventRecord(event_, cudaStream_t{stream}));
}
inline CUDAStreamView DefaultStream() { return CUDAStreamView{cudaStreamLegacy}; }
inline CUDAStreamView DefaultStream() {
#ifdef CUDA_API_PER_THREAD_DEFAULT_STREAM
return CUDAStreamView{cudaStreamPerThread};
#else
return CUDAStreamView{cudaStreamLegacy};
#endif
}
class CUDAStream {
cudaStream_t stream_;

View File

@@ -134,12 +134,12 @@ void LaunchGetColumnSizeKernel(std::int32_t device, IterSpan<BatchIt> batch_iter
CHECK(!force_use_u64);
auto kernel = GetColumnSizeSharedMemKernel<kBlockThreads, std::uint32_t, BatchIt>;
auto grid_size = EstimateGridSize<kBlockThreads>(device, kernel, required_shared_memory);
dh::LaunchKernel{grid_size, kBlockThreads, required_shared_memory, dh::DefaultStream()}(
dh::LaunchKernel{grid_size, kBlockThreads, required_shared_memory}(
kernel, batch_iter, is_valid, out_column_size);
} else {
auto kernel = GetColumnSizeSharedMemKernel<kBlockThreads, std::size_t, BatchIt>;
auto grid_size = EstimateGridSize<kBlockThreads>(device, kernel, required_shared_memory);
dh::LaunchKernel{grid_size, kBlockThreads, required_shared_memory, dh::DefaultStream()}(
dh::LaunchKernel{grid_size, kBlockThreads, required_shared_memory}(
kernel, batch_iter, is_valid, out_column_size);
}
} else {