finish array_interface.cu

This commit is contained in:
amdsc21 2023-03-10 04:36:04 +01:00
parent 185dbce21f
commit 6e2c5be83e
3 changed files with 12 additions and 3 deletions

View File

@ -31,6 +31,8 @@ bool ArrayInterfaceHandler::IsCudaPtr(void const* ptr) {
if (!ptr) { if (!ptr) {
return false; return false;
} }
#if defined(XGBOOST_USE_CUDA)
cudaPointerAttributes attr; cudaPointerAttributes attr;
auto err = cudaPointerGetAttributes(&attr, ptr); auto err = cudaPointerGetAttributes(&attr, ptr);
// reset error // reset error
@ -48,6 +50,9 @@ bool ArrayInterfaceHandler::IsCudaPtr(void const* ptr) {
return true; return true;
} }
return true; return true;
#elif defined(XGBOOST_USE_HIP)
return false;
#endif
} else { } else {
// other errors, `cudaErrorNoDevice`, `cudaErrorInsufficientDriver` etc. // other errors, `cudaErrorNoDevice`, `cudaErrorInsufficientDriver` etc.
return false; return false;

View File

@ -458,11 +458,11 @@ class ArrayInterface {
CHECK(sizeof(long double) == 16) CHECK(sizeof(long double) == 16)
<< "128-bit floating point is not supported on current platform."; << "128-bit floating point is not supported on current platform.";
} else if (typestr[1] == 'f' && typestr[2] == '2') { } else if (typestr[1] == 'f' && typestr[2] == '2') {
#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600) || defined(XGBOOST_USE_HIP) #if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600) || defined(__HIP_PLATFORM_AMD__)
type = T::kF2; type = T::kF2;
#else #else
LOG(FATAL) << "Half type is not supported."; LOG(FATAL) << "Half type is not supported.";
#endif // (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600) || defined(XGBOOST_USE_HIP) #endif // (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600) || defined(__HIP_PLATFORM_AMD__)
} else if (typestr[1] == 'f' && typestr[2] == '4') { } else if (typestr[1] == 'f' && typestr[2] == '4') {
type = T::kF4; type = T::kF4;
} else if (typestr[1] == 'f' && typestr[2] == '8') { } else if (typestr[1] == 'f' && typestr[2] == '8') {
@ -508,7 +508,7 @@ class ArrayInterface {
return func(reinterpret_cast<float const *>(data)); return func(reinterpret_cast<float const *>(data));
case T::kF8: case T::kF8:
return func(reinterpret_cast<double const *>(data)); return func(reinterpret_cast<double const *>(data));
#ifdef __CUDA_ARCH__ #if defined(__CUDA_ARCH__ ) || defined(__HIP_PLATFORM_AMD__)
case T::kF16: { case T::kF16: {
// CUDA device code doesn't support long double. // CUDA device code doesn't support long double.
SPAN_CHECK(false); SPAN_CHECK(false);

View File

@ -0,0 +1,4 @@
#if defined(XGBOOST_USE_HIP)
#include "array_interface.cu"
#endif