|
|
|
|
@@ -35,7 +35,7 @@ struct ColumnarErrors {
|
|
|
|
|
return "Memory should be contigious.";
|
|
|
|
|
}
|
|
|
|
|
static char const* TypestrFormat() {
|
|
|
|
|
return "`typestr` should be of format <endian><type><size>.";
|
|
|
|
|
return "`typestr' should be of format <endian><type><size of type in bytes>.";
|
|
|
|
|
}
|
|
|
|
|
// Not supported in Apache Arrow.
|
|
|
|
|
static char const* BigEndian() {
|
|
|
|
|
@@ -50,7 +50,7 @@ struct ColumnarErrors {
|
|
|
|
|
return str.c_str();
|
|
|
|
|
}
|
|
|
|
|
static char const* Version() {
|
|
|
|
|
return "Only version 1 of __cuda_array_interface__ is being supported.";
|
|
|
|
|
return "Only version 1 of `__cuda_array_interface__' is supported.";
|
|
|
|
|
}
|
|
|
|
|
static char const* ofType(std::string const& type) {
|
|
|
|
|
static std::string str;
|
|
|
|
|
@@ -60,22 +60,6 @@ struct ColumnarErrors {
|
|
|
|
|
str += " type.";
|
|
|
|
|
return str.c_str();
|
|
|
|
|
}
|
|
|
|
|
static std::string UnknownTypeStr(std::string const& typestr) {
|
|
|
|
|
return "typestr from array interface: " + typestr + " is not supported.";
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// TODO(trivialfis): Abstract this into a class that accept a json
|
|
|
|
|
// object and turn it into an array (for cupy and numba).
|
|
|
|
|
class ArrayInterfaceHandler {
|
|
|
|
|
public:
|
|
|
|
|
template <typename T>
|
|
|
|
|
static constexpr char TypeChar() {
|
|
|
|
|
return
|
|
|
|
|
(std::is_floating_point<T>::value ? 'f' :
|
|
|
|
|
(std::is_integral<T>::value ?
|
|
|
|
|
(std::is_signed<T>::value ? 'i' : 'u') : '\0'));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static std::string TypeStr(char c) {
|
|
|
|
|
switch (c) {
|
|
|
|
|
@@ -89,12 +73,47 @@ class ArrayInterfaceHandler {
|
|
|
|
|
return "Unsigned integer";
|
|
|
|
|
case 'f':
|
|
|
|
|
return "Floating point";
|
|
|
|
|
case 'c':
|
|
|
|
|
return "Complex floating point";
|
|
|
|
|
case 'm':
|
|
|
|
|
return "Timedelta";
|
|
|
|
|
case 'M':
|
|
|
|
|
return "Datetime";
|
|
|
|
|
case 'O':
|
|
|
|
|
return "Object";
|
|
|
|
|
case 'S':
|
|
|
|
|
return "String";
|
|
|
|
|
case 'U':
|
|
|
|
|
return "Unicode";
|
|
|
|
|
case 'V':
|
|
|
|
|
return "Other";
|
|
|
|
|
default:
|
|
|
|
|
LOG(FATAL) << "Invalid type code: " << c << " in typestr of input array interface.";
|
|
|
|
|
LOG(FATAL) << "Invalid type code: " << c << " in `typestr' of input array."
|
|
|
|
|
<< "\nPlease verify the `__cuda_array_interface__' "
|
|
|
|
|
<< "of your input data complies to: "
|
|
|
|
|
<< "https://docs.scipy.org/doc/numpy/reference/arrays.interface.html"
|
|
|
|
|
<< "\nOr open an issue.";
|
|
|
|
|
return "";
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static std::string UnSupportedType(std::string const& typestr) {
|
|
|
|
|
return TypeStr(typestr.at(1)) + " is not supported.";
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// TODO(trivialfis): Abstract this into a class that accept a json
|
|
|
|
|
// object and turn it into an array (for cupy and numba).
|
|
|
|
|
class ArrayInterfaceHandler {
|
|
|
|
|
public:
|
|
|
|
|
template <typename T>
|
|
|
|
|
static constexpr char TypeChar() {
|
|
|
|
|
return
|
|
|
|
|
(std::is_floating_point<T>::value ? 'f' :
|
|
|
|
|
(std::is_integral<T>::value ?
|
|
|
|
|
(std::is_signed<T>::value ? 'i' : 'u') : '\0'));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename PtrType>
|
|
|
|
|
static PtrType GetPtrFromArrayData(std::map<std::string, Json> const& obj) {
|
|
|
|
|
if (obj.find("data") == obj.cend()) {
|
|
|
|
|
@@ -110,30 +129,30 @@ class ArrayInterfaceHandler {
|
|
|
|
|
|
|
|
|
|
static void Validate(std::map<std::string, Json> const& array) {
|
|
|
|
|
if (array.find("version") == array.cend()) {
|
|
|
|
|
LOG(FATAL) << "Missing version field for array interface";
|
|
|
|
|
LOG(FATAL) << "Missing `version' field for array interface";
|
|
|
|
|
}
|
|
|
|
|
auto version = get<Integer const>(array.at("version"));
|
|
|
|
|
CHECK_EQ(version, 1) << ColumnarErrors::Version();
|
|
|
|
|
|
|
|
|
|
if (array.find("typestr") == array.cend()) {
|
|
|
|
|
LOG(FATAL) << "Missing typestr field for array interface";
|
|
|
|
|
LOG(FATAL) << "Missing `typestr' field for array interface";
|
|
|
|
|
}
|
|
|
|
|
auto typestr = get<String const>(array.at("typestr"));
|
|
|
|
|
CHECK_EQ(typestr.size(), 3) << ColumnarErrors::TypestrFormat();
|
|
|
|
|
CHECK_NE(typestr.front(), '>') << ColumnarErrors::BigEndian();
|
|
|
|
|
|
|
|
|
|
if (array.find("shape") == array.cend()) {
|
|
|
|
|
LOG(FATAL) << "Missing shape field for array interface";
|
|
|
|
|
LOG(FATAL) << "Missing `shape' field for array interface";
|
|
|
|
|
}
|
|
|
|
|
if (array.find("data") == array.cend()) {
|
|
|
|
|
LOG(FATAL) << "Missing data field for array interface";
|
|
|
|
|
LOG(FATAL) << "Missing `data' field for array interface";
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Find null mask (validity mask) field
|
|
|
|
|
// Mask object is also an array interface, but with different requirements.
|
|
|
|
|
static void ExtractMask(std::map<std::string, Json> const& column,
|
|
|
|
|
common::Span<RBitField8::value_type>* p_out) {
|
|
|
|
|
static size_t ExtractMask(std::map<std::string, Json> const &column,
|
|
|
|
|
common::Span<RBitField8::value_type> *p_out) {
|
|
|
|
|
auto& s_mask = *p_out;
|
|
|
|
|
if (column.find("mask") != column.cend()) {
|
|
|
|
|
auto const& j_mask = get<Object const>(column.at("mask"));
|
|
|
|
|
@@ -143,24 +162,42 @@ class ArrayInterfaceHandler {
|
|
|
|
|
|
|
|
|
|
auto j_shape = get<Array const>(j_mask.at("shape"));
|
|
|
|
|
CHECK_EQ(j_shape.size(), 1) << ColumnarErrors::Dimension(1);
|
|
|
|
|
CHECK_EQ(get<Integer>(j_shape.front()) % 8, 0) <<
|
|
|
|
|
"Length of validity mask must be a multiple of 8 bytes.";
|
|
|
|
|
int64_t size = get<Integer>(j_shape.at(0)) *
|
|
|
|
|
sizeof(unsigned char) / sizeof(RBitField8::value_type);
|
|
|
|
|
auto typestr = get<String const>(j_mask.at("typestr"));
|
|
|
|
|
// For now this is just 1, we can support different size of interger in mask.
|
|
|
|
|
int64_t const type_length = typestr.at(2) - 48;
|
|
|
|
|
/*
|
|
|
|
|
* shape represents how many bits is in the mask. (This is a grey area, don't be
|
|
|
|
|
* suprised if it suddently represents something else when supporting a new
|
|
|
|
|
* implementation). Quoting from numpy array interface:
|
|
|
|
|
*
|
|
|
|
|
* The shape of this object should be "broadcastable" to the shape of the original
|
|
|
|
|
* array.
|
|
|
|
|
*
|
|
|
|
|
* And that's the only requirement.
|
|
|
|
|
*/
|
|
|
|
|
int64_t const n_bits = get<Integer>(j_shape.at(0));
|
|
|
|
|
// The size of span required to cover all bits. Here with 8 bits bitfield, we
|
|
|
|
|
// assume 1 byte alignment.
|
|
|
|
|
int64_t const span_size = RBitField8::ComputeStorageSize(n_bits);
|
|
|
|
|
|
|
|
|
|
if (j_mask.find("strides") != j_mask.cend()) {
|
|
|
|
|
auto strides = get<Array const>(column.at("strides"));
|
|
|
|
|
CHECK_EQ(strides.size(), 1) << ColumnarErrors::Dimension(1);
|
|
|
|
|
CHECK_EQ(get<Integer>(strides.at(0)), type_length) << ColumnarErrors::Contigious();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (typestr.at(1) == 't') {
|
|
|
|
|
CHECK_EQ(typestr.at(2), '1') << "There can be only 1 bit in each entry of bitfield.";
|
|
|
|
|
CHECK_EQ(typestr.at(2), '1') << "mask with bitfield type should be of 1 byte per bitfield.";
|
|
|
|
|
} else if (typestr.at(1) == 'i') {
|
|
|
|
|
CHECK_EQ(typestr.at(2), '1') << "mask with integer type should be of 1 byte per integer.";
|
|
|
|
|
CHECK_EQ(typestr.at(2), '1') << "mask with integer type should be of 1 byte per integer.";
|
|
|
|
|
} else {
|
|
|
|
|
LOG(FATAL) << "mask must be of integer type or bit field type.";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// For now this is just 1
|
|
|
|
|
int64_t const type_length = typestr.at(2) - 48;
|
|
|
|
|
s_mask = {p_mask, size / type_length};
|
|
|
|
|
s_mask = {p_mask, span_size};
|
|
|
|
|
return n_bits;
|
|
|
|
|
}
|
|
|
|
|
return 0;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
@@ -178,8 +215,8 @@ class ArrayInterfaceHandler {
|
|
|
|
|
|
|
|
|
|
if (column.find("strides") != column.cend()) {
|
|
|
|
|
auto strides = get<Array const>(column.at("strides"));
|
|
|
|
|
CHECK_EQ(strides.size(), 1) << ColumnarErrors::Dimension(1);
|
|
|
|
|
CHECK_EQ(get<Integer>(strides.at(0)), 4) << ColumnarErrors::Contigious();
|
|
|
|
|
CHECK_EQ(strides.size(), 1) << ColumnarErrors::Dimension(1);
|
|
|
|
|
CHECK_EQ(get<Integer>(strides.at(0)), sizeof(T)) << ColumnarErrors::Contigious();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto length = get<Integer const>(j_shape.at(0));
|
|
|
|
|
@@ -197,15 +234,22 @@ class ArrayInterfaceHandler {
|
|
|
|
|
foreign_col.size = s_data.size();
|
|
|
|
|
|
|
|
|
|
common::Span<RBitField8::value_type> s_mask;
|
|
|
|
|
ArrayInterfaceHandler::ExtractMask(column, &s_mask);
|
|
|
|
|
size_t n_bits = ArrayInterfaceHandler::ExtractMask(column, &s_mask);
|
|
|
|
|
|
|
|
|
|
foreign_col.valid = RBitField8(s_mask);
|
|
|
|
|
|
|
|
|
|
if (s_mask.data()) {
|
|
|
|
|
CHECK_EQ(n_bits, foreign_col.data.size())
|
|
|
|
|
<< "Shape of bit mask doesn't match data shape. "
|
|
|
|
|
<< "XGBoost doesn't support internal broadcasting.";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return foreign_col;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
#define DISPATCH_TYPE(__dispatched_func, __typestr, ...) { \
|
|
|
|
|
CHECK_EQ(__typestr.size(), 3) << ColumnarErrors::TypestrFormat(); \
|
|
|
|
|
if (__typestr.at(1) == 'f' && __typestr.at(2) == '4') { \
|
|
|
|
|
__dispatched_func<float>(__VA_ARGS__); \
|
|
|
|
|
} else if (__typestr.at(1) == 'f' && __typestr.at(2) == '8') { \
|
|
|
|
|
@@ -227,7 +271,7 @@ class ArrayInterfaceHandler {
|
|
|
|
|
} else if (__typestr.at(1) == 'u' && __typestr.at(2) == '8') { \
|
|
|
|
|
__dispatched_func<uint64_t>(__VA_ARGS__); \
|
|
|
|
|
} else { \
|
|
|
|
|
LOG(FATAL) << ColumnarErrors::UnknownTypeStr(__typestr); \
|
|
|
|
|
LOG(FATAL) << ColumnarErrors::UnSupportedType(__typestr); \
|
|
|
|
|
} \
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|