Add range-based slicing to tensor view. (#7453)

This commit is contained in:
Jiaming Yuan 2021-11-27 13:42:36 +08:00 committed by GitHub
parent 6f38f5affa
commit 85cbd32c5a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 361 additions and 132 deletions

View File

@ -20,6 +20,15 @@
#include <utility>
#include <vector>
// decouple it from xgboost.
#ifndef LINALG_HD
#if defined(__CUDA__) || defined(__NVCC__)
#define LINALG_HD __host__ __device__
#else
#define LINALG_HD
#endif // defined (__CUDA__) || defined(__NVCC__)
#endif // LINALG_HD
namespace xgboost {
namespace linalg {
namespace detail {
@ -46,17 +55,32 @@ constexpr std::enable_if_t<sizeof...(Tail) != 0, size_t> Offset(S (&strides)[D],
return Offset<dim + 1>(strides, n + (head * strides[dim]), std::forward<Tail>(rest)...);
}
template <int32_t D>
constexpr void CalcStride(size_t (&shape)[D], size_t (&stride)[D]) {
template <int32_t D, bool f_array = false>
constexpr void CalcStride(size_t const (&shape)[D], size_t (&stride)[D]) {
if (f_array) {
stride[0] = 1;
for (int32_t s = 1; s < D; ++s) {
stride[s] = shape[s - 1] * stride[s - 1];
}
} else {
stride[D - 1] = 1;
for (int32_t s = D - 2; s >= 0; --s) {
stride[s] = shape[s + 1] * stride[s + 1];
}
}
}
struct AllTag {};
struct IntTag {};
template <typename I>
struct RangeTag {
I beg;
I end;
constexpr size_t Size() const { return end - beg; }
};
/**
* \brief Calculate the dimension of sliced tensor.
*/
@ -83,10 +107,10 @@ template <typename S>
using RemoveCRType = std::remove_const_t<std::remove_reference_t<S>>;
template <typename S>
using IndexToTag = std::conditional_t<std::is_integral<RemoveCRType<S>>::value, IntTag, AllTag>;
using IndexToTag = std::conditional_t<std::is_integral<RemoveCRType<S>>::value, IntTag, S>;
template <int32_t n, typename Fn>
XGBOOST_DEVICE constexpr auto UnrollLoop(Fn fn) {
LINALG_HD constexpr auto UnrollLoop(Fn fn) {
#if defined __CUDA_ARCH__
#pragma unroll n
#endif // defined __CUDA_ARCH__
@ -102,7 +126,7 @@ int32_t NativePopc(T v) {
return c;
}
inline XGBOOST_DEVICE int Popc(uint32_t v) {
inline LINALG_HD int Popc(uint32_t v) {
#if defined(__CUDA_ARCH__)
return __popc(v);
#elif defined(__GNUC__) || defined(__clang__)
@ -114,7 +138,7 @@ inline XGBOOST_DEVICE int Popc(uint32_t v) {
#endif // compiler
}
inline XGBOOST_DEVICE int Popc(uint64_t v) {
inline LINALG_HD int Popc(uint64_t v) {
#if defined(__CUDA_ARCH__)
return __popcll(v);
#elif defined(__GNUC__) || defined(__clang__)
@ -140,7 +164,7 @@ constexpr auto Arr2Tup(T (&arr)[N]) {
// slow on both CPU and GPU, especially 64 bit integer. So here we first try to avoid 64
// bit when the index is smaller, then try to avoid division when it's exp of 2.
template <typename I, int32_t D>
XGBOOST_DEVICE auto UnravelImpl(I idx, common::Span<size_t const, D> shape) {
LINALG_HD auto UnravelImpl(I idx, common::Span<size_t const, D> shape) {
size_t index[D]{0};
static_assert(std::is_signed<decltype(D)>::value,
"Don't change the type without changing the for loop.");
@ -174,7 +198,7 @@ void ReshapeImpl(size_t (&out_shape)[D], I &&s, S &&...rest) {
}
template <typename Fn, typename Tup, size_t... I>
XGBOOST_DEVICE decltype(auto) constexpr Apply(Fn &&f, Tup &&t, std::index_sequence<I...>) {
LINALG_HD decltype(auto) constexpr Apply(Fn &&f, Tup &&t, std::index_sequence<I...>) {
return f(std::get<I>(t)...);
}
@ -185,19 +209,26 @@ XGBOOST_DEVICE decltype(auto) constexpr Apply(Fn &&f, Tup &&t, std::index_sequen
* \param t tuple of arguments
*/
template <typename Fn, typename Tup>
XGBOOST_DEVICE decltype(auto) constexpr Apply(Fn &&f, Tup &&t) {
LINALG_HD decltype(auto) constexpr Apply(Fn &&f, Tup &&t) {
constexpr auto kSize = std::tuple_size<Tup>::value;
return Apply(std::forward<Fn>(f), std::forward<Tup>(t), std::make_index_sequence<kSize>{});
}
} // namespace detail
/**
* \brief Specify all elements in the axis is used for slice.
* \brief Specify all elements in the axis for slicing.
*/
constexpr detail::AllTag All() { return {}; }
/**
* \brief Specify a range of elements in the axis for slicing.
*/
template <typename I>
constexpr detail::RangeTag<I> Range(I beg, I end) {
return {beg, end};
}
/**
* \brief A tensor view with static type and shape. It implements indexing and slicing.
* \brief A tensor view with static type and dimension. It implements indexing and slicing.
*
* Most of the algorithms in XGBoost are implemented for both CPU and GPU without using
* much linear algebra routines, this class is a helper intended to ease some high level
@ -209,7 +240,7 @@ constexpr detail::AllTag All() { return {}; }
* some functions expect data types that can be used in everywhere (update prediction
* cache for example).
*/
template <typename T, int32_t kDim = 5>
template <typename T, int32_t kDim>
class TensorView {
public:
using ShapeT = size_t[kDim];
@ -225,7 +256,7 @@ class TensorView {
int32_t device_{-1};
// Unlike `Tensor`, the data_ can have arbitrary size since this is just a view.
XGBOOST_DEVICE void CalcSize() {
LINALG_HD void CalcSize() {
if (data_.empty()) {
size_ = 0;
} else {
@ -233,9 +264,38 @@ class TensorView {
}
}
template <size_t old_dim, size_t new_dim, int32_t D, typename... S>
XGBOOST_DEVICE size_t MakeSliceDim(size_t new_shape[D], size_t new_stride[D],
detail::AllTag) const {
template <size_t old_dim, size_t new_dim, int32_t D, typename I>
LINALG_HD size_t MakeSliceDim(size_t new_shape[D], size_t new_stride[D],
detail::RangeTag<I> &&range) const {
static_assert(new_dim < D, "");
static_assert(old_dim < kDim, "");
new_stride[new_dim] = stride_[old_dim];
new_shape[new_dim] = range.Size();
assert(static_cast<decltype(shape_[old_dim])>(range.end) <= shape_[old_dim]);
auto offset = stride_[old_dim] * range.beg;
return offset;
}
/**
* \brief Slice dimension for Range tag.
*/
template <size_t old_dim, size_t new_dim, int32_t D, typename I, typename... S>
LINALG_HD size_t MakeSliceDim(size_t new_shape[D], size_t new_stride[D],
detail::RangeTag<I> &&range, S &&...slices) const {
static_assert(new_dim < D, "");
static_assert(old_dim < kDim, "");
new_stride[new_dim] = stride_[old_dim];
new_shape[new_dim] = range.Size();
assert(static_cast<decltype(shape_[old_dim])>(range.end) <= shape_[old_dim]);
auto offset = stride_[old_dim] * range.beg;
return MakeSliceDim<old_dim + 1, new_dim + 1, D>(new_shape, new_stride,
std::forward<S>(slices)...) +
offset;
}
template <size_t old_dim, size_t new_dim, int32_t D>
LINALG_HD size_t MakeSliceDim(size_t new_shape[D], size_t new_stride[D], detail::AllTag) const {
static_assert(new_dim < D, "");
static_assert(old_dim < kDim, "");
new_stride[new_dim] = stride_[old_dim];
@ -246,7 +306,7 @@ class TensorView {
* \brief Slice dimension for All tag.
*/
template <size_t old_dim, size_t new_dim, int32_t D, typename... S>
XGBOOST_DEVICE size_t MakeSliceDim(size_t new_shape[D], size_t new_stride[D], detail::AllTag,
LINALG_HD size_t MakeSliceDim(size_t new_shape[D], size_t new_stride[D], detail::AllTag,
S &&...slices) const {
static_assert(new_dim < D, "");
static_assert(old_dim < kDim, "");
@ -257,7 +317,7 @@ class TensorView {
}
template <size_t old_dim, size_t new_dim, int32_t D, typename Index>
XGBOOST_DEVICE size_t MakeSliceDim(size_t new_shape[D], size_t new_stride[D], Index i) const {
LINALG_HD size_t MakeSliceDim(size_t new_shape[D], size_t new_stride[D], Index i) const {
static_assert(old_dim < kDim, "");
return stride_[old_dim] * i;
}
@ -265,7 +325,7 @@ class TensorView {
* \brief Slice dimension for Index tag.
*/
template <size_t old_dim, size_t new_dim, int32_t D, typename Index, typename... S>
XGBOOST_DEVICE std::enable_if_t<std::is_integral<Index>::value, size_t> MakeSliceDim(
LINALG_HD std::enable_if_t<std::is_integral<Index>::value, size_t> MakeSliceDim(
size_t new_shape[D], size_t new_stride[D], Index i, S &&...slices) const {
static_assert(old_dim < kDim, "");
auto offset = stride_[old_dim] * i;
@ -291,7 +351,7 @@ class TensorView {
* \param device Device ordinal
*/
template <typename I, int32_t D>
XGBOOST_DEVICE TensorView(common::Span<T> data, I const (&shape)[D], int32_t device)
LINALG_HD TensorView(common::Span<T> data, I const (&shape)[D], int32_t device)
: data_{data}, ptr_{data_.data()}, device_{device} {
static_assert(D > 0 && D <= kDim, "Invalid shape.");
// shape
@ -310,7 +370,7 @@ class TensorView {
* stride can be calculated from shape.
*/
template <typename I, int32_t D>
XGBOOST_DEVICE TensorView(common::Span<T> data, I const (&shape)[D], I const (&stride)[D],
LINALG_HD TensorView(common::Span<T> data, I const (&shape)[D], I const (&stride)[D],
int32_t device)
: data_{data}, ptr_{data_.data()}, device_{device} {
static_assert(D == kDim, "Invalid shape & stride.");
@ -321,11 +381,14 @@ class TensorView {
this->CalcSize();
}
XGBOOST_DEVICE TensorView(TensorView const &that)
: data_{that.data_}, ptr_{data_.data()}, size_{that.size_}, device_{that.device_} {
template <
typename U,
std::enable_if_t<common::detail::IsAllowedElementTypeConversion<U, T>::value> * = nullptr>
LINALG_HD TensorView(TensorView<U, kDim> const &that) // NOLINT
: data_{that.Values()}, ptr_{data_.data()}, size_{that.Size()}, device_{that.DeviceIdx()} {
detail::UnrollLoop<kDim>([&](auto i) {
stride_[i] = that.stride_[i];
shape_[i] = that.shape_[i];
stride_[i] = that.Stride(i);
shape_[i] = that.Shape(i);
});
}
@ -343,7 +406,7 @@ class TensorView {
* \endcode
*/
template <typename... Index>
XGBOOST_DEVICE T &operator()(Index &&...index) {
LINALG_HD T &operator()(Index &&...index) {
static_assert(sizeof...(index) <= kDim, "Invalid index.");
size_t offset = detail::Offset<0ul>(stride_, 0ul, std::forward<Index>(index)...);
assert(offset < data_.size() && "Out of bound access.");
@ -353,7 +416,7 @@ class TensorView {
* \brief Index the tensor to obtain a scalar value.
*/
template <typename... Index>
XGBOOST_DEVICE T const &operator()(Index &&...index) const {
LINALG_HD T const &operator()(Index &&...index) const {
static_assert(sizeof...(index) <= kDim, "Invalid index.");
size_t offset = detail::Offset<0ul>(stride_, 0ul, std::forward<Index>(index)...);
assert(offset < data_.size() && "Out of bound access.");
@ -374,7 +437,7 @@ class TensorView {
* \endcode
*/
template <typename... S>
XGBOOST_DEVICE auto Slice(S &&...slices) const {
LINALG_HD auto Slice(S &&...slices) const {
static_assert(sizeof...(slices) <= kDim, "Invalid slice.");
int32_t constexpr kNewDim{detail::CalcSliceDim<detail::IndexToTag<S>...>()};
size_t new_shape[kNewDim];
@ -387,99 +450,77 @@ class TensorView {
return ret;
}
XGBOOST_DEVICE auto Shape() const { return common::Span<size_t const, kDim>{shape_}; }
LINALG_HD auto Shape() const { return common::Span<size_t const, kDim>{shape_}; }
/**
* Get the shape for i^th dimension
*/
XGBOOST_DEVICE auto Shape(size_t i) const { return shape_[i]; }
XGBOOST_DEVICE auto Stride() const { return common::Span<size_t const, kDim>{stride_}; }
LINALG_HD auto Shape(size_t i) const { return shape_[i]; }
LINALG_HD auto Stride() const { return common::Span<size_t const, kDim>{stride_}; }
/**
* Get the stride for i^th dimension, stride is specified as number of items instead of bytes.
*/
XGBOOST_DEVICE auto Stride(size_t i) const { return stride_[i]; }
LINALG_HD auto Stride(size_t i) const { return stride_[i]; }
XGBOOST_DEVICE auto cbegin() const { return data_.cbegin(); } // NOLINT
XGBOOST_DEVICE auto cend() const { return data_.cend(); } // NOLINT
XGBOOST_DEVICE auto begin() { return data_.begin(); } // NOLINT
XGBOOST_DEVICE auto end() { return data_.end(); } // NOLINT
/**
* \brief Number of items in the tensor.
*/
XGBOOST_DEVICE size_t Size() const { return size_; }
LINALG_HD size_t Size() const { return size_; }
/**
* \brief Whether it's a contiguous array. (c and f contiguous are both contiguous)
* \brief Whether this is a contiguous array, both C and F contiguous returns true.
*/
XGBOOST_DEVICE bool Contiguous() const { return size_ == data_.size(); }
LINALG_HD bool Contiguous() const {
return data_.size() == this->Size() || this->CContiguous() || this->FContiguous();
}
/**
* \brief Obtain the raw data.
* \brief Whether it's a c-contiguous array.
*/
XGBOOST_DEVICE auto Values() const { return data_; }
LINALG_HD bool CContiguous() const {
StrideT stride;
static_assert(std::is_same<decltype(stride), decltype(stride_)>::value, "");
// It's contiguous if the stride can be calculated from shape.
detail::CalcStride(shape_, stride);
return common::Span<size_t const, kDim>{stride_} == common::Span<size_t const, kDim>{stride};
}
/**
* \brief Whether it's a f-contiguous array.
*/
LINALG_HD bool FContiguous() const {
StrideT stride;
static_assert(std::is_same<decltype(stride), decltype(stride_)>::value, "");
// It's contiguous if the stride can be calculated from shape.
detail::CalcStride<kDim, true>(shape_, stride);
return common::Span<size_t const, kDim>{stride_} == common::Span<size_t const, kDim>{stride};
}
/**
* \brief Obtain a reference to the raw data.
*/
LINALG_HD auto Values() const -> decltype(data_) const & { return data_; }
/**
* \brief Obtain the CUDA device ordinal.
*/
XGBOOST_DEVICE auto DeviceIdx() const { return device_; }
/**
* \brief Array Interface defined by
* <a href="https://numpy.org/doc/stable/reference/arrays.interface.html">numpy</a>.
*
* `stream` is optionally included when data is on CUDA device.
*/
Json ArrayInterface() const {
Json array_interface{Object{}};
array_interface["data"] = std::vector<Json>(2);
array_interface["data"][0] = Integer(reinterpret_cast<int64_t>(data_.data()));
array_interface["data"][1] = Boolean{true};
if (this->DeviceIdx() >= 0) {
// Change this once we have different CUDA stream.
array_interface["stream"] = Null{};
}
std::vector<Json> shape(Shape().size());
std::vector<Json> stride(Stride().size());
for (size_t i = 0; i < Shape().size(); ++i) {
shape[i] = Integer(Shape(i));
stride[i] = Integer(Stride(i) * sizeof(T));
}
array_interface["shape"] = Array{shape};
array_interface["strides"] = Array{stride};
array_interface["version"] = 3;
char constexpr kT = detail::ArrayInterfaceHandler::TypeChar<T>();
static_assert(kT != '\0', "");
if (DMLC_LITTLE_ENDIAN) {
array_interface["typestr"] = String{"<" + (kT + std::to_string(sizeof(T)))};
} else {
array_interface["typestr"] = String{">" + (kT + std::to_string(sizeof(T)))};
}
return array_interface;
}
/**
* \brief Same as const version, but returns non-readonly data pointer.
*/
Json ArrayInterface() {
auto const &as_const = *this;
auto res = as_const.ArrayInterface();
res["data"][1] = Boolean{false};
return res;
}
auto ArrayInterfaceStr() const {
std::string str;
Json::Dump(this->ArrayInterface(), &str);
return str;
}
auto ArrayInterfaceStr() {
std::string str;
Json::Dump(this->ArrayInterface(), &str);
return str;
}
LINALG_HD auto DeviceIdx() const { return device_; }
};
/**
* \brief Constructor for automatic type deduction.
*/
template <typename Container, typename I, int32_t D,
std::enable_if_t<!common::detail::IsSpan<Container>::value> * = nullptr>
auto MakeTensorView(Container &data, I const (&shape)[D], int32_t device) { // NOLINT
using T = typename Container::value_type;
return TensorView<T, D>{data, shape, device};
}
template <typename T, typename I, int32_t D>
LINALG_HD auto MakeTensorView(common::Span<T> data, I const (&shape)[D], int32_t device) {
return TensorView<T, D>{data, shape, device};
}
/**
* \brief Turns linear index into multi-dimension index. Similar to numpy unravel.
*/
template <size_t D>
XGBOOST_DEVICE auto UnravelIndex(size_t idx, common::Span<size_t const, D> shape) {
LINALG_HD auto UnravelIndex(size_t idx, common::Span<size_t const, D> shape) {
if (idx > std::numeric_limits<uint32_t>::max()) {
return detail::UnravelImpl<uint64_t, D>(static_cast<uint64_t>(idx), shape);
} else {
@ -516,6 +557,70 @@ auto MakeVec(T *ptr, size_t s, int32_t device = -1) {
template <typename T>
using MatrixView = TensorView<T, 2>;
/**
* \brief Array Interface defined by
* <a href="https://numpy.org/doc/stable/reference/arrays.interface.html">numpy</a>.
*
* `stream` is optionally included when data is on CUDA device.
*/
template <typename T, int32_t D>
Json ArrayInterface(TensorView<T const, D> const &t) {
Json array_interface{Object{}};
array_interface["data"] = std::vector<Json>(2);
array_interface["data"][0] = Integer(reinterpret_cast<int64_t>(t.Values().data()));
array_interface["data"][1] = Boolean{true};
if (t.DeviceIdx() >= 0) {
// Change this once we have different CUDA stream.
array_interface["stream"] = Null{};
}
std::vector<Json> shape(t.Shape().size());
std::vector<Json> stride(t.Stride().size());
for (size_t i = 0; i < t.Shape().size(); ++i) {
shape[i] = Integer(t.Shape(i));
stride[i] = Integer(t.Stride(i) * sizeof(T));
}
array_interface["shape"] = Array{shape};
array_interface["strides"] = Array{stride};
array_interface["version"] = 3;
char constexpr kT = detail::ArrayInterfaceHandler::TypeChar<T>();
static_assert(kT != '\0', "");
if (DMLC_LITTLE_ENDIAN) {
array_interface["typestr"] = String{"<" + (kT + std::to_string(sizeof(T)))};
} else {
array_interface["typestr"] = String{">" + (kT + std::to_string(sizeof(T)))};
}
return array_interface;
}
/**
* \brief Same as const version, but returns non-readonly data pointer.
*/
template <typename T, int32_t D>
Json ArrayInterface(TensorView<T, D> const &t) {
TensorView<T const, D> const &as_const = t;
auto res = ArrayInterface(as_const);
res["data"][1] = Boolean{false};
return res;
}
/**
* \brief Return string representation of array interface.
*/
template <typename T, int32_t D>
auto ArrayInterfaceStr(TensorView<T const, D> const &t) {
std::string str;
Json::Dump(ArrayInterface(t), &str);
return str;
}
template <typename T, int32_t D>
auto ArrayInterfaceStr(TensorView<T, D> const &t) {
std::string str;
Json::Dump(ArrayInterface(t), &str);
return str;
}
/**
* \brief A tensor storage. To use it for other functionality like slicing one needs to
* obtain a view first. This way we can use it on both host and device.
@ -674,4 +779,8 @@ void Stack(Tensor<T, D> *l, Tensor<T, D> const &r) {
}
} // namespace linalg
} // namespace xgboost
#if defined(LINALG_HD)
#undef LINALG_HD
#endif // defined(LINALG_HD)
#endif // XGBOOST_LINALG_H_

View File

@ -413,7 +413,7 @@ void CopyTensorInfoImpl(Json arr_interface, linalg::Tensor<T, D>* p_out) {
}
p_out->Reshape(array.shape);
auto t = p_out->View(GenericParameter::kCpuId);
CHECK(t.Contiguous());
CHECK(t.CContiguous());
// FIXME(jiamingy): Remove the use of this default thread.
linalg::ElementWiseKernelHost(t, common::OmpGetNumThreads(0), [&](auto i, auto) {
return linalg::detail::Apply(TypedIndex<T, D>{array}, linalg::UnravelIndex<D>(i, t.Shape()));
@ -531,8 +531,8 @@ void MetaInfo::SetInfo(const char* key, const void* dptr, DataType dtype, size_t
using T = std::remove_pointer_t<decltype(cast_d_ptr)>;
auto t =
linalg::TensorView<T, 1>(common::Span<T>{cast_d_ptr, num}, {num}, GenericParameter::kCpuId);
CHECK(t.Contiguous());
Json interface { t.ArrayInterface() };
CHECK(t.CContiguous());
Json interface { linalg::ArrayInterface(t) };
assert(ArrayInterface<1>{interface}.is_contiguous);
return interface;
};

View File

@ -61,9 +61,9 @@ class FileIterator {
row_block_ = parser_->Value();
using linalg::MakeVec;
indptr_ = MakeVec(row_block_.offset, row_block_.size + 1).ArrayInterfaceStr();
values_ = MakeVec(row_block_.value, row_block_.offset[row_block_.size]).ArrayInterfaceStr();
indices_ = MakeVec(row_block_.index, row_block_.offset[row_block_.size]).ArrayInterfaceStr();
indptr_ = ArrayInterfaceStr(MakeVec(row_block_.offset, row_block_.size + 1));
values_ = ArrayInterfaceStr(MakeVec(row_block_.value, row_block_.offset[row_block_.size]));
indices_ = ArrayInterfaceStr(MakeVec(row_block_.index, row_block_.offset[row_block_.size]));
size_t n_columns = *std::max_element(row_block_.index,
row_block_.index + row_block_.offset[row_block_.size]);

View File

@ -85,8 +85,7 @@ double MultiClassOVR(common::Span<float const> predts, MetaInfo const &info,
auto const &labels = info.labels_.ConstHostVector();
std::vector<double> results_storage(n_classes * 3, 0);
linalg::TensorView<double> results(results_storage,
{n_classes, static_cast<size_t>(3)},
linalg::TensorView<double, 2> results(results_storage, {n_classes, static_cast<size_t>(3)},
GenericParameter::kCpuId);
auto local_area = results.Slice(linalg::All(), 0);
auto tp = results.Slice(linalg::All(), 1);

View File

@ -51,7 +51,7 @@ TEST(Linalg, TensorView) {
std::vector<double> data(2 * 3 * 4, 0);
std::iota(data.begin(), data.end(), 0);
TensorView<double> t{data, {2, 3, 4}, -1};
auto t = MakeTensorView(data, {2, 3, 4}, -1);
ASSERT_EQ(t.Shape()[0], 2);
ASSERT_EQ(t.Shape()[1], 3);
ASSERT_EQ(t.Shape()[2], 4);
@ -96,17 +96,114 @@ TEST(Linalg, TensorView) {
// assignment
TensorView<double, 3> t{data, {2, 3, 4}, 0};
double pi = 3.14159;
auto old = t(1, 2, 3);
t(1, 2, 3) = pi;
ASSERT_EQ(t(1, 2, 3), pi);
t(1, 2, 3) = old;
ASSERT_EQ(t(1, 2, 3), old);
}
{
// Don't assign the initial dimension, tensor should be able to deduce the correct dim
// for Slice.
TensorView<double> t{data, {2, 3, 4}, 0};
auto t = MakeTensorView(data, {2, 3, 4}, 0);
auto s = t.Slice(1, 2, All());
static_assert(decltype(s)::kDimension == 1, "");
}
{
auto t = MakeTensorView(data, {2, 3, 4}, 0);
auto s = t.Slice(1, linalg::All(), 1);
ASSERT_EQ(s(0), 13);
ASSERT_EQ(s(1), 17);
ASSERT_EQ(s(2), 21);
}
{
// range slice
auto t = MakeTensorView(data, {2, 3, 4}, 0);
auto s = t.Slice(linalg::All(), linalg::Range(1, 3), 2);
static_assert(decltype(s)::kDimension == 2, "");
std::vector<double> sol{6, 10, 18, 22};
auto k = 0;
for (size_t i = 0; i < s.Shape(0); ++i) {
for (size_t j = 0; j < s.Shape(1); ++j) {
ASSERT_EQ(s(i, j), sol.at(k));
k++;
}
}
ASSERT_FALSE(s.CContiguous());
}
{
// range slice
auto t = MakeTensorView(data, {2, 3, 4}, 0);
auto s = t.Slice(1, linalg::Range(1, 3), linalg::Range(1, 3));
static_assert(decltype(s)::kDimension == 2, "");
std::vector<double> sol{17, 18, 21, 22};
auto k = 0;
for (size_t i = 0; i < s.Shape(0); ++i) {
for (size_t j = 0; j < s.Shape(1); ++j) {
ASSERT_EQ(s(i, j), sol.at(k));
k++;
}
}
ASSERT_FALSE(s.CContiguous());
}
{
// same as no slice.
auto t = MakeTensorView(data, {2, 3, 4}, 0);
auto s = t.Slice(linalg::All(), linalg::Range(0, 3), linalg::Range(0, 4));
static_assert(decltype(s)::kDimension == 3, "");
auto all = t.Slice(linalg::All(), linalg::All(), linalg::All());
for (size_t i = 0; i < s.Shape(0); ++i) {
for (size_t j = 0; j < s.Shape(1); ++j) {
for (size_t k = 0; k < s.Shape(2); ++k) {
ASSERT_EQ(s(i, j, k), all(i, j, k));
}
}
}
ASSERT_TRUE(s.CContiguous());
ASSERT_TRUE(all.CContiguous());
}
{
// copy and move constructor.
auto t = MakeTensorView(data, {2, 3, 4}, kCpuId);
auto from_copy = t;
auto from_move = std::move(t);
for (size_t i = 0; i < t.Shape().size(); ++i) {
ASSERT_EQ(from_copy.Shape(i), from_move.Shape(i));
ASSERT_EQ(from_copy.Stride(i), from_copy.Stride(i));
}
}
{
// multiple slices
auto t = MakeTensorView(data, {2, 3, 4}, kCpuId);
auto s_0 = t.Slice(linalg::All(), linalg::Range(0, 2), linalg::Range(1, 4));
ASSERT_FALSE(s_0.CContiguous());
auto s_1 = s_0.Slice(1, 1, linalg::Range(0, 2));
ASSERT_EQ(s_1.Size(), 2);
ASSERT_TRUE(s_1.CContiguous());
ASSERT_TRUE(s_1.Contiguous());
ASSERT_EQ(s_1(0), 17);
ASSERT_EQ(s_1(1), 18);
auto s_2 = s_0.Slice(1, linalg::All(), linalg::Range(0, 2));
std::vector<double> sol{13, 14, 17, 18};
auto k = 0;
for (size_t i = 0; i < s_2.Shape(0); i++) {
for (size_t j = 0; j < s_2.Shape(1); ++j) {
ASSERT_EQ(s_2(i, j), sol[k]);
k++;
}
}
}
{
// f-contiguous
TensorView<double, 3> t{data, {4, 3, 2}, {1, 4, 12}, kCpuId};
ASSERT_TRUE(t.Contiguous());
ASSERT_TRUE(t.FContiguous());
ASSERT_FALSE(t.CContiguous());
}
}
TEST(Linalg, Tensor) {
@ -119,7 +216,8 @@ TEST(Linalg, Tensor) {
size_t n = 2 * 3 * 4;
ASSERT_EQ(t.Size(), n);
ASSERT_TRUE(std::equal(k_view.cbegin(), k_view.cbegin(), view.begin()));
ASSERT_TRUE(
std::equal(k_view.Values().cbegin(), k_view.Values().cend(), view.Values().cbegin()));
Tensor<float, 3> t_0{std::move(t)};
ASSERT_EQ(t_0.Size(), n);
@ -173,13 +271,17 @@ TEST(Linalg, ArrayInterface) {
auto cpu = kCpuId;
auto t = Tensor<double, 2>{{3, 3}, cpu};
auto v = t.View(cpu);
std::iota(v.begin(), v.end(), 0);
auto arr = Json::Load(StringView{v.ArrayInterfaceStr()});
std::iota(v.Values().begin(), v.Values().end(), 0);
auto arr = Json::Load(StringView{ArrayInterfaceStr(v)});
ASSERT_EQ(get<Integer>(arr["shape"][0]), 3);
ASSERT_EQ(get<Integer>(arr["strides"][0]), 3 * sizeof(double));
ASSERT_FALSE(get<Boolean>(arr["data"][1]));
ASSERT_EQ(reinterpret_cast<double *>(get<Integer>(arr["data"][0])), v.Values().data());
TensorView<double const, 2> as_const = v;
auto const_arr = ArrayInterface(as_const);
ASSERT_TRUE(get<Boolean>(const_arr["data"][1]));
}
TEST(Linalg, Popc) {

View File

@ -18,7 +18,7 @@ void TestElementWiseKernel() {
*/
// GPU view
auto t = l.View(0).Slice(linalg::All(), 1, linalg::All());
ASSERT_FALSE(t.Contiguous());
ASSERT_FALSE(t.CContiguous());
ElementWiseKernelDevice(t, [] __device__(size_t i, float) { return i; });
// CPU view
t = l.View(GenericParameter::kCpuId).Slice(linalg::All(), 1, linalg::All());
@ -42,7 +42,7 @@ void TestElementWiseKernel() {
*/
auto t = l.View(0);
ElementWiseKernelDevice(t, [] __device__(size_t i, float) { return i; });
ASSERT_TRUE(t.Contiguous());
ASSERT_TRUE(t.CContiguous());
// CPU view
t = l.View(GenericParameter::kCpuId);
@ -56,7 +56,27 @@ void TestElementWiseKernel() {
}
}
}
void TestSlice() {
thrust::device_vector<double> data(2 * 3 * 4);
auto t = MakeTensorView(dh::ToSpan(data), {2, 3, 4}, 0);
dh::LaunchN(1, [=] __device__(size_t) {
auto s = t.Slice(linalg::All(), linalg::Range(0, 3), linalg::Range(0, 4));
auto all = t.Slice(linalg::All(), linalg::All(), linalg::All());
static_assert(decltype(s)::kDimension == 3, "");
for (size_t i = 0; i < s.Shape(0); ++i) {
for (size_t j = 0; j < s.Shape(1); ++j) {
for (size_t k = 0; k < s.Shape(2); ++k) {
SPAN_CHECK(s(i, j, k) == all(i, j, k));
}
}
}
});
}
} // anonymous namespace
TEST(Linalg, GPUElementWise) { TestElementWiseKernel(); }
TEST(Linalg, GPUTensorView) { TestSlice(); }
} // namespace linalg
} // namespace xgboost

View File

@ -42,9 +42,9 @@ TEST(Adapter, CSRArrayAdapter) {
size_t n_features = 100, n_samples = 10;
RandomDataGenerator{n_samples, n_features, 0.5}.GenerateCSR(&values, &indptr, &indices);
using linalg::MakeVec;
auto indptr_arr = MakeVec(indptr.HostPointer(), indptr.Size()).ArrayInterfaceStr();
auto values_arr = MakeVec(values.HostPointer(), values.Size()).ArrayInterfaceStr();
auto indices_arr = MakeVec(indices.HostPointer(), indices.Size()).ArrayInterfaceStr();
auto indptr_arr = ArrayInterfaceStr(MakeVec(indptr.HostPointer(), indptr.Size()));
auto values_arr = ArrayInterfaceStr(MakeVec(values.HostPointer(), values.Size()));
auto indices_arr = ArrayInterfaceStr(MakeVec(indices.HostPointer(), indices.Size()));
auto adapter = data::CSRArrayAdapter(
StringView{indptr_arr.c_str(), indptr_arr.size()},
StringView{values_arr.c_str(), values_arr.size()},

View File

@ -19,9 +19,8 @@ TEST(ArrayInterface, Initialize) {
ASSERT_EQ(arr_interface.type, ArrayInterfaceHandler::kF4);
HostDeviceVector<size_t> u64_storage(storage.Size());
std::string u64_arr_str{linalg::TensorView<size_t const, 2>{
u64_storage.ConstHostSpan(), {kRows, kCols}, GenericParameter::kCpuId}
.ArrayInterfaceStr()};
std::string u64_arr_str{ArrayInterfaceStr(linalg::TensorView<size_t const, 2>{
u64_storage.ConstHostSpan(), {kRows, kCols}, GenericParameter::kCpuId})};
std::copy(storage.ConstHostVector().cbegin(), storage.ConstHostVector().cend(),
u64_storage.HostSpan().begin());
auto u64_arr = ArrayInterface<2>{u64_arr_str};

View File

@ -127,7 +127,8 @@ TEST(MetaInfo, SaveLoadBinary) {
auto orig_margin = info.base_margin_.View(xgboost::GenericParameter::kCpuId);
auto read_margin = inforead.base_margin_.View(xgboost::GenericParameter::kCpuId);
EXPECT_TRUE(std::equal(orig_margin.cbegin(), orig_margin.cend(), read_margin.cbegin()));
EXPECT_TRUE(std::equal(orig_margin.Values().cbegin(), orig_margin.Values().cend(),
read_margin.Values().cbegin()));
EXPECT_EQ(inforead.feature_type_names.size(), kCols);
EXPECT_EQ(inforead.feature_types.Size(), kCols);
@ -259,9 +260,8 @@ TEST(MetaInfo, Validate) {
xgboost::HostDeviceVector<xgboost::bst_group_t> d_groups{groups};
d_groups.SetDevice(0);
d_groups.DevicePointer(); // pull to device
std::string arr_interface_str{
xgboost::linalg::MakeVec(d_groups.ConstDevicePointer(), d_groups.Size(), 0)
.ArrayInterfaceStr()};
std::string arr_interface_str{ArrayInterfaceStr(
xgboost::linalg::MakeVec(d_groups.ConstDevicePointer(), d_groups.Size(), 0))};
EXPECT_THROW(info.SetInfo("group", xgboost::StringView{arr_interface_str}), dmlc::Error);
#endif // defined(XGBOOST_USE_CUDA)
}

View File

@ -30,7 +30,7 @@ inline void TestMetaInfoStridedData(int32_t device) {
is_gpu ? labels.ConstDeviceSpan() : labels.ConstHostSpan(), {32, 2}, device};
auto s = t.Slice(linalg::All(), 0);
auto str = s.ArrayInterfaceStr();
auto str = ArrayInterfaceStr(s);
ASSERT_EQ(s.Size(), 32);
info.SetInfo("label", StringView{str});
@ -48,7 +48,7 @@ inline void TestMetaInfoStridedData(int32_t device) {
auto& h_qid = qid.Data()->HostVector();
std::iota(h_qid.begin(), h_qid.end(), 0);
auto s = qid.View(device).Slice(linalg::All(), 0);
auto str = s.ArrayInterfaceStr();
auto str = ArrayInterfaceStr(s);
info.SetInfo("qid", StringView{str});
auto const& h_result = info.group_ptr_;
ASSERT_EQ(h_result.size(), s.Size() + 1);
@ -62,7 +62,7 @@ inline void TestMetaInfoStridedData(int32_t device) {
auto t_margin = base_margin.View(device).Slice(linalg::All(), 0, linalg::All());
ASSERT_EQ(t_margin.Shape().size(), 2);
info.SetInfo("base_margin", StringView{t_margin.ArrayInterfaceStr()});
info.SetInfo("base_margin", StringView{ArrayInterfaceStr(t_margin)});
auto const& h_result = info.base_margin_.View(-1);
ASSERT_EQ(h_result.Shape().size(), 2);
auto in_margin = base_margin.View(-1);