From 0ed5d3c849bed2198ca0d5582064fe02f63b59b7 Mon Sep 17 00:00:00 2001 From: amdsc21 <96135754+amdsc21@users.noreply.github.com> Date: Thu, 9 Mar 2023 21:28:37 +0100 Subject: [PATCH] finished histogram.cu --- src/common/bitfield.h | 6 ++++- src/common/compressed_iterator.h | 4 +++- src/data/ellpack_page.cuh | 6 +++++ src/tree/gpu_hist/histogram.cu | 34 +++++++++++++++++++++++++++ src/tree/gpu_hist/histogram.hip | 4 ++++ src/tree/gpu_hist/row_partitioner.cuh | 17 ++++++++++++++ 6 files changed, 69 insertions(+), 2 deletions(-) diff --git a/src/common/bitfield.h b/src/common/bitfield.h index 0c726f70f..3aef1cb36 100644 --- a/src/common/bitfield.h +++ b/src/common/bitfield.h @@ -13,10 +13,14 @@ #include #include -#if defined(__CUDACC__) || defined(__HIP_PLATFORM_AMD__) +#if defined(__CUDACC__) #include #include #include "device_helpers.cuh" +#elif defined(__HIP_PLATFORM_AMD__) +#include +#include +#include "device_helpers.hip.h" #endif // defined(__CUDACC__) || defined(__HIP_PLATFORM_AMD__) #include "xgboost/span.h" diff --git a/src/common/compressed_iterator.h b/src/common/compressed_iterator.h index 9e7b7b22a..eee08c488 100644 --- a/src/common/compressed_iterator.h +++ b/src/common/compressed_iterator.h @@ -11,8 +11,10 @@ #include "common.h" -#if defined(__CUDACC__) || defined(__HIP_PLATFORM_AMD__) +#if defined(__CUDACC__) #include "device_helpers.cuh" +#elif defined(__HIP_PLATFORM_AMD__) +#include "device_helpers.hip.h" #endif // __CUDACC__ || __HIP_PLATFORM_AMD__ namespace xgboost { diff --git a/src/data/ellpack_page.cuh b/src/data/ellpack_page.cuh index faf44b3b6..807ee0ea6 100644 --- a/src/data/ellpack_page.cuh +++ b/src/data/ellpack_page.cuh @@ -8,7 +8,13 @@ #include #include "../common/compressed_iterator.h" + +#if defined(XGBOOST_USE_CUDA) #include "../common/device_helpers.cuh" +#elif defined(XGBOOST_USE_HIP) +#include "../common/device_helpers.hip.h" +#endif + #include "../common/hist_util.h" #include "../common/categorical.h" #include diff --git a/src/tree/gpu_hist/histogram.cu b/src/tree/gpu_hist/histogram.cu index 489c8d6f7..985b52c8f 100644 --- a/src/tree/gpu_hist/histogram.cu +++ b/src/tree/gpu_hist/histogram.cu @@ -9,7 +9,13 @@ #include #include "../../common/deterministic.cuh" + +#if defined(XGBOOST_USE_CUDA) #include "../../common/device_helpers.cuh" +#elif defined(XGBOOST_USE_HIP) +#include "../../common/device_helpers.hip.h" +#endif + #include "../../data/ellpack_page.cuh" #include "histogram.cuh" #include "row_partitioner.cuh" @@ -59,8 +65,14 @@ GradientQuantiser::GradientQuantiser(common::Span gpair) { thrust::device_ptr gpair_beg{gpair.data()}; auto beg = thrust::make_transform_iterator(gpair_beg, Clip()); +#if defined(XGBOOST_USE_CUDA) Pair p = dh::Reduce(thrust::cuda::par(alloc), beg, beg + gpair.size(), Pair{}, thrust::plus{}); +#elif defined(XGBOOST_USE_HIP) + Pair p = + dh::Reduce(thrust::hip::par(alloc), beg, beg + gpair.size(), Pair{}, thrust::plus{}); +#endif + // Treat pair as array of 4 primitive types to allreduce using ReduceT = typename decltype(p.first)::ValueT; static_assert(sizeof(Pair) == sizeof(ReduceT) * 4, "Expected to reduce four elements."); @@ -258,7 +270,13 @@ void BuildGradientHistogram(CUDAContext const* ctx, EllpackDeviceAccessor const& bool force_global_memory) { // decide whether to use shared memory int device = 0; + +#if defined(XGBOOST_USE_CUDA) dh::safe_cuda(cudaGetDevice(&device)); +#elif defined(XGBOOST_USE_HIP) + dh::safe_cuda(hipGetDevice(&device)); +#endif + // opt into maximum shared memory for the kernel if necessary size_t max_shared_memory = dh::MaxSharedMemoryOptin(device); @@ -273,16 +291,28 @@ void BuildGradientHistogram(CUDAContext const* ctx, EllpackDeviceAccessor const& auto runit = [&, kMinItemsPerBlock = kItemsPerTile](auto kernel) { if (shared) { +#if defined(XGBOOST_USE_CUDA) dh::safe_cuda(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_memory)); +#elif defined(XGBOOST_USE_HIP) + dh::safe_cuda(hipFuncSetAttribute((const void *)kernel, hipFuncAttributeMaxDynamicSharedMemorySize, + max_shared_memory)); +#endif } // determine the launch configuration int num_groups = feature_groups.NumGroups(); int n_mps = 0; + +#if defined(XGBOOST_USE_CUDA) dh::safe_cuda(cudaDeviceGetAttribute(&n_mps, cudaDevAttrMultiProcessorCount, device)); int n_blocks_per_mp = 0; dh::safe_cuda(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&n_blocks_per_mp, kernel, +#elif defined(XGBOOST_USE_HIP) + dh::safe_cuda(hipDeviceGetAttribute(&n_mps, hipDeviceAttributeMultiprocessorCount, device)); + int n_blocks_per_mp = 0; + dh::safe_cuda(hipOccupancyMaxActiveBlocksPerMultiprocessor(&n_blocks_per_mp, kernel, +#endif kBlockThreads, smem_size)); // This gives the number of blocks to keep the device occupied // Use this as the maximum number of blocks @@ -311,7 +341,11 @@ void BuildGradientHistogram(CUDAContext const* ctx, EllpackDeviceAccessor const& runit(SharedMemHistKernel); } +#if defined(XGBOOST_USE_CUDA) dh::safe_cuda(cudaGetLastError()); +#elif defined(XGBOOST_USE_HIP) + dh::safe_cuda(hipGetLastError()); +#endif } } // namespace tree diff --git a/src/tree/gpu_hist/histogram.hip b/src/tree/gpu_hist/histogram.hip index e69de29bb..d505b3fd3 100644 --- a/src/tree/gpu_hist/histogram.hip +++ b/src/tree/gpu_hist/histogram.hip @@ -0,0 +1,4 @@ + +#if defined(XGBOOST_USE_HIP) +#include "histogram.cu" +#endif diff --git a/src/tree/gpu_hist/row_partitioner.cuh b/src/tree/gpu_hist/row_partitioner.cuh index 8a9fc53d8..acacc40e8 100644 --- a/src/tree/gpu_hist/row_partitioner.cuh +++ b/src/tree/gpu_hist/row_partitioner.cuh @@ -7,7 +7,12 @@ #include #include +#if defined(XGBOOST_USE_CUDA) #include "../../common/device_helpers.cuh" +#elif defined(XGBOOST_USE_HIP) +#include "../../common/device_helpers.hip.h" +#endif + #include "xgboost/base.h" #include "xgboost/context.h" #include "xgboost/task.h" @@ -140,13 +145,25 @@ void SortPositionBatch(common::Span> d_batch_info, }); size_t temp_bytes = 0; if (tmp->empty()) { +#if defined(XGBOOST_USE_CUDA) cub::DeviceScan::InclusiveScan(nullptr, temp_bytes, input_iterator, discard_write_iterator, IndexFlagOp(), total_rows, stream); +#elif defined(XGBOOST_USE_HIP) + rocprim::inclusive_scan(nullptr, temp_bytes, input_iterator, discard_write_iterator, + total_rows, IndexFlagOp(), stream); +#endif + tmp->resize(temp_bytes); } temp_bytes = tmp->size(); + +#if defined(XGBOOST_USE_CUDA) cub::DeviceScan::InclusiveScan(tmp->data().get(), temp_bytes, input_iterator, discard_write_iterator, IndexFlagOp(), total_rows, stream); +#elif defined(XGBOOST_USE_HIP) + rocprim::inclusive_scan(tmp->data().get(), temp_bytes, input_iterator, discard_write_iterator, + total_rows, IndexFlagOp(), stream); +#endif constexpr int kBlockSize = 256;