Add CUDA iterator to tensor view. (#10074)
This commit is contained in:
parent
d24df52bb9
commit
8189126d51
@ -295,6 +295,9 @@ class TensorView {
|
|||||||
using ShapeT = std::size_t[kDim];
|
using ShapeT = std::size_t[kDim];
|
||||||
using StrideT = ShapeT;
|
using StrideT = ShapeT;
|
||||||
|
|
||||||
|
using element_type = T; // NOLINT
|
||||||
|
using value_type = std::remove_cv_t<T>; // NOLINT
|
||||||
|
|
||||||
private:
|
private:
|
||||||
StrideT stride_{1};
|
StrideT stride_{1};
|
||||||
ShapeT shape_{0};
|
ShapeT shape_{0};
|
||||||
@ -314,7 +317,7 @@ class TensorView {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <size_t old_dim, size_t new_dim, int32_t D, typename I>
|
template <size_t old_dim, size_t new_dim, int32_t D, typename I>
|
||||||
LINALG_HD size_t MakeSliceDim(size_t new_shape[D], size_t new_stride[D],
|
LINALG_HD size_t MakeSliceDim(std::size_t new_shape[D], std::size_t new_stride[D],
|
||||||
detail::RangeTag<I> &&range) const {
|
detail::RangeTag<I> &&range) const {
|
||||||
static_assert(new_dim < D);
|
static_assert(new_dim < D);
|
||||||
static_assert(old_dim < kDim);
|
static_assert(old_dim < kDim);
|
||||||
@ -528,9 +531,10 @@ class TensorView {
|
|||||||
LINALG_HD auto Stride(size_t i) const { return stride_[i]; }
|
LINALG_HD auto Stride(size_t i) const { return stride_[i]; }
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* \brief Number of items in the tensor.
|
* @brief Number of items in the tensor.
|
||||||
*/
|
*/
|
||||||
[[nodiscard]] LINALG_HD std::size_t Size() const { return size_; }
|
[[nodiscard]] LINALG_HD std::size_t Size() const { return size_; }
|
||||||
|
[[nodiscard]] bool Empty() const { return Size() == 0; }
|
||||||
/**
|
/**
|
||||||
* \brief Whether this is a contiguous array, both C and F contiguous returns true.
|
* \brief Whether this is a contiguous array, both C and F contiguous returns true.
|
||||||
*/
|
*/
|
||||||
@ -865,7 +869,9 @@ class Tensor {
|
|||||||
auto HostView() { return this->View(DeviceOrd::CPU()); }
|
auto HostView() { return this->View(DeviceOrd::CPU()); }
|
||||||
auto HostView() const { return this->View(DeviceOrd::CPU()); }
|
auto HostView() const { return this->View(DeviceOrd::CPU()); }
|
||||||
|
|
||||||
[[nodiscard]] size_t Size() const { return data_.Size(); }
|
[[nodiscard]] std::size_t Size() const { return data_.Size(); }
|
||||||
|
[[nodiscard]] bool Empty() const { return Size() == 0; }
|
||||||
|
|
||||||
auto Shape() const { return common::Span<size_t const, kDim>{shape_}; }
|
auto Shape() const { return common::Span<size_t const, kDim>{shape_}; }
|
||||||
auto Shape(size_t i) const { return shape_[i]; }
|
auto Shape(size_t i) const { return shape_[i]; }
|
||||||
|
|
||||||
|
|||||||
@ -701,10 +701,10 @@ class IterSpan {
|
|||||||
return {data() + _offset, _count == dynamic_extent ? size() - _offset : _count};
|
return {data() + _offset, _count == dynamic_extent ? size() - _offset : _count};
|
||||||
}
|
}
|
||||||
[[nodiscard]] XGBOOST_DEVICE constexpr iterator begin() const noexcept { // NOLINT
|
[[nodiscard]] XGBOOST_DEVICE constexpr iterator begin() const noexcept { // NOLINT
|
||||||
return {this, 0};
|
return it_;
|
||||||
}
|
}
|
||||||
[[nodiscard]] XGBOOST_DEVICE constexpr iterator end() const noexcept { // NOLINT
|
[[nodiscard]] XGBOOST_DEVICE constexpr iterator end() const noexcept { // NOLINT
|
||||||
return {this, size()};
|
return it_ + size();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
} // namespace common
|
} // namespace common
|
||||||
|
|||||||
@ -13,15 +13,14 @@
|
|||||||
#include "xgboost/context.h" // for Context
|
#include "xgboost/context.h" // for Context
|
||||||
#include "xgboost/linalg.h" // for TensorView
|
#include "xgboost/linalg.h" // for TensorView
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost::linalg {
|
||||||
namespace linalg {
|
|
||||||
namespace cuda_impl {
|
namespace cuda_impl {
|
||||||
// Use template specialization to dispatch, Windows + CUDA 11.8 doesn't support extended
|
// Use template specialization to dispatch, Windows + CUDA 11.8 doesn't support extended
|
||||||
// lambda inside constexpr if
|
// lambda inside constexpr if
|
||||||
template <typename T, std::int32_t D>
|
template <typename T, std::int32_t D>
|
||||||
struct ElementWiseImpl {
|
struct ElementWiseImpl {
|
||||||
template <typename Fn>
|
template <typename Fn>
|
||||||
void operator()(linalg::TensorView<T, D> t, Fn&& fn, cudaStream_t s) {
|
void operator()(TensorView<T, D> t, Fn&& fn, cudaStream_t s) {
|
||||||
static_assert(D > 1);
|
static_assert(D > 1);
|
||||||
dh::LaunchN(t.Size(), s, [=] __device__(std::size_t i) mutable {
|
dh::LaunchN(t.Size(), s, [=] __device__(std::size_t i) mutable {
|
||||||
std::apply(fn, linalg::UnravelIndex(i, t.Shape()));
|
std::apply(fn, linalg::UnravelIndex(i, t.Shape()));
|
||||||
@ -32,36 +31,59 @@ struct ElementWiseImpl {
|
|||||||
template <typename T>
|
template <typename T>
|
||||||
struct ElementWiseImpl<T, 1> {
|
struct ElementWiseImpl<T, 1> {
|
||||||
template <typename Fn>
|
template <typename Fn>
|
||||||
void operator()(linalg::TensorView<T, 1> t, Fn&& fn, cudaStream_t s) {
|
void operator()(TensorView<T, 1> t, Fn&& fn, cudaStream_t s) {
|
||||||
dh::LaunchN(t.Size(), s, [=] __device__(std::size_t i) { fn(i); });
|
dh::LaunchN(t.Size(), s, [=] __device__(std::size_t i) { fn(i); });
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T, std::int32_t D, typename Fn>
|
template <typename T, std::int32_t D, typename Fn>
|
||||||
void ElementWiseKernel(linalg::TensorView<T, D> t, Fn&& fn, cudaStream_t s = nullptr) {
|
void ElementWiseKernel(TensorView<T, D> t, Fn&& fn, cudaStream_t s = nullptr) {
|
||||||
dh::safe_cuda(cudaSetDevice(t.Device().ordinal));
|
dh::safe_cuda(cudaSetDevice(t.Device().ordinal));
|
||||||
cuda_impl::ElementWiseImpl<T, D>{}(t, fn, s);
|
cuda_impl::ElementWiseImpl<T, D>{}(t, fn, s);
|
||||||
}
|
}
|
||||||
} // namespace cuda_impl
|
} // namespace cuda_impl
|
||||||
|
|
||||||
template <typename T, int32_t D, typename Fn>
|
template <typename T, int32_t D, typename Fn>
|
||||||
void ElementWiseTransformDevice(linalg::TensorView<T, D> t, Fn&& fn, cudaStream_t s = nullptr) {
|
void ElementWiseTransformDevice(TensorView<T, D> t, Fn&& fn, cudaStream_t s = nullptr) {
|
||||||
if (t.Contiguous()) {
|
if (t.Contiguous()) {
|
||||||
auto ptr = t.Values().data();
|
auto ptr = t.Values().data();
|
||||||
dh::LaunchN(t.Size(), s, [=] __device__(size_t i) { ptr[i] = fn(i, ptr[i]); });
|
dh::LaunchN(t.Size(), s, [=] __device__(size_t i) { ptr[i] = fn(i, ptr[i]); });
|
||||||
} else {
|
} else {
|
||||||
dh::LaunchN(t.Size(), s, [=] __device__(size_t i) mutable {
|
dh::LaunchN(t.Size(), s, [=] __device__(size_t i) mutable {
|
||||||
T& v = detail::Apply(t, linalg::UnravelIndex(i, t.Shape()));
|
T& v = detail::Apply(t, UnravelIndex(i, t.Shape()));
|
||||||
v = fn(i, v);
|
v = fn(i, v);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, int32_t D, typename Fn>
|
template <typename T, int32_t D, typename Fn>
|
||||||
void ElementWiseKernel(Context const* ctx, linalg::TensorView<T, D> t, Fn&& fn) {
|
void ElementWiseKernel(Context const* ctx, TensorView<T, D> t, Fn&& fn) {
|
||||||
ctx->IsCUDA() ? cuda_impl::ElementWiseKernel(t, fn)
|
ctx->IsCUDA() ? cuda_impl::ElementWiseKernel(t, fn)
|
||||||
: ElementWiseKernelHost(t, ctx->Threads(), fn);
|
: ElementWiseKernelHost(t, ctx->Threads(), fn);
|
||||||
}
|
}
|
||||||
} // namespace linalg
|
|
||||||
} // namespace xgboost
|
namespace detail {
|
||||||
|
template <typename T, std::int32_t kDim>
|
||||||
|
struct IterOp {
|
||||||
|
TensorView<T, kDim> v;
|
||||||
|
XGBOOST_DEVICE T& operator()(std::size_t i) {
|
||||||
|
return detail::Apply(v, UnravelIndex(i, v.Shape()));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace detail
|
||||||
|
|
||||||
|
// naming: thrust begin
|
||||||
|
// returns a thrust iterator for a tensor view.
|
||||||
|
template <typename T, std::int32_t kDim>
|
||||||
|
auto tcbegin(TensorView<T, kDim> v) { // NOLINT
|
||||||
|
return dh::MakeTransformIterator<T>(
|
||||||
|
thrust::make_counting_iterator(0ul),
|
||||||
|
detail::IterOp<std::add_const_t<std::remove_const_t<T>>, kDim>{v});
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, std::int32_t kDim>
|
||||||
|
auto tcend(TensorView<T, kDim> v) { // NOLINT
|
||||||
|
return tcbegin(v) + v.Size();
|
||||||
|
}
|
||||||
|
} // namespace xgboost::linalg
|
||||||
#endif // XGBOOST_COMMON_LINALG_OP_CUH_
|
#endif // XGBOOST_COMMON_LINALG_OP_CUH_
|
||||||
|
|||||||
@ -1,8 +1,11 @@
|
|||||||
/**
|
/**
|
||||||
* Copyright 2021-2023 by XGBoost Contributors
|
* Copyright 2021-2024, XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
|
#include <thrust/equal.h> // for equal
|
||||||
|
#include <thrust/sequence.h> // for sequence
|
||||||
|
|
||||||
|
#include "../../../src/common/cuda_context.cuh"
|
||||||
#include "../../../src/common/linalg_op.cuh"
|
#include "../../../src/common/linalg_op.cuh"
|
||||||
#include "../helpers.h"
|
#include "../helpers.h"
|
||||||
#include "xgboost/context.h"
|
#include "xgboost/context.h"
|
||||||
@ -85,4 +88,23 @@ void TestSlice() {
|
|||||||
TEST(Linalg, GPUElementWise) { TestElementWiseKernel(); }
|
TEST(Linalg, GPUElementWise) { TestElementWiseKernel(); }
|
||||||
|
|
||||||
TEST(Linalg, GPUTensorView) { TestSlice(); }
|
TEST(Linalg, GPUTensorView) { TestSlice(); }
|
||||||
|
|
||||||
|
TEST(Linalg, GPUIter) {
|
||||||
|
auto ctx = MakeCUDACtx(1);
|
||||||
|
auto cuctx = ctx.CUDACtx();
|
||||||
|
|
||||||
|
dh::device_vector<double> data(2 * 3 * 4);
|
||||||
|
thrust::sequence(cuctx->CTP(), data.begin(), data.end(), 1.0);
|
||||||
|
|
||||||
|
auto t = MakeTensorView(&ctx, dh::ToSpan(data), 2, 3, 4);
|
||||||
|
static_assert(!std::is_const_v<decltype(t)::element_type>);
|
||||||
|
static_assert(!std::is_const_v<decltype(t)::value_type>);
|
||||||
|
|
||||||
|
auto n = std::distance(linalg::tcbegin(t), linalg::tcend(t));
|
||||||
|
ASSERT_EQ(n, t.Size());
|
||||||
|
ASSERT_FALSE(t.Empty());
|
||||||
|
|
||||||
|
bool eq = thrust::equal(cuctx->CTP(), data.cbegin(), data.cend(), linalg::tcbegin(t));
|
||||||
|
ASSERT_TRUE(eq);
|
||||||
|
}
|
||||||
} // namespace xgboost::linalg
|
} // namespace xgboost::linalg
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user