Support for all primitive types from array. (#7003)
* Change C API name. * Test for all primitive types from array. * Add native support for CPU 128 float. * Convert boolean and float16 in Python. * Fix dask version for now.
This commit is contained in:
@@ -42,7 +42,8 @@ struct ArrayInterfaceErrors {
|
||||
return str.c_str();
|
||||
}
|
||||
static char const* Version() {
|
||||
return "Only version <= 3 of `__cuda_array_interface__' are supported.";
|
||||
return "Only version <= 3 of "
|
||||
"`__cuda_array_interface__/__array_interface__' are supported.";
|
||||
}
|
||||
static char const* OfType(std::string const& type) {
|
||||
static std::string str;
|
||||
@@ -81,7 +82,7 @@ struct ArrayInterfaceErrors {
|
||||
return "Other";
|
||||
default:
|
||||
LOG(FATAL) << "Invalid type code: " << c << " in `typestr' of input array."
|
||||
<< "\nPlease verify the `__cuda_array_interface__' "
|
||||
<< "\nPlease verify the `__cuda_array_interface__/__array_interface__' "
|
||||
<< "of your input data complies to: "
|
||||
<< "https://docs.scipy.org/doc/numpy/reference/arrays.interface.html"
|
||||
<< "\nOr open an issue.";
|
||||
@@ -90,7 +91,7 @@ struct ArrayInterfaceErrors {
|
||||
}
|
||||
|
||||
static std::string UnSupportedType(StringView typestr) {
|
||||
return TypeStr(typestr[1]) + " is not supported.";
|
||||
return TypeStr(typestr[1]) + "-" + typestr[2] + " is not supported.";
|
||||
}
|
||||
};
|
||||
|
||||
@@ -135,8 +136,9 @@ class ArrayInterfaceHandler {
|
||||
if (array.find("typestr") == array.cend()) {
|
||||
LOG(FATAL) << "Missing `typestr' field for array interface";
|
||||
}
|
||||
|
||||
auto typestr = get<String const>(array.at("typestr"));
|
||||
CHECK_EQ(typestr.size(), 3) << ArrayInterfaceErrors::TypestrFormat();
|
||||
CHECK(typestr.size() == 3 || typestr.size() == 4) << ArrayInterfaceErrors::TypestrFormat();
|
||||
CHECK_NE(typestr.front(), '>') << ArrayInterfaceErrors::BigEndian();
|
||||
|
||||
if (array.find("shape") == array.cend()) {
|
||||
@@ -295,7 +297,7 @@ class ArrayInterface {
|
||||
}
|
||||
|
||||
public:
|
||||
enum Type : std::int8_t { kF4, kF8, kI1, kI2, kI4, kI8, kU1, kU2, kU4, kU8 };
|
||||
enum Type : std::int8_t { kF4, kF8, kF16, kI1, kI2, kI4, kI8, kU1, kU2, kU4, kU8 };
|
||||
|
||||
public:
|
||||
ArrayInterface() = default;
|
||||
@@ -331,7 +333,12 @@ class ArrayInterface {
|
||||
}
|
||||
|
||||
void AssignType(StringView typestr) {
|
||||
if (typestr[1] == 'f' && typestr[2] == '4') {
|
||||
if (typestr.size() == 4 && typestr[1] == 'f' && typestr[2] == '1' &&
|
||||
typestr[3] == '6') {
|
||||
type = kF16;
|
||||
CHECK(sizeof(long double) == 16)
|
||||
<< "128-bit floating point is not supported on current platform.";
|
||||
} else if (typestr[1] == 'f' && typestr[2] == '4') {
|
||||
type = kF4;
|
||||
} else if (typestr[1] == 'f' && typestr[2] == '8') {
|
||||
type = kF8;
|
||||
@@ -364,6 +371,16 @@ class ArrayInterface {
|
||||
return func(reinterpret_cast<float *>(data));
|
||||
case kF8:
|
||||
return func(reinterpret_cast<double *>(data));
|
||||
#ifdef __CUDA_ARCH__
|
||||
case kF16: {
|
||||
// CUDA device code doesn't support long double.
|
||||
SPAN_CHECK(false);
|
||||
return func(reinterpret_cast<double *>(data));
|
||||
}
|
||||
#else
|
||||
case kF16:
|
||||
return func(reinterpret_cast<long double *>(data));
|
||||
#endif
|
||||
case kI1:
|
||||
return func(reinterpret_cast<int8_t *>(data));
|
||||
case kI2:
|
||||
|
||||
Reference in New Issue
Block a user