Optimize array interface input. (#9090)
This commit is contained in:
parent
fb941262b4
commit
17ff471616
@ -24,5 +24,9 @@ constexpr StringView LabelScoreSize() {
|
|||||||
constexpr StringView InfInData() {
|
constexpr StringView InfInData() {
|
||||||
return "Input data contains `inf` or a value too large, while `missing` is not set to `inf`";
|
return "Input data contains `inf` or a value too large, while `missing` is not set to `inf`";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
constexpr StringView NoF128() {
|
||||||
|
return "128-bit floating point is not supported on current platform.";
|
||||||
|
}
|
||||||
} // namespace xgboost::error
|
} // namespace xgboost::error
|
||||||
#endif // XGBOOST_COMMON_ERROR_MSG_H_
|
#endif // XGBOOST_COMMON_ERROR_MSG_H_
|
||||||
|
|||||||
@ -7,8 +7,9 @@
|
|||||||
#define XGBOOST_DATA_ARRAY_INTERFACE_H_
|
#define XGBOOST_DATA_ARRAY_INTERFACE_H_
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cstddef> // std::size_t
|
#include <cstddef> // for size_t
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
|
#include <limits> // for numeric_limits
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <type_traits> // std::alignment_of,std::remove_pointer_t
|
#include <type_traits> // std::alignment_of,std::remove_pointer_t
|
||||||
@ -17,6 +18,7 @@
|
|||||||
|
|
||||||
#include "../common/bitfield.h"
|
#include "../common/bitfield.h"
|
||||||
#include "../common/common.h"
|
#include "../common/common.h"
|
||||||
|
#include "../common/error_msg.h" // for NoF128
|
||||||
#include "xgboost/base.h"
|
#include "xgboost/base.h"
|
||||||
#include "xgboost/data.h"
|
#include "xgboost/data.h"
|
||||||
#include "xgboost/json.h"
|
#include "xgboost/json.h"
|
||||||
@ -454,9 +456,8 @@ class ArrayInterface {
|
|||||||
void AssignType(StringView typestr) {
|
void AssignType(StringView typestr) {
|
||||||
using T = ArrayInterfaceHandler::Type;
|
using T = ArrayInterfaceHandler::Type;
|
||||||
if (typestr.size() == 4 && typestr[1] == 'f' && typestr[2] == '1' && typestr[3] == '6') {
|
if (typestr.size() == 4 && typestr[1] == 'f' && typestr[2] == '1' && typestr[3] == '6') {
|
||||||
|
CHECK(sizeof(long double) == 16) << error::NoF128();
|
||||||
type = T::kF16;
|
type = T::kF16;
|
||||||
CHECK(sizeof(long double) == 16)
|
|
||||||
<< "128-bit floating point is not supported on current platform.";
|
|
||||||
} else if (typestr[1] == 'f' && typestr[2] == '2') {
|
} else if (typestr[1] == 'f' && typestr[2] == '2') {
|
||||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600
|
||||||
type = T::kF2;
|
type = T::kF2;
|
||||||
@ -572,19 +573,90 @@ class ArrayInterface {
|
|||||||
// Used only by columnar format.
|
// Used only by columnar format.
|
||||||
RBitField8 valid;
|
RBitField8 valid;
|
||||||
// Array stride
|
// Array stride
|
||||||
size_t strides[D]{0};
|
std::size_t strides[D]{0};
|
||||||
// Array shape
|
// Array shape
|
||||||
size_t shape[D]{0};
|
std::size_t shape[D]{0};
|
||||||
// Type earsed pointer referencing the data.
|
// Type earsed pointer referencing the data.
|
||||||
void const *data{nullptr};
|
void const *data{nullptr};
|
||||||
// Total number of items
|
// Total number of items
|
||||||
size_t n{0};
|
std::size_t n{0};
|
||||||
// Whether the memory is c-contiguous
|
// Whether the memory is c-contiguous
|
||||||
bool is_contiguous{false};
|
bool is_contiguous{false};
|
||||||
// RTTI, initialized to the f16 to avoid masking potential bugs in initialization.
|
// RTTI, initialized to the f16 to avoid masking potential bugs in initialization.
|
||||||
ArrayInterfaceHandler::Type type{ArrayInterfaceHandler::kF16};
|
ArrayInterfaceHandler::Type type{ArrayInterfaceHandler::kF16};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <std::int32_t D, typename Fn>
|
||||||
|
void DispatchDType(ArrayInterface<D> const array, std::int32_t device, Fn fn) {
|
||||||
|
// Only used for cuDF at the moment.
|
||||||
|
CHECK_EQ(array.valid.Size(), 0);
|
||||||
|
auto dispatch = [&](auto t) {
|
||||||
|
using T = std::remove_const_t<decltype(t)> const;
|
||||||
|
// Set the data size to max as we don't know the original size of a sliced array:
|
||||||
|
//
|
||||||
|
// Slicing an array A with shape (4, 2, 3) and stride (6, 3, 1) by [:, 1, :] results
|
||||||
|
// in an array B with shape (4, 3) and strides (6, 1). We can't calculate the original
|
||||||
|
// size 24 based on the slice.
|
||||||
|
fn(linalg::TensorView<T, D>{common::Span<T const>{static_cast<T *>(array.data),
|
||||||
|
std::numeric_limits<std::size_t>::max()},
|
||||||
|
array.shape, array.strides, device});
|
||||||
|
};
|
||||||
|
switch (array.type) {
|
||||||
|
case ArrayInterfaceHandler::kF2: {
|
||||||
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600
|
||||||
|
dispatch(__half{});
|
||||||
|
#endif
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case ArrayInterfaceHandler::kF4: {
|
||||||
|
dispatch(float{});
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case ArrayInterfaceHandler::kF8: {
|
||||||
|
dispatch(double{});
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case ArrayInterfaceHandler::kF16: {
|
||||||
|
using T = long double;
|
||||||
|
CHECK(sizeof(long double) == 16) << error::NoF128();
|
||||||
|
dispatch(T{});
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case ArrayInterfaceHandler::kI1: {
|
||||||
|
dispatch(std::int8_t{});
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case ArrayInterfaceHandler::kI2: {
|
||||||
|
dispatch(std::int16_t{});
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case ArrayInterfaceHandler::kI4: {
|
||||||
|
dispatch(std::int32_t{});
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case ArrayInterfaceHandler::kI8: {
|
||||||
|
dispatch(std::int64_t{});
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case ArrayInterfaceHandler::kU1: {
|
||||||
|
dispatch(std::uint8_t{});
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case ArrayInterfaceHandler::kU2: {
|
||||||
|
dispatch(std::uint16_t{});
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case ArrayInterfaceHandler::kU4: {
|
||||||
|
dispatch(std::uint32_t{});
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case ArrayInterfaceHandler::kU8: {
|
||||||
|
dispatch(std::uint64_t{});
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* \brief Helper for type casting.
|
* \brief Helper for type casting.
|
||||||
*/
|
*/
|
||||||
|
|||||||
@ -427,10 +427,13 @@ void CopyTensorInfoImpl(Context const& ctx, Json arr_interface, linalg::Tensor<T
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
p_out->Reshape(array.shape);
|
p_out->Reshape(array.shape);
|
||||||
auto t = p_out->View(Context::kCpuId);
|
auto t_out = p_out->View(Context::kCpuId);
|
||||||
CHECK(t.CContiguous());
|
CHECK(t_out.CContiguous());
|
||||||
linalg::ElementWiseTransformHost(t, ctx.Threads(), [&](auto i, auto) {
|
auto const shape = t_out.Shape();
|
||||||
return linalg::detail::Apply(TypedIndex<T, D>{array}, linalg::UnravelIndex<D>(i, t.Shape()));
|
DispatchDType(array, Context::kCpuId, [&](auto&& in) {
|
||||||
|
linalg::ElementWiseTransformHost(t_out, ctx.Threads(), [&](auto i, auto) {
|
||||||
|
return std::apply(in, linalg::UnravelIndex<D>(i, shape));
|
||||||
|
});
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user