Don't use mask in array interface. (#5730)

This commit is contained in:
Jiaming Yuan 2020-06-01 12:17:24 +08:00 committed by GitHub
parent 267c1ed784
commit d19cec70f1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 17 additions and 13 deletions

View File

@ -231,7 +231,8 @@ class ArrayInterfaceHandler {
class ArrayInterface { class ArrayInterface {
public: public:
ArrayInterface() = default; ArrayInterface() = default;
explicit ArrayInterface(std::map<std::string, Json> const& column) { explicit ArrayInterface(std::map<std::string, Json> const &column,
bool allow_mask = true) {
ArrayInterfaceHandler::Validate(column); ArrayInterfaceHandler::Validate(column);
data = ArrayInterfaceHandler::GetPtrFromArrayData<void*>(column); data = ArrayInterfaceHandler::GetPtrFromArrayData<void*>(column);
CHECK(data) << "Column is null"; CHECK(data) << "Column is null";
@ -239,6 +240,7 @@ class ArrayInterface {
num_rows = shape.first; num_rows = shape.first;
num_cols = shape.second; num_cols = shape.second;
if (allow_mask) {
common::Span<RBitField8::value_type> s_mask; common::Span<RBitField8::value_type> s_mask;
size_t n_bits = ArrayInterfaceHandler::ExtractMask(column, &s_mask); size_t n_bits = ArrayInterfaceHandler::ExtractMask(column, &s_mask);
@ -249,6 +251,11 @@ class ArrayInterface {
<< "Shape of bit mask doesn't match data shape. " << "Shape of bit mask doesn't match data shape. "
<< "XGBoost doesn't support internal broadcasting."; << "XGBoost doesn't support internal broadcasting.";
} }
} else {
CHECK(column.find("mask") == column.cend())
<< "Masked array is not yet supported.";
}
auto typestr = get<String const>(column.at("typestr")); auto typestr = get<String const>(column.at("typestr"));
type[0] = typestr.at(0); type[0] = typestr.at(0);
type[1] = typestr.at(1); type[1] = typestr.at(1);

View File

@ -177,10 +177,7 @@ class CupyAdapterBatch : public detail::NoMetaInfo {
__device__ COOTuple GetElement(size_t idx) const { __device__ COOTuple GetElement(size_t idx) const {
size_t column_idx = idx % array_interface_.num_cols; size_t column_idx = idx % array_interface_.num_cols;
size_t row_idx = idx / array_interface_.num_cols; size_t row_idx = idx / array_interface_.num_cols;
float value = array_interface_.valid.Data() == nullptr || float value = array_interface_.GetElement(idx);
array_interface_.valid.Check(row_idx)
? array_interface_.GetElement(idx)
: std::numeric_limits<float>::quiet_NaN();
return {row_idx, column_idx, value}; return {row_idx, column_idx, value};
} }
@ -193,7 +190,7 @@ class CupyAdapter : public detail::SingleBatchDataIter<CupyAdapterBatch> {
explicit CupyAdapter(std::string cuda_interface_str) { explicit CupyAdapter(std::string cuda_interface_str) {
Json json_array_interface = Json json_array_interface =
Json::Load({cuda_interface_str.c_str(), cuda_interface_str.size()}); Json::Load({cuda_interface_str.c_str(), cuda_interface_str.size()});
array_interface_ = ArrayInterface(get<Object const>(json_array_interface)); array_interface_ = ArrayInterface(get<Object const>(json_array_interface), false);
device_idx_ = dh::CudaGetPointerDevice(array_interface_.data); device_idx_ = dh::CudaGetPointerDevice(array_interface_.data);
CHECK_NE(device_idx_, -1); CHECK_NE(device_idx_, -1);
batch_ = CupyAdapterBatch(array_interface_); batch_ = CupyAdapterBatch(array_interface_);