/** * Copyright 2023-2024, XGBoost Contributors */ #pragma once #include // for int8_t #include // for function #include // for is_invocable_v, enable_if_t #include // for vector #include "../common/type.h" // for EraseType, RestoreType #include "../data/array_interface.h" // for ToDType, ArrayInterfaceHandler #include "comm.h" // for Comm, RestoreType #include "comm_group.h" // for GlobalCommGroup #include "xgboost/collective/result.h" // for Result #include "xgboost/context.h" // for Context #include "xgboost/span.h" // for Span namespace xgboost::collective { namespace cpu_impl { using Func = std::function lhs, common::Span out)>; Result RingAllreduce(Comm const& comm, common::Span data, Func const& op, ArrayInterfaceHandler::Type type); } // namespace cpu_impl template std::enable_if_t, common::Span>, Result> Allreduce( Comm const& comm, common::Span data, Fn redop) { auto erased = common::EraseType(data); auto type = ToDType::kType; auto erased_fn = [redop](common::Span lhs, common::Span out) { CHECK_EQ(lhs.size(), out.size()) << "Invalid input for reduction."; auto lhs_t = common::RestoreType(lhs); auto rhs_t = common::RestoreType(out); redop(lhs_t, rhs_t); }; return cpu_impl::RingAllreduce(comm, erased, erased_fn, type); } template [[nodiscard]] Result Allreduce(Context const* ctx, CommGroup const& comm, linalg::TensorView data, Op op) { if (!comm.IsDistributed()) { return Success(); } CHECK(data.Contiguous()); auto erased = common::EraseType(data.Values()); auto type = ToDType::kType; auto backend = comm.Backend(data.Device()); return backend->Allreduce(comm.Ctx(ctx, data.Device()), erased, type, op); } template [[nodiscard]] Result Allreduce(Context const* ctx, linalg::TensorView data, Op op) { return Allreduce(ctx, *GlobalCommGroup(), data, op); } /** * @brief Specialization for std::vector. */ template [[nodiscard]] Result Allreduce(Context const* ctx, std::vector* data, Op op) { return Allreduce(ctx, linalg::MakeVec(data->data(), data->size()), op); } /** * @brief Specialization for scalar value. */ template [[nodiscard]] std::enable_if_t && std::is_trivial_v, Result> Allreduce(Context const* ctx, T* data, Op op) { return Allreduce(ctx, linalg::MakeVec(data, 1), op); } } // namespace xgboost::collective