diff --git a/src/data/adapter.h b/src/data/adapter.h index 5633f1605..34e918cd2 100644 --- a/src/data/adapter.h +++ b/src/data/adapter.h @@ -327,8 +327,8 @@ class CSRArrayAdapterBatch : public detail::NoMetaInfo { : indices_{std::move(indices)}, values_{std::move(values)}, ridx_{ridx}, offset_{offset} {} - COOTuple GetElement(size_t idx) const { - return {ridx_, TypedIndex{indices_}(offset_ + idx), values_(offset_ + idx)}; + COOTuple GetElement(std::size_t idx) const { + return {ridx_, TypedIndex{indices_}(offset_ + idx), values_(offset_ + idx)}; } size_t Size() const { diff --git a/src/data/array_interface.h b/src/data/array_interface.h index e75510806..d8aa504df 100644 --- a/src/data/array_interface.h +++ b/src/data/array_interface.h @@ -96,7 +96,7 @@ struct ArrayInterfaceErrors { */ class ArrayInterfaceHandler { public: - enum Type : std::int8_t { kF4, kF8, kF16, kI1, kI2, kI4, kI8, kU1, kU2, kU4, kU8 }; + enum Type : std::int8_t { kF2, kF4, kF8, kF16, kI1, kI2, kI4, kI8, kU1, kU2, kU4, kU8 }; template static PtrType GetPtrFromArrayData(Object::Map const &obj) { @@ -300,6 +300,12 @@ class ArrayInterfaceHandler { template struct ToDType; // float +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600 +template <> +struct ToDType<__half> { + static constexpr ArrayInterfaceHandler::Type kType = ArrayInterfaceHandler::kF2; +}; +#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600 template <> struct ToDType { static constexpr ArrayInterfaceHandler::Type kType = ArrayInterfaceHandler::kF4; @@ -444,6 +450,12 @@ class ArrayInterface { type = T::kF16; CHECK(sizeof(long double) == 16) << "128-bit floating point is not supported on current platform."; + } else if (typestr[1] == 'f' && typestr[2] == '2') { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600 + type = T::kF2; +#else + LOG(FATAL) << "Half type is not supported."; +#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600 } else if (typestr[1] == 'f' && typestr[2] == '4') { type = T::kF4; } else if (typestr[1] == 'f' && typestr[2] == '8') { @@ -477,6 +489,14 @@ class ArrayInterface { XGBOOST_HOST_DEV_INLINE decltype(auto) DispatchCall(Fn func) const { using T = ArrayInterfaceHandler::Type; switch (type) { + case T::kF2: { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600 + return func(reinterpret_cast<__half const *>(data)); +#else + SPAN_CHECK(false); + return func(reinterpret_cast(data)); +#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600 + } case T::kF4: return func(reinterpret_cast(data)); case T::kF8: @@ -521,8 +541,18 @@ class ArrayInterface { XGBOOST_DEVICE T operator()(Index &&...index) const { static_assert(sizeof...(index) <= D, "Invalid index."); return this->DispatchCall([=](auto const *p_values) -> T { - size_t offset = linalg::detail::Offset<0ul>(strides, 0ul, index...); + std::size_t offset = linalg::detail::Offset<0ul>(strides, 0ul, index...); +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600 + // No operator defined for half -> size_t + using Type = std::conditional_t< + std::is_same<__half, + std::remove_cv_t>>::value && + std::is_same>::value, + unsigned long long, T>; // NOLINT + return static_cast(static_cast(p_values[offset])); +#else return static_cast(p_values[offset]); +#endif }); } diff --git a/tests/python-gpu/test_from_cupy.py b/tests/python-gpu/test_from_cupy.py index 841ab7d34..77592747e 100644 --- a/tests/python-gpu/test_from_cupy.py +++ b/tests/python-gpu/test_from_cupy.py @@ -42,6 +42,8 @@ def dmatrix_from_cupy(input_type, DMatrixT, missing=np.NAN): def _test_from_cupy(DMatrixT): '''Test constructing DMatrix from cupy''' import cupy as cp + + dmatrix_from_cupy(np.float16, DMatrixT, np.NAN) dmatrix_from_cupy(np.float32, DMatrixT, np.NAN) dmatrix_from_cupy(np.float64, DMatrixT, np.NAN)