Add CUDA iterator to tensor view. (#10074)

This commit is contained in:
Jiaming Yuan 2024-03-01 14:15:31 +08:00 committed by GitHub
parent d24df52bb9
commit 8189126d51
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 66 additions and 16 deletions

View File

@ -295,6 +295,9 @@ class TensorView {
using ShapeT = std::size_t[kDim]; using ShapeT = std::size_t[kDim];
using StrideT = ShapeT; using StrideT = ShapeT;
using element_type = T; // NOLINT
using value_type = std::remove_cv_t<T>; // NOLINT
private: private:
StrideT stride_{1}; StrideT stride_{1};
ShapeT shape_{0}; ShapeT shape_{0};
@ -314,7 +317,7 @@ class TensorView {
} }
template <size_t old_dim, size_t new_dim, int32_t D, typename I> template <size_t old_dim, size_t new_dim, int32_t D, typename I>
LINALG_HD size_t MakeSliceDim(size_t new_shape[D], size_t new_stride[D], LINALG_HD size_t MakeSliceDim(std::size_t new_shape[D], std::size_t new_stride[D],
detail::RangeTag<I> &&range) const { detail::RangeTag<I> &&range) const {
static_assert(new_dim < D); static_assert(new_dim < D);
static_assert(old_dim < kDim); static_assert(old_dim < kDim);
@ -528,9 +531,10 @@ class TensorView {
LINALG_HD auto Stride(size_t i) const { return stride_[i]; } LINALG_HD auto Stride(size_t i) const { return stride_[i]; }
/** /**
* \brief Number of items in the tensor. * @brief Number of items in the tensor.
*/ */
[[nodiscard]] LINALG_HD std::size_t Size() const { return size_; } [[nodiscard]] LINALG_HD std::size_t Size() const { return size_; }
[[nodiscard]] bool Empty() const { return Size() == 0; }
/** /**
* \brief Whether this is a contiguous array, both C and F contiguous returns true. * \brief Whether this is a contiguous array, both C and F contiguous returns true.
*/ */
@ -865,7 +869,9 @@ class Tensor {
auto HostView() { return this->View(DeviceOrd::CPU()); } auto HostView() { return this->View(DeviceOrd::CPU()); }
auto HostView() const { return this->View(DeviceOrd::CPU()); } auto HostView() const { return this->View(DeviceOrd::CPU()); }
[[nodiscard]] size_t Size() const { return data_.Size(); } [[nodiscard]] std::size_t Size() const { return data_.Size(); }
[[nodiscard]] bool Empty() const { return Size() == 0; }
auto Shape() const { return common::Span<size_t const, kDim>{shape_}; } auto Shape() const { return common::Span<size_t const, kDim>{shape_}; }
auto Shape(size_t i) const { return shape_[i]; } auto Shape(size_t i) const { return shape_[i]; }

View File

@ -701,10 +701,10 @@ class IterSpan {
return {data() + _offset, _count == dynamic_extent ? size() - _offset : _count}; return {data() + _offset, _count == dynamic_extent ? size() - _offset : _count};
} }
[[nodiscard]] XGBOOST_DEVICE constexpr iterator begin() const noexcept { // NOLINT [[nodiscard]] XGBOOST_DEVICE constexpr iterator begin() const noexcept { // NOLINT
return {this, 0}; return it_;
} }
[[nodiscard]] XGBOOST_DEVICE constexpr iterator end() const noexcept { // NOLINT [[nodiscard]] XGBOOST_DEVICE constexpr iterator end() const noexcept { // NOLINT
return {this, size()}; return it_ + size();
} }
}; };
} // namespace common } // namespace common

View File

