finish array_interface.cu
This commit is contained in:
parent
185dbce21f
commit
6e2c5be83e
@ -31,6 +31,8 @@ bool ArrayInterfaceHandler::IsCudaPtr(void const* ptr) {
|
|||||||
if (!ptr) {
|
if (!ptr) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if defined(XGBOOST_USE_CUDA)
|
||||||
cudaPointerAttributes attr;
|
cudaPointerAttributes attr;
|
||||||
auto err = cudaPointerGetAttributes(&attr, ptr);
|
auto err = cudaPointerGetAttributes(&attr, ptr);
|
||||||
// reset error
|
// reset error
|
||||||
@ -48,6 +50,9 @@ bool ArrayInterfaceHandler::IsCudaPtr(void const* ptr) {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
|
#elif defined(XGBOOST_USE_HIP)
|
||||||
|
return false;
|
||||||
|
#endif
|
||||||
} else {
|
} else {
|
||||||
// other errors, `cudaErrorNoDevice`, `cudaErrorInsufficientDriver` etc.
|
// other errors, `cudaErrorNoDevice`, `cudaErrorInsufficientDriver` etc.
|
||||||
return false;
|
return false;
|
||||||
|
|||||||
@ -458,11 +458,11 @@ class ArrayInterface {
|
|||||||
CHECK(sizeof(long double) == 16)
|
CHECK(sizeof(long double) == 16)
|
||||||
<< "128-bit floating point is not supported on current platform.";
|
<< "128-bit floating point is not supported on current platform.";
|
||||||
} else if (typestr[1] == 'f' && typestr[2] == '2') {
|
} else if (typestr[1] == 'f' && typestr[2] == '2') {
|
||||||
#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600) || defined(XGBOOST_USE_HIP)
|
#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600) || defined(__HIP_PLATFORM_AMD__)
|
||||||
type = T::kF2;
|
type = T::kF2;
|
||||||
#else
|
#else
|
||||||
LOG(FATAL) << "Half type is not supported.";
|
LOG(FATAL) << "Half type is not supported.";
|
||||||
#endif // (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600) || defined(XGBOOST_USE_HIP)
|
#endif // (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600) || defined(__HIP_PLATFORM_AMD__)
|
||||||
} else if (typestr[1] == 'f' && typestr[2] == '4') {
|
} else if (typestr[1] == 'f' && typestr[2] == '4') {
|
||||||
type = T::kF4;
|
type = T::kF4;
|
||||||
} else if (typestr[1] == 'f' && typestr[2] == '8') {
|
} else if (typestr[1] == 'f' && typestr[2] == '8') {
|
||||||
@ -508,7 +508,7 @@ class ArrayInterface {
|
|||||||
return func(reinterpret_cast<float const *>(data));
|
return func(reinterpret_cast<float const *>(data));
|
||||||
case T::kF8:
|
case T::kF8:
|
||||||
return func(reinterpret_cast<double const *>(data));
|
return func(reinterpret_cast<double const *>(data));
|
||||||
#ifdef __CUDA_ARCH__
|
#if defined(__CUDA_ARCH__ ) || defined(__HIP_PLATFORM_AMD__)
|
||||||
case T::kF16: {
|
case T::kF16: {
|
||||||
// CUDA device code doesn't support long double.
|
// CUDA device code doesn't support long double.
|
||||||
SPAN_CHECK(false);
|
SPAN_CHECK(false);
|
||||||
|
|||||||
@ -0,0 +1,4 @@
|
|||||||
|
|
||||||
|
#if defined(XGBOOST_USE_HIP)
|
||||||
|
#include "array_interface.cu"
|
||||||
|
#endif
|
||||||
Loading…
x
Reference in New Issue
Block a user