Support null value in CUDA array interface. (#8486) (#8499)

This commit is contained in:
Philip Hyunsu Cho 2022-11-30 11:44:54 -08:00 committed by GitHub
parent 9372370dda
commit db14e3feb7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 27 additions and 13 deletions

View File

@ -101,7 +101,7 @@ class ArrayInterfaceHandler {
template <typename PtrType> template <typename PtrType>
static PtrType GetPtrFromArrayData(Object::Map const &obj) { static PtrType GetPtrFromArrayData(Object::Map const &obj) {
auto data_it = obj.find("data"); auto data_it = obj.find("data");
if (data_it == obj.cend()) { if (data_it == obj.cend() || IsA<Null>(data_it->second)) {
LOG(FATAL) << "Empty data passed in."; LOG(FATAL) << "Empty data passed in.";
} }
auto p_data = reinterpret_cast<PtrType>( auto p_data = reinterpret_cast<PtrType>(
@ -111,7 +111,7 @@ class ArrayInterfaceHandler {
static void Validate(Object::Map const &array) { static void Validate(Object::Map const &array) {
auto version_it = array.find("version"); auto version_it = array.find("version");
if (version_it == array.cend()) { if (version_it == array.cend() || IsA<Null>(version_it->second)) {
LOG(FATAL) << "Missing `version' field for array interface"; LOG(FATAL) << "Missing `version' field for array interface";
} }
if (get<Integer const>(version_it->second) > 3) { if (get<Integer const>(version_it->second) > 3) {
@ -119,17 +119,19 @@ class ArrayInterfaceHandler {
} }
auto typestr_it = array.find("typestr"); auto typestr_it = array.find("typestr");
if (typestr_it == array.cend()) { if (typestr_it == array.cend() || IsA<Null>(typestr_it->second)) {
LOG(FATAL) << "Missing `typestr' field for array interface"; LOG(FATAL) << "Missing `typestr' field for array interface";
} }
auto typestr = get<String const>(typestr_it->second); auto typestr = get<String const>(typestr_it->second);
CHECK(typestr.size() == 3 || typestr.size() == 4) << ArrayInterfaceErrors::TypestrFormat(); CHECK(typestr.size() == 3 || typestr.size() == 4) << ArrayInterfaceErrors::TypestrFormat();
if (array.find("shape") == array.cend()) { auto shape_it = array.find("shape");
if (shape_it == array.cend() || IsA<Null>(shape_it->second)) {
LOG(FATAL) << "Missing `shape' field for array interface"; LOG(FATAL) << "Missing `shape' field for array interface";
} }
if (array.find("data") == array.cend()) { auto data_it = array.find("data");
if (data_it == array.cend() || IsA<Null>(data_it->second)) {
LOG(FATAL) << "Missing `data' field for array interface"; LOG(FATAL) << "Missing `data' field for array interface";
} }
} }
@ -139,8 +141,9 @@ class ArrayInterfaceHandler {
static size_t ExtractMask(Object::Map const &column, static size_t ExtractMask(Object::Map const &column,
common::Span<RBitField8::value_type> *p_out) { common::Span<RBitField8::value_type> *p_out) {
auto &s_mask = *p_out; auto &s_mask = *p_out;
if (column.find("mask") != column.cend()) { auto const &mask_it = column.find("mask");
auto const &j_mask = get<Object const>(column.at("mask")); if (mask_it != column.cend() && !IsA<Null>(mask_it->second)) {
auto const &j_mask = get<Object const>(mask_it->second);
Validate(j_mask); Validate(j_mask);
auto p_mask = GetPtrFromArrayData<RBitField8::value_type *>(j_mask); auto p_mask = GetPtrFromArrayData<RBitField8::value_type *>(j_mask);
@ -173,8 +176,9 @@ class ArrayInterfaceHandler {
// assume 1 byte alignment. // assume 1 byte alignment.
size_t const span_size = RBitField8::ComputeStorageSize(n_bits); size_t const span_size = RBitField8::ComputeStorageSize(n_bits);
if (j_mask.find("strides") != j_mask.cend()) { auto strides_it = j_mask.find("strides");
auto strides = get<Array const>(column.at("strides")); if (strides_it != j_mask.cend() && !IsA<Null>(strides_it->second)) {
auto strides = get<Array const>(strides_it->second);
CHECK_EQ(strides.size(), 1) << ArrayInterfaceErrors::Dimension(1); CHECK_EQ(strides.size(), 1) << ArrayInterfaceErrors::Dimension(1);
CHECK_EQ(get<Integer>(strides.at(0)), type_length) << ArrayInterfaceErrors::Contiguous(); CHECK_EQ(get<Integer>(strides.at(0)), type_length) << ArrayInterfaceErrors::Contiguous();
} }
@ -401,7 +405,9 @@ class ArrayInterface {
<< "XGBoost doesn't support internal broadcasting."; << "XGBoost doesn't support internal broadcasting.";
} }
} else { } else {
CHECK(array.find("mask") == array.cend()) << "Masked array is not yet supported."; auto mask_it = array.find("mask");
CHECK(mask_it == array.cend() || IsA<Null>(mask_it->second))
<< "Masked array is not yet supported.";
} }
auto stream_it = array.find("stream"); auto stream_it = array.find("stream");

View File

@ -33,8 +33,7 @@ TEST(ArrayInterface, Error) {
Json column { Object() }; Json column { Object() };
std::vector<Json> j_shape {Json(Integer(static_cast<Integer::Int>(kRows)))}; std::vector<Json> j_shape {Json(Integer(static_cast<Integer::Int>(kRows)))};
column["shape"] = Array(j_shape); column["shape"] = Array(j_shape);
std::vector<Json> j_data { std::vector<Json> j_data{Json(Integer(reinterpret_cast<Integer::Int>(nullptr))),
Json(Integer(reinterpret_cast<Integer::Int>(nullptr))),
Json(Boolean(false))}; Json(Boolean(false))};
auto const& column_obj = get<Object>(column); auto const& column_obj = get<Object>(column);
@ -45,6 +44,10 @@ TEST(ArrayInterface, Error) {
EXPECT_THROW(ArrayInterfaceHandler::ExtractData(column_obj, n), dmlc::Error); EXPECT_THROW(ArrayInterfaceHandler::ExtractData(column_obj, n), dmlc::Error);
column["version"] = 3; column["version"] = 3;
// missing data // missing data
EXPECT_THROW(ArrayInterfaceHandler::ExtractData(column_obj, n),
dmlc::Error);
// null data
column["data"] = Null{};
EXPECT_THROW(ArrayInterfaceHandler::ExtractData(column_obj, n), EXPECT_THROW(ArrayInterfaceHandler::ExtractData(column_obj, n),
dmlc::Error); dmlc::Error);
column["data"] = j_data; column["data"] = j_data;
@ -63,6 +66,11 @@ TEST(ArrayInterface, Error) {
Json(Boolean(false))}; Json(Boolean(false))};
column["data"] = j_data; column["data"] = j_data;
EXPECT_NO_THROW(ArrayInterfaceHandler::ExtractData(column_obj, n)); EXPECT_NO_THROW(ArrayInterfaceHandler::ExtractData(column_obj, n));
// null data in mask
column["mask"] = Object{};
column["mask"]["data"] = Null{};
common::Span<RBitField8::value_type> s_mask;
EXPECT_THROW(ArrayInterfaceHandler::ExtractMask(column_obj, &s_mask), dmlc::Error);
} }
TEST(ArrayInterface, GetElement) { TEST(ArrayInterface, GetElement) {