Don't use mask in array interface. (#5730)
This commit is contained in:
parent
267c1ed784
commit
d19cec70f1
@ -231,7 +231,8 @@ class ArrayInterfaceHandler {
|
|||||||
class ArrayInterface {
|
class ArrayInterface {
|
||||||
public:
|
public:
|
||||||
ArrayInterface() = default;
|
ArrayInterface() = default;
|
||||||
explicit ArrayInterface(std::map<std::string, Json> const& column) {
|
explicit ArrayInterface(std::map<std::string, Json> const &column,
|
||||||
|
bool allow_mask = true) {
|
||||||
ArrayInterfaceHandler::Validate(column);
|
ArrayInterfaceHandler::Validate(column);
|
||||||
data = ArrayInterfaceHandler::GetPtrFromArrayData<void*>(column);
|
data = ArrayInterfaceHandler::GetPtrFromArrayData<void*>(column);
|
||||||
CHECK(data) << "Column is null";
|
CHECK(data) << "Column is null";
|
||||||
@ -239,6 +240,7 @@ class ArrayInterface {
|
|||||||
num_rows = shape.first;
|
num_rows = shape.first;
|
||||||
num_cols = shape.second;
|
num_cols = shape.second;
|
||||||
|
|
||||||
|
if (allow_mask) {
|
||||||
common::Span<RBitField8::value_type> s_mask;
|
common::Span<RBitField8::value_type> s_mask;
|
||||||
size_t n_bits = ArrayInterfaceHandler::ExtractMask(column, &s_mask);
|
size_t n_bits = ArrayInterfaceHandler::ExtractMask(column, &s_mask);
|
||||||
|
|
||||||
@ -249,6 +251,11 @@ class ArrayInterface {
|
|||||||
<< "Shape of bit mask doesn't match data shape. "
|
<< "Shape of bit mask doesn't match data shape. "
|
||||||
<< "XGBoost doesn't support internal broadcasting.";
|
<< "XGBoost doesn't support internal broadcasting.";
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
CHECK(column.find("mask") == column.cend())
|
||||||
|
<< "Masked array is not yet supported.";
|
||||||
|
}
|
||||||
|
|
||||||
auto typestr = get<String const>(column.at("typestr"));
|
auto typestr = get<String const>(column.at("typestr"));
|
||||||
type[0] = typestr.at(0);
|
type[0] = typestr.at(0);
|
||||||
type[1] = typestr.at(1);
|
type[1] = typestr.at(1);
|
||||||
|
|||||||
@ -177,10 +177,7 @@ class CupyAdapterBatch : public detail::NoMetaInfo {
|
|||||||
__device__ COOTuple GetElement(size_t idx) const {
|
__device__ COOTuple GetElement(size_t idx) const {
|
||||||
size_t column_idx = idx % array_interface_.num_cols;
|
size_t column_idx = idx % array_interface_.num_cols;
|
||||||
size_t row_idx = idx / array_interface_.num_cols;
|
size_t row_idx = idx / array_interface_.num_cols;
|
||||||
float value = array_interface_.valid.Data() == nullptr ||
|
float value = array_interface_.GetElement(idx);
|
||||||
array_interface_.valid.Check(row_idx)
|
|
||||||
? array_interface_.GetElement(idx)
|
|
||||||
: std::numeric_limits<float>::quiet_NaN();
|
|
||||||
return {row_idx, column_idx, value};
|
return {row_idx, column_idx, value};
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -193,7 +190,7 @@ class CupyAdapter : public detail::SingleBatchDataIter<CupyAdapterBatch> {
|
|||||||
explicit CupyAdapter(std::string cuda_interface_str) {
|
explicit CupyAdapter(std::string cuda_interface_str) {
|
||||||
Json json_array_interface =
|
Json json_array_interface =
|
||||||
Json::Load({cuda_interface_str.c_str(), cuda_interface_str.size()});
|
Json::Load({cuda_interface_str.c_str(), cuda_interface_str.size()});
|
||||||
array_interface_ = ArrayInterface(get<Object const>(json_array_interface));
|
array_interface_ = ArrayInterface(get<Object const>(json_array_interface), false);
|
||||||
device_idx_ = dh::CudaGetPointerDevice(array_interface_.data);
|
device_idx_ = dh::CudaGetPointerDevice(array_interface_.data);
|
||||||
CHECK_NE(device_idx_, -1);
|
CHECK_NE(device_idx_, -1);
|
||||||
batch_ = CupyAdapterBatch(array_interface_);
|
batch_ = CupyAdapterBatch(array_interface_);
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user