Support column major array. (#6765)
This commit is contained in:
@@ -6,6 +6,7 @@
|
||||
#ifndef XGBOOST_DATA_ARRAY_INTERFACE_H_
|
||||
#define XGBOOST_DATA_ARRAY_INTERFACE_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <cinttypes>
|
||||
#include <map>
|
||||
#include <string>
|
||||
@@ -40,7 +41,7 @@ struct ArrayInterfaceErrors {
|
||||
return str.c_str();
|
||||
}
|
||||
static char const* Version() {
|
||||
return "Only version 1 of `__cuda_array_interface__' is supported.";
|
||||
return "Only version 1 and 2 of `__cuda_array_interface__' are supported.";
|
||||
}
|
||||
static char const* OfType(std::string const& type) {
|
||||
static std::string str;
|
||||
@@ -191,43 +192,46 @@ class ArrayInterfaceHandler {
|
||||
std::map<std::string, Json> const& column) {
|
||||
auto j_shape = get<Array const>(column.at("shape"));
|
||||
auto typestr = get<String const>(column.at("typestr"));
|
||||
if (column.find("strides") != column.cend()) {
|
||||
if (!IsA<Null>(column.at("strides"))) {
|
||||
auto strides = get<Array const>(column.at("strides"));
|
||||
CHECK_EQ(strides.size(), j_shape.size())
|
||||
<< ArrayInterfaceErrors::Dimension(1);
|
||||
CHECK_EQ(get<Integer>(strides.at(0)), typestr.at(2) - '0')
|
||||
<< ArrayInterfaceErrors::Contigious();
|
||||
}
|
||||
}
|
||||
|
||||
if (j_shape.size() == 1) {
|
||||
return {static_cast<bst_row_t>(get<Integer const>(j_shape.at(0))), 1};
|
||||
} else {
|
||||
CHECK_EQ(j_shape.size(), 2)
|
||||
<< "Only 1D or 2-D arrays currently supported.";
|
||||
CHECK_EQ(j_shape.size(), 2) << "Only 1-D and 2-D arrays are supported.";
|
||||
return {static_cast<bst_row_t>(get<Integer const>(j_shape.at(0))),
|
||||
static_cast<bst_feature_t>(get<Integer const>(j_shape.at(1)))};
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static common::Span<T> ExtractData(std::map<std::string, Json> const& column) {
|
||||
static void ExtractStride(std::map<std::string, Json> const &column,
|
||||
size_t strides[2], size_t rows, size_t cols, size_t itemsize) {
|
||||
auto strides_it = column.find("strides");
|
||||
if (strides_it == column.cend() || IsA<Null>(strides_it->second)) {
|
||||
// default strides
|
||||
strides[0] = cols;
|
||||
strides[1] = 1;
|
||||
} else {
|
||||
// strides specified by the array interface
|
||||
auto const &j_strides = get<Array const>(strides_it->second);
|
||||
CHECK_LE(j_strides.size(), 2) << ArrayInterfaceErrors::Dimension(2);
|
||||
strides[0] = get<Integer const>(j_strides[0]) / itemsize;
|
||||
size_t n = 1;
|
||||
if (j_strides.size() == 2) {
|
||||
n = get<Integer const>(j_strides[1]) / itemsize;
|
||||
}
|
||||
strides[1] = n;
|
||||
}
|
||||
auto valid = (rows - 1) * strides[0] + (cols - 1) * strides[1] == (rows * cols) - 1;
|
||||
CHECK(valid) << "Invalid strides in array.";
|
||||
}
|
||||
|
||||
static void* ExtractData(std::map<std::string, Json> const &column,
|
||||
StringView typestr,
|
||||
std::pair<size_t, size_t> shape) {
|
||||
Validate(column);
|
||||
|
||||
auto typestr = get<String const>(column.at("typestr"));
|
||||
CHECK_EQ(typestr.at(1), TypeChar<T>())
|
||||
<< "Input data type and typestr mismatch. typestr: " << typestr;
|
||||
CHECK_EQ(typestr.at(2), static_cast<char>(sizeof(T) + 48))
|
||||
<< "Input data type and typestr mismatch. typestr: " << typestr;
|
||||
|
||||
auto shape = ExtractShape(column);
|
||||
|
||||
T* p_data = ArrayInterfaceHandler::GetPtrFromArrayData<T*>(column);
|
||||
void* p_data = ArrayInterfaceHandler::GetPtrFromArrayData<void*>(column);
|
||||
if (!p_data) {
|
||||
CHECK_EQ(shape.first * shape.second, 0) << "Empty data with non-zero shape.";
|
||||
}
|
||||
return common::Span<T>{p_data, shape.first * shape.second};
|
||||
return p_data;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -236,11 +240,15 @@ class ArrayInterface {
|
||||
void Initialize(std::map<std::string, Json> const &column,
|
||||
bool allow_mask = true) {
|
||||
ArrayInterfaceHandler::Validate(column);
|
||||
data = ArrayInterfaceHandler::GetPtrFromArrayData<void*>(column);
|
||||
auto typestr = get<String const>(column.at("typestr"));
|
||||
this->AssignType(StringView{typestr});
|
||||
|
||||
auto shape = ArrayInterfaceHandler::ExtractShape(column);
|
||||
num_rows = shape.first;
|
||||
num_cols = shape.second;
|
||||
|
||||
data = ArrayInterfaceHandler::ExtractData(column, StringView{typestr}, shape);
|
||||
|
||||
if (allow_mask) {
|
||||
common::Span<RBitField8::value_type> s_mask;
|
||||
size_t n_bits = ArrayInterfaceHandler::ExtractMask(column, &s_mask);
|
||||
@@ -257,8 +265,8 @@ class ArrayInterface {
|
||||
<< "Masked array is not yet supported.";
|
||||
}
|
||||
|
||||
auto typestr = get<String const>(column.at("typestr"));
|
||||
this->AssignType(StringView{typestr});
|
||||
ArrayInterfaceHandler::ExtractStride(column, strides, num_rows, num_cols,
|
||||
typestr[2] - '0');
|
||||
}
|
||||
|
||||
public:
|
||||
@@ -288,6 +296,15 @@ class ArrayInterface {
|
||||
}
|
||||
}
|
||||
|
||||
void AsColumnVector() {
|
||||
CHECK(num_rows == 1 || num_cols == 1) << "Array should be a vector instead of matrix.";
|
||||
num_rows = std::max(num_rows, static_cast<size_t>(num_cols));
|
||||
num_cols = 1;
|
||||
|
||||
strides[0] = std::max(strides[0], strides[1]);
|
||||
strides[1] = 1;
|
||||
}
|
||||
|
||||
void AssignType(StringView typestr) {
|
||||
if (typestr[1] == 'f' && typestr[2] == '4') {
|
||||
type = kF4;
|
||||
@@ -320,95 +337,45 @@ class ArrayInterface {
|
||||
switch (type) {
|
||||
case kF4:
|
||||
return func(reinterpret_cast<float *>(data));
|
||||
break;
|
||||
case kF8:
|
||||
return func(reinterpret_cast<double *>(data));
|
||||
break;
|
||||
case kI1:
|
||||
return func(reinterpret_cast<int8_t *>(data));
|
||||
break;
|
||||
case kI2:
|
||||
return func(reinterpret_cast<int16_t *>(data));
|
||||
break;
|
||||
case kI4:
|
||||
return func(reinterpret_cast<int32_t *>(data));
|
||||
break;
|
||||
case kI8:
|
||||
return func(reinterpret_cast<int64_t *>(data));
|
||||
break;
|
||||
case kU1:
|
||||
return func(reinterpret_cast<uint8_t *>(data));
|
||||
break;
|
||||
case kU2:
|
||||
return func(reinterpret_cast<uint16_t *>(data));
|
||||
break;
|
||||
case kU4:
|
||||
return func(reinterpret_cast<uint32_t *>(data));
|
||||
break;
|
||||
case kU8:
|
||||
return func(reinterpret_cast<uint64_t *>(data));
|
||||
break;
|
||||
}
|
||||
SPAN_CHECK(false);
|
||||
return func(reinterpret_cast<uint64_t *>(data));
|
||||
}
|
||||
|
||||
XGBOOST_DEVICE ArrayInterface SliceOffset(size_t offset) const {
|
||||
void* p_values{nullptr};
|
||||
this->DispatchCall([&p_values, offset](auto *ptr) {
|
||||
p_values = ptr + offset;
|
||||
});
|
||||
|
||||
ArrayInterface ret = *this;
|
||||
ret.data = p_values;
|
||||
return ret;
|
||||
}
|
||||
|
||||
XGBOOST_DEVICE ArrayInterface SliceRow(size_t idx) const {
|
||||
size_t offset = idx * num_cols;
|
||||
auto ret = this->SliceOffset(offset);
|
||||
ret.num_rows = 1;
|
||||
return ret;
|
||||
}
|
||||
|
||||
template <typename T = float>
|
||||
XGBOOST_DEVICE T GetElement(size_t idx) const {
|
||||
SPAN_CHECK(idx < num_cols * num_rows);
|
||||
switch (type) {
|
||||
case kF4:
|
||||
return reinterpret_cast<float*>(data)[idx];
|
||||
case kF8:
|
||||
return reinterpret_cast<double*>(data)[idx];
|
||||
case kI1:
|
||||
return reinterpret_cast<int8_t*>(data)[idx];
|
||||
case kI2:
|
||||
return reinterpret_cast<int16_t*>(data)[idx];
|
||||
case kI4:
|
||||
return reinterpret_cast<int32_t*>(data)[idx];
|
||||
case kI8:
|
||||
return reinterpret_cast<int64_t*>(data)[idx];
|
||||
case kU1:
|
||||
return reinterpret_cast<uint8_t*>(data)[idx];
|
||||
case kU2:
|
||||
return reinterpret_cast<uint16_t*>(data)[idx];
|
||||
case kU4:
|
||||
return reinterpret_cast<uint32_t*>(data)[idx];
|
||||
case kU8:
|
||||
return reinterpret_cast<uint64_t*>(data)[idx];
|
||||
}
|
||||
SPAN_CHECK(false);
|
||||
return reinterpret_cast<float*>(data)[idx];
|
||||
}
|
||||
|
||||
XGBOOST_DEVICE size_t ElementSize() {
|
||||
return this->DispatchCall([](auto* p_values) {
|
||||
return sizeof(std::remove_pointer_t<decltype(p_values)>);
|
||||
});
|
||||
}
|
||||
|
||||
template <typename T = float>
|
||||
XGBOOST_DEVICE T GetElement(size_t r, size_t c) const {
|
||||
return this->DispatchCall(
|
||||
[=](auto *p_values) -> T { return p_values[strides[0] * r + strides[1] * c]; });
|
||||
}
|
||||
|
||||
RBitField8 valid;
|
||||
bst_row_t num_rows;
|
||||
bst_feature_t num_cols;
|
||||
size_t strides[2]{0, 0};
|
||||
void* data;
|
||||
|
||||
Type type;
|
||||
|
||||
Reference in New Issue
Block a user