add HIP flags

This commit is contained in:
amdsc21
2023-03-08 01:33:38 +01:00
parent 6b7be96373
commit f5f800c80d
7 changed files with 20 additions and 20 deletions

View File

@@ -302,12 +302,12 @@ class ArrayInterfaceHandler {
template <typename T, typename E = void>
struct ToDType;
// float
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600
#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600) || defined(__HIP_PLATFORM_AMD__)
template <>
struct ToDType<__half> {
static constexpr ArrayInterfaceHandler::Type kType = ArrayInterfaceHandler::kF2;
};
#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600
#endif // (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600) || defined(__HIP_PLATFORM_AMD__)
template <>
struct ToDType<float> {
static constexpr ArrayInterfaceHandler::Type kType = ArrayInterfaceHandler::kF4;
@@ -356,10 +356,10 @@ struct ToDType<int64_t> {
static constexpr ArrayInterfaceHandler::Type kType = ArrayInterfaceHandler::kI8;
};
#if !defined(XGBOOST_USE_CUDA)
#if !defined(XGBOOST_USE_CUDA) && !defined(XGBOOST_USE_HIP)
inline void ArrayInterfaceHandler::SyncCudaStream(int64_t) { common::AssertGPUSupport(); }
inline bool ArrayInterfaceHandler::IsCudaPtr(void const *) { return false; }
#endif // !defined(XGBOOST_USE_CUDA)
#endif // !defined(XGBOOST_USE_CUDA) && !defined(XGBOOST_USE_HIP)
/**
* \brief A type erased view over __array_interface__ protocol defined by numpy
@@ -458,11 +458,11 @@ class ArrayInterface {
CHECK(sizeof(long double) == 16)
<< "128-bit floating point is not supported on current platform.";
} else if (typestr[1] == 'f' && typestr[2] == '2') {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600
#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600) || defined(XGBOOST_USE_HIP)
type = T::kF2;
#else
LOG(FATAL) << "Half type is not supported.";
#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600
#endif // (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600) || defined(XGBOOST_USE_HIP)
} else if (typestr[1] == 'f' && typestr[2] == '4') {
type = T::kF4;
} else if (typestr[1] == 'f' && typestr[2] == '8') {
@@ -497,12 +497,12 @@ class ArrayInterface {
using T = ArrayInterfaceHandler::Type;
switch (type) {
case T::kF2: {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600
#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600) || defined(__HIP_PLATFORM_AMD__)
return func(reinterpret_cast<__half const *>(data));
#else
SPAN_CHECK(false);
return func(reinterpret_cast<float const *>(data));
#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600
#endif // (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600) || defined(__HIP_PLATFORM_AMD__)
}
case T::kF4:
return func(reinterpret_cast<float const *>(data));
@@ -555,7 +555,7 @@ class ArrayInterface {
static_assert(sizeof...(index) <= D, "Invalid index.");
return this->DispatchCall([=](auto const *p_values) -> T {
std::size_t offset = linalg::detail::Offset<0ul>(strides, 0ul, index...);
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600
#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600) || defined(__HIP_PLATFORM_AMD__)
// No operator defined for half -> size_t
using Type = std::conditional_t<
std::is_same<__half,