Support F order for the tensor type. (#8872)
- Add F order support for tensor and view. - Use parameter pack for automatic type cast. (avoid excessive static cast for shape).
This commit is contained in:
parent
f53055f75e
commit
f236640427
@ -15,11 +15,11 @@
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <cinttypes> // std::int32_t
|
||||
#include <cstddef> // std::size_t
|
||||
#include <cinttypes> // for int32_t
|
||||
#include <cstddef> // for size_t
|
||||
#include <limits>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <tuple> // for make_tuple
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
@ -37,8 +37,7 @@
|
||||
#endif // defined (__CUDA__) || defined(__NVCC__)
|
||||
#endif // LINALG_HD
|
||||
|
||||
namespace xgboost {
|
||||
namespace linalg {
|
||||
namespace xgboost::linalg {
|
||||
namespace detail {
|
||||
|
||||
struct ArrayInterfaceHandler {
|
||||
@ -86,7 +85,7 @@ template <typename I>
|
||||
struct RangeTag {
|
||||
I beg;
|
||||
I end;
|
||||
constexpr size_t Size() const { return end - beg; }
|
||||
[[nodiscard]] constexpr size_t Size() const { return end - beg; }
|
||||
};
|
||||
|
||||
/**
|
||||
@ -158,14 +157,34 @@ inline LINALG_HD int Popc(uint64_t v) {
|
||||
#endif // compiler
|
||||
}
|
||||
|
||||
template <std::size_t D, typename Head>
|
||||
LINALG_HD void IndexToArr(std::size_t (&arr)[D], Head head) {
|
||||
static_assert(std::is_integral<std::remove_reference_t<Head>>::value, "Invalid index type.");
|
||||
arr[D - 1] = head;
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Convert index from parameter pack to C-style array.
|
||||
*/
|
||||
template <std::size_t D, typename Head, typename... Rest>
|
||||
LINALG_HD void IndexToArr(std::size_t (&arr)[D], Head head, Rest &&...index) {
|
||||
static_assert(sizeof...(Rest) < D, "Index overflow.");
|
||||
static_assert(std::is_integral<std::remove_reference_t<Head>>::value, "Invalid index type.");
|
||||
arr[D - sizeof...(Rest) - 1] = head;
|
||||
IndexToArr(arr, std::forward<Rest>(index)...);
|
||||
}
|
||||
|
||||
template <class T, std::size_t N, std::size_t... Idx>
|
||||
constexpr auto Arr2Tup(T (&arr)[N], std::index_sequence<Idx...>) {
|
||||
constexpr auto ArrToTuple(T (&arr)[N], std::index_sequence<Idx...>) {
|
||||
return std::make_tuple(arr[Idx]...);
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Convert C-styple array to std::tuple.
|
||||
*/
|
||||
template <class T, std::size_t N>
|
||||
constexpr auto Arr2Tup(T (&arr)[N]) {
|
||||
return Arr2Tup(arr, std::make_index_sequence<N>{});
|
||||
constexpr auto ArrToTuple(T (&arr)[N]) {
|
||||
return ArrToTuple(arr, std::make_index_sequence<N>{});
|
||||
}
|
||||
|
||||
// uint division optimization inspired by the CIndexer in cupy. Division operation is
|
||||
@ -188,7 +207,7 @@ LINALG_HD auto UnravelImpl(I idx, common::Span<size_t const, D> shape) {
|
||||
}
|
||||
}
|
||||
index[0] = idx;
|
||||
return Arr2Tup(index);
|
||||
return ArrToTuple(index);
|
||||
}
|
||||
|
||||
template <size_t dim, typename I, int32_t D>
|
||||
@ -252,6 +271,11 @@ constexpr detail::RangeTag<I> Range(I beg, I end) {
|
||||
return {beg, end};
|
||||
}
|
||||
|
||||
enum Order : std::uint8_t {
|
||||
kC, // Row major
|
||||
kF, // Col major
|
||||
};
|
||||
|
||||
/**
|
||||
* \brief A tensor view with static type and dimension. It implements indexing and slicing.
|
||||
*
|
||||
@ -377,7 +401,11 @@ class TensorView {
|
||||
* \param device Device ordinal
|
||||
*/
|
||||
template <typename I, int32_t D>
|
||||
LINALG_HD TensorView(common::Span<T> data, I const (&shape)[D], int32_t device)
|
||||
LINALG_HD TensorView(common::Span<T> data, I const (&shape)[D], std::int32_t device)
|
||||
: TensorView{data, shape, device, Order::kC} {}
|
||||
|
||||
template <typename I, int32_t D>
|
||||
LINALG_HD TensorView(common::Span<T> data, I const (&shape)[D], std::int32_t device, Order order)
|
||||
: data_{data}, ptr_{data_.data()}, device_{device} {
|
||||
static_assert(D > 0 && D <= kDim, "Invalid shape.");
|
||||
// shape
|
||||
@ -386,7 +414,19 @@ class TensorView {
|
||||
shape_[i] = 1;
|
||||
}
|
||||
// stride
|
||||
switch (order) {
|
||||
case Order::kC: {
|
||||
detail::CalcStride(shape_, stride_);
|
||||
break;
|
||||
}
|
||||
case Order::kF: {
|
||||
detail::CalcStride<kDim, true>(shape_, stride_);
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
SPAN_CHECK(false);
|
||||
}
|
||||
}
|
||||
// size
|
||||
this->CalcSize();
|
||||
}
|
||||
@ -490,17 +530,17 @@ class TensorView {
|
||||
/**
|
||||
* \brief Number of items in the tensor.
|
||||
*/
|
||||
LINALG_HD size_t Size() const { return size_; }
|
||||
LINALG_HD [[nodiscard]] std::size_t Size() const { return size_; }
|
||||
/**
|
||||
* \brief Whether this is a contiguous array, both C and F contiguous returns true.
|
||||
*/
|
||||
LINALG_HD bool Contiguous() const {
|
||||
LINALG_HD [[nodiscard]] bool Contiguous() const {
|
||||
return data_.size() == this->Size() || this->CContiguous() || this->FContiguous();
|
||||
}
|
||||
/**
|
||||
* \brief Whether it's a c-contiguous array.
|
||||
*/
|
||||
LINALG_HD bool CContiguous() const {
|
||||
LINALG_HD [[nodiscard]] bool CContiguous() const {
|
||||
StrideT stride;
|
||||
static_assert(std::is_same<decltype(stride), decltype(stride_)>::value);
|
||||
// It's contiguous if the stride can be calculated from shape.
|
||||
@ -510,7 +550,7 @@ class TensorView {
|
||||
/**
|
||||
* \brief Whether it's a f-contiguous array.
|
||||
*/
|
||||
LINALG_HD bool FContiguous() const {
|
||||
LINALG_HD [[nodiscard]] bool FContiguous() const {
|
||||
StrideT stride;
|
||||
static_assert(std::is_same<decltype(stride), decltype(stride_)>::value);
|
||||
// It's contiguous if the stride can be calculated from shape.
|
||||
@ -530,16 +570,38 @@ class TensorView {
|
||||
/**
|
||||
* \brief Constructor for automatic type deduction.
|
||||
*/
|
||||
template <typename Container, typename I, int32_t D,
|
||||
std::enable_if_t<!common::detail::IsSpan<Container>::value> * = nullptr>
|
||||
auto MakeTensorView(Container &data, I const (&shape)[D], int32_t device) { // NOLINT
|
||||
template <typename Container, typename... S,
|
||||
std::enable_if_t<!common::detail::IsSpan<Container>::value &&
|
||||
!std::is_pointer_v<Container>> * = nullptr>
|
||||
auto MakeTensorView(Context const *ctx, Container &data, S &&...shape) { // NOLINT
|
||||
using T = typename Container::value_type;
|
||||
return TensorView<T, D>{data, shape, device};
|
||||
std::size_t in_shape[sizeof...(S)];
|
||||
detail::IndexToArr(in_shape, std::forward<S>(shape)...);
|
||||
return TensorView<T, sizeof...(S)>{data, in_shape, ctx->gpu_id};
|
||||
}
|
||||
|
||||
template <typename T, typename I, int32_t D>
|
||||
LINALG_HD auto MakeTensorView(common::Span<T> data, I const (&shape)[D], int32_t device) {
|
||||
return TensorView<T, D>{data, shape, device};
|
||||
template <typename T, typename... S>
|
||||
LINALG_HD auto MakeTensorView(std::int32_t device, common::Span<T> data, S &&...shape) {
|
||||
std::size_t in_shape[sizeof...(S)];
|
||||
detail::IndexToArr(in_shape, std::forward<S>(shape)...);
|
||||
return TensorView<T, sizeof...(S)>{data, in_shape, device};
|
||||
}
|
||||
|
||||
template <typename T, typename... S>
|
||||
auto MakeTensorView(Context const *ctx, common::Span<T> data, S &&...shape) {
|
||||
return MakeTensorView(ctx->gpu_id, data, std::forward<S>(shape)...);
|
||||
}
|
||||
|
||||
template <typename T, typename... S>
|
||||
auto MakeTensorView(Context const *ctx, HostDeviceVector<T> *data, S &&...shape) {
|
||||
auto span = ctx->IsCPU() ? data->HostSpan() : data->DeviceSpan();
|
||||
return MakeTensorView(ctx->gpu_id, span, std::forward<S>(shape)...);
|
||||
}
|
||||
|
||||
template <typename T, typename... S>
|
||||
auto MakeTensorView(Context const *ctx, HostDeviceVector<T> const *data, S &&...shape) {
|
||||
auto span = ctx->IsCPU() ? data->ConstHostSpan() : data->ConstDeviceSpan();
|
||||
return MakeTensorView(ctx->gpu_id, span, std::forward<S>(shape)...);
|
||||
}
|
||||
|
||||
/**
|
||||
@ -559,6 +621,13 @@ LINALG_HD auto UnravelIndex(size_t idx, std::size_t const (&shape)[D]) {
|
||||
return UnravelIndex(idx, common::Span<std::size_t const, D>(shape));
|
||||
}
|
||||
|
||||
template <typename... S>
|
||||
LINALG_HD auto UnravelIndex(std::size_t idx, S... shape) {
|
||||
std::size_t s[sizeof...(S)];
|
||||
detail::IndexToArr(s, shape...);
|
||||
return UnravelIndex(idx, common::Span<std::size_t const, sizeof...(S)>(s));
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief A view over a vector, specialization of Tensor
|
||||
*
|
||||
@ -676,6 +745,7 @@ class Tensor {
|
||||
private:
|
||||
HostDeviceVector<T> data_;
|
||||
ShapeT shape_{0};
|
||||
Order order_{Order::kC};
|
||||
|
||||
template <typename I, std::int32_t D>
|
||||
void Initialize(I const (&shape)[D], std::int32_t device) {
|
||||
@ -701,11 +771,12 @@ class Tensor {
|
||||
* See \ref TensorView for parameters of this constructor.
|
||||
*/
|
||||
template <typename I, int32_t D>
|
||||
explicit Tensor(I const (&shape)[D], int32_t device)
|
||||
: Tensor{common::Span<I const, D>{shape}, device} {}
|
||||
explicit Tensor(I const (&shape)[D], std::int32_t device, Order order = kC)
|
||||
: Tensor{common::Span<I const, D>{shape}, device, order} {}
|
||||
|
||||
template <typename I, size_t D>
|
||||
explicit Tensor(common::Span<I const, D> shape, int32_t device) {
|
||||
explicit Tensor(common::Span<I const, D> shape, std::int32_t device, Order order = kC)
|
||||
: order_{order} {
|
||||
// No device unroll as this is a host only function.
|
||||
std::copy(shape.data(), shape.data() + D, shape_);
|
||||
for (auto i = D; i < kDim; ++i) {
|
||||
@ -724,7 +795,8 @@ class Tensor {
|
||||
* Initialize from 2 host iterators.
|
||||
*/
|
||||
template <typename It, typename I, int32_t D>
|
||||
explicit Tensor(It begin, It end, I const (&shape)[D], int32_t device) {
|
||||
explicit Tensor(It begin, It end, I const (&shape)[D], std::int32_t device, Order order = kC)
|
||||
: order_{order} {
|
||||
auto &h_vec = data_.HostVector();
|
||||
h_vec.insert(h_vec.begin(), begin, end);
|
||||
// shape
|
||||
@ -732,8 +804,9 @@ class Tensor {
|
||||
}
|
||||
|
||||
template <typename I, int32_t D>
|
||||
explicit Tensor(std::initializer_list<T> data, I const (&shape)[D],
|
||||
int32_t device = Context::kCpuId) {
|
||||
explicit Tensor(std::initializer_list<T> data, I const (&shape)[D], std::int32_t device,
|
||||
Order order = kC)
|
||||
: order_{order} {
|
||||
auto &h_vec = data_.HostVector();
|
||||
h_vec = data;
|
||||
// shape
|
||||
@ -763,27 +836,27 @@ class Tensor {
|
||||
if (device >= 0) {
|
||||
data_.SetDevice(device);
|
||||
auto span = data_.DeviceSpan();
|
||||
return {span, shape_, device};
|
||||
return {span, shape_, device, order_};
|
||||
} else {
|
||||
auto span = data_.HostSpan();
|
||||
return {span, shape_, device};
|
||||
return {span, shape_, device, order_};
|
||||
}
|
||||
}
|
||||
TensorView<T const, kDim> View(int32_t device) const {
|
||||
if (device >= 0) {
|
||||
data_.SetDevice(device);
|
||||
auto span = data_.ConstDeviceSpan();
|
||||
return {span, shape_, device};
|
||||
return {span, shape_, device, order_};
|
||||
} else {
|
||||
auto span = data_.ConstHostSpan();
|
||||
return {span, shape_, device};
|
||||
return {span, shape_, device, order_};
|
||||
}
|
||||
}
|
||||
|
||||
auto HostView() const { return this->View(-1); }
|
||||
auto HostView() { return this->View(-1); }
|
||||
|
||||
size_t Size() const { return data_.Size(); }
|
||||
[[nodiscard]] size_t Size() const { return data_.Size(); }
|
||||
auto Shape() const { return common::Span<size_t const, kDim>{shape_}; }
|
||||
auto Shape(size_t i) const { return shape_[i]; }
|
||||
|
||||
@ -837,12 +910,26 @@ class Tensor {
|
||||
void Reshape(size_t (&shape)[D]) {
|
||||
this->Reshape(common::Span<size_t const, D>{shape});
|
||||
}
|
||||
/**
|
||||
* \brief Get a host view on the slice.
|
||||
*/
|
||||
template <typename... S>
|
||||
auto Slice(S &&...slices) const {
|
||||
return this->HostView().Slice(std::forward<S>(slices)...);
|
||||
}
|
||||
/**
|
||||
* \brief Get a host view on the slice.
|
||||
*/
|
||||
template <typename... S>
|
||||
auto Slice(S &&...slices) {
|
||||
return this->HostView().Slice(std::forward<S>(slices)...);
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Set device ordinal for this tensor.
|
||||
*/
|
||||
void SetDevice(int32_t device) const { data_.SetDevice(device); }
|
||||
int32_t DeviceIdx() const { return data_.DeviceIdx(); }
|
||||
[[nodiscard]] int32_t DeviceIdx() const { return data_.DeviceIdx(); }
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
@ -900,8 +987,7 @@ void Stack(Tensor<T, D> *l, Tensor<T, D> const &r) {
|
||||
shape[0] = l->Shape(0) + r.Shape(0);
|
||||
});
|
||||
}
|
||||
} // namespace linalg
|
||||
} // namespace xgboost
|
||||
} // namespace xgboost::linalg
|
||||
|
||||
#if defined(LINALG_HD)
|
||||
#undef LINALG_HD
|
||||
|
||||
@ -451,9 +451,8 @@ class QuantileError : public MetricNoCache {
|
||||
auto alpha = ctx->IsCPU() ? alpha_.ConstHostSpan() : alpha_.ConstDeviceSpan();
|
||||
std::size_t n_targets = preds.Size() / info.num_row_ / alpha_.Size();
|
||||
CHECK_NE(n_targets, 0);
|
||||
auto y_predt = linalg::MakeTensorView(
|
||||
ctx->IsCPU() ? preds.ConstHostSpan() : preds.ConstDeviceSpan(),
|
||||
{static_cast<std::size_t>(info.num_row_), alpha_.Size(), n_targets}, ctx->gpu_id);
|
||||
auto y_predt = linalg::MakeTensorView(ctx, &preds, static_cast<std::size_t>(info.num_row_),
|
||||
alpha_.Size(), n_targets);
|
||||
|
||||
info.weights_.SetDevice(ctx->gpu_id);
|
||||
common::OptionalWeights weight{ctx->IsCPU() ? info.weights_.ConstHostSpan()
|
||||
|
||||
@ -23,9 +23,7 @@
|
||||
#include "xgboost/span.h" // Span
|
||||
#include "xgboost/tree_model.h" // RegTree
|
||||
|
||||
namespace xgboost {
|
||||
namespace obj {
|
||||
namespace detail {
|
||||
namespace xgboost::obj::detail {
|
||||
void EncodeTreeLeafHost(Context const* ctx, RegTree const& tree,
|
||||
std::vector<bst_node_t> const& position, std::vector<size_t>* p_nptr,
|
||||
std::vector<bst_node_t>* p_nidx, std::vector<size_t>* p_ridx) {
|
||||
@ -98,8 +96,8 @@ void UpdateTreeLeafHost(Context const* ctx, std::vector<bst_node_t> const& posit
|
||||
auto const& h_node_idx = nidx;
|
||||
auto const& h_node_ptr = nptr;
|
||||
CHECK_LE(h_node_ptr.back(), info.num_row_);
|
||||
auto h_predt = linalg::MakeTensorView(predt.ConstHostSpan(),
|
||||
{info.num_row_, predt.Size() / info.num_row_}, ctx->gpu_id);
|
||||
auto h_predt = linalg::MakeTensorView(ctx, predt.ConstHostSpan(), info.num_row_,
|
||||
predt.Size() / info.num_row_);
|
||||
|
||||
// loop over each leaf
|
||||
common::ParallelFor(quantiles.size(), ctx->Threads(), [&](size_t k) {
|
||||
@ -138,11 +136,8 @@ void UpdateTreeLeafHost(Context const* ctx, std::vector<bst_node_t> const& posit
|
||||
|
||||
#if !defined(XGBOOST_USE_CUDA)
|
||||
void UpdateTreeLeafDevice(Context const*, common::Span<bst_node_t const>, std::int32_t,
|
||||
MetaInfo const&, float learning_rate, HostDeviceVector<float> const&,
|
||||
float, RegTree*) {
|
||||
MetaInfo const&, float, HostDeviceVector<float> const&, float, RegTree*) {
|
||||
common::AssertGPUSupport();
|
||||
}
|
||||
#endif // !defined(XGBOOST_USE_CUDA)
|
||||
} // namespace detail
|
||||
} // namespace obj
|
||||
} // namespace xgboost
|
||||
} // namespace xgboost::obj::detail
|
||||
|
||||
@ -157,8 +157,8 @@ void UpdateTreeLeafDevice(Context const* ctx, common::Span<bst_node_t const> pos
|
||||
HostDeviceVector<float> quantiles;
|
||||
predt.SetDevice(ctx->gpu_id);
|
||||
|
||||
auto d_predt = linalg::MakeTensorView(predt.ConstDeviceSpan(),
|
||||
{info.num_row_, predt.Size() / info.num_row_}, ctx->gpu_id);
|
||||
auto d_predt = linalg::MakeTensorView(ctx, predt.ConstDeviceSpan(), info.num_row_,
|
||||
predt.Size() / info.num_row_);
|
||||
CHECK_LT(group_idx, d_predt.Shape(1));
|
||||
auto t_predt = d_predt.Slice(linalg::All(), group_idx);
|
||||
auto d_labels = info.labels.View(ctx->gpu_id).Slice(linalg::All(), IdxY(info, group_idx));
|
||||
|
||||
@ -64,8 +64,7 @@ class QuantileRegression : public ObjFunction {
|
||||
out_gpair->SetDevice(ctx_->gpu_id);
|
||||
out_gpair->Resize(n_targets * info.num_row_);
|
||||
auto gpair =
|
||||
linalg::MakeTensorView(ctx_->IsCPU() ? out_gpair->HostSpan() : out_gpair->DeviceSpan(),
|
||||
{info.num_row_, n_alphas, n_targets / n_alphas}, ctx_->gpu_id);
|
||||
linalg::MakeTensorView(ctx_, out_gpair, info.num_row_, n_alphas, n_targets / n_alphas);
|
||||
|
||||
info.weights_.SetDevice(ctx_->gpu_id);
|
||||
common::OptionalWeights weight{ctx_->IsCPU() ? info.weights_.ConstHostSpan()
|
||||
@ -80,15 +79,8 @@ class QuantileRegression : public ObjFunction {
|
||||
|
||||
linalg::ElementWiseKernel(
|
||||
ctx_, gpair, [=] XGBOOST_DEVICE(std::size_t i, GradientPair const&) mutable {
|
||||
auto idx = linalg::UnravelIndex(static_cast<std::size_t>(i),
|
||||
{static_cast<std::size_t>(n_samples),
|
||||
static_cast<std::size_t>(alpha.size()),
|
||||
static_cast<std::size_t>(n_targets / alpha.size())});
|
||||
|
||||
// std::tie is not available for cuda kernel.
|
||||
std::size_t sample_id = std::get<0>(idx);
|
||||
std::size_t quantile_id = std::get<1>(idx);
|
||||
std::size_t target_id = std::get<2>(idx);
|
||||
auto [sample_id, quantile_id, target_id] =
|
||||
linalg::UnravelIndex(i, n_samples, alpha.size(), n_targets / alpha.size());
|
||||
|
||||
auto d = predt(i) - labels(sample_id, target_id);
|
||||
auto h = weight[sample_id];
|
||||
|
||||
@ -274,8 +274,7 @@ void QuantileHistMaker::Builder::InitData(DMatrix *fmat, const RegTree &tree,
|
||||
histogram_builder_->Reset(n_total_bins, HistBatch(param_), ctx_->Threads(), page_id,
|
||||
collective::IsDistributed(), fmat->IsColumnSplit());
|
||||
|
||||
auto m_gpair =
|
||||
linalg::MakeTensorView(*gpair, {gpair->size(), static_cast<std::size_t>(1)}, ctx_->gpu_id);
|
||||
auto m_gpair = linalg::MakeTensorView(ctx_, *gpair, gpair->size(), static_cast<std::size_t>(1));
|
||||
SampleGradient(ctx_, *param_, m_gpair);
|
||||
}
|
||||
|
||||
|
||||
@ -6,17 +6,18 @@
|
||||
#include <xgboost/host_device_vector.h>
|
||||
#include <xgboost/linalg.h>
|
||||
|
||||
#include <numeric>
|
||||
#include <cstddef> // size_t
|
||||
#include <numeric> // iota
|
||||
#include <vector>
|
||||
|
||||
#include "../../../src/common/linalg_op.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace linalg {
|
||||
namespace xgboost::linalg {
|
||||
namespace {
|
||||
auto kCpuId = Context::kCpuId;
|
||||
}
|
||||
|
||||
auto MakeMatrixFromTest(HostDeviceVector<float> *storage, size_t n_rows, size_t n_cols) {
|
||||
auto MakeMatrixFromTest(HostDeviceVector<float> *storage, std::size_t n_rows, std::size_t n_cols) {
|
||||
storage->Resize(n_rows * n_cols);
|
||||
auto &h_storage = storage->HostVector();
|
||||
|
||||
@ -48,10 +49,11 @@ TEST(Linalg, VectorView) {
|
||||
}
|
||||
|
||||
TEST(Linalg, TensorView) {
|
||||
Context ctx;
|
||||
std::vector<double> data(2 * 3 * 4, 0);
|
||||
std::iota(data.begin(), data.end(), 0);
|
||||
|
||||
auto t = MakeTensorView(data, {2, 3, 4}, -1);
|
||||
auto t = MakeTensorView(&ctx, data, 2, 3, 4);
|
||||
ASSERT_EQ(t.Shape()[0], 2);
|
||||
ASSERT_EQ(t.Shape()[1], 3);
|
||||
ASSERT_EQ(t.Shape()[2], 4);
|
||||
@ -106,12 +108,12 @@ TEST(Linalg, TensorView) {
|
||||
{
|
||||
// Don't assign the initial dimension, tensor should be able to deduce the correct dim
|
||||
// for Slice.
|
||||
auto t = MakeTensorView(data, {2, 3, 4}, 0);
|
||||
auto t = MakeTensorView(&ctx, data, 2, 3, 4);
|
||||
auto s = t.Slice(1, 2, All());
|
||||
static_assert(decltype(s)::kDimension == 1);
|
||||
}
|
||||
{
|
||||
auto t = MakeTensorView(data, {2, 3, 4}, 0);
|
||||
auto t = MakeTensorView(&ctx, data, 2, 3, 4);
|
||||
auto s = t.Slice(1, linalg::All(), 1);
|
||||
ASSERT_EQ(s(0), 13);
|
||||
ASSERT_EQ(s(1), 17);
|
||||
@ -119,7 +121,7 @@ TEST(Linalg, TensorView) {
|
||||
}
|
||||
{
|
||||
// range slice
|
||||
auto t = MakeTensorView(data, {2, 3, 4}, 0);
|
||||
auto t = MakeTensorView(&ctx, data, 2, 3, 4);
|
||||
auto s = t.Slice(linalg::All(), linalg::Range(1, 3), 2);
|
||||
static_assert(decltype(s)::kDimension == 2);
|
||||
std::vector<double> sol{6, 10, 18, 22};
|
||||
@ -134,7 +136,7 @@ TEST(Linalg, TensorView) {
|
||||
}
|
||||
{
|
||||
// range slice
|
||||
auto t = MakeTensorView(data, {2, 3, 4}, 0);
|
||||
auto t = MakeTensorView(&ctx, data, 2, 3, 4);
|
||||
auto s = t.Slice(1, linalg::Range(1, 3), linalg::Range(1, 3));
|
||||
static_assert(decltype(s)::kDimension == 2);
|
||||
std::vector<double> sol{17, 18, 21, 22};
|
||||
@ -149,7 +151,7 @@ TEST(Linalg, TensorView) {
|
||||
}
|
||||
{
|
||||
// same as no slice.
|
||||
auto t = MakeTensorView(data, {2, 3, 4}, 0);
|
||||
auto t = MakeTensorView(&ctx, data, 2, 3, 4);
|
||||
auto s = t.Slice(linalg::All(), linalg::Range(0, 3), linalg::Range(0, 4));
|
||||
static_assert(decltype(s)::kDimension == 3);
|
||||
auto all = t.Slice(linalg::All(), linalg::All(), linalg::All());
|
||||
@ -166,7 +168,7 @@ TEST(Linalg, TensorView) {
|
||||
|
||||
{
|
||||
// copy and move constructor.
|
||||
auto t = MakeTensorView(data, {2, 3, 4}, kCpuId);
|
||||
auto t = MakeTensorView(&ctx, data, 2, 3, 4);
|
||||
auto from_copy = t;
|
||||
auto from_move = std::move(t);
|
||||
for (size_t i = 0; i < t.Shape().size(); ++i) {
|
||||
@ -177,7 +179,7 @@ TEST(Linalg, TensorView) {
|
||||
|
||||
{
|
||||
// multiple slices
|
||||
auto t = MakeTensorView(data, {2, 3, 4}, kCpuId);
|
||||
auto t = MakeTensorView(&ctx, data, 2, 3, 4);
|
||||
auto s_0 = t.Slice(linalg::All(), linalg::Range(0, 2), linalg::Range(1, 4));
|
||||
ASSERT_FALSE(s_0.CContiguous());
|
||||
auto s_1 = s_0.Slice(1, 1, linalg::Range(0, 2));
|
||||
@ -208,7 +210,7 @@ TEST(Linalg, TensorView) {
|
||||
|
||||
TEST(Linalg, Tensor) {
|
||||
{
|
||||
Tensor<float, 3> t{{2, 3, 4}, kCpuId};
|
||||
Tensor<float, 3> t{{2, 3, 4}, kCpuId, Order::kC};
|
||||
auto view = t.View(kCpuId);
|
||||
|
||||
auto const &as_const = t;
|
||||
@ -227,7 +229,7 @@ TEST(Linalg, Tensor) {
|
||||
}
|
||||
{
|
||||
// Reshape
|
||||
Tensor<float, 3> t{{2, 3, 4}, kCpuId};
|
||||
Tensor<float, 3> t{{2, 3, 4}, kCpuId, Order::kC};
|
||||
t.Reshape(4, 3, 2);
|
||||
ASSERT_EQ(t.Size(), 24);
|
||||
ASSERT_EQ(t.Shape(2), 2);
|
||||
@ -245,7 +247,7 @@ TEST(Linalg, Tensor) {
|
||||
|
||||
TEST(Linalg, Empty) {
|
||||
{
|
||||
auto t = TensorView<double, 2>{{}, {0, 3}, kCpuId};
|
||||
auto t = TensorView<double, 2>{{}, {0, 3}, kCpuId, Order::kC};
|
||||
for (int32_t i : {0, 1, 2}) {
|
||||
auto s = t.Slice(All(), i);
|
||||
ASSERT_EQ(s.Size(), 0);
|
||||
@ -254,7 +256,7 @@ TEST(Linalg, Empty) {
|
||||
}
|
||||
}
|
||||
{
|
||||
auto t = Tensor<double, 2>{{0, 3}, kCpuId};
|
||||
auto t = Tensor<double, 2>{{0, 3}, kCpuId, Order::kC};
|
||||
ASSERT_EQ(t.Size(), 0);
|
||||
auto view = t.View(kCpuId);
|
||||
|
||||
@ -269,7 +271,7 @@ TEST(Linalg, Empty) {
|
||||
|
||||
TEST(Linalg, ArrayInterface) {
|
||||
auto cpu = kCpuId;
|
||||
auto t = Tensor<double, 2>{{3, 3}, cpu};
|
||||
auto t = Tensor<double, 2>{{3, 3}, cpu, Order::kC};
|
||||
auto v = t.View(cpu);
|
||||
std::iota(v.Values().begin(), v.Values().end(), 0);
|
||||
auto arr = Json::Load(StringView{ArrayInterfaceStr(v)});
|
||||
@ -313,21 +315,48 @@ TEST(Linalg, Popc) {
|
||||
}
|
||||
|
||||
TEST(Linalg, Stack) {
|
||||
Tensor<float, 3> l{{2, 3, 4}, kCpuId};
|
||||
Tensor<float, 3> l{{2, 3, 4}, kCpuId, Order::kC};
|
||||
ElementWiseTransformHost(l.View(kCpuId), omp_get_max_threads(),
|
||||
[=](size_t i, float) { return i; });
|
||||
Tensor<float, 3> r_0{{2, 3, 4}, kCpuId};
|
||||
Tensor<float, 3> r_0{{2, 3, 4}, kCpuId, Order::kC};
|
||||
ElementWiseTransformHost(r_0.View(kCpuId), omp_get_max_threads(),
|
||||
[=](size_t i, float) { return i; });
|
||||
|
||||
Stack(&l, r_0);
|
||||
|
||||
Tensor<float, 3> r_1{{0, 3, 4}, kCpuId};
|
||||
Tensor<float, 3> r_1{{0, 3, 4}, kCpuId, Order::kC};
|
||||
Stack(&l, r_1);
|
||||
ASSERT_EQ(l.Shape(0), 4);
|
||||
|
||||
Stack(&r_1, l);
|
||||
ASSERT_EQ(r_1.Shape(0), l.Shape(0));
|
||||
}
|
||||
} // namespace linalg
|
||||
} // namespace xgboost
|
||||
|
||||
TEST(Linalg, FOrder) {
|
||||
std::size_t constexpr kRows = 16, kCols = 3;
|
||||
std::vector<float> data(kRows * kCols);
|
||||
MatrixView<float> mat{data, {kRows, kCols}, Context::kCpuId, Order::kF};
|
||||
float k{0};
|
||||
for (std::size_t i = 0; i < kRows; ++i) {
|
||||
for (std::size_t j = 0; j < kCols; ++j) {
|
||||
mat(i, j) = k;
|
||||
k++;
|
||||
}
|
||||
}
|
||||
auto column = mat.Slice(linalg::All(), 1);
|
||||
ASSERT_TRUE(column.FContiguous());
|
||||
ASSERT_EQ(column.Stride(0), 1);
|
||||
ASSERT_TRUE(column.CContiguous());
|
||||
k = 1;
|
||||
for (auto it = linalg::cbegin(column); it != linalg::cend(column); ++it) {
|
||||
ASSERT_EQ(*it, k);
|
||||
k += kCols;
|
||||
}
|
||||
k = 1;
|
||||
auto ptr = column.Values().data();
|
||||
for (auto it = ptr; it != ptr + kRows; ++it) {
|
||||
ASSERT_EQ(*it, k);
|
||||
k += kCols;
|
||||
}
|
||||
}
|
||||
} // namespace xgboost::linalg
|
||||
|
||||
@ -7,8 +7,7 @@
|
||||
#include "xgboost/context.h"
|
||||
#include "xgboost/linalg.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace linalg {
|
||||
namespace xgboost::linalg {
|
||||
namespace {
|
||||
void TestElementWiseKernel() {
|
||||
Tensor<float, 3> l{{2, 3, 4}, 0};
|
||||
@ -55,8 +54,10 @@ void TestElementWiseKernel() {
|
||||
}
|
||||
|
||||
void TestSlice() {
|
||||
Context ctx;
|
||||
ctx.gpu_id = 1;
|
||||
thrust::device_vector<double> data(2 * 3 * 4);
|
||||
auto t = MakeTensorView(dh::ToSpan(data), {2, 3, 4}, 0);
|
||||
auto t = MakeTensorView(&ctx, dh::ToSpan(data), 2, 3, 4);
|
||||
dh::LaunchN(1, [=] __device__(size_t) {
|
||||
auto s = t.Slice(linalg::All(), linalg::Range(0, 3), linalg::Range(0, 4));
|
||||
auto all = t.Slice(linalg::All(), linalg::All(), linalg::All());
|
||||
@ -75,5 +76,4 @@ void TestSlice() {
|
||||
TEST(Linalg, GPUElementWise) { TestElementWiseKernel(); }
|
||||
|
||||
TEST(Linalg, GPUTensorView) { TestSlice(); }
|
||||
} // namespace linalg
|
||||
} // namespace xgboost
|
||||
} // namespace xgboost::linalg
|
||||
|
||||
@ -433,8 +433,8 @@ TEST(Objective, DeclareUnifiedTest(AbsoluteErrorLeaf)) {
|
||||
auto h_labels = info.labels.HostView().Slice(linalg::All(), t);
|
||||
std::iota(linalg::begin(h_labels), linalg::end(h_labels), 0);
|
||||
|
||||
auto h_predt = linalg::MakeTensorView(predt.HostSpan(), {kRows, kTargets}, Context::kCpuId)
|
||||
.Slice(linalg::All(), t);
|
||||
auto h_predt =
|
||||
linalg::MakeTensorView(&ctx, predt.HostSpan(), kRows, kTargets).Slice(linalg::All(), t);
|
||||
for (size_t i = 0; i < h_predt.Size(); ++i) {
|
||||
h_predt(i) = h_labels(i) + i;
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user