Support half type from cupy. (#8487)
This commit is contained in:
parent
addaa63732
commit
157e98edf7
@ -327,8 +327,8 @@ class CSRArrayAdapterBatch : public detail::NoMetaInfo {
|
||||
: indices_{std::move(indices)}, values_{std::move(values)}, ridx_{ridx},
|
||||
offset_{offset} {}
|
||||
|
||||
COOTuple GetElement(size_t idx) const {
|
||||
return {ridx_, TypedIndex<size_t, 1>{indices_}(offset_ + idx), values_(offset_ + idx)};
|
||||
COOTuple GetElement(std::size_t idx) const {
|
||||
return {ridx_, TypedIndex<std::size_t, 1>{indices_}(offset_ + idx), values_(offset_ + idx)};
|
||||
}
|
||||
|
||||
size_t Size() const {
|
||||
|
||||
@ -96,7 +96,7 @@ struct ArrayInterfaceErrors {
|
||||
*/
|
||||
class ArrayInterfaceHandler {
|
||||
public:
|
||||
enum Type : std::int8_t { kF4, kF8, kF16, kI1, kI2, kI4, kI8, kU1, kU2, kU4, kU8 };
|
||||
enum Type : std::int8_t { kF2, kF4, kF8, kF16, kI1, kI2, kI4, kI8, kU1, kU2, kU4, kU8 };
|
||||
|
||||
template <typename PtrType>
|
||||
static PtrType GetPtrFromArrayData(Object::Map const &obj) {
|
||||
@ -300,6 +300,12 @@ class ArrayInterfaceHandler {
|
||||
template <typename T, typename E = void>
|
||||
struct ToDType;
|
||||
// float
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600
|
||||
template <>
|
||||
struct ToDType<__half> {
|
||||
static constexpr ArrayInterfaceHandler::Type kType = ArrayInterfaceHandler::kF2;
|
||||
};
|
||||
#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600
|
||||
template <>
|
||||
struct ToDType<float> {
|
||||
static constexpr ArrayInterfaceHandler::Type kType = ArrayInterfaceHandler::kF4;
|
||||
@ -444,6 +450,12 @@ class ArrayInterface {
|
||||
type = T::kF16;
|
||||
CHECK(sizeof(long double) == 16)
|
||||
<< "128-bit floating point is not supported on current platform.";
|
||||
} else if (typestr[1] == 'f' && typestr[2] == '2') {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600
|
||||
type = T::kF2;
|
||||
#else
|
||||
LOG(FATAL) << "Half type is not supported.";
|
||||
#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600
|
||||
} else if (typestr[1] == 'f' && typestr[2] == '4') {
|
||||
type = T::kF4;
|
||||
} else if (typestr[1] == 'f' && typestr[2] == '8') {
|
||||
@ -477,6 +489,14 @@ class ArrayInterface {
|
||||
XGBOOST_HOST_DEV_INLINE decltype(auto) DispatchCall(Fn func) const {
|
||||
using T = ArrayInterfaceHandler::Type;
|
||||
switch (type) {
|
||||
case T::kF2: {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600
|
||||
return func(reinterpret_cast<__half const *>(data));
|
||||
#else
|
||||
SPAN_CHECK(false);
|
||||
return func(reinterpret_cast<float const *>(data));
|
||||
#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600
|
||||
}
|
||||
case T::kF4:
|
||||
return func(reinterpret_cast<float const *>(data));
|
||||
case T::kF8:
|
||||
@ -521,8 +541,18 @@ class ArrayInterface {
|
||||
XGBOOST_DEVICE T operator()(Index &&...index) const {
|
||||
static_assert(sizeof...(index) <= D, "Invalid index.");
|
||||
return this->DispatchCall([=](auto const *p_values) -> T {
|
||||
size_t offset = linalg::detail::Offset<0ul>(strides, 0ul, index...);
|
||||
std::size_t offset = linalg::detail::Offset<0ul>(strides, 0ul, index...);
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600
|
||||
// No operator defined for half -> size_t
|
||||
using Type = std::conditional_t<
|
||||
std::is_same<__half,
|
||||
std::remove_cv_t<std::remove_pointer_t<decltype(p_values)>>>::value &&
|
||||
std::is_same<std::size_t, std::remove_cv_t<T>>::value,
|
||||
unsigned long long, T>; // NOLINT
|
||||
return static_cast<T>(static_cast<Type>(p_values[offset]));
|
||||
#else
|
||||
return static_cast<T>(p_values[offset]);
|
||||
#endif
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@ -42,6 +42,8 @@ def dmatrix_from_cupy(input_type, DMatrixT, missing=np.NAN):
|
||||
def _test_from_cupy(DMatrixT):
|
||||
'''Test constructing DMatrix from cupy'''
|
||||
import cupy as cp
|
||||
|
||||
dmatrix_from_cupy(np.float16, DMatrixT, np.NAN)
|
||||
dmatrix_from_cupy(np.float32, DMatrixT, np.NAN)
|
||||
dmatrix_from_cupy(np.float64, DMatrixT, np.NAN)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user