support HIP for half in coll

This commit is contained in:
Hui Liu 2023-11-02 10:53:12 -07:00
parent 3af5dfd546
commit 51efb7442e

View File

@ -25,6 +25,8 @@ template <typename T>
bool constexpr IsFloatingPointV() { bool constexpr IsFloatingPointV() {
#if defined(XGBOOST_USE_CUDA) #if defined(XGBOOST_USE_CUDA)
return std::is_floating_point_v<T> || std::is_same_v<T, __half>; return std::is_floating_point_v<T> || std::is_same_v<T, __half>;
#elif defined(XGBOOST_USE_HIP) /* hack for HIP/Clang */
return std::is_floating_point_v<T> || (sizeof(T) == sizeof(unsigned short));
#else #else
return std::is_floating_point_v<T>; return std::is_floating_point_v<T>;
#endif // defined(XGBOOST_USE_CUDA) #endif // defined(XGBOOST_USE_CUDA)