Use dynamic types for array interface columns instead of templates (#5108)

This commit is contained in:
Rory Mitchell 2019-12-21 16:08:10 +13:00 committed by GitHub
parent b915788708
commit 3d04a8cc97
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 149 additions and 109 deletions

View File

@ -14,21 +14,9 @@
#include "xgboost/json.h" #include "xgboost/json.h"
#include "xgboost/logging.h" #include "xgboost/logging.h"
#include "xgboost/span.h" #include "xgboost/span.h"
#include "../common/bitfield.h" #include "../common/bitfield.h"
namespace xgboost { namespace xgboost {
// A view over __array_interface__
template <typename T>
struct Columnar {
using mask_type = unsigned char;
using index_type = int32_t;
common::Span<T> data;
RBitField8 valid;
int32_t size;
};
// Common errors in parsing columnar format. // Common errors in parsing columnar format.
struct ColumnarErrors { struct ColumnarErrors {
static char const* Contigious() { static char const* Contigious() {
@ -97,8 +85,8 @@ struct ColumnarErrors {
} }
} }
static std::string UnSupportedType(std::string const& typestr) { static std::string UnSupportedType(const char (&typestr)[3]) {
return TypeStr(typestr.at(1)) + " is not supported."; return TypeStr(typestr[1]) + " is not supported.";
} }
}; };
@ -200,6 +188,19 @@ class ArrayInterfaceHandler {
return 0; return 0;
} }
static size_t ExtractLength(std::map<std::string, Json> const& column) {
auto j_shape = get<Array const>(column.at("shape"));
CHECK_EQ(j_shape.size(), 1) << ColumnarErrors::Dimension(1);
auto typestr = get<String const>(column.at("typestr"));
if (column.find("strides") != column.cend()) {
auto strides = get<Array const>(column.at("strides"));
CHECK_EQ(strides.size(), 1) << ColumnarErrors::Dimension(1);
CHECK_EQ(get<Integer>(strides.at(0)), typestr.at(2) - '0')
<< ColumnarErrors::Contigious();
}
return static_cast<size_t>(get<Integer const>(j_shape.at(0)));
}
template <typename T> template <typename T>
static common::Span<T> ExtractData(std::map<std::string, Json> const& column) { static common::Span<T> ExtractData(std::map<std::string, Json> const& column) {
Validate(column); Validate(column);
@ -210,70 +211,102 @@ class ArrayInterfaceHandler {
CHECK_EQ(typestr.at(2), static_cast<char>(sizeof(T) + 48)) CHECK_EQ(typestr.at(2), static_cast<char>(sizeof(T) + 48))
<< "Input data type and typestr mismatch. typestr: " << typestr; << "Input data type and typestr mismatch. typestr: " << typestr;
auto j_shape = get<Array const>(column.at("shape"));
CHECK_EQ(j_shape.size(), 1) << ColumnarErrors::Dimension(1);
if (column.find("strides") != column.cend()) { auto length = ExtractLength(column);
auto strides = get<Array const>(column.at("strides"));
CHECK_EQ(strides.size(), 1) << ColumnarErrors::Dimension(1);
CHECK_EQ(get<Integer>(strides.at(0)), sizeof(T)) << ColumnarErrors::Contigious();
}
auto length = static_cast<size_t>(get<Integer const>(j_shape.at(0)));
T* p_data = ArrayInterfaceHandler::GetPtrFromArrayData<T*>(column); T* p_data = ArrayInterfaceHandler::GetPtrFromArrayData<T*>(column);
return common::Span<T>{p_data, length}; return common::Span<T>{p_data, length};
} }
};
template <typename T> // A view over __array_interface__
static Columnar<T> ExtractArray(std::map<std::string, Json> const& column) { class Columnar {
common::Span<T> s_data { ArrayInterfaceHandler::ExtractData<T>(column) }; using mask_type = unsigned char;
using index_type = int32_t;
Columnar<T> foreign_col; public:
foreign_col.data = s_data; explicit Columnar(std::map<std::string, Json> const& column) {
foreign_col.size = s_data.size(); ArrayInterfaceHandler::Validate(column);
data = ArrayInterfaceHandler::GetPtrFromArrayData<void*>(column);
CHECK(data) << "Column is null";
size = ArrayInterfaceHandler::ExtractLength(column);
common::Span<RBitField8::value_type> s_mask; common::Span<RBitField8::value_type> s_mask;
size_t n_bits = ArrayInterfaceHandler::ExtractMask(column, &s_mask); size_t n_bits = ArrayInterfaceHandler::ExtractMask(column, &s_mask);
foreign_col.valid = RBitField8(s_mask); valid = RBitField8(s_mask);
if (s_mask.data()) { if (s_mask.data()) {
CHECK_EQ(n_bits, foreign_col.data.size()) CHECK_EQ(n_bits, size)
<< "Shape of bit mask doesn't match data shape. " << "Shape of bit mask doesn't match data shape. "
<< "XGBoost doesn't support internal broadcasting."; << "XGBoost doesn't support internal broadcasting.";
} }
auto typestr = get<String const>(column.at("typestr"));
return foreign_col; type[0] = typestr.at(0);
type[1] = typestr.at(1);
type[2] = typestr.at(2);
this->CheckType();
} }
void CheckType() const {
if (type[1] == 'f' && type[2] == '4') {
return;
} else if (type[1] == 'f' && type[2] == '8') {
return;
} else if (type[1] == 'i' && type[2] == '1') {
return;
} else if (type[1] == 'i' && type[2] == '2') {
return;
} else if (type[1] == 'i' && type[2] == '4') {
return;
} else if (type[1] == 'i' && type[2] == '8') {
return;
} else if (type[1] == 'u' && type[2] == '1') {
return;
} else if (type[1] == 'u' && type[2] == '2') {
return;
} else if (type[1] == 'u' && type[2] == '4') {
return;
} else if (type[1] == 'u' && type[2] == '8') {
return;
} else {
LOG(FATAL) << ColumnarErrors::UnSupportedType(type);
return;
}
}
XGBOOST_DEVICE float GetElement(size_t idx) const {
if (type[1] == 'f' && type[2] == '4') {
return reinterpret_cast<float*>(data)[idx];
} else if (type[1] == 'f' && type[2] == '8') {
return reinterpret_cast<double*>(data)[idx];
} else if (type[1] == 'i' && type[2] == '1') {
return reinterpret_cast<int8_t*>(data)[idx];
} else if (type[1] == 'i' && type[2] == '2') {
return reinterpret_cast<int16_t*>(data)[idx];
} else if (type[1] == 'i' && type[2] == '4') {
return reinterpret_cast<int32_t*>(data)[idx];
} else if (type[1] == 'i' && type[2] == '8') {
return reinterpret_cast<int64_t*>(data)[idx];
} else if (type[1] == 'u' && type[2] == '1') {
return reinterpret_cast<uint8_t*>(data)[idx];
} else if (type[1] == 'u' && type[2] == '2') {
return reinterpret_cast<uint16_t*>(data)[idx];
} else if (type[1] == 'u' && type[2] == '4') {
return reinterpret_cast<uint32_t*>(data)[idx];
} else if (type[1] == 'u' && type[2] == '8') {
return reinterpret_cast<uint64_t*>(data)[idx];
} else {
SPAN_CHECK(false);
return 0;
}
}
RBitField8 valid;
int32_t size;
void* data;
char type[3];
}; };
#define DISPATCH_TYPE(__dispatched_func, __typestr, ...) { \
CHECK_EQ(__typestr.size(), 3) << ColumnarErrors::TypestrFormat(); \
if (__typestr.at(1) == 'f' && __typestr.at(2) == '4') { \
__dispatched_func<float>(__VA_ARGS__); \
} else if (__typestr.at(1) == 'f' && __typestr.at(2) == '8') { \
__dispatched_func<double>(__VA_ARGS__); \
} else if (__typestr.at(1) == 'i' && __typestr.at(2) == '1') { \
__dispatched_func<int8_t>(__VA_ARGS__); \
} else if (__typestr.at(1) == 'i' && __typestr.at(2) == '2') { \
__dispatched_func<int16_t>(__VA_ARGS__); \
} else if (__typestr.at(1) == 'i' && __typestr.at(2) == '4') { \
__dispatched_func<int32_t>(__VA_ARGS__); \
} else if (__typestr.at(1) == 'i' && __typestr.at(2) == '8') { \
__dispatched_func<int64_t>(__VA_ARGS__); \
} else if (__typestr.at(1) == 'u' && __typestr.at(2) == '1') { \
__dispatched_func<uint8_t>(__VA_ARGS__); \
} else if (__typestr.at(1) == 'u' && __typestr.at(2) == '2') { \
__dispatched_func<uint16_t>(__VA_ARGS__); \
} else if (__typestr.at(1) == 'u' && __typestr.at(2) == '4') { \
__dispatched_func<uint32_t>(__VA_ARGS__); \
} else if (__typestr.at(1) == 'u' && __typestr.at(2) == '8') { \
__dispatched_func<uint64_t>(__VA_ARGS__); \
} else { \
LOG(FATAL) << ColumnarErrors::UnSupportedType(__typestr); \
} \
}
} // namespace xgboost } // namespace xgboost
#endif // XGBOOST_DATA_COLUMNAR_H_ #endif // XGBOOST_DATA_COLUMNAR_H_

