Support for all primitive types from array. (#7003)

* Change C API name.
* Test for all primitive types from array.
* Add native support for CPU 128 float.
* Convert boolean and float16 in Python.

* Fix dask version for now.
This commit is contained in:
Jiaming Yuan
2021-06-01 08:34:48 +08:00
committed by GitHub
parent 816b789bf0
commit ee4f51a631
8 changed files with 154 additions and 24 deletions

View File

@@ -42,7 +42,8 @@ struct ArrayInterfaceErrors {
return str.c_str();
}
static char const* Version() {
return "Only version <= 3 of `__cuda_array_interface__' are supported.";
return "Only version <= 3 of "
"`__cuda_array_interface__/__array_interface__' are supported.";
}
static char const* OfType(std::string const& type) {
static std::string str;
@@ -81,7 +82,7 @@ struct ArrayInterfaceErrors {
return "Other";
default:
LOG(FATAL) << "Invalid type code: " << c << " in `typestr' of input array."
<< "\nPlease verify the `__cuda_array_interface__' "
<< "\nPlease verify the `__cuda_array_interface__/__array_interface__' "
<< "of your input data complies to: "
<< "https://docs.scipy.org/doc/numpy/reference/arrays.interface.html"
<< "\nOr open an issue.";
@@ -90,7 +91,7 @@ struct ArrayInterfaceErrors {
}
static std::string UnSupportedType(StringView typestr) {
return TypeStr(typestr[1]) + " is not supported.";
return TypeStr(typestr[1]) + "-" + typestr[2] + " is not supported.";
}
};
@@ -135,8 +136,9 @@ class ArrayInterfaceHandler {
if (array.find("typestr") == array.cend()) {
LOG(FATAL) << "Missing `typestr' field for array interface";
}
auto typestr = get<String const>(array.at("typestr"));
CHECK_EQ(typestr.size(), 3) << ArrayInterfaceErrors::TypestrFormat();
CHECK(typestr.size() == 3 || typestr.size() == 4) << ArrayInterfaceErrors::TypestrFormat();
CHECK_NE(typestr.front(), '>') << ArrayInterfaceErrors::BigEndian();
if (array.find("shape") == array.cend()) {
@@ -295,7 +297,7 @@ class ArrayInterface {
}
public:
enum Type : std::int8_t { kF4, kF8, kI1, kI2, kI4, kI8, kU1, kU2, kU4, kU8 };
enum Type : std::int8_t { kF4, kF8, kF16, kI1, kI2, kI4, kI8, kU1, kU2, kU4, kU8 };
public:
ArrayInterface() = default;
@@ -331,7 +333,12 @@ class ArrayInterface {
}
void AssignType(StringView typestr) {
if (typestr[1] == 'f' && typestr[2] == '4') {
if (typestr.size() == 4 && typestr[1] == 'f' && typestr[2] == '1' &&
typestr[3] == '6') {
type = kF16;
CHECK(sizeof(long double) == 16)
<< "128-bit floating point is not supported on current platform.";
} else if (typestr[1] == 'f' && typestr[2] == '4') {
type = kF4;
} else if (typestr[1] == 'f' && typestr[2] == '8') {
type = kF8;
@@ -364,6 +371,16 @@ class ArrayInterface {
return func(reinterpret_cast<float *>(data));
case kF8:
return func(reinterpret_cast<double *>(data));
#ifdef __CUDA_ARCH__
case kF16: {
// CUDA device code doesn't support long double.
SPAN_CHECK(false);
return func(reinterpret_cast<double *>(data));
}
#else
case kF16:
return func(reinterpret_cast<long double *>(data));
#endif
case kI1:
return func(reinterpret_cast<int8_t *>(data));
case kI2: