From 4eb371b3f0f8866ed04663d36054c1f55e5218f9 Mon Sep 17 00:00:00 2001 From: Hui Liu <96135754+amdsc21@users.noreply.github.com> Date: Mon, 30 Oct 2023 17:10:06 -0700 Subject: [PATCH] unify cuda to hip --- src/common/algorithm.cuh | 5 +- src/common/common.h | 19 +------ src/common/cuda_to_hip.h | 95 ++++++++++++++++++--------------- src/common/device_helpers.hip.h | 2 +- src/data/array_interface.cu | 21 +------- 5 files changed, 60 insertions(+), 82 deletions(-) diff --git a/src/common/algorithm.cuh b/src/common/algorithm.cuh index 2d80c06d8..b5ffac2c1 100644 --- a/src/common/algorithm.cuh +++ b/src/common/algorithm.cuh @@ -32,6 +32,7 @@ namespace xgboost { namespace common { namespace detail { + // Wrapper around cub sort to define is_decending template @@ -56,13 +57,13 @@ static void DeviceSegmentedRadixSortKeys(CUDAContext const *ctx, void *d_temp_st end_bit, false, ctx->Stream(), debug_synchronous))); #elif defined(XGBOOST_USE_HIP) if (IS_DESCENDING) { - rocprim::segmented_radix_sort_pairs_desc(d_temp_storage, + rocprim::segmented_radix_sort_pairs_desc(d_temp_storage, temp_storage_bytes, d_keys_in, d_keys_out, nullptr, nullptr, num_items, num_segments, d_begin_offsets, d_end_offsets, begin_bit, end_bit, ctx->Stream(), debug_synchronous); } else { - rocprim::segmented_radix_sort_pairs(d_temp_storage, + rocprim::segmented_radix_sort_pairs(d_temp_storage, temp_storage_bytes, d_keys_in, d_keys_out, nullptr, nullptr, num_items, num_segments, d_begin_offsets, d_end_offsets, begin_bit, end_bit, ctx->Stream(), debug_synchronous); diff --git a/src/common/common.h b/src/common/common.h index 31fffb955..7cea0591f 100644 --- a/src/common/common.h +++ b/src/common/common.h @@ -26,6 +26,7 @@ #define WITH_CUDA() true #elif defined(__HIP_PLATFORM_AMD__) +#include "cuda_to_hip.h" #include #include @@ -38,7 +39,7 @@ #endif // defined(__CUDACC__) namespace dh { -#if defined(__CUDACC__) +#if defined(__CUDACC__) || defined(__HIP_PLATFORM_AMD__) /* * Error handling functions */ @@ -53,22 +54,6 @@ inline cudaError_t ThrowOnCudaError(cudaError_t code, const char *file, int line } return code; } - -#elif defined(__HIP_PLATFORM_AMD__) -/* - * Error handling functions - */ -#define safe_cuda(ans) ThrowOnCudaError((ans), __FILE__, __LINE__) - -inline hipError_t ThrowOnCudaError(hipError_t code, const char *file, int line) -{ - if (code != hipSuccess) { - LOG(FATAL) << thrust::system_error(code, thrust::hip_category(), - std::string{file} + ": " + // NOLINT - std::to_string(line)).what(); - } - return code; -} #endif } // namespace dh diff --git a/src/common/cuda_to_hip.h b/src/common/cuda_to_hip.h index 2f9a5b4d1..202b31b1d 100644 --- a/src/common/cuda_to_hip.h +++ b/src/common/cuda_to_hip.h @@ -5,64 +5,75 @@ #if defined(XGBOOST_USE_HIP) -#define cudaSuccess hipSuccess -#define cudaError hipError_t -#define cudaGetLastError hipGetLastError -#define cudaPeekAtLastError hipPeekAtLastError +#define cudaSuccess hipSuccess +#define cudaError hipError_t +#define cudaError_t hipError_t +#define cudaGetLastError hipGetLastError +#define cudaPeekAtLastError hipPeekAtLastError +#define cudaErrorInvalidValue hipErrorInvalidValue -#define cudaStream_t hipStream_t -#define cudaStreamCreate hipStreamCreate -#define cudaStreamCreateWithFlags hipStreamCreateWithFlags -#define cudaStreamDestroy hipStreamDestroy -#define cudaStreamWaitEvent hipStreamWaitEvent -#define cudaStreamSynchronize hipStreamSynchronize -#define cudaStreamPerThread hipStreamPerThread -#define cudaStreamLegacy hipStreamLegacy +#define cudaStream_t hipStream_t +#define cudaStreamCreate hipStreamCreate +#define cudaStreamCreateWithFlags hipStreamCreateWithFlags +#define cudaStreamDestroy hipStreamDestroy +#define cudaStreamWaitEvent hipStreamWaitEvent +#define cudaStreamSynchronize hipStreamSynchronize +#define cudaStreamPerThread hipStreamPerThread -#define cudaEvent_t hipEvent_t -#define cudaEventCreate hipEventCreate -#define cudaEventCreateWithFlags hipEventCreateWithFlags -#define cudaEventDestroy hipEventDestroy +/* not compatible */ +#define cudaStreamLegacy hipStreamDefault +#define hipStreamLegacy hipStreamDefault -#define cudaGetDevice hipGetDevice -#define cudaSetDevice hipSetDevice -#define cudaGetDeviceCount hipGetDeviceCount -#define cudaDeviceSynchronize hipDeviceSynchronize +#define cudaEvent_t hipEvent_t +#define cudaEventCreate hipEventCreate +#define cudaEventCreateWithFlags hipEventCreateWithFlags +#define cudaEventDestroy hipEventDestroy -#define cudaGetDeviceProperties hipGetDeviceProperties -#define cudaDeviceGetAttribute hipDeviceGetAttribute +#define cudaGetDevice hipGetDevice +#define cudaSetDevice hipSetDevice +#define cudaGetDeviceCount hipGetDeviceCount +#define cudaDeviceSynchronize hipDeviceSynchronize -#define cudaMallocHost hipMallocHost -#define cudaFreeHost hipFreeHost -#define cudaMalloc hipMalloc -#define cudaFree hipFree +#define cudaGetDeviceProperties hipGetDeviceProperties +#define cudaDeviceGetAttribute hipDeviceGetAttribute -#define cudaMemcpy hipMemcpy -#define cudaMemcpyAsync hipMemcpyAsync -#define cudaMemcpyDefault hipMemcpyDefault -#define cudaMemcpyHostToDevice hipMemcpyHostToDevice -#define cudaMemcpyHostToHost hipMemcpyHostToHost -#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost -#define cudaMemcpyDeviceToDevice hipMemcpyDeviceToDevice -#define cudaMemsetAsync hipMemsetAsync -#define cudaMemset hipMemset +#define cudaMallocHost hipMallocHost +#define cudaFreeHost hipFreeHost +#define cudaMalloc hipMalloc +#define cudaFree hipFree -#define cudaPointerAttributes hipPointerAttribute_t -#define cudaPointerGetAttributes hipPointerGetAttributes +#define cudaMemcpy hipMemcpy +#define cudaMemcpyAsync hipMemcpyAsync +#define cudaMemcpyDefault hipMemcpyDefault +#define cudaMemcpyHostToDevice hipMemcpyHostToDevice +#define cudaMemcpyHostToHost hipMemcpyHostToHost +#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost +#define cudaMemcpyDeviceToDevice hipMemcpyDeviceToDevice +#define cudaMemsetAsync hipMemsetAsync +#define cudaMemset hipMemset -#define cudaMemGetInfo hipMemGetInfo -#define cudaFuncSetAttribute hipFuncSetAttribute +#define cudaPointerAttributes hipPointerAttribute_t +#define cudaPointerGetAttributes hipPointerGetAttributes -#define cudaDevAttrMultiProcessorCount hipDeviceAttributeMultiprocessorCount -#define cudaOccupancyMaxActiveBlocksPerMultiprocessor hipOccupancyMaxActiveBlocksPerMultiprocessor +/* hipMemoryTypeUnregistered not supported */ +#define cudaMemoryTypeUnregistered hipMemoryTypeUnified +#define cudaMemoryTypeHost hipMemoryTypeHost +#define cudaMemoryTypeUnified hipMemoryTypeUnified + +#define cudaMemGetInfo hipMemGetInfo +#define cudaFuncSetAttribute hipFuncSetAttribute + +#define cudaDevAttrMultiProcessorCount hipDeviceAttributeMultiprocessorCount +#define cudaOccupancyMaxActiveBlocksPerMultiprocessor hipOccupancyMaxActiveBlocksPerMultiprocessor namespace thrust { namespace hip { } + + namespace cuda = thrust::hip; } namespace thrust { - namespace cuda = thrust::hip; #define cuda_category hip_category } diff --git a/src/common/device_helpers.hip.h b/src/common/device_helpers.hip.h index 9f55d6ef8..fcfe2bdd4 100644 --- a/src/common/device_helpers.hip.h +++ b/src/common/device_helpers.hip.h @@ -1109,7 +1109,7 @@ inline CUDAStreamView DefaultStream() { #ifdef HIP_API_PER_THREAD_DEFAULT_STREAM return CUDAStreamView{hipStreamPerThread}; #else - return CUDAStreamView{hipStreamDefault}; + return CUDAStreamView{hipStreamLegacy}; #endif } diff --git a/src/data/array_interface.cu b/src/data/array_interface.cu index b0004c300..b29987ff4 100644 --- a/src/data/array_interface.cu +++ b/src/data/array_interface.cu @@ -42,7 +42,7 @@ bool ArrayInterfaceHandler::IsCudaPtr(void const* ptr) { return false; } -#if defined(XGBOOST_USE_CUDA) +#if defined(XGBOOST_USE_CUDA) || defined(XGBOOST_USE_HIP) cudaPointerAttributes attr; auto err = cudaPointerGetAttributes(&attr, ptr); // reset error @@ -64,25 +64,6 @@ bool ArrayInterfaceHandler::IsCudaPtr(void const* ptr) { // other errors, `cudaErrorNoDevice`, `cudaErrorInsufficientDriver` etc. return false; } -#elif defined(XGBOOST_USE_HIP) - hipPointerAttribute_t attr; - auto err = hipPointerGetAttributes(&attr, ptr); - // reset error - CHECK_EQ(err, hipGetLastError()); - if (err == hipErrorInvalidValue) { - return false; - } else if (err == hipSuccess) { - switch (attr.memoryType) { - case hipMemoryTypeUnified: - case hipMemoryTypeHost: - return false; - default: - return true; - } - return true; - } else { - return false; - } #endif } } // namespace xgboost