Support v3 cuda array interface. (#6776)
This commit is contained in:
parent
bcc0277338
commit
794fd6a46b
21
src/data/array_interface.cu
Normal file
21
src/data/array_interface.cu
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
/*!
|
||||||
|
* Copyright 2021 by Contributors
|
||||||
|
*/
|
||||||
|
#include "array_interface.h"
|
||||||
|
#include "../common/common.h"
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
void ArrayInterfaceHandler::SyncCudaStream(int64_t stream) {
|
||||||
|
switch (stream) {
|
||||||
|
case 0:
|
||||||
|
LOG(FATAL) << "Invalid stream ID in array interface: " << stream;
|
||||||
|
case 1:
|
||||||
|
// default legacy stream
|
||||||
|
break;
|
||||||
|
case 2:
|
||||||
|
// default per-thread stream
|
||||||
|
default:
|
||||||
|
dh::safe_cuda(cudaStreamSynchronize(reinterpret_cast<cudaStream_t>(stream)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace xgboost
|
||||||
@ -18,6 +18,7 @@
|
|||||||
#include "xgboost/logging.h"
|
#include "xgboost/logging.h"
|
||||||
#include "xgboost/span.h"
|
#include "xgboost/span.h"
|
||||||
#include "../common/bitfield.h"
|
#include "../common/bitfield.h"
|
||||||
|
#include "../common/common.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
// Common errors in parsing columnar format.
|
// Common errors in parsing columnar format.
|
||||||
@ -41,7 +42,7 @@ struct ArrayInterfaceErrors {
|
|||||||
return str.c_str();
|
return str.c_str();
|
||||||
}
|
}
|
||||||
static char const* Version() {
|
static char const* Version() {
|
||||||
return "Only version 1 and 2 of `__cuda_array_interface__' are supported.";
|
return "Only version <= 3 of `__cuda_array_interface__' are supported.";
|
||||||
}
|
}
|
||||||
static char const* OfType(std::string const& type) {
|
static char const* OfType(std::string const& type) {
|
||||||
static std::string str;
|
static std::string str;
|
||||||
@ -119,9 +120,18 @@ class ArrayInterfaceHandler {
|
|||||||
}
|
}
|
||||||
|
|
||||||
static void Validate(std::map<std::string, Json> const& array) {
|
static void Validate(std::map<std::string, Json> const& array) {
|
||||||
if (array.find("version") == array.cend()) {
|
auto version_it = array.find("version");
|
||||||
|
if (version_it == array.cend()) {
|
||||||
LOG(FATAL) << "Missing `version' field for array interface";
|
LOG(FATAL) << "Missing `version' field for array interface";
|
||||||
}
|
}
|
||||||
|
auto stream_it = array.find("stream");
|
||||||
|
if (stream_it != array.cend() && !IsA<Null>(stream_it->second)) {
|
||||||
|
// is cuda, check the version.
|
||||||
|
if (get<Integer const>(version_it->second) > 3) {
|
||||||
|
LOG(FATAL) << ArrayInterfaceErrors::Version();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (array.find("typestr") == array.cend()) {
|
if (array.find("typestr") == array.cend()) {
|
||||||
LOG(FATAL) << "Missing `typestr' field for array interface";
|
LOG(FATAL) << "Missing `typestr' field for array interface";
|
||||||
}
|
}
|
||||||
@ -233,25 +243,31 @@ class ArrayInterfaceHandler {
|
|||||||
}
|
}
|
||||||
return p_data;
|
return p_data;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void SyncCudaStream(int64_t stream);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
#if !defined(XGBOOST_USE_CUDA)
|
||||||
|
inline void ArrayInterfaceHandler::SyncCudaStream(int64_t stream) {
|
||||||
|
common::AssertGPUSupport();
|
||||||
|
}
|
||||||
|
#endif // !defined(XGBOOST_USE_CUDA)
|
||||||
|
|
||||||
// A view over __array_interface__
|
// A view over __array_interface__
|
||||||
class ArrayInterface {
|
class ArrayInterface {
|
||||||
void Initialize(std::map<std::string, Json> const &column,
|
void Initialize(std::map<std::string, Json> const &array,
|
||||||
bool allow_mask = true) {
|
bool allow_mask = true) {
|
||||||
ArrayInterfaceHandler::Validate(column);
|
ArrayInterfaceHandler::Validate(array);
|
||||||
auto typestr = get<String const>(column.at("typestr"));
|
auto typestr = get<String const>(array.at("typestr"));
|
||||||
this->AssignType(StringView{typestr});
|
this->AssignType(StringView{typestr});
|
||||||
|
|
||||||
auto shape = ArrayInterfaceHandler::ExtractShape(column);
|
std::tie(num_rows, num_cols) = ArrayInterfaceHandler::ExtractShape(array);
|
||||||
num_rows = shape.first;
|
data = ArrayInterfaceHandler::ExtractData(
|
||||||
num_cols = shape.second;
|
array, StringView{typestr}, std::make_pair(num_rows, num_cols));
|
||||||
|
|
||||||
data = ArrayInterfaceHandler::ExtractData(column, StringView{typestr}, shape);
|
|
||||||
|
|
||||||
if (allow_mask) {
|
if (allow_mask) {
|
||||||
common::Span<RBitField8::value_type> s_mask;
|
common::Span<RBitField8::value_type> s_mask;
|
||||||
size_t n_bits = ArrayInterfaceHandler::ExtractMask(column, &s_mask);
|
size_t n_bits = ArrayInterfaceHandler::ExtractMask(array, &s_mask);
|
||||||
|
|
||||||
valid = RBitField8(s_mask);
|
valid = RBitField8(s_mask);
|
||||||
|
|
||||||
@ -261,12 +277,18 @@ class ArrayInterface {
|
|||||||
<< "XGBoost doesn't support internal broadcasting.";
|
<< "XGBoost doesn't support internal broadcasting.";
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
CHECK(column.find("mask") == column.cend())
|
CHECK(array.find("mask") == array.cend())
|
||||||
<< "Masked array is not yet supported.";
|
<< "Masked array is not yet supported.";
|
||||||
}
|
}
|
||||||
|
|
||||||
ArrayInterfaceHandler::ExtractStride(column, strides, num_rows, num_cols,
|
ArrayInterfaceHandler::ExtractStride(array, strides, num_rows, num_cols,
|
||||||
typestr[2] - '0');
|
typestr[2] - '0');
|
||||||
|
|
||||||
|
auto stream_it = array.find("stream");
|
||||||
|
if (stream_it != array.cend() && !IsA<Null>(stream_it->second)) {
|
||||||
|
int64_t stream = get<Integer const>(stream_it->second);
|
||||||
|
ArrayInterfaceHandler::SyncCudaStream(stream);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
@ -377,7 +399,6 @@ class ArrayInterface {
|
|||||||
bst_feature_t num_cols;
|
bst_feature_t num_cols;
|
||||||
size_t strides[2]{0, 0};
|
size_t strides[2]{0, 0};
|
||||||
void* data;
|
void* data;
|
||||||
|
|
||||||
Type type;
|
Type type;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
42
tests/cpp/data/test_array_interface.cu
Normal file
42
tests/cpp/data/test_array_interface.cu
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
/*!
|
||||||
|
* Copyright 2021 by Contributors
|
||||||
|
*/
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
#include <xgboost/host_device_vector.h>
|
||||||
|
#include "../helpers.h"
|
||||||
|
#include "../../../src/data/array_interface.h"
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
|
||||||
|
__global__ void SleepForTest(uint64_t *out, uint64_t duration) {
|
||||||
|
auto start = clock64();
|
||||||
|
auto t = 0;
|
||||||
|
while (t < duration) {
|
||||||
|
t = clock64() - start;
|
||||||
|
}
|
||||||
|
out[0] = t;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ArrayInterface, Stream) {
|
||||||
|
size_t constexpr kRows = 10, kCols = 10;
|
||||||
|
HostDeviceVector<float> storage;
|
||||||
|
auto arr_str = RandomDataGenerator{kRows, kCols, 0}.GenerateArrayInterface(&storage);
|
||||||
|
|
||||||
|
cudaStream_t stream;
|
||||||
|
cudaStreamCreate(&stream);
|
||||||
|
|
||||||
|
auto j_arr =Json::Load(StringView{arr_str});
|
||||||
|
j_arr["stream"] = Integer(reinterpret_cast<int64_t>(stream));
|
||||||
|
Json::Dump(j_arr, &arr_str);
|
||||||
|
|
||||||
|
dh::caching_device_vector<uint64_t> out(1, 0);
|
||||||
|
uint64_t dur = 1e9;
|
||||||
|
dh::LaunchKernel{1, 1, 0, stream}(SleepForTest, out.data().get(), dur);
|
||||||
|
ArrayInterface arr(arr_str);
|
||||||
|
|
||||||
|
auto t = out[0];
|
||||||
|
CHECK_GE(t, dur);
|
||||||
|
|
||||||
|
cudaStreamDestroy(stream);
|
||||||
|
}
|
||||||
|
} // namespace xgboost
|
||||||
Loading…
x
Reference in New Issue
Block a user