Support v3 cuda array interface. (#6776)

This commit is contained in:
Jiaming Yuan
2021-03-25 09:58:09 +08:00
committed by GitHub
parent bcc0277338
commit 794fd6a46b
3 changed files with 98 additions and 14 deletions

View File

@@ -18,6 +18,7 @@
#include "xgboost/logging.h"
#include "xgboost/span.h"
#include "../common/bitfield.h"
#include "../common/common.h"
namespace xgboost {
// Common errors in parsing columnar format.
@@ -41,7 +42,7 @@ struct ArrayInterfaceErrors {
return str.c_str();
}
static char const* Version() {
return "Only version 1 and 2 of `__cuda_array_interface__' are supported.";
return "Only version <= 3 of `__cuda_array_interface__' are supported.";
}
static char const* OfType(std::string const& type) {
static std::string str;
@@ -119,9 +120,18 @@ class ArrayInterfaceHandler {
}
static void Validate(std::map<std::string, Json> const& array) {
if (array.find("version") == array.cend()) {
auto version_it = array.find("version");
if (version_it == array.cend()) {
LOG(FATAL) << "Missing `version' field for array interface";
}
auto stream_it = array.find("stream");
if (stream_it != array.cend() && !IsA<Null>(stream_it->second)) {
// is cuda, check the version.
if (get<Integer const>(version_it->second) > 3) {
LOG(FATAL) << ArrayInterfaceErrors::Version();
}
}
if (array.find("typestr") == array.cend()) {
LOG(FATAL) << "Missing `typestr' field for array interface";
}
@@ -233,25 +243,31 @@ class ArrayInterfaceHandler {
}
return p_data;
}
static void SyncCudaStream(int64_t stream);
};
#if !defined(XGBOOST_USE_CUDA)
inline void ArrayInterfaceHandler::SyncCudaStream(int64_t stream) {
common::AssertGPUSupport();
}
#endif // !defined(XGBOOST_USE_CUDA)
// A view over __array_interface__
class ArrayInterface {
void Initialize(std::map<std::string, Json> const &column,
void Initialize(std::map<std::string, Json> const &array,
bool allow_mask = true) {
ArrayInterfaceHandler::Validate(column);
auto typestr = get<String const>(column.at("typestr"));
ArrayInterfaceHandler::Validate(array);
auto typestr = get<String const>(array.at("typestr"));
this->AssignType(StringView{typestr});
auto shape = ArrayInterfaceHandler::ExtractShape(column);
num_rows = shape.first;
num_cols = shape.second;
data = ArrayInterfaceHandler::ExtractData(column, StringView{typestr}, shape);
std::tie(num_rows, num_cols) = ArrayInterfaceHandler::ExtractShape(array);
data = ArrayInterfaceHandler::ExtractData(
array, StringView{typestr}, std::make_pair(num_rows, num_cols));
if (allow_mask) {
common::Span<RBitField8::value_type> s_mask;
size_t n_bits = ArrayInterfaceHandler::ExtractMask(column, &s_mask);
size_t n_bits = ArrayInterfaceHandler::ExtractMask(array, &s_mask);
valid = RBitField8(s_mask);
@@ -261,12 +277,18 @@ class ArrayInterface {
<< "XGBoost doesn't support internal broadcasting.";
}
} else {
CHECK(column.find("mask") == column.cend())
CHECK(array.find("mask") == array.cend())
<< "Masked array is not yet supported.";
}
ArrayInterfaceHandler::ExtractStride(column, strides, num_rows, num_cols,
ArrayInterfaceHandler::ExtractStride(array, strides, num_rows, num_cols,
typestr[2] - '0');
auto stream_it = array.find("stream");
if (stream_it != array.cend() && !IsA<Null>(stream_it->second)) {
int64_t stream = get<Integer const>(stream_it->second);
ArrayInterfaceHandler::SyncCudaStream(stream);
}
}
public:
@@ -377,7 +399,6 @@ class ArrayInterface {
bst_feature_t num_cols;
size_t strides[2]{0, 0};
void* data;
Type type;
};