diff --git a/src/data/columnar.h b/src/data/columnar.h index b47324377..843985985 100644 --- a/src/data/columnar.h +++ b/src/data/columnar.h @@ -14,21 +14,9 @@ #include "xgboost/json.h" #include "xgboost/logging.h" #include "xgboost/span.h" - #include "../common/bitfield.h" namespace xgboost { -// A view over __array_interface__ -template -struct Columnar { - using mask_type = unsigned char; - using index_type = int32_t; - - common::Span data; - RBitField8 valid; - int32_t size; -}; - // Common errors in parsing columnar format. struct ColumnarErrors { static char const* Contigious() { @@ -97,8 +85,8 @@ struct ColumnarErrors { } } - static std::string UnSupportedType(std::string const& typestr) { - return TypeStr(typestr.at(1)) + " is not supported."; + static std::string UnSupportedType(const char (&typestr)[3]) { + return TypeStr(typestr[1]) + " is not supported."; } }; @@ -200,6 +188,19 @@ class ArrayInterfaceHandler { return 0; } + static size_t ExtractLength(std::map const& column) { + auto j_shape = get(column.at("shape")); + CHECK_EQ(j_shape.size(), 1) << ColumnarErrors::Dimension(1); + auto typestr = get(column.at("typestr")); + if (column.find("strides") != column.cend()) { + auto strides = get(column.at("strides")); + CHECK_EQ(strides.size(), 1) << ColumnarErrors::Dimension(1); + CHECK_EQ(get(strides.at(0)), typestr.at(2) - '0') + << ColumnarErrors::Contigious(); + } + + return static_cast(get(j_shape.at(0))); + } template static common::Span ExtractData(std::map const& column) { Validate(column); @@ -210,70 +211,102 @@ class ArrayInterfaceHandler { CHECK_EQ(typestr.at(2), static_cast(sizeof(T) + 48)) << "Input data type and typestr mismatch. typestr: " << typestr; - auto j_shape = get(column.at("shape")); - CHECK_EQ(j_shape.size(), 1) << ColumnarErrors::Dimension(1); - if (column.find("strides") != column.cend()) { - auto strides = get(column.at("strides")); - CHECK_EQ(strides.size(), 1) << ColumnarErrors::Dimension(1); - CHECK_EQ(get(strides.at(0)), sizeof(T)) << ColumnarErrors::Contigious(); - } - - auto length = static_cast(get(j_shape.at(0))); + auto length = ExtractLength(column); T* p_data = ArrayInterfaceHandler::GetPtrFromArrayData(column); return common::Span{p_data, length}; } +}; - template - static Columnar ExtractArray(std::map const& column) { - common::Span s_data { ArrayInterfaceHandler::ExtractData(column) }; +// A view over __array_interface__ +class Columnar { + using mask_type = unsigned char; + using index_type = int32_t; - Columnar foreign_col; - foreign_col.data = s_data; - foreign_col.size = s_data.size(); + public: + explicit Columnar(std::map const& column) { + ArrayInterfaceHandler::Validate(column); + data = ArrayInterfaceHandler::GetPtrFromArrayData(column); + CHECK(data) << "Column is null"; + size = ArrayInterfaceHandler::ExtractLength(column); common::Span s_mask; size_t n_bits = ArrayInterfaceHandler::ExtractMask(column, &s_mask); - foreign_col.valid = RBitField8(s_mask); + valid = RBitField8(s_mask); if (s_mask.data()) { - CHECK_EQ(n_bits, foreign_col.data.size()) + CHECK_EQ(n_bits, size) << "Shape of bit mask doesn't match data shape. " << "XGBoost doesn't support internal broadcasting."; } - - return foreign_col; + auto typestr = get(column.at("typestr")); + type[0] = typestr.at(0); + type[1] = typestr.at(1); + type[2] = typestr.at(2); + this->CheckType(); } + + void CheckType() const { + if (type[1] == 'f' && type[2] == '4') { + return; + } else if (type[1] == 'f' && type[2] == '8') { + return; + } else if (type[1] == 'i' && type[2] == '1') { + return; + } else if (type[1] == 'i' && type[2] == '2') { + return; + } else if (type[1] == 'i' && type[2] == '4') { + return; + } else if (type[1] == 'i' && type[2] == '8') { + return; + } else if (type[1] == 'u' && type[2] == '1') { + return; + } else if (type[1] == 'u' && type[2] == '2') { + return; + } else if (type[1] == 'u' && type[2] == '4') { + return; + } else if (type[1] == 'u' && type[2] == '8') { + return; + } else { + LOG(FATAL) << ColumnarErrors::UnSupportedType(type); + return; + } + } + + XGBOOST_DEVICE float GetElement(size_t idx) const { + if (type[1] == 'f' && type[2] == '4') { + return reinterpret_cast(data)[idx]; + } else if (type[1] == 'f' && type[2] == '8') { + return reinterpret_cast(data)[idx]; + } else if (type[1] == 'i' && type[2] == '1') { + return reinterpret_cast(data)[idx]; + } else if (type[1] == 'i' && type[2] == '2') { + return reinterpret_cast(data)[idx]; + } else if (type[1] == 'i' && type[2] == '4') { + return reinterpret_cast(data)[idx]; + } else if (type[1] == 'i' && type[2] == '8') { + return reinterpret_cast(data)[idx]; + } else if (type[1] == 'u' && type[2] == '1') { + return reinterpret_cast(data)[idx]; + } else if (type[1] == 'u' && type[2] == '2') { + return reinterpret_cast(data)[idx]; + } else if (type[1] == 'u' && type[2] == '4') { + return reinterpret_cast(data)[idx]; + } else if (type[1] == 'u' && type[2] == '8') { + return reinterpret_cast(data)[idx]; + } else { + SPAN_CHECK(false); + return 0; + } + } + + RBitField8 valid; + int32_t size; + void* data; + char type[3]; }; -#define DISPATCH_TYPE(__dispatched_func, __typestr, ...) { \ - CHECK_EQ(__typestr.size(), 3) << ColumnarErrors::TypestrFormat(); \ - if (__typestr.at(1) == 'f' && __typestr.at(2) == '4') { \ - __dispatched_func(__VA_ARGS__); \ - } else if (__typestr.at(1) == 'f' && __typestr.at(2) == '8') { \ - __dispatched_func(__VA_ARGS__); \ - } else if (__typestr.at(1) == 'i' && __typestr.at(2) == '1') { \ - __dispatched_func(__VA_ARGS__); \ - } else if (__typestr.at(1) == 'i' && __typestr.at(2) == '2') { \ - __dispatched_func(__VA_ARGS__); \ - } else if (__typestr.at(1) == 'i' && __typestr.at(2) == '4') { \ - __dispatched_func(__VA_ARGS__); \ - } else if (__typestr.at(1) == 'i' && __typestr.at(2) == '8') { \ - __dispatched_func(__VA_ARGS__); \ - } else if (__typestr.at(1) == 'u' && __typestr.at(2) == '1') { \ - __dispatched_func(__VA_ARGS__); \ - } else if (__typestr.at(1) == 'u' && __typestr.at(2) == '2') { \ - __dispatched_func(__VA_ARGS__); \ - } else if (__typestr.at(1) == 'u' && __typestr.at(2) == '4') { \ - __dispatched_func(__VA_ARGS__); \ - } else if (__typestr.at(1) == 'u' && __typestr.at(2) == '8') { \ - __dispatched_func(__VA_ARGS__); \ - } else { \ - LOG(FATAL) << ColumnarErrors::UnSupportedType(__typestr); \ - } \ - } - -} // namespace xgboost +} // namespace xgboost #endif // XGBOOST_DATA_COLUMNAR_H_ diff --git a/src/data/data.cu b/src/data/data.cu index 4c1750700..d95a983d2 100644 --- a/src/data/data.cu +++ b/src/data/data.cu @@ -12,7 +12,6 @@ namespace xgboost { -template void CopyInfoImpl(std::map const& column, HostDeviceVector* out) { auto SetDeviceToPtr = [](void* ptr) { cudaPointerAttributes attr; @@ -21,17 +20,17 @@ void CopyInfoImpl(std::map const& column, HostDeviceVector s_data { ArrayInterfaceHandler::ExtractData(column) }; - auto ptr_device = SetDeviceToPtr(s_data.data()); - thrust::device_ptr p_src {s_data.data()}; - - auto length = s_data.size(); out->SetDevice(ptr_device); - out->Resize(length); + out->Resize(foreign_column.size); auto p_dst = thrust::device_pointer_cast(out->DevicePointer()); - thrust::copy(p_src, p_src + length, p_dst); + + dh::LaunchN(ptr_device, foreign_column.size, [=] __device__(size_t idx) { + p_dst[idx] = foreign_column.GetElement(idx); + }); } void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) { @@ -46,14 +45,12 @@ void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) { } auto const& typestr = get(j_arr_obj.at("typestr")); - if (key == "root_index") { - LOG(FATAL) << "root index for columnar data is not supported."; - } else if (key == "label") { - DISPATCH_TYPE(CopyInfoImpl, typestr, j_arr_obj, &labels_); + if (key == "label") { + CopyInfoImpl(j_arr_obj, &labels_); } else if (key == "weight") { - DISPATCH_TYPE(CopyInfoImpl, typestr, j_arr_obj, &weights_); + CopyInfoImpl(j_arr_obj, &weights_); } else if (key == "base_margin") { - DISPATCH_TYPE(CopyInfoImpl, typestr, j_arr_obj, &base_margin_); + CopyInfoImpl(j_arr_obj, &base_margin_); } else if (key == "group") { // Ranking is not performed on device. auto s_data = ArrayInterfaceHandler::ExtractData(j_arr_obj); diff --git a/src/data/simple_csr_source.cu b/src/data/simple_csr_source.cu index af068a639..598238a9c 100644 --- a/src/data/simple_csr_source.cu +++ b/src/data/simple_csr_source.cu @@ -26,8 +26,7 @@ namespace xgboost { namespace data { -template -__global__ void CountValidKernel(Columnar const column, +__global__ void CountValidKernel(Columnar const column, bool has_missing, float missing, int32_t* flag, common::Span offsets) { auto const tid = threadIdx.x + blockDim.x * blockIdx.x; @@ -40,18 +39,18 @@ __global__ void CountValidKernel(Columnar const column, if (!has_missing) { if ((mask.Data() == nullptr || mask.Check(tid)) && - !common::CheckNAN(column.data[tid])) { + !common::CheckNAN(column.GetElement(tid))) { offsets[tid+1] += 1; } } else if (missing_is_nan) { - if (!common::CheckNAN(column.data[tid])) { + if (!common::CheckNAN(column.GetElement(tid))) { offsets[tid+1] += 1; } } else { - if (!common::CloseTo(column.data[tid], missing)) { + if (!common::CloseTo(column.GetElement(tid), missing)) { offsets[tid+1] += 1; } - if (common::CheckNAN(column.data[tid])) { + if (common::CheckNAN(column.GetElement(tid))) { *flag = 1; } } @@ -67,8 +66,7 @@ __device__ void AssignValue(T fvalue, int32_t colid, out_offsets[tid] += 1; } -template -__global__ void CreateCSRKernel(Columnar const column, +__global__ void CreateCSRKernel(Columnar const column, int32_t colid, bool has_missing, float missing, common::Span offsets, common::Span out_data) { auto const tid = threadIdx.x + blockDim.x * blockIdx.x; @@ -79,23 +77,22 @@ __global__ void CreateCSRKernel(Columnar const column, if (!has_missing) { // no missing value is specified if ((column.valid.Data() == nullptr || column.valid.Check(tid)) && - !common::CheckNAN(column.data[tid])) { - AssignValue(column.data[tid], colid, offsets, out_data); + !common::CheckNAN(column.GetElement(tid))) { + AssignValue(column.GetElement(tid), colid, offsets, out_data); } } else if (missing_is_nan) { // specified missing value, but it's NaN - if (!common::CheckNAN(column.data[tid])) { - AssignValue(column.data[tid], colid, offsets, out_data); + if (!common::CheckNAN(column.GetElement(tid))) { + AssignValue(column.GetElement(tid), colid, offsets, out_data); } } else { // specified missing value, and it's not NaN - if (!common::CloseTo(column.data[tid], missing)) { - AssignValue(column.data[tid], colid, offsets, out_data); + if (!common::CloseTo(column.GetElement(tid), missing)) { + AssignValue(column.GetElement(tid), colid, offsets, out_data); } } } -template void CountValid(std::vector const& j_columns, uint32_t column_id, bool has_missing, float missing, HostDeviceVector* out_offset, @@ -104,10 +101,10 @@ void CountValid(std::vector const& j_columns, uint32_t column_id, uint32_t constexpr kThreads = 256; auto const& j_column = j_columns[column_id]; auto const& column_obj = get(j_column); - Columnar foreign_column = ArrayInterfaceHandler::ExtractArray(column_obj); + Columnar foreign_column(column_obj); uint32_t const n_rows = foreign_column.size; - auto ptr = foreign_column.data.data(); + auto ptr = foreign_column.data; int32_t device = dh::CudaGetPointerDevice(ptr); CHECK_NE(device, -1); dh::safe_cuda(cudaSetDevice(device)); @@ -125,24 +122,23 @@ void CountValid(std::vector const& j_columns, uint32_t column_id, uint32_t const kBlocks = common::DivRoundUp(n_rows, kThreads); dh::LaunchKernel {kBlocks, kThreads} ( - CountValidKernel, + CountValidKernel, foreign_column, has_missing, missing, out_d_flag->data().get(), s_offsets); *out_n_rows = n_rows; } -template void CreateCSR(std::vector const& j_columns, uint32_t column_id, uint32_t n_rows, bool has_missing, float missing, dh::device_vector* tmp_offset, common::Span s_data) { uint32_t constexpr kThreads = 256; auto const& j_column = j_columns[column_id]; auto const& column_obj = get(j_column); - Columnar foreign_column = ArrayInterfaceHandler::ExtractArray(column_obj); + Columnar foreign_column(column_obj); uint32_t kBlocks = common::DivRoundUp(n_rows, kThreads); dh::LaunchKernel {kBlocks, kThreads} ( - CreateCSRKernel, + CreateCSRKernel, foreign_column, column_id, has_missing, missing, dh::ToSpan(*tmp_offset), s_data); } @@ -159,9 +155,8 @@ void SimpleCSRSource::FromDeviceColumnar(std::vector const& columns, } uint32_t n_rows {0}; for (size_t i = 0; i < n_cols; ++i) { - auto const& typestr = get(columns[i]["typestr"]); - DISPATCH_TYPE(CountValid, typestr, - columns, i, has_missing, missing, &(this->page_.offset), &d_flag, &n_rows); + CountValid(columns, i, has_missing, missing, &(this->page_.offset), &d_flag, + &n_rows); } // don't pay for what you don't use. if (!common::CheckNAN(missing)) { @@ -197,9 +192,7 @@ void SimpleCSRSource::FromDeviceColumnar(std::vector const& columns, int32_t kBlocks = common::DivRoundUp(n_rows, kThreads); for (size_t i = 0; i < n_cols; ++i) { - auto const& typestr = get(columns[i]["typestr"]); - DISPATCH_TYPE(CreateCSR, typestr, columns, i, n_rows, - has_missing, missing, &tmp_offset, s_data); + CreateCSR(columns, i, n_rows, has_missing, missing, &tmp_offset, s_data); } } diff --git a/tests/cpp/data/test_simple_csr_source.cu b/tests/cpp/data/test_simple_csr_source.cu index 47bd19d04..838235b6a 100644 --- a/tests/cpp/data/test_simple_csr_source.cu +++ b/tests/cpp/data/test_simple_csr_source.cu @@ -23,21 +23,21 @@ TEST(ArrayInterfaceHandler, Error) { auto const& column_obj = get(column); // missing version - EXPECT_THROW(ArrayInterfaceHandler::ExtractArray(column_obj), dmlc::Error); + EXPECT_THROW(Columnar c(column_obj), dmlc::Error); column["version"] = Integer(static_cast(1)); // missing data - EXPECT_THROW(ArrayInterfaceHandler::ExtractArray(column_obj), dmlc::Error); + EXPECT_THROW(Columnar c(column_obj), dmlc::Error); column["data"] = j_data; // missing typestr - EXPECT_THROW(ArrayInterfaceHandler::ExtractArray(column_obj), dmlc::Error); + EXPECT_THROW(Columnar c(column_obj), dmlc::Error); column["typestr"] = String("(column_obj), dmlc::Error); + EXPECT_THROW(Columnar c(column_obj), dmlc::Error); thrust::device_vector d_data(kRows); j_data = {Json(Integer(reinterpret_cast(d_data.data().get()))), Json(Boolean(false))}; column["data"] = j_data; - EXPECT_NO_THROW(ArrayInterfaceHandler::ExtractArray(column_obj)); + EXPECT_NO_THROW(Columnar c(column_obj)); std::vector j_mask_shape {Json(Integer(static_cast(kRows - 1)))}; column["mask"] = Object(); @@ -46,7 +46,7 @@ TEST(ArrayInterfaceHandler, Error) { column["mask"]["typestr"] = String("(1)); // shape of mask and data doesn't match. - EXPECT_THROW(ArrayInterfaceHandler::ExtractArray(column_obj), dmlc::Error); + EXPECT_THROW(Columnar c(column_obj), dmlc::Error); } template @@ -75,6 +75,23 @@ Json GenerateDenseColumn(std::string const& typestr, size_t kRows, return column; } +void TestGetElement() { + thrust::device_vector data; + auto j_column = GenerateDenseColumn("(j_column); + Columnar foreign_column(column_obj); + + EXPECT_NO_THROW({ + dh::LaunchN(0, 1, [=] __device__(size_t idx) { + KERNEL_CHECK(foreign_column.GetElement(0) == 0.0f); + KERNEL_CHECK(foreign_column.GetElement(1) == 2.0f); + KERNEL_CHECK(foreign_column.GetElement(2) == 4.0f); + }); + }); +} + +TEST(Columnar, GetElement) { TestGetElement(); } + void TestDenseColumn(std::unique_ptr const& source, size_t n_rows, size_t n_cols) { auto const& data = source->page_.data.HostVector(); @@ -384,4 +401,4 @@ TEST(SimpleCSRSource, Types) { TestDenseColumn(source, kRows, kCols); } -} // namespace xgboost \ No newline at end of file +} // namespace xgboost