From 6e2c5be83e29820ca32b82945f6ee7807ed07c8b Mon Sep 17 00:00:00 2001 From: amdsc21 <96135754+amdsc21@users.noreply.github.com> Date: Fri, 10 Mar 2023 04:36:04 +0100 Subject: [PATCH] finish array_interface.cu --- src/data/array_interface.cu | 5 +++++ src/data/array_interface.h | 6 +++--- src/data/array_interface.hip | 4 ++++ 3 files changed, 12 insertions(+), 3 deletions(-) diff --git a/src/data/array_interface.cu b/src/data/array_interface.cu index b1a80251e..875a10606 100644 --- a/src/data/array_interface.cu +++ b/src/data/array_interface.cu @@ -31,6 +31,8 @@ bool ArrayInterfaceHandler::IsCudaPtr(void const* ptr) { if (!ptr) { return false; } + +#if defined(XGBOOST_USE_CUDA) cudaPointerAttributes attr; auto err = cudaPointerGetAttributes(&attr, ptr); // reset error @@ -48,6 +50,9 @@ bool ArrayInterfaceHandler::IsCudaPtr(void const* ptr) { return true; } return true; +#elif defined(XGBOOST_USE_HIP) + return false; +#endif } else { // other errors, `cudaErrorNoDevice`, `cudaErrorInsufficientDriver` etc. return false; diff --git a/src/data/array_interface.h b/src/data/array_interface.h index 997bc4788..2a078ed60 100644 --- a/src/data/array_interface.h +++ b/src/data/array_interface.h @@ -458,11 +458,11 @@ class ArrayInterface { CHECK(sizeof(long double) == 16) << "128-bit floating point is not supported on current platform."; } else if (typestr[1] == 'f' && typestr[2] == '2') { -#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600) || defined(XGBOOST_USE_HIP) +#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600) || defined(__HIP_PLATFORM_AMD__) type = T::kF2; #else LOG(FATAL) << "Half type is not supported."; -#endif // (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600) || defined(XGBOOST_USE_HIP) +#endif // (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600) || defined(__HIP_PLATFORM_AMD__) } else if (typestr[1] == 'f' && typestr[2] == '4') { type = T::kF4; } else if (typestr[1] == 'f' && typestr[2] == '8') { @@ -508,7 +508,7 @@ class ArrayInterface { return func(reinterpret_cast(data)); case T::kF8: return func(reinterpret_cast(data)); -#ifdef __CUDA_ARCH__ +#if defined(__CUDA_ARCH__ ) || defined(__HIP_PLATFORM_AMD__) case T::kF16: { // CUDA device code doesn't support long double. SPAN_CHECK(false); diff --git a/src/data/array_interface.hip b/src/data/array_interface.hip index e69de29bb..b90160d91 100644 --- a/src/data/array_interface.hip +++ b/src/data/array_interface.hip @@ -0,0 +1,4 @@ + +#if defined(XGBOOST_USE_HIP) +#include "array_interface.cu" +#endif