Optimize array interface input. (#9090)
This commit is contained in:
parent
fb941262b4
commit
17ff471616
@ -24,5 +24,9 @@ constexpr StringView LabelScoreSize() {
|
||||
constexpr StringView InfInData() {
|
||||
return "Input data contains `inf` or a value too large, while `missing` is not set to `inf`";
|
||||
}
|
||||
|
||||
constexpr StringView NoF128() {
|
||||
return "128-bit floating point is not supported on current platform.";
|
||||
}
|
||||
} // namespace xgboost::error
|
||||
#endif // XGBOOST_COMMON_ERROR_MSG_H_
|
||||
|
||||
@ -7,8 +7,9 @@
|
||||
#define XGBOOST_DATA_ARRAY_INTERFACE_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstddef> // std::size_t
|
||||
#include <cstddef> // for size_t
|
||||
#include <cstdint>
|
||||
#include <limits> // for numeric_limits
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <type_traits> // std::alignment_of,std::remove_pointer_t
|
||||
@ -17,6 +18,7 @@
|
||||
|
||||
#include "../common/bitfield.h"
|
||||
#include "../common/common.h"
|
||||
#include "../common/error_msg.h" // for NoF128
|
||||
#include "xgboost/base.h"
|
||||
#include "xgboost/data.h"
|
||||
#include "xgboost/json.h"
|
||||
@ -454,9 +456,8 @@ class ArrayInterface {
|
||||
void AssignType(StringView typestr) {
|
||||
using T = ArrayInterfaceHandler::Type;
|
||||
if (typestr.size() == 4 && typestr[1] == 'f' && typestr[2] == '1' && typestr[3] == '6') {
|
||||
CHECK(sizeof(long double) == 16) << error::NoF128();
|
||||
type = T::kF16;
|
||||
CHECK(sizeof(long double) == 16)
|
||||
<< "128-bit floating point is not supported on current platform.";
|
||||
} else if (typestr[1] == 'f' && typestr[2] == '2') {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600
|
||||
type = T::kF2;
|
||||
@ -572,19 +573,90 @@ class ArrayInterface {
|
||||
// Used only by columnar format.
|
||||
RBitField8 valid;
|
||||
// Array stride
|
||||
size_t strides[D]{0};
|
||||
std::size_t strides[D]{0};
|
||||
// Array shape
|
||||
size_t shape[D]{0};
|
||||
std::size_t shape[D]{0};
|
||||
// Type earsed pointer referencing the data.
|
||||
void const *data{nullptr};
|
||||
// Total number of items
|
||||
size_t n{0};
|
||||
std::size_t n{0};
|
||||
// Whether the memory is c-contiguous
|
||||
bool is_contiguous{false};
|
||||
// RTTI, initialized to the f16 to avoid masking potential bugs in initialization.
|
||||
ArrayInterfaceHandler::Type type{ArrayInterfaceHandler::kF16};
|
||||
};
|
||||
|
||||
template <std::int32_t D, typename Fn>
|
||||
void DispatchDType(ArrayInterface<D> const array, std::int32_t device, Fn fn) {
|
||||
// Only used for cuDF at the moment.
|
||||
CHECK_EQ(array.valid.Size(), 0);
|
||||
auto dispatch = [&](auto t) {
|
||||
using T = std::remove_const_t<decltype(t)> const;
|
||||
// Set the data size to max as we don't know the original size of a sliced array:
|
||||
//
|
||||
// Slicing an array A with shape (4, 2, 3) and stride (6, 3, 1) by [:, 1, :] results
|
||||
// in an array B with shape (4, 3) and strides (6, 1). We can't calculate the original
|
||||
// size 24 based on the slice.
|
||||
fn(linalg::TensorView<T, D>{common::Span<T const>{static_cast<T *>(array.data),
|
||||
std::numeric_limits<std::size_t>::max()},
|
||||
array.shape, array.strides, device});
|
||||
};
|
||||
switch (array.type) {
|
||||
case ArrayInterfaceHandler::kF2: {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600
|
||||
dispatch(__half{});
|
||||
#endif
|
||||
break;
|
||||
}
|
||||
case ArrayInterfaceHandler::kF4: {
|
||||
dispatch(float{});
|
||||
break;
|
||||
}
|
||||
case ArrayInterfaceHandler::kF8: {
|
||||
dispatch(double{});
|
||||
break;
|
||||
}
|
||||
case ArrayInterfaceHandler::kF16: {
|
||||
using T = long double;
|
||||
CHECK(sizeof(long double) == 16) << error::NoF128();
|
||||
dispatch(T{});
|
||||
break;
|
||||
}
|
||||
case ArrayInterfaceHandler::kI1: {
|
||||
dispatch(std::int8_t{});
|
||||
break;
|
||||
}
|
||||
case ArrayInterfaceHandler::kI2: {
|
||||
dispatch(std::int16_t{});
|
||||
break;
|
||||
}
|
||||
case ArrayInterfaceHandler::kI4: {
|
||||
dispatch(std::int32_t{});
|
||||
break;
|
||||
}
|
||||
case ArrayInterfaceHandler::kI8: {
|
||||
dispatch(std::int64_t{});
|
||||
break;
|
||||
}
|
||||
case ArrayInterfaceHandler::kU1: {
|
||||
dispatch(std::uint8_t{});
|
||||
break;
|
||||
}
|
||||
case ArrayInterfaceHandler::kU2: {
|
||||
dispatch(std::uint16_t{});
|
||||
break;
|
||||
}
|
||||
case ArrayInterfaceHandler::kU4: {
|
||||
dispatch(std::uint32_t{});
|
||||
break;
|
||||
}
|
||||
case ArrayInterfaceHandler::kU8: {
|
||||
dispatch(std::uint64_t{});
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Helper for type casting.
|
||||
*/
|
||||
|
||||
@ -427,10 +427,13 @@ void CopyTensorInfoImpl(Context const& ctx, Json arr_interface, linalg::Tensor<T
|
||||
return;
|
||||
}
|
||||
p_out->Reshape(array.shape);
|
||||
auto t = p_out->View(Context::kCpuId);
|
||||
CHECK(t.CContiguous());
|
||||
linalg::ElementWiseTransformHost(t, ctx.Threads(), [&](auto i, auto) {
|
||||
return linalg::detail::Apply(TypedIndex<T, D>{array}, linalg::UnravelIndex<D>(i, t.Shape()));
|
||||
auto t_out = p_out->View(Context::kCpuId);
|
||||
CHECK(t_out.CContiguous());
|
||||
auto const shape = t_out.Shape();
|
||||
DispatchDType(array, Context::kCpuId, [&](auto&& in) {
|
||||
linalg::ElementWiseTransformHost(t_out, ctx.Threads(), [&](auto i, auto) {
|
||||
return std::apply(in, linalg::UnravelIndex<D>(i, shape));
|
||||
});
|
||||
});
|
||||
}
|
||||
} // namespace
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user