/** * Copyright 2023, XGBoost Contributors */ #pragma once #include // for int8_t #include // for function #include // for is_invocable_v, enable_if_t #include "../common/type.h" // for EraseType, RestoreType #include "../data/array_interface.h" // for ArrayInterfaceHandler #include "comm.h" // for Comm, RestoreType #include "xgboost/collective/result.h" // for Result #include "xgboost/span.h" // for Span namespace xgboost::collective { namespace cpu_impl { using Func = std::function lhs, common::Span out)>; Result RingAllreduce(Comm const& comm, common::Span data, Func const& op, ArrayInterfaceHandler::Type type); } // namespace cpu_impl template std::enable_if_t, common::Span>, Result> Allreduce( Comm const& comm, common::Span data, Fn redop) { auto erased = common::EraseType(data); auto type = ToDType::kType; auto erased_fn = [type, redop](common::Span lhs, common::Span out) { CHECK_EQ(lhs.size(), out.size()) << "Invalid input for reduction."; auto lhs_t = common::RestoreType(lhs); auto rhs_t = common::RestoreType(out); redop(lhs_t, rhs_t); }; return cpu_impl::RingAllreduce(comm, erased, erased_fn, type); } } // namespace xgboost::collective