Support F order for the tensor type. (#8872)

- Add F order support for tensor and view.
- Use parameter pack for automatic type cast. (avoid excessive static cast for shape).
This commit is contained in:
Jiaming Yuan 2023-03-08 03:27:49 +08:00 committed by GitHub
parent f53055f75e
commit f236640427
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 194 additions and 94 deletions

View File

@ -15,11 +15,11 @@
#include <algorithm> #include <algorithm>
#include <cassert> #include <cassert>
#include <cinttypes> // std::int32_t #include <cinttypes> // for int32_t
#include <cstddef> // std::size_t #include <cstddef> // for size_t
#include <limits> #include <limits>
#include <string> #include <string>
#include <tuple> #include <tuple> // for make_tuple
#include <type_traits> #include <type_traits>
#include <utility> #include <utility>
#include <vector> #include <vector>
@ -37,8 +37,7 @@
#endif // defined (__CUDA__) || defined(__NVCC__) #endif // defined (__CUDA__) || defined(__NVCC__)
#endif // LINALG_HD #endif // LINALG_HD
namespace xgboost { namespace xgboost::linalg {
namespace linalg {
namespace detail { namespace detail {
struct ArrayInterfaceHandler { struct ArrayInterfaceHandler {
@ -86,7 +85,7 @@ template <typename I>
struct RangeTag { struct RangeTag {
I beg; I beg;
I end; I end;
constexpr size_t Size() const { return end - beg; } [[nodiscard]] constexpr size_t Size() const { return end - beg; }
}; };
/** /**
@ -158,14 +157,34 @@ inline LINALG_HD int Popc(uint64_t v) {
#endif // compiler #endif // compiler
} }
template <std::size_t D, typename Head>
LINALG_HD void IndexToArr(std::size_t (&arr)[D], Head head) {
static_assert(std::is_integral<std::remove_reference_t<Head>>::value, "Invalid index type.");
arr[D - 1] = head;
}
/**
* \brief Convert index from parameter pack to C-style array.
*/
template <std::size_t D, typename Head, typename... Rest>
LINALG_HD void IndexToArr(std::size_t (&arr)[D], Head head, Rest &&...index) {
static_assert(sizeof...(Rest) < D, "Index overflow.");
static_assert(std::is_integral<std::remove_reference_t<Head>>::value, "Invalid index type.");
arr[D - sizeof...(Rest) - 1] = head;
IndexToArr(arr, std::forward<Rest>(index)...);
}
template <class T, std::size_t N, std::size_t... Idx> template <class T, std::size_t N, std::size_t... Idx>
constexpr auto Arr2Tup(T (&arr)[N], std::index_sequence<Idx...>) { constexpr auto ArrToTuple(T (&arr)[N], std::index_sequence<Idx...>) {
return std::make_tuple(arr[Idx]...); return std::make_tuple(arr[Idx]...);
} }
/**
* \brief Convert C-styple array to std::tuple.
*/
template <class T, std::size_t N> template <class T, std::size_t N>
constexpr auto Arr2Tup(T (&arr)[N]) { constexpr auto ArrToTuple(T (&arr)[N]) {
return Arr2Tup(arr, std::make_index_sequence<N>{}); return ArrToTuple(arr, std::make_index_sequence<N>{});
} }
// uint division optimization inspired by the CIndexer in cupy. Division operation is // uint division optimization inspired by the CIndexer in cupy. Division operation is
@ -188,7 +207,7 @@ LINALG_HD auto UnravelImpl(I idx, common::Span<size_t const, D> shape) {
} }
} }
index[0] = idx; index[0] = idx;
return Arr2Tup(index); return ArrToTuple(index);
} }
template <size_t dim, typename I, int32_t D> template <size_t dim, typename I, int32_t D>
@ -252,6 +271,11 @@ constexpr detail::RangeTag<I> Range(I beg, I end) {
return {beg, end}; return {beg, end};
} }
enum Order : std::uint8_t {
kC, // Row major
kF, // Col major
};
/** /**
* \brief A tensor view with static type and dimension. It implements indexing and slicing. * \brief A tensor view with static type and dimension. It implements indexing and slicing.
* *
@ -377,7 +401,11 @@ class TensorView {
* \param device Device ordinal * \param device Device ordinal
*/ */
template <typename I, int32_t D> template <typename I, int32_t D>
LINALG_HD TensorView(common::Span<T> data, I const (&shape)[D], int32_t device) LINALG_HD TensorView(common::Span<T> data, I const (&shape)[D], std::int32_t device)
: TensorView{data, shape, device, Order::kC} {}
template <typename I, int32_t D>
LINALG_HD TensorView(common::Span<T> data, I const (&shape)[D], std::int32_t device, Order order)
: data_{data}, ptr_{data_.data()}, device_{device} { : data_{data}, ptr_{data_.data()}, device_{device} {
static_assert(D > 0 && D <= kDim, "Invalid shape."); static_assert(D > 0 && D <= kDim, "Invalid shape.");
// shape // shape
@ -386,7 +414,19 @@ class TensorView {
shape_[i] = 1; shape_[i] = 1;
} }
// stride // stride
switch (order) {
case Order::kC: {
detail::CalcStride(shape_, stride_); detail::CalcStride(shape_, stride_);
break;
}
case Order::kF: {
detail::CalcStride<kDim, true>(shape_, stride_);
break;
}
default: {
SPAN_CHECK(false);
}
}
// size // size
this->CalcSize(); this->CalcSize();
} }
@ -490,17 +530,17 @@ class TensorView {
/** /**
* \brief Number of items in the tensor. * \brief Number of items in the tensor.
*/ */
LINALG_HD size_t Size() const { return size_; } LINALG_HD [[nodiscard]] std::size_t Size() const { return size_; }
/** /**
* \brief Whether this is a contiguous array, both C and F contiguous returns true. * \brief Whether this is a contiguous array, both C and F contiguous returns true.
*/ */
LINALG_HD bool Contiguous() const { LINALG_HD [[nodiscard]] bool Contiguous() const {
return data_.size() == this->Size() || this->CContiguous() || this->FContiguous(); return data_.size() == this->Size() || this->CContiguous() || this->FContiguous();
} }
/** /**
* \brief Whether it's a c-contiguous array. * \brief Whether it's a c-contiguous array.
*/ */
LINALG_HD bool CContiguous() const { LINALG_HD [[nodiscard]] bool CContiguous() const {
StrideT stride; StrideT stride;
static_assert(std::is_same<decltype(stride), decltype(stride_)>::value); static_assert(std::is_same<decltype(stride), decltype(stride_)>::value);
// It's contiguous if the stride can be calculated from shape. // It's contiguous if the stride can be calculated from shape.
@ -510,7 +550,7 @@ class TensorView {
/** /**
* \brief Whether it's a f-contiguous array. * \brief Whether it's a f-contiguous array.
*/ */
LINALG_HD bool FContiguous() const { LINALG_HD [[nodiscard]] bool FContiguous() const {
StrideT stride; StrideT stride;
static_assert(std::is_same<decltype(stride), decltype(stride_)>::value); static_assert(std::is_same<decltype(stride), decltype(stride_)>::value);
// It's contiguous if the stride can be calculated from shape. // It's contiguous if the stride can be calculated from shape.
@ -530,16 +570,38 @@ class TensorView {
/** /**
* \brief Constructor for automatic type deduction. * \brief Constructor for automatic type deduction.
*/ */
template <typename Container, typename I, int32_t D, template <typename Container, typename... S,
std::enable_if_t<!common::detail::IsSpan<Container>::value> * = nullptr> std::enable_if_t<!common::detail::IsSpan<Container>::value &&
auto MakeTensorView(Container &data, I const (&shape)[D], int32_t device) { // NOLINT !std::is_pointer_v<Container>> * = nullptr>
auto MakeTensorView(Context const *ctx, Container &data, S &&...shape) { // NOLINT
using T = typename Container::value_type; using T = typename Container::value_type;
return TensorView<T, D>{data, shape, device}; std::size_t in_shape[sizeof...(S)];
detail::IndexToArr(in_shape, std::forward<S>(shape)...);
return TensorView<T, sizeof...(S)>{data, in_shape, ctx->gpu_id};
} }
template <typename T, typename I, int32_t D> template <typename T, typename... S>
LINALG_HD auto MakeTensorView(common::Span<T> data, I const (&shape)[D], int32_t device) { LINALG_HD auto MakeTensorView(std::int32_t device, common::Span<T> data, S &&...shape) {
return TensorView<T, D>{data, shape, device}; std::size_t in_shape[sizeof...(S)];
detail::IndexToArr(in_shape, std::forward<S>(shape)...);
return TensorView<T, sizeof...(S)>{data, in_shape, device};
}
template <typename T, typename... S>
auto MakeTensorView(Context const *ctx, common::Span<T> data, S &&...shape) {
return MakeTensorView(ctx->gpu_id, data, std::forward<S>(shape)...);
}
template <typename T, typename... S>
auto MakeTensorView(Context const *ctx, HostDeviceVector<T> *data, S &&...shape) {
auto span = ctx->IsCPU() ? data->HostSpan() : data->DeviceSpan();
return MakeTensorView(ctx->gpu_id, span, std::forward<S>(shape)...);
}
template <typename T, typename... S>
auto MakeTensorView(Context const *ctx, HostDeviceVector<T> const *data, S &&...shape) {
auto span = ctx->IsCPU() ? data->ConstHostSpan() : data->ConstDeviceSpan();
return MakeTensorView(ctx->gpu_id, span, std::forward<S>(shape)...);
} }
/** /**
@ -559,6 +621,13 @@ LINALG_HD auto UnravelIndex(size_t idx, std::size_t const (&shape)[D]) {
return UnravelIndex(idx, common::Span<std::size_t const, D>(shape)); return UnravelIndex(idx, common::Span<std::size_t const, D>(shape));
} }
template <typename... S>
LINALG_HD auto UnravelIndex(std::size_t idx, S... shape) {
std::size_t s[sizeof...(S)];
detail::IndexToArr(s, shape...);
return UnravelIndex(idx, common::Span<std::size_t const, sizeof...(S)>(s));
}
/** /**
* \brief A view over a vector, specialization of Tensor * \brief A view over a vector, specialization of Tensor
* *
@ -676,6 +745,7 @@ class Tensor {
private: private:
HostDeviceVector<T> data_; HostDeviceVector<T> data_;
ShapeT shape_{0}; ShapeT shape_{0};
Order order_{Order::kC};
template <typename I, std::int32_t D> template <typename I, std::int32_t D>
void Initialize(I const (&shape)[D], std::int32_t device) { void Initialize(I const (&shape)[D], std::int32_t device) {
@ -701,11 +771,12 @@ class Tensor {
* See \ref TensorView for parameters of this constructor. * See \ref TensorView for parameters of this constructor.
*/ */
template <typename I, int32_t D> template <typename I, int32_t D>
explicit Tensor(I const (&shape)[D], int32_t device) explicit Tensor(I const (&shape)[D], std::int32_t device, Order order = kC)
: Tensor{common::Span<I const, D>{shape}, device} {} : Tensor{common::Span<I const, D>{shape}, device, order} {}
template <typename I, size_t D> template <typename I, size_t D>
explicit Tensor(common::Span<I const, D> shape, int32_t device) { explicit Tensor(common::Span<I const, D> shape, std::int32_t device, Order order = kC)
: order_{order} {
// No device unroll as this is a host only function. // No device unroll as this is a host only function.
std::copy(shape.data(), shape.data() + D, shape_); std::copy(shape.data(), shape.data() + D, shape_);
for (auto i = D; i < kDim; ++i) { for (auto i = D; i < kDim; ++i) {
@ -724,7 +795,8 @@ class Tensor {
* Initialize from 2 host iterators. * Initialize from 2 host iterators.
*/ */
template <typename It, typename I, int32_t D> template <typename It, typename I, int32_t D>
explicit Tensor(It begin, It end, I const (&shape)[D], int32_t device) { explicit Tensor(It begin, It end, I const (&shape)[D], std::int32_t device, Order order = kC)
: order_{order} {
auto &h_vec = data_.HostVector(); auto &h_vec = data_.HostVector();
h_vec.insert(h_vec.begin(), begin, end); h_vec.insert(h_vec.begin(), begin, end);
// shape // shape
@ -732,8 +804,9 @@ class Tensor {
} }
template <typename I, int32_t D> template <typename I, int32_t D>
explicit Tensor(std::initializer_list<T> data, I const (&shape)[D], explicit Tensor(std::initializer_list<T> data, I const (&shape)[D], std::int32_t device,
int32_t device = Context::kCpuId) { Order order = kC)
: order_{order} {
auto &h_vec = data_.HostVector(); auto &h_vec = data_.HostVector();
h_vec = data; h_vec = data;
// shape // shape
@ -763,27 +836,27 @@ class Tensor {
if (device >= 0) { if (device >= 0) {
data_.SetDevice(device); data_.SetDevice(device);
auto span = data_.DeviceSpan(); auto span = data_.DeviceSpan();
return {span, shape_, device}; return {span, shape_, device, order_};
} else { } else {
auto span = data_.HostSpan(); auto span = data_.HostSpan();
return {span, shape_, device}; return {span, shape_, device, order_};
} }
} }
TensorView<T const, kDim> View(int32_t device) const { TensorView<T const, kDim> View(int32_t device) const {
if (device >= 0) { if (device >= 0) {
data_.SetDevice(device); data_.SetDevice(device);
auto span = data_.ConstDeviceSpan(); auto span = data_.ConstDeviceSpan();
return {span, shape_, device}; return {span, shape_, device, order_};
} else { } else {
auto span = data_.ConstHostSpan(); auto span = data_.ConstHostSpan();
return {span, shape_, device}; return {span, shape_, device, order_};
} }
} }
auto HostView() const { return this->View(-1); } auto HostView() const { return this->View(-1); }
auto HostView() { return this->View(-1); } auto HostView() { return this->View(-1); }
size_t Size() const { return data_.Size(); } [[nodiscard]] size_t Size() const { return data_.Size(); }
auto Shape() const { return common::Span<size_t const, kDim>{shape_}; } auto Shape() const { return common::Span<size_t const, kDim>{shape_}; }
auto Shape(size_t i) const { return shape_[i]; } auto Shape(size_t i) const { return shape_[i]; }
@ -837,12 +910,26 @@ class Tensor {
void Reshape(size_t (&shape)[D]) { void Reshape(size_t (&shape)[D]) {
this->Reshape(common::Span<size_t const, D>{shape}); this->Reshape(common::Span<size_t const, D>{shape});
} }
/**
* \brief Get a host view on the slice.
*/
template <typename... S>
auto Slice(S &&...slices) const {
return this->HostView().Slice(std::forward<S>(slices)...);
}
/**
* \brief Get a host view on the slice.
*/
template <typename... S>
auto Slice(S &&...slices) {
return this->HostView().Slice(std::forward<S>(slices)...);
}
/** /**
* \brief Set device ordinal for this tensor. * \brief Set device ordinal for this tensor.
*/ */
void SetDevice(int32_t device) const { data_.SetDevice(device); } void SetDevice(int32_t device) const { data_.SetDevice(device); }
int32_t DeviceIdx() const { return data_.DeviceIdx(); } [[nodiscard]] int32_t DeviceIdx() const { return data_.DeviceIdx(); }
}; };
template <typename T> template <typename T>
@ -900,8 +987,7 @@ void Stack(Tensor<T, D> *l, Tensor<T, D> const &r) {
shape[0] = l->Shape(0) + r.Shape(0); shape[0] = l->Shape(0) + r.Shape(0);
}); });
} }
} // namespace linalg } // namespace xgboost::linalg
} // namespace xgboost
#if defined(LINALG_HD) #if defined(LINALG_HD)
#undef LINALG_HD #undef LINALG_HD

View File

@ -451,9 +451,8 @@ class QuantileError : public MetricNoCache {
auto alpha = ctx->IsCPU() ? alpha_.ConstHostSpan() : alpha_.ConstDeviceSpan(); auto alpha = ctx->IsCPU() ? alpha_.ConstHostSpan() : alpha_.ConstDeviceSpan();
std::size_t n_targets = preds.Size() / info.num_row_ / alpha_.Size(); std::size_t n_targets = preds.Size() / info.num_row_ / alpha_.Size();
CHECK_NE(n_targets, 0); CHECK_NE(n_targets, 0);
auto y_predt = linalg::MakeTensorView( auto y_predt = linalg::MakeTensorView(ctx, &preds, static_cast<std::size_t>(info.num_row_),
ctx->IsCPU() ? preds.ConstHostSpan() : preds.ConstDeviceSpan(), alpha_.Size(), n_targets);
{static_cast<std::size_t>(info.num_row_), alpha_.Size(), n_targets}, ctx->gpu_id);
info.weights_.SetDevice(ctx->gpu_id); info.weights_.SetDevice(ctx->gpu_id);
common::OptionalWeights weight{ctx->IsCPU() ? info.weights_.ConstHostSpan() common::OptionalWeights weight{ctx->IsCPU() ? info.weights_.ConstHostSpan()

View File

@ -23,9 +23,7 @@
#include "xgboost/span.h" // Span #include "xgboost/span.h" // Span
#include "xgboost/tree_model.h" // RegTree #include "xgboost/tree_model.h" // RegTree
namespace xgboost { namespace xgboost::obj::detail {
namespace obj {
namespace detail {
void EncodeTreeLeafHost(Context const* ctx, RegTree const& tree, void EncodeTreeLeafHost(Context const* ctx, RegTree const& tree,
std::vector<bst_node_t> const& position, std::vector<size_t>* p_nptr, std::vector<bst_node_t> const& position, std::vector<size_t>* p_nptr,
std::vector<bst_node_t>* p_nidx, std::vector<size_t>* p_ridx) { std::vector<bst_node_t>* p_nidx, std::vector<size_t>* p_ridx) {
@ -98,8 +96,8 @@ void UpdateTreeLeafHost(Context const* ctx, std::vector<bst_node_t> const& posit
auto const& h_node_idx = nidx; auto const& h_node_idx = nidx;
auto const& h_node_ptr = nptr; auto const& h_node_ptr = nptr;
CHECK_LE(h_node_ptr.back(), info.num_row_); CHECK_LE(h_node_ptr.back(), info.num_row_);
auto h_predt = linalg::MakeTensorView(predt.ConstHostSpan(), auto h_predt = linalg::MakeTensorView(ctx, predt.ConstHostSpan(), info.num_row_,
{info.num_row_, predt.Size() / info.num_row_}, ctx->gpu_id); predt.Size() / info.num_row_);
// loop over each leaf // loop over each leaf
common::ParallelFor(quantiles.size(), ctx->Threads(), [&](size_t k) { common::ParallelFor(quantiles.size(), ctx->Threads(), [&](size_t k) {
@ -138,11 +136,8 @@ void UpdateTreeLeafHost(Context const* ctx, std::vector<bst_node_t> const& posit
#if !defined(XGBOOST_USE_CUDA) #if !defined(XGBOOST_USE_CUDA)
void UpdateTreeLeafDevice(Context const*, common::Span<bst_node_t const>, std::int32_t, void UpdateTreeLeafDevice(Context const*, common::Span<bst_node_t const>, std::int32_t,
MetaInfo const&, float learning_rate, HostDeviceVector<float> const&, MetaInfo const&, float, HostDeviceVector<float> const&, float, RegTree*) {
float, RegTree*) {
common::AssertGPUSupport(); common::AssertGPUSupport();
} }
#endif // !defined(XGBOOST_USE_CUDA) #endif // !defined(XGBOOST_USE_CUDA)
} // namespace detail } // namespace xgboost::obj::detail
} // namespace obj
} // namespace xgboost

View File

@ -157,8 +157,8 @@ void UpdateTreeLeafDevice(Context const* ctx, common::Span<bst_node_t const> pos
HostDeviceVector<float> quantiles; HostDeviceVector<float> quantiles;
predt.SetDevice(ctx->gpu_id); predt.SetDevice(ctx->gpu_id);
auto d_predt = linalg::MakeTensorView(predt.ConstDeviceSpan(), auto d_predt = linalg::MakeTensorView(ctx, predt.ConstDeviceSpan(), info.num_row_,
{info.num_row_, predt.Size() / info.num_row_}, ctx->gpu_id); predt.Size() / info.num_row_);
CHECK_LT(group_idx, d_predt.Shape(1)); CHECK_LT(group_idx, d_predt.Shape(1));
auto t_predt = d_predt.Slice(linalg::All(), group_idx); auto t_predt = d_predt.Slice(linalg::All(), group_idx);
auto d_labels = info.labels.View(ctx->gpu_id).Slice(linalg::All(), IdxY(info, group_idx)); auto d_labels = info.labels.View(ctx->gpu_id).Slice(linalg::All(), IdxY(info, group_idx));

View File

@ -64,8 +64,7 @@ class QuantileRegression : public ObjFunction {
out_gpair->SetDevice(ctx_->gpu_id); out_gpair->SetDevice(ctx_->gpu_id);
out_gpair->Resize(n_targets * info.num_row_); out_gpair->Resize(n_targets * info.num_row_);
auto gpair = auto gpair =
linalg::MakeTensorView(ctx_->IsCPU() ? out_gpair->HostSpan() : out_gpair->DeviceSpan(), linalg::MakeTensorView(ctx_, out_gpair, info.num_row_, n_alphas, n_targets / n_alphas);
{info.num_row_, n_alphas, n_targets / n_alphas}, ctx_->gpu_id);
info.weights_.SetDevice(ctx_->gpu_id); info.weights_.SetDevice(ctx_->gpu_id);
common::OptionalWeights weight{ctx_->IsCPU() ? info.weights_.ConstHostSpan() common::OptionalWeights weight{ctx_->IsCPU() ? info.weights_.ConstHostSpan()
@ -80,15 +79,8 @@ class QuantileRegression : public ObjFunction {
linalg::ElementWiseKernel( linalg::ElementWiseKernel(
ctx_, gpair, [=] XGBOOST_DEVICE(std::size_t i, GradientPair const&) mutable { ctx_, gpair, [=] XGBOOST_DEVICE(std::size_t i, GradientPair const&) mutable {
auto idx = linalg::UnravelIndex(static_cast<std::size_t>(i), auto [sample_id, quantile_id, target_id] =
{static_cast<std::size_t>(n_samples), linalg::UnravelIndex(i, n_samples, alpha.size(), n_targets / alpha.size());
static_cast<std::size_t>(alpha.size()),
static_cast<std::size_t>(n_targets / alpha.size())});
// std::tie is not available for cuda kernel.
std::size_t sample_id = std::get<0>(idx);
std::size_t quantile_id = std::get<1>(idx);
std::size_t target_id = std::get<2>(idx);
auto d = predt(i) - labels(sample_id, target_id); auto d = predt(i) - labels(sample_id, target_id);
auto h = weight[sample_id]; auto h = weight[sample_id];

View File

@ -274,8 +274,7 @@ void QuantileHistMaker::Builder::InitData(DMatrix *fmat, const RegTree &tree,
histogram_builder_->Reset(n_total_bins, HistBatch(param_), ctx_->Threads(), page_id, histogram_builder_->Reset(n_total_bins, HistBatch(param_), ctx_->Threads(), page_id,
collective::IsDistributed(), fmat->IsColumnSplit()); collective::IsDistributed(), fmat->IsColumnSplit());
auto m_gpair = auto m_gpair = linalg::MakeTensorView(ctx_, *gpair, gpair->size(), static_cast<std::size_t>(1));
linalg::MakeTensorView(*gpair, {gpair->size(), static_cast<std::size_t>(1)}, ctx_->gpu_id);
SampleGradient(ctx_, *param_, m_gpair); SampleGradient(ctx_, *param_, m_gpair);
} }

View File

@ -6,17 +6,18 @@
#include <xgboost/host_device_vector.h> #include <xgboost/host_device_vector.h>
#include <xgboost/linalg.h> #include <xgboost/linalg.h>
#include <numeric> #include <cstddef> // size_t
#include <numeric> // iota
#include <vector>
#include "../../../src/common/linalg_op.h" #include "../../../src/common/linalg_op.h"
namespace xgboost { namespace xgboost::linalg {
namespace linalg {
namespace { namespace {
auto kCpuId = Context::kCpuId; auto kCpuId = Context::kCpuId;
} }
auto MakeMatrixFromTest(HostDeviceVector<float> *storage, size_t n_rows, size_t n_cols) { auto MakeMatrixFromTest(HostDeviceVector<float> *storage, std::size_t n_rows, std::size_t n_cols) {
storage->Resize(n_rows * n_cols); storage->Resize(n_rows * n_cols);
auto &h_storage = storage->HostVector(); auto &h_storage = storage->HostVector();
@ -48,10 +49,11 @@ TEST(Linalg, VectorView) {
} }
TEST(Linalg, TensorView) { TEST(Linalg, TensorView) {
Context ctx;
std::vector<double> data(2 * 3 * 4, 0); std::vector<double> data(2 * 3 * 4, 0);
std::iota(data.begin(), data.end(), 0); std::iota(data.begin(), data.end(), 0);
auto t = MakeTensorView(data, {2, 3, 4}, -1); auto t = MakeTensorView(&ctx, data, 2, 3, 4);
ASSERT_EQ(t.Shape()[0], 2); ASSERT_EQ(t.Shape()[0], 2);
ASSERT_EQ(t.Shape()[1], 3); ASSERT_EQ(t.Shape()[1], 3);
ASSERT_EQ(t.Shape()[2], 4); ASSERT_EQ(t.Shape()[2], 4);
@ -106,12 +108,12 @@ TEST(Linalg, TensorView) {
{ {
// Don't assign the initial dimension, tensor should be able to deduce the correct dim // Don't assign the initial dimension, tensor should be able to deduce the correct dim
// for Slice. // for Slice.
auto t = MakeTensorView(data, {2, 3, 4}, 0); auto t = MakeTensorView(&ctx, data, 2, 3, 4);
auto s = t.Slice(1, 2, All()); auto s = t.Slice(1, 2, All());
static_assert(decltype(s)::kDimension == 1); static_assert(decltype(s)::kDimension == 1);
} }
{ {
auto t = MakeTensorView(data, {2, 3, 4}, 0); auto t = MakeTensorView(&ctx, data, 2, 3, 4);
auto s = t.Slice(1, linalg::All(), 1); auto s = t.Slice(1, linalg::All(), 1);
ASSERT_EQ(s(0), 13); ASSERT_EQ(s(0), 13);
ASSERT_EQ(s(1), 17); ASSERT_EQ(s(1), 17);
@ -119,7 +121,7 @@ TEST(Linalg, TensorView) {
} }
{ {
// range slice // range slice
auto t = MakeTensorView(data, {2, 3, 4}, 0); auto t = MakeTensorView(&ctx, data, 2, 3, 4);
auto s = t.Slice(linalg::All(), linalg::Range(1, 3), 2); auto s = t.Slice(linalg::All(), linalg::Range(1, 3), 2);
static_assert(decltype(s)::kDimension == 2); static_assert(decltype(s)::kDimension == 2);
std::vector<double> sol{6, 10, 18, 22}; std::vector<double> sol{6, 10, 18, 22};
@ -134,7 +136,7 @@ TEST(Linalg, TensorView) {
} }
{ {
// range slice // range slice
auto t = MakeTensorView(data, {2, 3, 4}, 0); auto t = MakeTensorView(&ctx, data, 2, 3, 4);
auto s = t.Slice(1, linalg::Range(1, 3), linalg::Range(1, 3)); auto s = t.Slice(1, linalg::Range(1, 3), linalg::Range(1, 3));
static_assert(decltype(s)::kDimension == 2); static_assert(decltype(s)::kDimension == 2);
std::vector<double> sol{17, 18, 21, 22}; std::vector<double> sol{17, 18, 21, 22};
@ -149,7 +151,7 @@ TEST(Linalg, TensorView) {
} }
{ {
// same as no slice. // same as no slice.
auto t = MakeTensorView(data, {2, 3, 4}, 0); auto t = MakeTensorView(&ctx, data, 2, 3, 4);
auto s = t.Slice(linalg::All(), linalg::Range(0, 3), linalg::Range(0, 4)); auto s = t.Slice(linalg::All(), linalg::Range(0, 3), linalg::Range(0, 4));
static_assert(decltype(s)::kDimension == 3); static_assert(decltype(s)::kDimension == 3);
auto all = t.Slice(linalg::All(), linalg::All(), linalg::All()); auto all = t.Slice(linalg::All(), linalg::All(), linalg::All());
@ -166,7 +168,7 @@ TEST(Linalg, TensorView) {
{ {
// copy and move constructor. // copy and move constructor.
auto t = MakeTensorView(data, {2, 3, 4}, kCpuId); auto t = MakeTensorView(&ctx, data, 2, 3, 4);
auto from_copy = t; auto from_copy = t;
auto from_move = std::move(t); auto from_move = std::move(t);
for (size_t i = 0; i < t.Shape().size(); ++i) { for (size_t i = 0; i < t.Shape().size(); ++i) {
@ -177,7 +179,7 @@ TEST(Linalg, TensorView) {
{ {
// multiple slices // multiple slices
auto t = MakeTensorView(data, {2, 3, 4}, kCpuId); auto t = MakeTensorView(&ctx, data, 2, 3, 4);
auto s_0 = t.Slice(linalg::All(), linalg::Range(0, 2), linalg::Range(1, 4)); auto s_0 = t.Slice(linalg::All(), linalg::Range(0, 2), linalg::Range(1, 4));
ASSERT_FALSE(s_0.CContiguous()); ASSERT_FALSE(s_0.CContiguous());
auto s_1 = s_0.Slice(1, 1, linalg::Range(0, 2)); auto s_1 = s_0.Slice(1, 1, linalg::Range(0, 2));
@ -208,7 +210,7 @@ TEST(Linalg, TensorView) {
TEST(Linalg, Tensor) { TEST(Linalg, Tensor) {
{ {
Tensor<float, 3> t{{2, 3, 4}, kCpuId}; Tensor<float, 3> t{{2, 3, 4}, kCpuId, Order::kC};
auto view = t.View(kCpuId); auto view = t.View(kCpuId);
auto const &as_const = t; auto const &as_const = t;
@ -227,7 +229,7 @@ TEST(Linalg, Tensor) {
} }
{ {
// Reshape // Reshape
Tensor<float, 3> t{{2, 3, 4}, kCpuId}; Tensor<float, 3> t{{2, 3, 4}, kCpuId, Order::kC};
t.Reshape(4, 3, 2); t.Reshape(4, 3, 2);
ASSERT_EQ(t.Size(), 24); ASSERT_EQ(t.Size(), 24);
ASSERT_EQ(t.Shape(2), 2); ASSERT_EQ(t.Shape(2), 2);
@ -245,7 +247,7 @@ TEST(Linalg, Tensor) {
TEST(Linalg, Empty) { TEST(Linalg, Empty) {
{ {
auto t = TensorView<double, 2>{{}, {0, 3}, kCpuId}; auto t = TensorView<double, 2>{{}, {0, 3}, kCpuId, Order::kC};
for (int32_t i : {0, 1, 2}) { for (int32_t i : {0, 1, 2}) {
auto s = t.Slice(All(), i); auto s = t.Slice(All(), i);
ASSERT_EQ(s.Size(), 0); ASSERT_EQ(s.Size(), 0);
@ -254,7 +256,7 @@ TEST(Linalg, Empty) {
} }
} }
{ {
auto t = Tensor<double, 2>{{0, 3}, kCpuId}; auto t = Tensor<double, 2>{{0, 3}, kCpuId, Order::kC};
ASSERT_EQ(t.Size(), 0); ASSERT_EQ(t.Size(), 0);
auto view = t.View(kCpuId); auto view = t.View(kCpuId);
@ -269,7 +271,7 @@ TEST(Linalg, Empty) {
TEST(Linalg, ArrayInterface) { TEST(Linalg, ArrayInterface) {
auto cpu = kCpuId; auto cpu = kCpuId;
auto t = Tensor<double, 2>{{3, 3}, cpu}; auto t = Tensor<double, 2>{{3, 3}, cpu, Order::kC};
auto v = t.View(cpu); auto v = t.View(cpu);
std::iota(v.Values().begin(), v.Values().end(), 0); std::iota(v.Values().begin(), v.Values().end(), 0);
auto arr = Json::Load(StringView{ArrayInterfaceStr(v)}); auto arr = Json::Load(StringView{ArrayInterfaceStr(v)});
@ -313,21 +315,48 @@ TEST(Linalg, Popc) {
} }
TEST(Linalg, Stack) { TEST(Linalg, Stack) {
Tensor<float, 3> l{{2, 3, 4}, kCpuId}; Tensor<float, 3> l{{2, 3, 4}, kCpuId, Order::kC};
ElementWiseTransformHost(l.View(kCpuId), omp_get_max_threads(), ElementWiseTransformHost(l.View(kCpuId), omp_get_max_threads(),
[=](size_t i, float) { return i; }); [=](size_t i, float) { return i; });
Tensor<float, 3> r_0{{2, 3, 4}, kCpuId}; Tensor<float, 3> r_0{{2, 3, 4}, kCpuId, Order::kC};
ElementWiseTransformHost(r_0.View(kCpuId), omp_get_max_threads(), ElementWiseTransformHost(r_0.View(kCpuId), omp_get_max_threads(),
[=](size_t i, float) { return i; }); [=](size_t i, float) { return i; });
Stack(&l, r_0); Stack(&l, r_0);
Tensor<float, 3> r_1{{0, 3, 4}, kCpuId}; Tensor<float, 3> r_1{{0, 3, 4}, kCpuId, Order::kC};
Stack(&l, r_1); Stack(&l, r_1);
ASSERT_EQ(l.Shape(0), 4); ASSERT_EQ(l.Shape(0), 4);
Stack(&r_1, l); Stack(&r_1, l);
ASSERT_EQ(r_1.Shape(0), l.Shape(0)); ASSERT_EQ(r_1.Shape(0), l.Shape(0));
} }
} // namespace linalg
} // namespace xgboost TEST(Linalg, FOrder) {
std::size_t constexpr kRows = 16, kCols = 3;
std::vector<float> data(kRows * kCols);
MatrixView<float> mat{data, {kRows, kCols}, Context::kCpuId, Order::kF};
float k{0};
for (std::size_t i = 0; i < kRows; ++i) {
for (std::size_t j = 0; j < kCols; ++j) {
mat(i, j) = k;
k++;
}
}
auto column = mat.Slice(linalg::All(), 1);
ASSERT_TRUE(column.FContiguous());
ASSERT_EQ(column.Stride(0), 1);
ASSERT_TRUE(column.CContiguous());
k = 1;
for (auto it = linalg::cbegin(column); it != linalg::cend(column); ++it) {
ASSERT_EQ(*it, k);
k += kCols;
}
k = 1;
auto ptr = column.Values().data();
for (auto it = ptr; it != ptr + kRows; ++it) {
ASSERT_EQ(*it, k);
k += kCols;
}
}
} // namespace xgboost::linalg

View File

@ -7,8 +7,7 @@
#include "xgboost/context.h" #include "xgboost/context.h"
#include "xgboost/linalg.h" #include "xgboost/linalg.h"
namespace xgboost { namespace xgboost::linalg {
namespace linalg {
namespace { namespace {
void TestElementWiseKernel() { void TestElementWiseKernel() {
Tensor<float, 3> l{{2, 3, 4}, 0}; Tensor<float, 3> l{{2, 3, 4}, 0};
@ -55,8 +54,10 @@ void TestElementWiseKernel() {
} }
void TestSlice() { void TestSlice() {
Context ctx;
ctx.gpu_id = 1;
thrust::device_vector<double> data(2 * 3 * 4); thrust::device_vector<double> data(2 * 3 * 4);
auto t = MakeTensorView(dh::ToSpan(data), {2, 3, 4}, 0); auto t = MakeTensorView(&ctx, dh::ToSpan(data), 2, 3, 4);
dh::LaunchN(1, [=] __device__(size_t) { dh::LaunchN(1, [=] __device__(size_t) {
auto s = t.Slice(linalg::All(), linalg::Range(0, 3), linalg::Range(0, 4)); auto s = t.Slice(linalg::All(), linalg::Range(0, 3), linalg::Range(0, 4));
auto all = t.Slice(linalg::All(), linalg::All(), linalg::All()); auto all = t.Slice(linalg::All(), linalg::All(), linalg::All());
@ -75,5 +76,4 @@ void TestSlice() {
TEST(Linalg, GPUElementWise) { TestElementWiseKernel(); } TEST(Linalg, GPUElementWise) { TestElementWiseKernel(); }
TEST(Linalg, GPUTensorView) { TestSlice(); } TEST(Linalg, GPUTensorView) { TestSlice(); }
} // namespace linalg } // namespace xgboost::linalg
} // namespace xgboost

View File

@ -433,8 +433,8 @@ TEST(Objective, DeclareUnifiedTest(AbsoluteErrorLeaf)) {
auto h_labels = info.labels.HostView().Slice(linalg::All(), t); auto h_labels = info.labels.HostView().Slice(linalg::All(), t);
std::iota(linalg::begin(h_labels), linalg::end(h_labels), 0); std::iota(linalg::begin(h_labels), linalg::end(h_labels), 0);
auto h_predt = linalg::MakeTensorView(predt.HostSpan(), {kRows, kTargets}, Context::kCpuId) auto h_predt =
.Slice(linalg::All(), t); linalg::MakeTensorView(&ctx, predt.HostSpan(), kRows, kTargets).Slice(linalg::All(), t);
for (size_t i = 0; i < h_predt.Size(); ++i) { for (size_t i = 0; i < h_predt.Size(); ++i) {
h_predt(i) = h_labels(i) + i; h_predt(i) = h_labels(i) + i;
} }