@ -13,15 +13,14 @@
#include "xgboost/context.h" // for Context #include "xgboost/context.h" // for Context
#include "xgboost/linalg.h" // for TensorView #include "xgboost/linalg.h" // for TensorView
namespace xgboost { namespace xgboost::linalg {
namespace linalg {
namespace cuda_impl { namespace cuda_impl {
// Use template specialization to dispatch, Windows + CUDA 11.8 doesn't support extended // Use template specialization to dispatch, Windows + CUDA 11.8 doesn't support extended
// lambda inside constexpr if // lambda inside constexpr if
template <typename T, std::int32_t D> template <typename T, std::int32_t D>
struct ElementWiseImpl { struct ElementWiseImpl {
template <typename Fn> template <typename Fn>
void operator()(linalg::TensorView<T, D> t, Fn&& fn, cudaStream_t s) { void operator()(TensorView<T, D> t, Fn&& fn, cudaStream_t s) {
static_assert(D > 1); static_assert(D > 1);
dh::LaunchN(t.Size(), s, [=] __device__(std::size_t i) mutable { dh::LaunchN(t.Size(), s, [=] __device__(std::size_t i) mutable {
std::apply(fn, linalg::UnravelIndex(i, t.Shape())); std::apply(fn, linalg::UnravelIndex(i, t.Shape()));
@ -32,36 +31,59 @@ struct ElementWiseImpl {
template <typename T> template <typename T>
struct ElementWiseImpl<T, 1> { struct ElementWiseImpl<T, 1> {
template <typename Fn> template <typename Fn>
void operator()(linalg::TensorView<T, 1> t, Fn&& fn, cudaStream_t s) { void operator()(TensorView<T, 1> t, Fn&& fn, cudaStream_t s) {
dh::LaunchN(t.Size(), s, [=] __device__(std::size_t i) { fn(i); }); dh::LaunchN(t.Size(), s, [=] __device__(std::size_t i) { fn(i); });
} }
}; };
template <typename T, std::int32_t D, typename Fn> template <typename T, std::int32_t D, typename Fn>
void ElementWiseKernel(linalg::TensorView<T, D> t, Fn&& fn, cudaStream_t s = nullptr) { void ElementWiseKernel(TensorView<T, D> t, Fn&& fn, cudaStream_t s = nullptr) {
dh::safe_cuda(cudaSetDevice(t.Device().ordinal)); dh::safe_cuda(cudaSetDevice(t.Device().ordinal));
cuda_impl::ElementWiseImpl<T, D>{}(t, fn, s); cuda_impl::ElementWiseImpl<T, D>{}(t, fn, s);
} }
} // namespace cuda_impl } // namespace cuda_impl
template <typename T, int32_t D, typename Fn> template <typename T, int32_t D, typename Fn>
void ElementWiseTransformDevice(linalg::TensorView<T, D> t, Fn&& fn, cudaStream_t s = nullptr) { void ElementWiseTransformDevice(TensorView<T, D> t, Fn&& fn, cudaStream_t s = nullptr) {
if (t.Contiguous()) { if (t.Contiguous()) {
auto ptr = t.Values().data(); auto ptr = t.Values().data();
dh::LaunchN(t.Size(), s, [=] __device__(size_t i) { ptr[i] = fn(i, ptr[i]); }); dh::LaunchN(t.Size(), s, [=] __device__(size_t i) { ptr[i] = fn(i, ptr[i]); });
} else { } else {
dh::LaunchN(t.Size(), s, [=] __device__(size_t i) mutable { dh::LaunchN(t.Size(), s, [=] __device__(size_t i) mutable {
T& v = detail::Apply(t, linalg::UnravelIndex(i, t.Shape())); T& v = detail::Apply(t, UnravelIndex(i, t.Shape()));
v = fn(i, v); v = fn(i, v);
}); });
} }
} }
template <typename T, int32_t D, typename Fn> template <typename T, int32_t D, typename Fn>
void ElementWiseKernel(Context const* ctx, linalg::TensorView<T, D> t, Fn&& fn) { void ElementWiseKernel(Context const* ctx, TensorView<T, D> t, Fn&& fn) {
ctx->IsCUDA() ? cuda_impl::ElementWiseKernel(t, fn) ctx->IsCUDA() ? cuda_impl::ElementWiseKernel(t, fn)
: ElementWiseKernelHost(t, ctx->Threads(), fn); : ElementWiseKernelHost(t, ctx->Threads(), fn);
} }
} // namespace linalg
} // namespace xgboost namespace detail {
template <typename T, std::int32_t kDim>
struct IterOp {
TensorView<T, kDim> v;
XGBOOST_DEVICE T& operator()(std::size_t i) {
return detail::Apply(v, UnravelIndex(i, v.Shape()));
}
};
} // namespace detail
// naming: thrust begin
// returns a thrust iterator for a tensor view.
template <typename T, std::int32_t kDim>
auto tcbegin(TensorView<T, kDim> v) { // NOLINT
return dh::MakeTransformIterator<T>(
thrust::make_counting_iterator(0ul),
detail::IterOp<std::add_const_t<std::remove_const_t<T>>, kDim>{v});
}
template <typename T, std::int32_t kDim>
auto tcend(TensorView<T, kDim> v) { // NOLINT
return tcbegin(v) + v.Size();
}
} // namespace xgboost::linalg
#endif // XGBOOST_COMMON_LINALG_OP_CUH_ #endif // XGBOOST_COMMON_LINALG_OP_CUH_

View File

@ -1,8 +1,11 @@
/** /**
* Copyright 2021-2023 by XGBoost Contributors * Copyright 2021-2024, XGBoost Contributors
*/ */
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <thrust/equal.h> // for equal
#include <thrust/sequence.h> // for sequence
#include "../../../src/common/cuda_context.cuh"
#include "../../../src/common/linalg_op.cuh" #include "../../../src/common/linalg_op.cuh"
#include "../helpers.h" #include "../helpers.h"
#include "xgboost/context.h" #include "xgboost/context.h"
@ -85,4 +88,23 @@ void TestSlice() {
TEST(Linalg, GPUElementWise) { TestElementWiseKernel(); } TEST(Linalg, GPUElementWise) { TestElementWiseKernel(); }
TEST(Linalg, GPUTensorView) { TestSlice(); } TEST(Linalg, GPUTensorView) { TestSlice(); }
TEST(Linalg, GPUIter) {
auto ctx = MakeCUDACtx(1);
auto cuctx = ctx.CUDACtx();
dh::device_vector<double> data(2 * 3 * 4);
thrust::sequence(cuctx->CTP(), data.begin(), data.end(), 1.0);
auto t = MakeTensorView(&ctx, dh::ToSpan(data), 2, 3, 4);
static_assert(!std::is_const_v<decltype(t)::element_type>);
static_assert(!std::is_const_v<decltype(t)::value_type>);
auto n = std::distance(linalg::tcbegin(t), linalg::tcend(t));
ASSERT_EQ(n, t.Size());
ASSERT_FALSE(t.Empty());
bool eq = thrust::equal(cuctx->CTP(), data.cbegin(), data.cend(), linalg::tcbegin(t));
ASSERT_TRUE(eq);
}
} // namespace xgboost::linalg } // namespace xgboost::linalg