Don't use mask in array interface. (#5730)

This commit is contained in:
Jiaming Yuan
2020-06-01 12:17:24 +08:00
committed by GitHub
parent 267c1ed784
commit d19cec70f1
2 changed files with 17 additions and 13 deletions

View File

@@ -231,7 +231,8 @@ class ArrayInterfaceHandler {
class ArrayInterface {
public:
ArrayInterface() = default;
explicit ArrayInterface(std::map<std::string, Json> const& column) {
explicit ArrayInterface(std::map<std::string, Json> const &column,
bool allow_mask = true) {
ArrayInterfaceHandler::Validate(column);
data = ArrayInterfaceHandler::GetPtrFromArrayData<void*>(column);
CHECK(data) << "Column is null";
@@ -239,16 +240,22 @@ class ArrayInterface {
num_rows = shape.first;
num_cols = shape.second;
common::Span<RBitField8::value_type> s_mask;
size_t n_bits = ArrayInterfaceHandler::ExtractMask(column, &s_mask);
if (allow_mask) {
common::Span<RBitField8::value_type> s_mask;
size_t n_bits = ArrayInterfaceHandler::ExtractMask(column, &s_mask);
valid = RBitField8(s_mask);
valid = RBitField8(s_mask);
if (s_mask.data()) {
CHECK_EQ(n_bits, num_rows)
<< "Shape of bit mask doesn't match data shape. "
<< "XGBoost doesn't support internal broadcasting.";
if (s_mask.data()) {
CHECK_EQ(n_bits, num_rows)
<< "Shape of bit mask doesn't match data shape. "
<< "XGBoost doesn't support internal broadcasting.";
}
} else {
CHECK(column.find("mask") == column.cend())
<< "Masked array is not yet supported.";
}
auto typestr = get<String const>(column.at("typestr"));
type[0] = typestr.at(0);
type[1] = typestr.at(1);