From 17ff471616ae2d4598d1143435236c2e0c191861 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Fri, 28 Apr 2023 18:01:58 +0800 Subject: [PATCH] Optimize array interface input. (#9090) --- src/common/error_msg.h | 4 ++ src/data/array_interface.h | 84 +++++++++++++++++++++++++++++++++++--- src/data/data.cc | 11 +++-- 3 files changed, 89 insertions(+), 10 deletions(-) diff --git a/src/common/error_msg.h b/src/common/error_msg.h index 3dbb7f52c..4415bf2ee 100644 --- a/src/common/error_msg.h +++ b/src/common/error_msg.h @@ -24,5 +24,9 @@ constexpr StringView LabelScoreSize() { constexpr StringView InfInData() { return "Input data contains `inf` or a value too large, while `missing` is not set to `inf`"; } + +constexpr StringView NoF128() { + return "128-bit floating point is not supported on current platform."; +} } // namespace xgboost::error #endif // XGBOOST_COMMON_ERROR_MSG_H_ diff --git a/src/data/array_interface.h b/src/data/array_interface.h index e9045899b..fee22203c 100644 --- a/src/data/array_interface.h +++ b/src/data/array_interface.h @@ -7,8 +7,9 @@ #define XGBOOST_DATA_ARRAY_INTERFACE_H_ #include -#include // std::size_t +#include // for size_t #include +#include // for numeric_limits #include #include #include // std::alignment_of,std::remove_pointer_t @@ -17,6 +18,7 @@ #include "../common/bitfield.h" #include "../common/common.h" +#include "../common/error_msg.h" // for NoF128 #include "xgboost/base.h" #include "xgboost/data.h" #include "xgboost/json.h" @@ -454,9 +456,8 @@ class ArrayInterface { void AssignType(StringView typestr) { using T = ArrayInterfaceHandler::Type; if (typestr.size() == 4 && typestr[1] == 'f' && typestr[2] == '1' && typestr[3] == '6') { + CHECK(sizeof(long double) == 16) << error::NoF128(); type = T::kF16; - CHECK(sizeof(long double) == 16) - << "128-bit floating point is not supported on current platform."; } else if (typestr[1] == 'f' && typestr[2] == '2') { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600 type = T::kF2; @@ -572,19 +573,90 @@ class ArrayInterface { // Used only by columnar format. RBitField8 valid; // Array stride - size_t strides[D]{0}; + std::size_t strides[D]{0}; // Array shape - size_t shape[D]{0}; + std::size_t shape[D]{0}; // Type earsed pointer referencing the data. void const *data{nullptr}; // Total number of items - size_t n{0}; + std::size_t n{0}; // Whether the memory is c-contiguous bool is_contiguous{false}; // RTTI, initialized to the f16 to avoid masking potential bugs in initialization. ArrayInterfaceHandler::Type type{ArrayInterfaceHandler::kF16}; }; +template +void DispatchDType(ArrayInterface const array, std::int32_t device, Fn fn) { + // Only used for cuDF at the moment. + CHECK_EQ(array.valid.Size(), 0); + auto dispatch = [&](auto t) { + using T = std::remove_const_t const; + // Set the data size to max as we don't know the original size of a sliced array: + // + // Slicing an array A with shape (4, 2, 3) and stride (6, 3, 1) by [:, 1, :] results + // in an array B with shape (4, 3) and strides (6, 1). We can't calculate the original + // size 24 based on the slice. + fn(linalg::TensorView{common::Span{static_cast(array.data), + std::numeric_limits::max()}, + array.shape, array.strides, device}); + }; + switch (array.type) { + case ArrayInterfaceHandler::kF2: { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600 + dispatch(__half{}); +#endif + break; + } + case ArrayInterfaceHandler::kF4: { + dispatch(float{}); + break; + } + case ArrayInterfaceHandler::kF8: { + dispatch(double{}); + break; + } + case ArrayInterfaceHandler::kF16: { + using T = long double; + CHECK(sizeof(long double) == 16) << error::NoF128(); + dispatch(T{}); + break; + } + case ArrayInterfaceHandler::kI1: { + dispatch(std::int8_t{}); + break; + } + case ArrayInterfaceHandler::kI2: { + dispatch(std::int16_t{}); + break; + } + case ArrayInterfaceHandler::kI4: { + dispatch(std::int32_t{}); + break; + } + case ArrayInterfaceHandler::kI8: { + dispatch(std::int64_t{}); + break; + } + case ArrayInterfaceHandler::kU1: { + dispatch(std::uint8_t{}); + break; + } + case ArrayInterfaceHandler::kU2: { + dispatch(std::uint16_t{}); + break; + } + case ArrayInterfaceHandler::kU4: { + dispatch(std::uint32_t{}); + break; + } + case ArrayInterfaceHandler::kU8: { + dispatch(std::uint64_t{}); + break; + } + } +} + /** * \brief Helper for type casting. */ diff --git a/src/data/data.cc b/src/data/data.cc index 9f85e7db2..236bd9131 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -427,10 +427,13 @@ void CopyTensorInfoImpl(Context const& ctx, Json arr_interface, linalg::TensorReshape(array.shape); - auto t = p_out->View(Context::kCpuId); - CHECK(t.CContiguous()); - linalg::ElementWiseTransformHost(t, ctx.Threads(), [&](auto i, auto) { - return linalg::detail::Apply(TypedIndex{array}, linalg::UnravelIndex(i, t.Shape())); + auto t_out = p_out->View(Context::kCpuId); + CHECK(t_out.CContiguous()); + auto const shape = t_out.Shape(); + DispatchDType(array, Context::kCpuId, [&](auto&& in) { + linalg::ElementWiseTransformHost(t_out, ctx.Threads(), [&](auto i, auto) { + return std::apply(in, linalg::UnravelIndex(i, shape)); + }); }); } } // namespace