/** * Copyright 2021-2023, XGBoost Contributors */ #ifndef XGBOOST_COMMON_LINALG_OP_CUH_ #define XGBOOST_COMMON_LINALG_OP_CUH_ #include // for int32_t #include // for size_t #include // for apply #include "device_helpers.cuh" // for LaunchN #include "linalg_op.h" #include "xgboost/context.h" // for Context #include "xgboost/linalg.h" // for TensorView namespace xgboost::linalg { namespace cuda_impl { // Use template specialization to dispatch, Windows + CUDA 11.8 doesn't support extended // lambda inside constexpr if template struct ElementWiseImpl { template void operator()(TensorView t, Fn&& fn, cudaStream_t s) { static_assert(D > 1); dh::LaunchN(t.Size(), s, [=] __device__(std::size_t i) mutable { std::apply(fn, linalg::UnravelIndex(i, t.Shape())); }); } }; template struct ElementWiseImpl { template void operator()(TensorView t, Fn&& fn, cudaStream_t s) { dh::LaunchN(t.Size(), s, [=] __device__(std::size_t i) { fn(i); }); } }; template void ElementWiseKernel(TensorView t, Fn&& fn, cudaStream_t s = nullptr) { dh::safe_cuda(cudaSetDevice(t.Device().ordinal)); cuda_impl::ElementWiseImpl{}(t, fn, s); } } // namespace cuda_impl template void ElementWiseTransformDevice(TensorView t, Fn&& fn, cudaStream_t s = nullptr) { if (t.Contiguous()) { auto ptr = t.Values().data(); dh::LaunchN(t.Size(), s, [=] __device__(size_t i) { ptr[i] = fn(i, ptr[i]); }); } else { dh::LaunchN(t.Size(), s, [=] __device__(size_t i) mutable { T& v = detail::Apply(t, UnravelIndex(i, t.Shape())); v = fn(i, v); }); } } template void ElementWiseKernel(Context const* ctx, TensorView t, Fn&& fn) { ctx->IsCUDA() ? cuda_impl::ElementWiseKernel(t, fn) : ElementWiseKernelHost(t, ctx->Threads(), fn); } namespace detail { template struct IterOp { TensorView v; XGBOOST_DEVICE T& operator()(std::size_t i) { return detail::Apply(v, UnravelIndex(i, v.Shape())); } }; } // namespace detail // naming: thrust begin // returns a thrust iterator for a tensor view. template auto tcbegin(TensorView v) { // NOLINT return dh::MakeTransformIterator( thrust::make_counting_iterator(0ul), detail::IterOp>, kDim>{v}); } template auto tcend(TensorView v) { // NOLINT return tcbegin(v) + v.Size(); } } // namespace xgboost::linalg #endif // XGBOOST_COMMON_LINALG_OP_CUH_