Optimize array interface input. (#9090)

This commit is contained in:
Jiaming Yuan 2023-04-28 18:01:58 +08:00 committed by GitHub
parent fb941262b4
commit 17ff471616
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 89 additions and 10 deletions

View File

@ -24,5 +24,9 @@ constexpr StringView LabelScoreSize() {
constexpr StringView InfInData() { constexpr StringView InfInData() {
return "Input data contains `inf` or a value too large, while `missing` is not set to `inf`"; return "Input data contains `inf` or a value too large, while `missing` is not set to `inf`";
} }
constexpr StringView NoF128() {
return "128-bit floating point is not supported on current platform.";
}
} // namespace xgboost::error } // namespace xgboost::error
#endif // XGBOOST_COMMON_ERROR_MSG_H_ #endif // XGBOOST_COMMON_ERROR_MSG_H_

View File

@ -7,8 +7,9 @@
#define XGBOOST_DATA_ARRAY_INTERFACE_H_ #define XGBOOST_DATA_ARRAY_INTERFACE_H_
#include <algorithm> #include <algorithm>
#include <cstddef> // std::size_t #include <cstddef> // for size_t
#include <cstdint> #include <cstdint>
#include <limits> // for numeric_limits
#include <map> #include <map>
#include <string> #include <string>
#include <type_traits> // std::alignment_of,std::remove_pointer_t #include <type_traits> // std::alignment_of,std::remove_pointer_t
@ -17,6 +18,7 @@
#include "../common/bitfield.h" #include "../common/bitfield.h"
#include "../common/common.h" #include "../common/common.h"
#include "../common/error_msg.h" // for NoF128
#include "xgboost/base.h" #include "xgboost/base.h"
#include "xgboost/data.h" #include "xgboost/data.h"
#include "xgboost/json.h" #include "xgboost/json.h"
@ -454,9 +456,8 @@ class ArrayInterface {
void AssignType(StringView typestr) { void AssignType(StringView typestr) {
using T = ArrayInterfaceHandler::Type; using T = ArrayInterfaceHandler::Type;
if (typestr.size() == 4 && typestr[1] == 'f' && typestr[2] == '1' && typestr[3] == '6') { if (typestr.size() == 4 && typestr[1] == 'f' && typestr[2] == '1' && typestr[3] == '6') {
CHECK(sizeof(long double) == 16) << error::NoF128();
type = T::kF16; type = T::kF16;
CHECK(sizeof(long double) == 16)
<< "128-bit floating point is not supported on current platform.";
} else if (typestr[1] == 'f' && typestr[2] == '2') { } else if (typestr[1] == 'f' && typestr[2] == '2') {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600
type = T::kF2; type = T::kF2;
@ -572,19 +573,90 @@ class ArrayInterface {
// Used only by columnar format. // Used only by columnar format.
RBitField8 valid; RBitField8 valid;
// Array stride // Array stride
size_t strides[D]{0}; std::size_t strides[D]{0};
// Array shape // Array shape
size_t shape[D]{0}; std::size_t shape[D]{0};
// Type earsed pointer referencing the data. // Type earsed pointer referencing the data.
void const *data{nullptr}; void const *data{nullptr};
// Total number of items // Total number of items
size_t n{0}; std::size_t n{0};
// Whether the memory is c-contiguous // Whether the memory is c-contiguous
bool is_contiguous{false}; bool is_contiguous{false};
// RTTI, initialized to the f16 to avoid masking potential bugs in initialization. // RTTI, initialized to the f16 to avoid masking potential bugs in initialization.
ArrayInterfaceHandler::Type type{ArrayInterfaceHandler::kF16}; ArrayInterfaceHandler::Type type{ArrayInterfaceHandler::kF16};
}; };
template <std::int32_t D, typename Fn>
void DispatchDType(ArrayInterface<D> const array, std::int32_t device, Fn fn) {
// Only used for cuDF at the moment.
CHECK_EQ(array.valid.Size(), 0);
auto dispatch = [&](auto t) {
using T = std::remove_const_t<decltype(t)> const;
// Set the data size to max as we don't know the original size of a sliced array:
//
// Slicing an array A with shape (4, 2, 3) and stride (6, 3, 1) by [:, 1, :] results
// in an array B with shape (4, 3) and strides (6, 1). We can't calculate the original
// size 24 based on the slice.
fn(linalg::TensorView<T, D>{common::Span<T const>{static_cast<T *>(array.data),
std::numeric_limits<std::size_t>::max()},
array.shape, array.strides, device});
};
switch (array.type) {
case ArrayInterfaceHandler::kF2: {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600
dispatch(__half{});
#endif
break;
}
case ArrayInterfaceHandler::kF4: {
dispatch(float{});
break;
}
case ArrayInterfaceHandler::kF8: {
dispatch(double{});
break;
}
case ArrayInterfaceHandler::kF16: {
using T = long double;
CHECK(sizeof(long double) == 16) << error::NoF128();
dispatch(T{});
break;
}
case ArrayInterfaceHandler::kI1: {
dispatch(std::int8_t{});
break;
}
case ArrayInterfaceHandler::kI2: {
dispatch(std::int16_t{});
break;
}
case ArrayInterfaceHandler::kI4: {
dispatch(std::int32_t{});
break;
}
case ArrayInterfaceHandler::kI8: {
dispatch(std::int64_t{});
break;
}
case ArrayInterfaceHandler::kU1: {
dispatch(std::uint8_t{});
break;
}
case ArrayInterfaceHandler::kU2: {
dispatch(std::uint16_t{});
break;
}
case ArrayInterfaceHandler::kU4: {
dispatch(std::uint32_t{});
break;
}
case ArrayInterfaceHandler::kU8: {
dispatch(std::uint64_t{});
break;
}
}
}
/** /**
* \brief Helper for type casting. * \brief Helper for type casting.
*/ */

View File

@ -427,10 +427,13 @@ void CopyTensorInfoImpl(Context const& ctx, Json arr_interface, linalg::Tensor<T
return; return;
} }
p_out->Reshape(array.shape); p_out->Reshape(array.shape);
auto t = p_out->View(Context::kCpuId); auto t_out = p_out->View(Context::kCpuId);
CHECK(t.CContiguous()); CHECK(t_out.CContiguous());
linalg::ElementWiseTransformHost(t, ctx.Threads(), [&](auto i, auto) { auto const shape = t_out.Shape();
return linalg::detail::Apply(TypedIndex<T, D>{array}, linalg::UnravelIndex<D>(i, t.Shape())); DispatchDType(array, Context::kCpuId, [&](auto&& in) {
linalg::ElementWiseTransformHost(t_out, ctx.Threads(), [&](auto i, auto) {
return std::apply(in, linalg::UnravelIndex<D>(i, shape));
});
}); });
} }
} // namespace } // namespace