parent
9372370dda
commit
db14e3feb7
@ -101,7 +101,7 @@ class ArrayInterfaceHandler {
|
||||
template <typename PtrType>
|
||||
static PtrType GetPtrFromArrayData(Object::Map const &obj) {
|
||||
auto data_it = obj.find("data");
|
||||
if (data_it == obj.cend()) {
|
||||
if (data_it == obj.cend() || IsA<Null>(data_it->second)) {
|
||||
LOG(FATAL) << "Empty data passed in.";
|
||||
}
|
||||
auto p_data = reinterpret_cast<PtrType>(
|
||||
@ -111,7 +111,7 @@ class ArrayInterfaceHandler {
|
||||
|
||||
static void Validate(Object::Map const &array) {
|
||||
auto version_it = array.find("version");
|
||||
if (version_it == array.cend()) {
|
||||
if (version_it == array.cend() || IsA<Null>(version_it->second)) {
|
||||
LOG(FATAL) << "Missing `version' field for array interface";
|
||||
}
|
||||
if (get<Integer const>(version_it->second) > 3) {
|
||||
@ -119,17 +119,19 @@ class ArrayInterfaceHandler {
|
||||
}
|
||||
|
||||
auto typestr_it = array.find("typestr");
|
||||
if (typestr_it == array.cend()) {
|
||||
if (typestr_it == array.cend() || IsA<Null>(typestr_it->second)) {
|
||||
LOG(FATAL) << "Missing `typestr' field for array interface";
|
||||
}
|
||||
|
||||
auto typestr = get<String const>(typestr_it->second);
|
||||
CHECK(typestr.size() == 3 || typestr.size() == 4) << ArrayInterfaceErrors::TypestrFormat();
|
||||
|
||||
if (array.find("shape") == array.cend()) {
|
||||
auto shape_it = array.find("shape");
|
||||
if (shape_it == array.cend() || IsA<Null>(shape_it->second)) {
|
||||
LOG(FATAL) << "Missing `shape' field for array interface";
|
||||
}
|
||||
if (array.find("data") == array.cend()) {
|
||||
auto data_it = array.find("data");
|
||||
if (data_it == array.cend() || IsA<Null>(data_it->second)) {
|
||||
LOG(FATAL) << "Missing `data' field for array interface";
|
||||
}
|
||||
}
|
||||
@ -139,8 +141,9 @@ class ArrayInterfaceHandler {
|
||||
static size_t ExtractMask(Object::Map const &column,
|
||||
common::Span<RBitField8::value_type> *p_out) {
|
||||
auto &s_mask = *p_out;
|
||||
if (column.find("mask") != column.cend()) {
|
||||
auto const &j_mask = get<Object const>(column.at("mask"));
|
||||
auto const &mask_it = column.find("mask");
|
||||
if (mask_it != column.cend() && !IsA<Null>(mask_it->second)) {
|
||||
auto const &j_mask = get<Object const>(mask_it->second);
|
||||
Validate(j_mask);
|
||||
|
||||
auto p_mask = GetPtrFromArrayData<RBitField8::value_type *>(j_mask);
|
||||
@ -173,8 +176,9 @@ class ArrayInterfaceHandler {
|
||||
// assume 1 byte alignment.
|
||||
size_t const span_size = RBitField8::ComputeStorageSize(n_bits);
|
||||
|
||||
if (j_mask.find("strides") != j_mask.cend()) {
|
||||
auto strides = get<Array const>(column.at("strides"));
|
||||
auto strides_it = j_mask.find("strides");
|
||||
if (strides_it != j_mask.cend() && !IsA<Null>(strides_it->second)) {
|
||||
auto strides = get<Array const>(strides_it->second);
|
||||
CHECK_EQ(strides.size(), 1) << ArrayInterfaceErrors::Dimension(1);
|
||||
CHECK_EQ(get<Integer>(strides.at(0)), type_length) << ArrayInterfaceErrors::Contiguous();
|
||||
}
|
||||
@ -401,7 +405,9 @@ class ArrayInterface {
|
||||
<< "XGBoost doesn't support internal broadcasting.";
|
||||
}
|
||||
} else {
|
||||
CHECK(array.find("mask") == array.cend()) << "Masked array is not yet supported.";
|
||||
auto mask_it = array.find("mask");
|
||||
CHECK(mask_it == array.cend() || IsA<Null>(mask_it->second))
|
||||
<< "Masked array is not yet supported.";
|
||||
}
|
||||
|
||||
auto stream_it = array.find("stream");
|
||||
|
||||
@ -33,9 +33,8 @@ TEST(ArrayInterface, Error) {
|
||||
Json column { Object() };
|
||||
std::vector<Json> j_shape {Json(Integer(static_cast<Integer::Int>(kRows)))};
|
||||
column["shape"] = Array(j_shape);
|
||||
std::vector<Json> j_data {
|
||||
Json(Integer(reinterpret_cast<Integer::Int>(nullptr))),
|
||||
Json(Boolean(false))};
|
||||
std::vector<Json> j_data{Json(Integer(reinterpret_cast<Integer::Int>(nullptr))),
|
||||
Json(Boolean(false))};
|
||||
|
||||
auto const& column_obj = get<Object>(column);
|
||||
std::string typestr{"<f4"};
|
||||
@ -45,6 +44,10 @@ TEST(ArrayInterface, Error) {
|
||||
EXPECT_THROW(ArrayInterfaceHandler::ExtractData(column_obj, n), dmlc::Error);
|
||||
column["version"] = 3;
|
||||
// missing data
|
||||
EXPECT_THROW(ArrayInterfaceHandler::ExtractData(column_obj, n),
|
||||
dmlc::Error);
|
||||
// null data
|
||||
column["data"] = Null{};
|
||||
EXPECT_THROW(ArrayInterfaceHandler::ExtractData(column_obj, n),
|
||||
dmlc::Error);
|
||||
column["data"] = j_data;
|
||||
@ -63,6 +66,11 @@ TEST(ArrayInterface, Error) {
|
||||
Json(Boolean(false))};
|
||||
column["data"] = j_data;
|
||||
EXPECT_NO_THROW(ArrayInterfaceHandler::ExtractData(column_obj, n));
|
||||
// null data in mask
|
||||
column["mask"] = Object{};
|
||||
column["mask"]["data"] = Null{};
|
||||
common::Span<RBitField8::value_type> s_mask;
|
||||
EXPECT_THROW(ArrayInterfaceHandler::ExtractMask(column_obj, &s_mask), dmlc::Error);
|
||||
}
|
||||
|
||||
TEST(ArrayInterface, GetElement) {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user