finish array_interface.cu

This commit is contained in:
amdsc21 2023-03-10 04:36:04 +01:00
parent 185dbce21f
commit 6e2c5be83e
3 changed files with 12 additions and 3 deletions

View File

@ -31,6 +31,8 @@ bool ArrayInterfaceHandler::IsCudaPtr(void const* ptr) {
if (!ptr) {
return false;
}
#if defined(XGBOOST_USE_CUDA)
cudaPointerAttributes attr;
auto err = cudaPointerGetAttributes(&attr, ptr);
// reset error
@ -48,6 +50,9 @@ bool ArrayInterfaceHandler::IsCudaPtr(void const* ptr) {
return true;
}
return true;
#elif defined(XGBOOST_USE_HIP)
return false;
#endif
} else {
// other errors, `cudaErrorNoDevice`, `cudaErrorInsufficientDriver` etc.
return false;

View File

@ -458,11 +458,11 @@ class ArrayInterface {
CHECK(sizeof(long double) == 16)
<< "128-bit floating point is not supported on current platform.";
} else if (typestr[1] == 'f' && typestr[2] == '2') {
#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600) || defined(XGBOOST_USE_HIP)
#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600) || defined(__HIP_PLATFORM_AMD__)
type = T::kF2;
#else
LOG(FATAL) << "Half type is not supported.";
#endif // (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600) || defined(XGBOOST_USE_HIP)
#endif // (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600) || defined(__HIP_PLATFORM_AMD__)
} else if (typestr[1] == 'f' && typestr[2] == '4') {
type = T::kF4;
} else if (typestr[1] == 'f' && typestr[2] == '8') {
@ -508,7 +508,7 @@ class ArrayInterface {
return func(reinterpret_cast<float const *>(data));
case T::kF8:
return func(reinterpret_cast<double const *>(data));
#ifdef __CUDA_ARCH__
#if defined(__CUDA_ARCH__ ) || defined(__HIP_PLATFORM_AMD__)
case T::kF16: {
// CUDA device code doesn't support long double.
SPAN_CHECK(false);

View File

@ -0,0 +1,4 @@
#if defined(XGBOOST_USE_HIP)
#include "array_interface.cu"
#endif