Add CUDA iterator to tensor view. (#10074)

This commit is contained in:
Jiaming Yuan
2024-03-01 14:15:31 +08:00
committed by GitHub
parent d24df52bb9
commit 8189126d51
4 changed files with 66 additions and 16 deletions

View File

@@ -13,15 +13,14 @@
#include "xgboost/context.h" // for Context
#include "xgboost/linalg.h" // for TensorView
namespace xgboost {
namespace linalg {
namespace xgboost::linalg {
namespace cuda_impl {
// Use template specialization to dispatch, Windows + CUDA 11.8 doesn't support extended
// lambda inside constexpr if
template <typename T, std::int32_t D>
struct ElementWiseImpl {
template <typename Fn>
void operator()(linalg::TensorView<T, D> t, Fn&& fn, cudaStream_t s) {
void operator()(TensorView<T, D> t, Fn&& fn, cudaStream_t s) {
static_assert(D > 1);
dh::LaunchN(t.Size(), s, [=] __device__(std::size_t i) mutable {
std::apply(fn, linalg::UnravelIndex(i, t.Shape()));
@@ -32,36 +31,59 @@ struct ElementWiseImpl {
template <typename T>
struct ElementWiseImpl<T, 1> {
template <typename Fn>
void operator()(linalg::TensorView<T, 1> t, Fn&& fn, cudaStream_t s) {
void operator()(TensorView<T, 1> t, Fn&& fn, cudaStream_t s) {
dh::LaunchN(t.Size(), s, [=] __device__(std::size_t i) { fn(i); });
}
};
template <typename T, std::int32_t D, typename Fn>
void ElementWiseKernel(linalg::TensorView<T, D> t, Fn&& fn, cudaStream_t s = nullptr) {
void ElementWiseKernel(TensorView<T, D> t, Fn&& fn, cudaStream_t s = nullptr) {
dh::safe_cuda(cudaSetDevice(t.Device().ordinal));
cuda_impl::ElementWiseImpl<T, D>{}(t, fn, s);
}
} // namespace cuda_impl
template <typename T, int32_t D, typename Fn>
void ElementWiseTransformDevice(linalg::TensorView<T, D> t, Fn&& fn, cudaStream_t s = nullptr) {
void ElementWiseTransformDevice(TensorView<T, D> t, Fn&& fn, cudaStream_t s = nullptr) {
if (t.Contiguous()) {
auto ptr = t.Values().data();
dh::LaunchN(t.Size(), s, [=] __device__(size_t i) { ptr[i] = fn(i, ptr[i]); });
} else {
dh::LaunchN(t.Size(), s, [=] __device__(size_t i) mutable {
T& v = detail::Apply(t, linalg::UnravelIndex(i, t.Shape()));
T& v = detail::Apply(t, UnravelIndex(i, t.Shape()));
v = fn(i, v);
});
}
}
template <typename T, int32_t D, typename Fn>
void ElementWiseKernel(Context const* ctx, linalg::TensorView<T, D> t, Fn&& fn) {
void ElementWiseKernel(Context const* ctx, TensorView<T, D> t, Fn&& fn) {
ctx->IsCUDA() ? cuda_impl::ElementWiseKernel(t, fn)
: ElementWiseKernelHost(t, ctx->Threads(), fn);
}
} // namespace linalg
} // namespace xgboost
namespace detail {
template <typename T, std::int32_t kDim>
struct IterOp {
TensorView<T, kDim> v;
XGBOOST_DEVICE T& operator()(std::size_t i) {
return detail::Apply(v, UnravelIndex(i, v.Shape()));
}
};
} // namespace detail
// naming: thrust begin
// returns a thrust iterator for a tensor view.
template <typename T, std::int32_t kDim>
auto tcbegin(TensorView<T, kDim> v) { // NOLINT
return dh::MakeTransformIterator<T>(
thrust::make_counting_iterator(0ul),
detail::IterOp<std::add_const_t<std::remove_const_t<T>>, kDim>{v});
}
template <typename T, std::int32_t kDim>
auto tcend(TensorView<T, kDim> v) { // NOLINT
return tcbegin(v) + v.Size();
}
} // namespace xgboost::linalg
#endif // XGBOOST_COMMON_LINALG_OP_CUH_