Support half type from cupy. (#8487)

This commit is contained in:
Jiaming Yuan
2022-11-30 17:56:42 +08:00
committed by GitHub
parent addaa63732
commit 157e98edf7
3 changed files with 36 additions and 4 deletions

View File

@@ -96,7 +96,7 @@ struct ArrayInterfaceErrors {
*/
class ArrayInterfaceHandler {
public:
enum Type : std::int8_t { kF4, kF8, kF16, kI1, kI2, kI4, kI8, kU1, kU2, kU4, kU8 };
enum Type : std::int8_t { kF2, kF4, kF8, kF16, kI1, kI2, kI4, kI8, kU1, kU2, kU4, kU8 };
template <typename PtrType>
static PtrType GetPtrFromArrayData(Object::Map const &obj) {
@@ -300,6 +300,12 @@ class ArrayInterfaceHandler {
template <typename T, typename E = void>
struct ToDType;
// float
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600
template <>
struct ToDType<__half> {
static constexpr ArrayInterfaceHandler::Type kType = ArrayInterfaceHandler::kF2;
};
#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600
template <>
struct ToDType<float> {
static constexpr ArrayInterfaceHandler::Type kType = ArrayInterfaceHandler::kF4;
@@ -444,6 +450,12 @@ class ArrayInterface {
type = T::kF16;
CHECK(sizeof(long double) == 16)
<< "128-bit floating point is not supported on current platform.";
} else if (typestr[1] == 'f' && typestr[2] == '2') {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600
type = T::kF2;
#else
LOG(FATAL) << "Half type is not supported.";
#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600
} else if (typestr[1] == 'f' && typestr[2] == '4') {
type = T::kF4;
} else if (typestr[1] == 'f' && typestr[2] == '8') {
@@ -477,6 +489,14 @@ class ArrayInterface {
XGBOOST_HOST_DEV_INLINE decltype(auto) DispatchCall(Fn func) const {
using T = ArrayInterfaceHandler::Type;
switch (type) {
case T::kF2: {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600
return func(reinterpret_cast<__half const *>(data));
#else
SPAN_CHECK(false);
return func(reinterpret_cast<float const *>(data));
#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600
}
case T::kF4:
return func(reinterpret_cast<float const *>(data));
case T::kF8:
@@ -521,8 +541,18 @@ class ArrayInterface {
XGBOOST_DEVICE T operator()(Index &&...index) const {
static_assert(sizeof...(index) <= D, "Invalid index.");
return this->DispatchCall([=](auto const *p_values) -> T {
size_t offset = linalg::detail::Offset<0ul>(strides, 0ul, index...);
std::size_t offset = linalg::detail::Offset<0ul>(strides, 0ul, index...);
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600
// No operator defined for half -> size_t
using Type = std::conditional_t<
std::is_same<__half,
std::remove_cv_t<std::remove_pointer_t<decltype(p_values)>>>::value &&
std::is_same<std::size_t, std::remove_cv_t<T>>::value,
unsigned long long, T>; // NOLINT
return static_cast<T>(static_cast<Type>(p_values[offset]));
#else
return static_cast<T>(p_values[offset]);
#endif
});
}