finish array_interface.cu
This commit is contained in:
@@ -458,11 +458,11 @@ class ArrayInterface {
|
||||
CHECK(sizeof(long double) == 16)
|
||||
<< "128-bit floating point is not supported on current platform.";
|
||||
} else if (typestr[1] == 'f' && typestr[2] == '2') {
|
||||
#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600) || defined(XGBOOST_USE_HIP)
|
||||
#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600) || defined(__HIP_PLATFORM_AMD__)
|
||||
type = T::kF2;
|
||||
#else
|
||||
LOG(FATAL) << "Half type is not supported.";
|
||||
#endif // (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600) || defined(XGBOOST_USE_HIP)
|
||||
#endif // (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600) || defined(__HIP_PLATFORM_AMD__)
|
||||
} else if (typestr[1] == 'f' && typestr[2] == '4') {
|
||||
type = T::kF4;
|
||||
} else if (typestr[1] == 'f' && typestr[2] == '8') {
|
||||
@@ -508,7 +508,7 @@ class ArrayInterface {
|
||||
return func(reinterpret_cast<float const *>(data));
|
||||
case T::kF8:
|
||||
return func(reinterpret_cast<double const *>(data));
|
||||
#ifdef __CUDA_ARCH__
|
||||
#if defined(__CUDA_ARCH__ ) || defined(__HIP_PLATFORM_AMD__)
|
||||
case T::kF16: {
|
||||
// CUDA device code doesn't support long double.
|
||||
SPAN_CHECK(false);
|
||||
|
||||
Reference in New Issue
Block a user