Extend array interface to handle ndarray. (#7434)

* Extend array interface to handle ndarray.

The `ArrayInterface` class is extended to support multi-dim array inputs. Previously this
class handles only 2-dim (vector is also matrix).  This PR specifies the expected
dimension at compile-time and the array interface can perform various checks automatically
for input data. Also, adapters like CSR are more rigorous about their input.  Lastly, row
vector and column vector are handled without intervention from the caller.
This commit is contained in:
Jiaming Yuan
2021-11-16 09:52:15 +08:00
committed by GitHub
parent e27f543deb
commit 55ee272ea8
18 changed files with 654 additions and 456 deletions

View File

@@ -13,24 +13,23 @@
#include <utility>
#include <vector>
#include "../common/bitfield.h"
#include "../common/common.h"
#include "xgboost/base.h"
#include "xgboost/data.h"
#include "xgboost/json.h"
#include "xgboost/linalg.h"
#include "xgboost/logging.h"
#include "xgboost/span.h"
#include "../common/bitfield.h"
#include "../common/common.h"
namespace xgboost {
// Common errors in parsing columnar format.
struct ArrayInterfaceErrors {
static char const* Contigious() {
return "Memory should be contigious.";
}
static char const* TypestrFormat() {
static char const *Contiguous() { return "Memory should be contiguous."; }
static char const *TypestrFormat() {
return "`typestr' should be of format <endian><type><size of type in bytes>.";
}
static char const* Dimension(int32_t d) {
static char const *Dimension(int32_t d) {
static std::string str;
str.clear();
str += "Only ";
@@ -38,11 +37,11 @@ struct ArrayInterfaceErrors {
str += " dimensional array is valid.";
return str.c_str();
}
static char const* Version() {
return "Only version <= 3 of "
"`__cuda_array_interface__/__array_interface__' are supported.";
static char const *Version() {
return "Only version <= 3 of `__cuda_array_interface__' and `__array_interface__' are "
"supported.";
}
static char const* OfType(std::string const& type) {
static char const *OfType(std::string const &type) {
static std::string str;
str.clear();
str += " should be of ";
@@ -92,49 +91,39 @@ struct ArrayInterfaceErrors {
}
};
// TODO(trivialfis): Abstract this into a class that accept a json
// object and turn it into an array (for cupy and numba).
/**
* Utilities for consuming array interface.
*/
class ArrayInterfaceHandler {
public:
template <typename T>
static constexpr char TypeChar() {
return
(std::is_floating_point<T>::value ? 'f' :
(std::is_integral<T>::value ?
(std::is_signed<T>::value ? 'i' : 'u') : '\0'));
}
enum Type : std::int8_t { kF4, kF8, kF16, kI1, kI2, kI4, kI8, kU1, kU2, kU4, kU8 };
template <typename PtrType>
static PtrType GetPtrFromArrayData(std::map<std::string, Json> const& obj) {
if (obj.find("data") == obj.cend()) {
static PtrType GetPtrFromArrayData(std::map<std::string, Json> const &obj) {
auto data_it = obj.find("data");
if (data_it == obj.cend()) {
LOG(FATAL) << "Empty data passed in.";
}
auto p_data = reinterpret_cast<PtrType>(static_cast<size_t>(
get<Integer const>(
get<Array const>(
obj.at("data"))
.at(0))));
auto p_data = reinterpret_cast<PtrType>(
static_cast<size_t>(get<Integer const>(get<Array const>(data_it->second).at(0))));
return p_data;
}
static void Validate(std::map<std::string, Json> const& array) {
static void Validate(std::map<std::string, Json> const &array) {
auto version_it = array.find("version");
if (version_it == array.cend()) {
LOG(FATAL) << "Missing `version' field for array interface";
}
auto stream_it = array.find("stream");
if (stream_it != array.cend() && !IsA<Null>(stream_it->second)) {
// is cuda, check the version.
if (get<Integer const>(version_it->second) > 3) {
LOG(FATAL) << ArrayInterfaceErrors::Version();
}
if (get<Integer const>(version_it->second) > 3) {
LOG(FATAL) << ArrayInterfaceErrors::Version();
}
if (array.find("typestr") == array.cend()) {
auto typestr_it = array.find("typestr");
if (typestr_it == array.cend()) {
LOG(FATAL) << "Missing `typestr' field for array interface";
}
auto typestr = get<String const>(array.at("typestr"));
auto typestr = get<String const>(typestr_it->second);
CHECK(typestr.size() == 3 || typestr.size() == 4) << ArrayInterfaceErrors::TypestrFormat();
if (array.find("shape") == array.cend()) {
@@ -149,12 +138,12 @@ class ArrayInterfaceHandler {
// Mask object is also an array interface, but with different requirements.
static size_t ExtractMask(std::map<std::string, Json> const &column,
common::Span<RBitField8::value_type> *p_out) {
auto& s_mask = *p_out;
auto &s_mask = *p_out;
if (column.find("mask") != column.cend()) {
auto const& j_mask = get<Object const>(column.at("mask"));
auto const &j_mask = get<Object const>(column.at("mask"));
Validate(j_mask);
auto p_mask = GetPtrFromArrayData<RBitField8::value_type*>(j_mask);
auto p_mask = GetPtrFromArrayData<RBitField8::value_type *>(j_mask);
auto j_shape = get<Array const>(j_mask.at("shape"));
CHECK_EQ(j_shape.size(), 1) << ArrayInterfaceErrors::Dimension(1);
@@ -186,8 +175,8 @@ class ArrayInterfaceHandler {
if (j_mask.find("strides") != j_mask.cend()) {
auto strides = get<Array const>(column.at("strides"));
CHECK_EQ(strides.size(), 1) << ArrayInterfaceErrors::Dimension(1);
CHECK_EQ(get<Integer>(strides.at(0)), type_length) << ArrayInterfaceErrors::Contigious();
CHECK_EQ(strides.size(), 1) << ArrayInterfaceErrors::Dimension(1);
CHECK_EQ(get<Integer>(strides.at(0)), type_length) << ArrayInterfaceErrors::Contiguous();
}
s_mask = {p_mask, span_size};
@@ -195,77 +184,212 @@ class ArrayInterfaceHandler {
}
return 0;
}
static std::pair<bst_row_t, bst_feature_t> ExtractShape(
std::map<std::string, Json> const& column) {
auto j_shape = get<Array const>(column.at("shape"));
auto typestr = get<String const>(column.at("typestr"));
if (j_shape.size() == 1) {
return {static_cast<bst_row_t>(get<Integer const>(j_shape.at(0))), 1};
} else {
CHECK_EQ(j_shape.size(), 2) << "Only 1-D and 2-D arrays are supported.";
return {static_cast<bst_row_t>(get<Integer const>(j_shape.at(0))),
static_cast<bst_feature_t>(get<Integer const>(j_shape.at(1)))};
}
}
static void ExtractStride(std::map<std::string, Json> const &column,
size_t *stride_r, size_t *stride_c, size_t rows,
size_t cols, size_t itemsize) {
auto strides_it = column.find("strides");
if (strides_it == column.cend() || IsA<Null>(strides_it->second)) {
// default strides
*stride_r = cols;
*stride_c = 1;
} else {
// strides specified by the array interface
auto const &j_strides = get<Array const>(strides_it->second);
CHECK_LE(j_strides.size(), 2) << ArrayInterfaceErrors::Dimension(2);
*stride_r = get<Integer const>(j_strides[0]) / itemsize;
size_t n = 1;
if (j_strides.size() == 2) {
n = get<Integer const>(j_strides[1]) / itemsize;
/**
* \brief Handle vector inputs. For higher dimension, we require strictly correct shape.
*/
template <int32_t D>
static void HandleRowVector(std::vector<size_t> const &shape, std::vector<size_t> *p_out) {
auto &out = *p_out;
if (shape.size() == 2 && D == 1) {
auto m = shape[0];
auto n = shape[1];
CHECK(m == 1 || n == 1);
if (m == 1) {
// keep the number of columns
out[0] = out[1];
out.resize(1);
} else if (n == 1) {
// keep the number of rows.
out.resize(1);
}
*stride_c = n;
// when both m and n are 1, above logic keeps the column.
// when neither m nor n is 1, caller should throw an error about Dimension.
}
auto valid = rows * (*stride_r) + cols * (*stride_c) >= (rows * cols);
CHECK(valid) << "Invalid strides in array."
<< " strides: (" << (*stride_r) << "," << (*stride_c)
<< "), shape: (" << rows << ", " << cols << ")";
}
static void* ExtractData(std::map<std::string, Json> const &column,
std::pair<size_t, size_t> shape) {
Validate(column);
void* p_data = ArrayInterfaceHandler::GetPtrFromArrayData<void*>(column);
template <int32_t D>
static void ExtractShape(std::map<std::string, Json> const &array, size_t (&out_shape)[D]) {
auto const &j_shape = get<Array const>(array.at("shape"));
std::vector<size_t> shape_arr(j_shape.size(), 0);
std::transform(j_shape.cbegin(), j_shape.cend(), shape_arr.begin(),
[](Json in) { return get<Integer const>(in); });
// handle column vector vs. row vector
HandleRowVector<D>(shape_arr, &shape_arr);
// Copy shape.
size_t i;
for (i = 0; i < shape_arr.size(); ++i) {
CHECK_LT(i, D) << ArrayInterfaceErrors::Dimension(D);
out_shape[i] = shape_arr[i];
}
// Fill the remaining dimensions
std::fill(out_shape + i, out_shape + D, 1);
}
/**
* \brief Extracts the optiona `strides' field and returns whether the array is c-contiguous.
*/
template <int32_t D>
static bool ExtractStride(std::map<std::string, Json> const &array, size_t itemsize,
size_t (&shape)[D], size_t (&stride)[D]) {
auto strides_it = array.find("strides");
// No stride is provided
if (strides_it == array.cend() || IsA<Null>(strides_it->second)) {
// No stride is provided, we can calculate it from shape.
linalg::detail::CalcStride(shape, stride);
// Quote:
//
// strides: Either None to indicate a C-style contiguous array or a Tuple of
// strides which provides the number of bytes
return true;
}
// Get shape, we need to make changes to handle row vector, so some duplicated code
// from `ExtractShape` for copying out the shape.
auto const &j_shape = get<Array const>(array.at("shape"));
std::vector<size_t> shape_arr(j_shape.size(), 0);
std::transform(j_shape.cbegin(), j_shape.cend(), shape_arr.begin(),
[](Json in) { return get<Integer const>(in); });
// Get stride
auto const &j_strides = get<Array const>(strides_it->second);
CHECK_EQ(j_strides.size(), j_shape.size()) << "stride and shape don't match.";
std::vector<size_t> stride_arr(j_strides.size(), 0);
std::transform(j_strides.cbegin(), j_strides.cend(), stride_arr.begin(),
[](Json in) { return get<Integer const>(in); });
// Handle column vector vs. row vector
HandleRowVector<D>(shape_arr, &stride_arr);
size_t i;
for (i = 0; i < stride_arr.size(); ++i) {
// If one of the dim has shape 0 then total size is 0, stride is meaningless, but we
// set it to 0 here just to be consistent
CHECK_LT(i, D) << ArrayInterfaceErrors::Dimension(D);
// We use number of items instead of number of bytes
stride[i] = stride_arr[i] / itemsize;
}
std::fill(stride + i, stride + D, 1);
// If the stride can be calculated from shape then it's contiguous.
size_t stride_tmp[D];
linalg::detail::CalcStride(shape, stride_tmp);
return std::equal(stride_tmp, stride_tmp + D, stride);
}
static void *ExtractData(std::map<std::string, Json> const &array, size_t size) {
Validate(array);
void *p_data = ArrayInterfaceHandler::GetPtrFromArrayData<void *>(array);
if (!p_data) {
CHECK_EQ(shape.first * shape.second, 0) << "Empty data with non-zero shape.";
CHECK_EQ(size, 0) << "Empty data with non-zero shape.";
}
return p_data;
}
/**
* \brief Whether the ptr is allocated by CUDA.
*/
static bool IsCudaPtr(void const *ptr);
/**
* \brief Sync the CUDA stream.
*/
static void SyncCudaStream(int64_t stream);
};
/**
* Dispatch compile time type to runtime type.
*/
template <typename T, typename E = void>
struct ToDType;
// float
template <>
struct ToDType<float> {
static constexpr ArrayInterfaceHandler::Type kType = ArrayInterfaceHandler::kF4;
};
template <>
struct ToDType<double> {
static constexpr ArrayInterfaceHandler::Type kType = ArrayInterfaceHandler::kF8;
};
template <typename T>
struct ToDType<T,
std::enable_if_t<std::is_same<T, long double>::value && sizeof(long double) == 16>> {
static constexpr ArrayInterfaceHandler::Type kType = ArrayInterfaceHandler::kF16;
};
// uint
template <>
struct ToDType<uint8_t> {
static constexpr ArrayInterfaceHandler::Type kType = ArrayInterfaceHandler::kU1;
};
template <>
struct ToDType<uint16_t> {
static constexpr ArrayInterfaceHandler::Type kType = ArrayInterfaceHandler::kU2;
};
template <>
struct ToDType<uint32_t> {
static constexpr ArrayInterfaceHandler::Type kType = ArrayInterfaceHandler::kU4;
};
template <>
struct ToDType<uint64_t> {
static constexpr ArrayInterfaceHandler::Type kType = ArrayInterfaceHandler::kU8;
};
// int
template <>
struct ToDType<int8_t> {
static constexpr ArrayInterfaceHandler::Type kType = ArrayInterfaceHandler::kI1;
};
template <>
struct ToDType<int16_t> {
static constexpr ArrayInterfaceHandler::Type kType = ArrayInterfaceHandler::kI2;
};
template <>
struct ToDType<int32_t> {
static constexpr ArrayInterfaceHandler::Type kType = ArrayInterfaceHandler::kI4;
};
template <>
struct ToDType<int64_t> {
static constexpr ArrayInterfaceHandler::Type kType = ArrayInterfaceHandler::kI8;
};
#if !defined(XGBOOST_USE_CUDA)
inline void ArrayInterfaceHandler::SyncCudaStream(int64_t stream) {
common::AssertGPUSupport();
}
inline void ArrayInterfaceHandler::SyncCudaStream(int64_t stream) { common::AssertGPUSupport(); }
inline bool ArrayInterfaceHandler::IsCudaPtr(void const *ptr) { return false; }
#endif // !defined(XGBOOST_USE_CUDA)
// A view over __array_interface__
/**
* \brief A type erased view over __array_interface__ protocol defined by numpy
*
* <a href="https://numpy.org/doc/stable/reference/arrays.interface.html">numpy</a>.
*
* \tparam D The number of maximum dimension.
* User input array must have dim <= D for all non-trivial dimensions. During
* construction, the ctor can automatically remove those trivial dimensions.
*
* \tparam allow_mask Whether masked array is accepted.
*
* Currently this only supported for 1-dim vector, which is used by cuDF column
* (apache arrow format). For general masked array, as the time of writting, only
* numpy has the proper support even though it's in the __cuda_array_interface__
* protocol defined by numba.
*/
template <int32_t D, bool allow_mask = (D == 1)>
class ArrayInterface {
void Initialize(std::map<std::string, Json> const &array,
bool allow_mask = true) {
static_assert(D > 0, "Invalid dimension for array interface.");
/**
* \brief Initialize the object, by extracting shape, stride and type.
*
* The function also perform some basic validation for input array. Lastly it will
* also remove trivial dimensions like converting a matrix with shape (n_samples, 1)
* to a vector of size n_samples. For for inputs like weights, this should be a 1
* dimension column vector even though user might provide a matrix.
*/
void Initialize(std::map<std::string, Json> const &array) {
ArrayInterfaceHandler::Validate(array);
auto typestr = get<String const>(array.at("typestr"));
this->AssignType(StringView{typestr});
ArrayInterfaceHandler::ExtractShape(array, shape);
size_t itemsize = typestr[2] - '0';
is_contiguous = ArrayInterfaceHandler::ExtractStride(array, itemsize, shape, strides);
n = linalg::detail::CalcSize(shape);
std::tie(num_rows, num_cols) = ArrayInterfaceHandler::ExtractShape(array);
data = ArrayInterfaceHandler::ExtractData(
array, std::make_pair(num_rows, num_cols));
data = ArrayInterfaceHandler::ExtractData(array, n);
static_assert(allow_mask ? D == 1 : D >= 1, "Masked ndarray is not supported.");
if (allow_mask) {
common::Span<RBitField8::value_type> s_mask;
size_t n_bits = ArrayInterfaceHandler::ExtractMask(array, &s_mask);
@@ -273,18 +397,13 @@ class ArrayInterface {
valid = RBitField8(s_mask);
if (s_mask.data()) {
CHECK_EQ(n_bits, num_rows)
<< "Shape of bit mask doesn't match data shape. "
<< "XGBoost doesn't support internal broadcasting.";
CHECK_EQ(n_bits, n) << "Shape of bit mask doesn't match data shape. "
<< "XGBoost doesn't support internal broadcasting.";
}
} else {
CHECK(array.find("mask") == array.cend())
<< "Masked array is not yet supported.";
CHECK(array.find("mask") == array.cend()) << "Masked array is not yet supported.";
}
ArrayInterfaceHandler::ExtractStride(array, &stride_row, &stride_col,
num_rows, num_cols, typestr[2] - '0');
auto stream_it = array.find("stream");
if (stream_it != array.cend() && !IsA<Null>(stream_it->second)) {
int64_t stream = get<Integer const>(stream_it->second);
@@ -292,151 +411,147 @@ class ArrayInterface {
}
}
public:
enum Type : std::int8_t { kF4, kF8, kF16, kI1, kI2, kI4, kI8, kU1, kU2, kU4, kU8 };
public:
ArrayInterface() = default;
explicit ArrayInterface(std::string const &str, bool allow_mask = true)
: ArrayInterface{StringView{str.c_str(), str.size()}, allow_mask} {}
explicit ArrayInterface(std::map<std::string, Json> const &array) { this->Initialize(array); }
explicit ArrayInterface(std::map<std::string, Json> const &column,
bool allow_mask = true) {
this->Initialize(column, allow_mask);
}
explicit ArrayInterface(StringView str, bool allow_mask = true) {
auto jinterface = Json::Load(str);
if (IsA<Object>(jinterface)) {
this->Initialize(get<Object const>(jinterface), allow_mask);
explicit ArrayInterface(Json const &array) {
if (IsA<Object>(array)) {
this->Initialize(get<Object const>(array));
return;
}
if (IsA<Array>(jinterface)) {
CHECK_EQ(get<Array const>(jinterface).size(), 1)
if (IsA<Array>(array)) {
CHECK_EQ(get<Array const>(array).size(), 1)
<< "Column: " << ArrayInterfaceErrors::Dimension(1);
this->Initialize(get<Object const>(get<Array const>(jinterface)[0]), allow_mask);
this->Initialize(get<Object const>(get<Array const>(array)[0]));
return;
}
}
void AsColumnVector() {
CHECK(num_rows == 1 || num_cols == 1) << "Array should be a vector instead of matrix.";
num_rows = std::max(num_rows, static_cast<size_t>(num_cols));
num_cols = 1;
explicit ArrayInterface(std::string const &str) : ArrayInterface{StringView{str}} {}
stride_row = std::max(stride_row, stride_col);
stride_col = 1;
}
explicit ArrayInterface(StringView str) : ArrayInterface<D>{Json::Load(str)} {}
void AssignType(StringView typestr) {
if (typestr.size() == 4 && typestr[1] == 'f' && typestr[2] == '1' &&
typestr[3] == '6') {
type = kF16;
using T = ArrayInterfaceHandler::Type;
if (typestr.size() == 4 && typestr[1] == 'f' && typestr[2] == '1' && typestr[3] == '6') {
type = T::kF16;
CHECK(sizeof(long double) == 16)
<< "128-bit floating point is not supported on current platform.";
} else if (typestr[1] == 'f' && typestr[2] == '4') {
type = kF4;
type = T::kF4;
} else if (typestr[1] == 'f' && typestr[2] == '8') {
type = kF8;
type = T::kF8;
} else if (typestr[1] == 'i' && typestr[2] == '1') {
type = kI1;
type = T::kI1;
} else if (typestr[1] == 'i' && typestr[2] == '2') {
type = kI2;
type = T::kI2;
} else if (typestr[1] == 'i' && typestr[2] == '4') {
type = kI4;
type = T::kI4;
} else if (typestr[1] == 'i' && typestr[2] == '8') {
type = kI8;
type = T::kI8;
} else if (typestr[1] == 'u' && typestr[2] == '1') {
type = kU1;
type = T::kU1;
} else if (typestr[1] == 'u' && typestr[2] == '2') {
type = kU2;
type = T::kU2;
} else if (typestr[1] == 'u' && typestr[2] == '4') {
type = kU4;
type = T::kU4;
} else if (typestr[1] == 'u' && typestr[2] == '8') {
type = kU8;
type = T::kU8;
} else {
LOG(FATAL) << ArrayInterfaceErrors::UnSupportedType(typestr);
return;
}
}
XGBOOST_DEVICE size_t Shape(size_t i) const { return shape[i]; }
XGBOOST_DEVICE size_t Stride(size_t i) const { return strides[i]; }
template <typename Fn>
XGBOOST_HOST_DEV_INLINE decltype(auto) DispatchCall(Fn func) const {
XGBOOST_HOST_DEV_INLINE constexpr decltype(auto) DispatchCall(Fn func) const {
using T = ArrayInterfaceHandler::Type;
switch (type) {
case kF4:
return func(reinterpret_cast<float *>(data));
case kF8:
return func(reinterpret_cast<double *>(data));
case T::kF4:
return func(reinterpret_cast<float const *>(data));
case T::kF8:
return func(reinterpret_cast<double const *>(data));
#ifdef __CUDA_ARCH__
case kF16: {
// CUDA device code doesn't support long double.
SPAN_CHECK(false);
return func(reinterpret_cast<double *>(data));
}
case T::kF16: {
// CUDA device code doesn't support long double.
SPAN_CHECK(false);
return func(reinterpret_cast<double const *>(data));
}
#else
case kF16:
return func(reinterpret_cast<long double *>(data));
case T::kF16:
return func(reinterpret_cast<long double const *>(data));
#endif
case kI1:
return func(reinterpret_cast<int8_t *>(data));
case kI2:
return func(reinterpret_cast<int16_t *>(data));
case kI4:
return func(reinterpret_cast<int32_t *>(data));
case kI8:
return func(reinterpret_cast<int64_t *>(data));
case kU1:
return func(reinterpret_cast<uint8_t *>(data));
case kU2:
return func(reinterpret_cast<uint16_t *>(data));
case kU4:
return func(reinterpret_cast<uint32_t *>(data));
case kU8:
return func(reinterpret_cast<uint64_t *>(data));
case T::kI1:
return func(reinterpret_cast<int8_t const *>(data));
case T::kI2:
return func(reinterpret_cast<int16_t const *>(data));
case T::kI4:
return func(reinterpret_cast<int32_t const *>(data));
case T::kI8:
return func(reinterpret_cast<int64_t const *>(data));
case T::kU1:
return func(reinterpret_cast<uint8_t const *>(data));
case T::kU2:
return func(reinterpret_cast<uint16_t const *>(data));
case T::kU4:
return func(reinterpret_cast<uint32_t const *>(data));
case T::kU8:
return func(reinterpret_cast<uint64_t const *>(data));
}
SPAN_CHECK(false);
return func(reinterpret_cast<uint64_t *>(data));
return func(reinterpret_cast<uint64_t const *>(data));
}
XGBOOST_DEVICE size_t ElementSize() {
return this->DispatchCall([](auto* p_values) {
return sizeof(std::remove_pointer_t<decltype(p_values)>);
XGBOOST_DEVICE size_t constexpr ElementSize() {
return this->DispatchCall(
[](auto *p_values) { return sizeof(std::remove_pointer_t<decltype(p_values)>); });
}
template <typename T = float, typename... Index>
XGBOOST_DEVICE T operator()(Index &&...index) const {
static_assert(sizeof...(index) <= D, "Invalid index.");
return this->DispatchCall([=](auto const *p_values) -> T {
size_t offset = linalg::detail::Offset<0ul>(strides, 0ul, index...);
return static_cast<T>(p_values[offset]);
});
}
template <typename T = float>
XGBOOST_DEVICE T GetElement(size_t r, size_t c) const {
return this->DispatchCall(
[=](auto *p_values) -> T { return p_values[stride_row * r + stride_col * c]; });
}
// Used only by columnar format.
RBitField8 valid;
bst_row_t num_rows;
bst_feature_t num_cols;
size_t stride_row{0};
size_t stride_col{0};
void* data;
Type type;
// Array stride
size_t strides[D]{0};
// Array shape
size_t shape[D]{0};
// Type earsed pointer referencing the data.
void const *data{nullptr};
// Total number of items
size_t n{0};
// Whether the memory is c-contiguous
bool is_contiguous{false};
// RTTI, initialized to the f16 to avoid masking potential bugs in initialization.
ArrayInterfaceHandler::Type type{ArrayInterfaceHandler::kF16};
};
template <typename T> std::string MakeArrayInterface(T const *data, size_t n) {
Json arr{Object{}};
arr["data"] = Array(std::vector<Json>{
Json{Integer{reinterpret_cast<int64_t>(data)}}, Json{Boolean{false}}});
arr["shape"] = Array{std::vector<Json>{Json{Integer{n}}, Json{Integer{1}}}};
std::string typestr;
if (DMLC_LITTLE_ENDIAN) {
typestr.push_back('<');
} else {
typestr.push_back('>');
/**
* \brief Helper for type casting.
*/
template <typename T, int32_t D>
struct TypedIndex {
ArrayInterface<D> const &array;
template <typename... I>
XGBOOST_DEVICE T operator()(I &&...ind) const {
static_assert(sizeof...(ind) <= D, "Invalid index.");
return array.template operator()<T>(ind...);
}
typestr.push_back(ArrayInterfaceHandler::TypeChar<T>());
typestr += std::to_string(sizeof(T));
arr["typestr"] = typestr;
arr["version"] = 3;
std::string str;
Json::Dump(arr, &str);
return str;
};
template <int32_t D>
inline void CheckArrayInterface(StringView key, ArrayInterface<D> const &array) {
CHECK(!array.valid.Data()) << "Meta info " << key << " should be dense, found validity mask";
}
} // namespace xgboost
#endif // XGBOOST_DATA_ARRAY_INTERFACE_H_