View File

@ -12,7 +12,6 @@
namespace xgboost { namespace xgboost {
template <typename T>
void CopyInfoImpl(std::map<std::string, Json> const& column, HostDeviceVector<float>* out) { void CopyInfoImpl(std::map<std::string, Json> const& column, HostDeviceVector<float>* out) {
auto SetDeviceToPtr = [](void* ptr) { auto SetDeviceToPtr = [](void* ptr) {
cudaPointerAttributes attr; cudaPointerAttributes attr;
@ -21,17 +20,17 @@ void CopyInfoImpl(std::map<std::string, Json> const& column, HostDeviceVector<fl
dh::safe_cuda(cudaSetDevice(ptr_device)); dh::safe_cuda(cudaSetDevice(ptr_device));
return ptr_device; return ptr_device;
}; };
Columnar foreign_column(column);
auto ptr_device = SetDeviceToPtr(foreign_column.data);
common::Span<T> s_data { ArrayInterfaceHandler::ExtractData<T>(column) };
auto ptr_device = SetDeviceToPtr(s_data.data());
thrust::device_ptr<T> p_src {s_data.data()};
auto length = s_data.size();
out->SetDevice(ptr_device); out->SetDevice(ptr_device);
out->Resize(length); out->Resize(foreign_column.size);
auto p_dst = thrust::device_pointer_cast(out->DevicePointer()); auto p_dst = thrust::device_pointer_cast(out->DevicePointer());
thrust::copy(p_src, p_src + length, p_dst);
dh::LaunchN(ptr_device, foreign_column.size, [=] __device__(size_t idx) {
p_dst[idx] = foreign_column.GetElement(idx);
});
} }
void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) { void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) {
@ -46,14 +45,12 @@ void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) {
} }
auto const& typestr = get<String const>(j_arr_obj.at("typestr")); auto const& typestr = get<String const>(j_arr_obj.at("typestr"));
if (key == "root_index") { if (key == "label") {
LOG(FATAL) << "root index for columnar data is not supported."; CopyInfoImpl(j_arr_obj, &labels_);
} else if (key == "label") {
DISPATCH_TYPE(CopyInfoImpl, typestr, j_arr_obj, &labels_);
} else if (key == "weight") { } else if (key == "weight") {
DISPATCH_TYPE(CopyInfoImpl, typestr, j_arr_obj, &weights_); CopyInfoImpl(j_arr_obj, &weights_);
} else if (key == "base_margin") { } else if (key == "base_margin") {
DISPATCH_TYPE(CopyInfoImpl, typestr, j_arr_obj, &base_margin_); CopyInfoImpl(j_arr_obj, &base_margin_);
} else if (key == "group") { } else if (key == "group") {
// Ranking is not performed on device. // Ranking is not performed on device.
auto s_data = ArrayInterfaceHandler::ExtractData<uint32_t>(j_arr_obj); auto s_data = ArrayInterfaceHandler::ExtractData<uint32_t>(j_arr_obj);

View File

@ -26,8 +26,7 @@
namespace xgboost { namespace xgboost {
namespace data { namespace data {
template <typename T> __global__ void CountValidKernel(Columnar const column,
__global__ void CountValidKernel(Columnar<T> const column,
bool has_missing, float missing, bool has_missing, float missing,
int32_t* flag, common::Span<bst_row_t> offsets) { int32_t* flag, common::Span<bst_row_t> offsets) {
auto const tid = threadIdx.x + blockDim.x * blockIdx.x; auto const tid = threadIdx.x + blockDim.x * blockIdx.x;
@ -40,18 +39,18 @@ __global__ void CountValidKernel(Columnar<T> const column,
if (!has_missing) { if (!has_missing) {
if ((mask.Data() == nullptr || mask.Check(tid)) && if ((mask.Data() == nullptr || mask.Check(tid)) &&
!common::CheckNAN(column.data[tid])) { !common::CheckNAN(column.GetElement(tid))) {
offsets[tid+1] += 1; offsets[tid+1] += 1;
} }
} else if (missing_is_nan) { } else if (missing_is_nan) {
if (!common::CheckNAN(column.data[tid])) { if (!common::CheckNAN(column.GetElement(tid))) {
offsets[tid+1] += 1; offsets[tid+1] += 1;
} }
} else { } else {
if (!common::CloseTo(column.data[tid], missing)) { if (!common::CloseTo(column.GetElement(tid), missing)) {
offsets[tid+1] += 1; offsets[tid+1] += 1;
} }
if (common::CheckNAN(column.data[tid])) { if (common::CheckNAN(column.GetElement(tid))) {
*flag = 1; *flag = 1;
} }
} }
@ -67,8 +66,7 @@ __device__ void AssignValue(T fvalue, int32_t colid,
out_offsets[tid] += 1; out_offsets[tid] += 1;
} }
template <typename T> __global__ void CreateCSRKernel(Columnar const column,
__global__ void CreateCSRKernel(Columnar<T> const column,
int32_t colid, bool has_missing, float missing, int32_t colid, bool has_missing, float missing,
common::Span<bst_row_t> offsets, common::Span<Entry> out_data) { common::Span<bst_row_t> offsets, common::Span<Entry> out_data) {
auto const tid = threadIdx.x + blockDim.x * blockIdx.x; auto const tid = threadIdx.x + blockDim.x * blockIdx.x;
@ -79,23 +77,22 @@ __global__ void CreateCSRKernel(Columnar<T> const column,
if (!has_missing) { if (!has_missing) {
// no missing value is specified // no missing value is specified
if ((column.valid.Data() == nullptr || column.valid.Check(tid)) && if ((column.valid.Data() == nullptr || column.valid.Check(tid)) &&
!common::CheckNAN(column.data[tid])) { !common::CheckNAN(column.GetElement(tid))) {
AssignValue(column.data[tid], colid, offsets, out_data); AssignValue(column.GetElement(tid), colid, offsets, out_data);
} }
} else if (missing_is_nan) { } else if (missing_is_nan) {
// specified missing value, but it's NaN // specified missing value, but it's NaN
if (!common::CheckNAN(column.data[tid])) { if (!common::CheckNAN(column.GetElement(tid))) {
AssignValue(column.data[tid], colid, offsets, out_data); AssignValue(column.GetElement(tid), colid, offsets, out_data);
} }
} else { } else {
// specified missing value, and it's not NaN // specified missing value, and it's not NaN
if (!common::CloseTo(column.data[tid], missing)) { if (!common::CloseTo(column.GetElement(tid), missing)) {
AssignValue(column.data[tid], colid, offsets, out_data); AssignValue(column.GetElement(tid), colid, offsets, out_data);
} }
} }
} }
template <typename T>
void CountValid(std::vector<Json> const& j_columns, uint32_t column_id, void CountValid(std::vector<Json> const& j_columns, uint32_t column_id,
bool has_missing, float missing, bool has_missing, float missing,
HostDeviceVector<bst_row_t>* out_offset, HostDeviceVector<bst_row_t>* out_offset,
@ -104,10 +101,10 @@ void CountValid(std::vector<Json> const& j_columns, uint32_t column_id,
uint32_t constexpr kThreads = 256; uint32_t constexpr kThreads = 256;
auto const& j_column = j_columns[column_id]; auto const& j_column = j_columns[column_id];
auto const& column_obj = get<Object const>(j_column); auto const& column_obj = get<Object const>(j_column);
Columnar<T> foreign_column = ArrayInterfaceHandler::ExtractArray<T>(column_obj); Columnar foreign_column(column_obj);
uint32_t const n_rows = foreign_column.size; uint32_t const n_rows = foreign_column.size;
auto ptr = foreign_column.data.data(); auto ptr = foreign_column.data;
int32_t device = dh::CudaGetPointerDevice(ptr); int32_t device = dh::CudaGetPointerDevice(ptr);
CHECK_NE(device, -1); CHECK_NE(device, -1);
dh::safe_cuda(cudaSetDevice(device)); dh::safe_cuda(cudaSetDevice(device));
@ -125,24 +122,23 @@ void CountValid(std::vector<Json> const& j_columns, uint32_t column_id,
uint32_t const kBlocks = common::DivRoundUp(n_rows, kThreads); uint32_t const kBlocks = common::DivRoundUp(n_rows, kThreads);
dh::LaunchKernel {kBlocks, kThreads} ( dh::LaunchKernel {kBlocks, kThreads} (
CountValidKernel<T>, CountValidKernel,
foreign_column, foreign_column,
has_missing, missing, has_missing, missing,
out_d_flag->data().get(), s_offsets); out_d_flag->data().get(), s_offsets);
*out_n_rows = n_rows; *out_n_rows = n_rows;
} }
template <typename T>
void CreateCSR(std::vector<Json> const& j_columns, uint32_t column_id, uint32_t n_rows, void CreateCSR(std::vector<Json> const& j_columns, uint32_t column_id, uint32_t n_rows,
bool has_missing, float missing, bool has_missing, float missing,
dh::device_vector<bst_row_t>* tmp_offset, common::Span<Entry> s_data) { dh::device_vector<bst_row_t>* tmp_offset, common::Span<Entry> s_data) {
uint32_t constexpr kThreads = 256; uint32_t constexpr kThreads = 256;
auto const& j_column = j_columns[column_id]; auto const& j_column = j_columns[column_id];
auto const& column_obj = get<Object const>(j_column); auto const& column_obj = get<Object const>(j_column);
Columnar<T> foreign_column = ArrayInterfaceHandler::ExtractArray<T>(column_obj); Columnar foreign_column(column_obj);
uint32_t kBlocks = common::DivRoundUp(n_rows, kThreads); uint32_t kBlocks = common::DivRoundUp(n_rows, kThreads);
dh::LaunchKernel {kBlocks, kThreads} ( dh::LaunchKernel {kBlocks, kThreads} (
CreateCSRKernel<T>, CreateCSRKernel,
foreign_column, column_id, has_missing, missing, foreign_column, column_id, has_missing, missing,
dh::ToSpan(*tmp_offset), s_data); dh::ToSpan(*tmp_offset), s_data);
} }
@ -159,9 +155,8 @@ void SimpleCSRSource::FromDeviceColumnar(std::vector<Json> const& columns,
} }
uint32_t n_rows {0}; uint32_t n_rows {0};
for (size_t i = 0; i < n_cols; ++i) { for (size_t i = 0; i < n_cols; ++i) {
auto const& typestr = get<String const>(columns[i]["typestr"]); CountValid(columns, i, has_missing, missing, &(this->page_.offset), &d_flag,
DISPATCH_TYPE(CountValid, typestr, &n_rows);
columns, i, has_missing, missing, &(this->page_.offset), &d_flag, &n_rows);
} }
// don't pay for what you don't use. // don't pay for what you don't use.
if (!common::CheckNAN(missing)) { if (!common::CheckNAN(missing)) {
@ -197,9 +192,7 @@ void SimpleCSRSource::FromDeviceColumnar(std::vector<Json> const& columns,
int32_t kBlocks = common::DivRoundUp(n_rows, kThreads); int32_t kBlocks = common::DivRoundUp(n_rows, kThreads);
for (size_t i = 0; i < n_cols; ++i) { for (size_t i = 0; i < n_cols; ++i) {
auto const& typestr = get<String const>(columns[i]["typestr"]); CreateCSR(columns, i, n_rows, has_missing, missing, &tmp_offset, s_data);
DISPATCH_TYPE(CreateCSR, typestr, columns, i, n_rows,
has_missing, missing, &tmp_offset, s_data);
} }
} }

View File

@ -23,21 +23,21 @@ TEST(ArrayInterfaceHandler, Error) {
auto const& column_obj = get<Object>(column); auto const& column_obj = get<Object>(column);
// missing version // missing version
EXPECT_THROW(ArrayInterfaceHandler::ExtractArray<float>(column_obj), dmlc::Error); EXPECT_THROW(Columnar c(column_obj), dmlc::Error);
column["version"] = Integer(static_cast<Integer::Int>(1)); column["version"] = Integer(static_cast<Integer::Int>(1));
// missing data // missing data
EXPECT_THROW(ArrayInterfaceHandler::ExtractArray<float>(column_obj), dmlc::Error); EXPECT_THROW(Columnar c(column_obj), dmlc::Error);
column["data"] = j_data; column["data"] = j_data;
// missing typestr // missing typestr
EXPECT_THROW(ArrayInterfaceHandler::ExtractArray<float>(column_obj), dmlc::Error); EXPECT_THROW(Columnar c(column_obj), dmlc::Error);
column["typestr"] = String("<f4"); column["typestr"] = String("<f4");
// nullptr is not valid // nullptr is not valid
EXPECT_THROW(ArrayInterfaceHandler::ExtractArray<float>(column_obj), dmlc::Error); EXPECT_THROW(Columnar c(column_obj), dmlc::Error);
thrust::device_vector<float> d_data(kRows); thrust::device_vector<float> d_data(kRows);
j_data = {Json(Integer(reinterpret_cast<Integer::Int>(d_data.data().get()))), j_data = {Json(Integer(reinterpret_cast<Integer::Int>(d_data.data().get()))),
Json(Boolean(false))}; Json(Boolean(false))};
column["data"] = j_data; column["data"] = j_data;
EXPECT_NO_THROW(ArrayInterfaceHandler::ExtractArray<float>(column_obj)); EXPECT_NO_THROW(Columnar c(column_obj));
std::vector<Json> j_mask_shape {Json(Integer(static_cast<Integer::Int>(kRows - 1)))}; std::vector<Json> j_mask_shape {Json(Integer(static_cast<Integer::Int>(kRows - 1)))};
column["mask"] = Object(); column["mask"] = Object();
@ -46,7 +46,7 @@ TEST(ArrayInterfaceHandler, Error) {
column["mask"]["typestr"] = String("<i1"); column["mask"]["typestr"] = String("<i1");
column["mask"]["version"] = Integer(static_cast<Integer::Int>(1)); column["mask"]["version"] = Integer(static_cast<Integer::Int>(1));
// shape of mask and data doesn't match. // shape of mask and data doesn't match.
EXPECT_THROW(ArrayInterfaceHandler::ExtractArray<float>(column_obj), dmlc::Error); EXPECT_THROW(Columnar c(column_obj), dmlc::Error);
} }
template <typename T> template <typename T>
@ -75,6 +75,23 @@ Json GenerateDenseColumn(std::string const& typestr, size_t kRows,
return column; return column;
} }
void TestGetElement() {
thrust::device_vector<float> data;
auto j_column = GenerateDenseColumn("<f4", 3, &data);
auto const& column_obj = get<Object const>(j_column);
Columnar foreign_column(column_obj);
EXPECT_NO_THROW({
dh::LaunchN(0, 1, [=] __device__(size_t idx) {
KERNEL_CHECK(foreign_column.GetElement(0) == 0.0f);
KERNEL_CHECK(foreign_column.GetElement(1) == 2.0f);
KERNEL_CHECK(foreign_column.GetElement(2) == 4.0f);
});
});
}
TEST(Columnar, GetElement) { TestGetElement(); }
void TestDenseColumn(std::unique_ptr<data::SimpleCSRSource> const& source, void TestDenseColumn(std::unique_ptr<data::SimpleCSRSource> const& source,
size_t n_rows, size_t n_cols) { size_t n_rows, size_t n_cols) {
auto const& data = source->page_.data.HostVector(); auto const& data = source->page_.data.HostVector();