/*! * Copyright 2021-2022 by XGBoost Contributors */ #ifndef XGBOOST_COMMON_LINALG_OP_H_ #define XGBOOST_COMMON_LINALG_OP_H_ #include // std::int32_t #include #include "common.h" #include "threading_utils.h" #include "xgboost/generic_parameters.h" #include "xgboost/linalg.h" namespace xgboost { namespace linalg { template void ElementWiseTransformHost(linalg::TensorView t, int32_t n_threads, Fn&& fn) { if (t.Contiguous()) { auto ptr = t.Values().data(); common::ParallelFor(t.Size(), n_threads, [&](size_t i) { ptr[i] = fn(i, ptr[i]); }); } else { common::ParallelFor(t.Size(), n_threads, [&](size_t i) { auto& v = detail::Apply(t, linalg::UnravelIndex(i, t.Shape())); v = fn(i, v); }); } } template void ElementWiseKernelHost(linalg::TensorView t, int32_t n_threads, Fn&& fn) { static_assert(std::is_void>::value, "For function with return, use transform instead."); if (t.Contiguous()) { auto ptr = t.Values().data(); common::ParallelFor(t.Size(), n_threads, [&](size_t i) { fn(i, ptr[i]); }); } else { common::ParallelFor(t.Size(), n_threads, [&](size_t i) { auto& v = detail::Apply(t, linalg::UnravelIndex(i, t.Shape())); fn(i, v); }); } } #if !defined(XGBOOST_USE_CUDA) template void ElementWiseKernelDevice(linalg::TensorView, Fn&&, void* = nullptr) { common::AssertGPUSupport(); } template void ElementWiseTransformDevice(linalg::TensorView, Fn&&, void* = nullptr) { common::AssertGPUSupport(); } template void ElementWiseKernel(GenericParameter const* ctx, linalg::TensorView t, Fn&& fn) { if (!ctx->IsCPU()) { common::AssertGPUSupport(); } ElementWiseKernelHost(t, ctx->Threads(), fn); } #endif // !defined(XGBOOST_USE_CUDA) template auto cbegin(TensorView v) { // NOLINT auto it = common::MakeIndexTransformIter([&](size_t i) -> std::remove_cv_t const& { return linalg::detail::Apply(v, linalg::UnravelIndex(i, v.Shape())); }); return it; } template auto cend(TensorView v) { // NOLINT return cbegin(v) + v.Size(); } template auto begin(TensorView v) { // NOLINT auto it = common::MakeIndexTransformIter( [&](size_t i) -> T& { return linalg::detail::Apply(v, linalg::UnravelIndex(i, v.Shape())); }); return it; } template auto end(TensorView v) { // NOLINT return begin(v) + v.Size(); } } // namespace linalg } // namespace xgboost #endif // XGBOOST_COMMON_LINALG_OP_H_