/** * Copyright 2023, XGBoost Contributors */ #include "coll.h" #include // for min, max #include // for size_t #include // for int8_t, int64_t #include // for bit_and, bit_or, bit_xor, plus #include "allgather.h" // for RingAllgatherV, RingAllgather #include "allreduce.h" // for Allreduce #include "broadcast.h" // for Broadcast #include "comm.h" // for Comm #include "xgboost/context.h" // for Context namespace xgboost::collective { [[nodiscard]] Result Coll::Allreduce(Context const*, Comm const& comm, common::Span data, ArrayInterfaceHandler::Type, Op op) { namespace coll = ::xgboost::collective; auto redop_fn = [](auto lhs, auto out, auto elem_op) { auto p_lhs = lhs.data(); auto p_out = out.data(); for (std::size_t i = 0; i < lhs.size(); ++i) { p_out[i] = elem_op(p_lhs[i], p_out[i]); } }; auto fn = [&](auto elem_op) { return coll::Allreduce( comm, data, [redop_fn, elem_op](auto lhs, auto rhs) { redop_fn(lhs, rhs, elem_op); }); }; switch (op) { case Op::kMax: { return fn([](auto l, auto r) { return std::max(l, r); }); } case Op::kMin: { return fn([](auto l, auto r) { return std::min(l, r); }); } case Op::kSum: { return fn(std::plus<>{}); } case Op::kBitwiseAND: { return fn(std::bit_and<>{}); } case Op::kBitwiseOR: { return fn(std::bit_or<>{}); } case Op::kBitwiseXOR: { return fn(std::bit_xor<>{}); } } return comm.Block(); } [[nodiscard]] Result Coll::Broadcast(Context const*, Comm const& comm, common::Span data, std::int32_t root) { return cpu_impl::Broadcast(comm, data, root); } [[nodiscard]] Result Coll::Allgather(Context const*, Comm const& comm, common::Span data, std::size_t size) { return RingAllgather(comm, data, size); } [[nodiscard]] Result Coll::AllgatherV(Context const*, Comm const& comm, common::Span data, common::Span sizes, common::Span recv_segments, common::Span recv) { return cpu_impl::RingAllgatherV(comm, sizes, data, recv_segments, recv); } } // namespace xgboost::collective