parent
9372370dda
commit
db14e3feb7
@ -101,7 +101,7 @@ class ArrayInterfaceHandler {
|
|||||||
template <typename PtrType>
|
template <typename PtrType>
|
||||||
static PtrType GetPtrFromArrayData(Object::Map const &obj) {
|
static PtrType GetPtrFromArrayData(Object::Map const &obj) {
|
||||||
auto data_it = obj.find("data");
|
auto data_it = obj.find("data");
|
||||||
if (data_it == obj.cend()) {
|
if (data_it == obj.cend() || IsA<Null>(data_it->second)) {
|
||||||
LOG(FATAL) << "Empty data passed in.";
|
LOG(FATAL) << "Empty data passed in.";
|
||||||
}
|
}
|
||||||
auto p_data = reinterpret_cast<PtrType>(
|
auto p_data = reinterpret_cast<PtrType>(
|
||||||
@ -111,7 +111,7 @@ class ArrayInterfaceHandler {
|
|||||||
|
|
||||||
static void Validate(Object::Map const &array) {
|
static void Validate(Object::Map const &array) {
|
||||||
auto version_it = array.find("version");
|
auto version_it = array.find("version");
|
||||||
if (version_it == array.cend()) {
|
if (version_it == array.cend() || IsA<Null>(version_it->second)) {
|
||||||
LOG(FATAL) << "Missing `version' field for array interface";
|
LOG(FATAL) << "Missing `version' field for array interface";
|
||||||
}
|
}
|
||||||
if (get<Integer const>(version_it->second) > 3) {
|
if (get<Integer const>(version_it->second) > 3) {
|
||||||
@ -119,17 +119,19 @@ class ArrayInterfaceHandler {
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto typestr_it = array.find("typestr");
|
auto typestr_it = array.find("typestr");
|
||||||
if (typestr_it == array.cend()) {
|
if (typestr_it == array.cend() || IsA<Null>(typestr_it->second)) {
|
||||||
LOG(FATAL) << "Missing `typestr' field for array interface";
|
LOG(FATAL) << "Missing `typestr' field for array interface";
|
||||||
}
|
}
|
||||||
|
|
||||||
auto typestr = get<String const>(typestr_it->second);
|
auto typestr = get<String const>(typestr_it->second);
|
||||||
CHECK(typestr.size() == 3 || typestr.size() == 4) << ArrayInterfaceErrors::TypestrFormat();
|
CHECK(typestr.size() == 3 || typestr.size() == 4) << ArrayInterfaceErrors::TypestrFormat();
|
||||||
|
|
||||||
if (array.find("shape") == array.cend()) {
|
auto shape_it = array.find("shape");
|
||||||
|
if (shape_it == array.cend() || IsA<Null>(shape_it->second)) {
|
||||||
LOG(FATAL) << "Missing `shape' field for array interface";
|
LOG(FATAL) << "Missing `shape' field for array interface";
|
||||||
}
|
}
|
||||||
if (array.find("data") == array.cend()) {
|
auto data_it = array.find("data");
|
||||||
|
if (data_it == array.cend() || IsA<Null>(data_it->second)) {
|
||||||
LOG(FATAL) << "Missing `data' field for array interface";
|
LOG(FATAL) << "Missing `data' field for array interface";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -139,8 +141,9 @@ class ArrayInterfaceHandler {
|
|||||||
static size_t ExtractMask(Object::Map const &column,
|
static size_t ExtractMask(Object::Map const &column,
|
||||||
common::Span<RBitField8::value_type> *p_out) {
|
common::Span<RBitField8::value_type> *p_out) {
|
||||||
auto &s_mask = *p_out;
|
auto &s_mask = *p_out;
|
||||||
if (column.find("mask") != column.cend()) {
|
auto const &mask_it = column.find("mask");
|
||||||
auto const &j_mask = get<Object const>(column.at("mask"));
|
if (mask_it != column.cend() && !IsA<Null>(mask_it->second)) {
|
||||||
|
auto const &j_mask = get<Object const>(mask_it->second);
|
||||||
Validate(j_mask);
|
Validate(j_mask);
|
||||||
|
|
||||||
auto p_mask = GetPtrFromArrayData<RBitField8::value_type *>(j_mask);
|
auto p_mask = GetPtrFromArrayData<RBitField8::value_type *>(j_mask);
|
||||||
@ -173,8 +176,9 @@ class ArrayInterfaceHandler {
|
|||||||
// assume 1 byte alignment.
|
// assume 1 byte alignment.
|
||||||
size_t const span_size = RBitField8::ComputeStorageSize(n_bits);
|
size_t const span_size = RBitField8::ComputeStorageSize(n_bits);
|
||||||
|
|
||||||
if (j_mask.find("strides") != j_mask.cend()) {
|
auto strides_it = j_mask.find("strides");
|
||||||
auto strides = get<Array const>(column.at("strides"));
|
if (strides_it != j_mask.cend() && !IsA<Null>(strides_it->second)) {
|
||||||
|
auto strides = get<Array const>(strides_it->second);
|
||||||
CHECK_EQ(strides.size(), 1) << ArrayInterfaceErrors::Dimension(1);
|
CHECK_EQ(strides.size(), 1) << ArrayInterfaceErrors::Dimension(1);
|
||||||
CHECK_EQ(get<Integer>(strides.at(0)), type_length) << ArrayInterfaceErrors::Contiguous();
|
CHECK_EQ(get<Integer>(strides.at(0)), type_length) << ArrayInterfaceErrors::Contiguous();
|
||||||
}
|
}
|
||||||
@ -401,7 +405,9 @@ class ArrayInterface {
|
|||||||
<< "XGBoost doesn't support internal broadcasting.";
|
<< "XGBoost doesn't support internal broadcasting.";
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
CHECK(array.find("mask") == array.cend()) << "Masked array is not yet supported.";
|
auto mask_it = array.find("mask");
|
||||||
|
CHECK(mask_it == array.cend() || IsA<Null>(mask_it->second))
|
||||||
|
<< "Masked array is not yet supported.";
|
||||||
}
|
}
|
||||||
|
|
||||||
auto stream_it = array.find("stream");
|
auto stream_it = array.find("stream");
|
||||||
|
|||||||
@ -33,8 +33,7 @@ TEST(ArrayInterface, Error) {
|
|||||||
Json column { Object() };
|
Json column { Object() };
|
||||||
std::vector<Json> j_shape {Json(Integer(static_cast<Integer::Int>(kRows)))};
|
std::vector<Json> j_shape {Json(Integer(static_cast<Integer::Int>(kRows)))};
|
||||||
column["shape"] = Array(j_shape);
|
column["shape"] = Array(j_shape);
|
||||||
std::vector<Json> j_data {
|
std::vector<Json> j_data{Json(Integer(reinterpret_cast<Integer::Int>(nullptr))),
|
||||||
Json(Integer(reinterpret_cast<Integer::Int>(nullptr))),
|
|
||||||
Json(Boolean(false))};
|
Json(Boolean(false))};
|
||||||
|
|
||||||
auto const& column_obj = get<Object>(column);
|
auto const& column_obj = get<Object>(column);
|
||||||
@ -45,6 +44,10 @@ TEST(ArrayInterface, Error) {
|
|||||||
EXPECT_THROW(ArrayInterfaceHandler::ExtractData(column_obj, n), dmlc::Error);
|
EXPECT_THROW(ArrayInterfaceHandler::ExtractData(column_obj, n), dmlc::Error);
|
||||||
column["version"] = 3;
|
column["version"] = 3;
|
||||||
// missing data
|
// missing data
|
||||||
|
EXPECT_THROW(ArrayInterfaceHandler::ExtractData(column_obj, n),
|
||||||
|
dmlc::Error);
|
||||||
|
// null data
|
||||||
|
column["data"] = Null{};
|
||||||
EXPECT_THROW(ArrayInterfaceHandler::ExtractData(column_obj, n),
|
EXPECT_THROW(ArrayInterfaceHandler::ExtractData(column_obj, n),
|
||||||
dmlc::Error);
|
dmlc::Error);
|
||||||
column["data"] = j_data;
|
column["data"] = j_data;
|
||||||
@ -63,6 +66,11 @@ TEST(ArrayInterface, Error) {
|
|||||||
Json(Boolean(false))};
|
Json(Boolean(false))};
|
||||||
column["data"] = j_data;
|
column["data"] = j_data;
|
||||||
EXPECT_NO_THROW(ArrayInterfaceHandler::ExtractData(column_obj, n));
|
EXPECT_NO_THROW(ArrayInterfaceHandler::ExtractData(column_obj, n));
|
||||||
|
// null data in mask
|
||||||
|
column["mask"] = Object{};
|
||||||
|
column["mask"]["data"] = Null{};
|
||||||
|
common::Span<RBitField8::value_type> s_mask;
|
||||||
|
EXPECT_THROW(ArrayInterfaceHandler::ExtractMask(column_obj, &s_mask), dmlc::Error);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(ArrayInterface, GetElement) {
|
TEST(ArrayInterface, GetElement) {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user