/** * Copyright 2023, XGBoost Contributors */ #pragma once #include // for size_t #include // for int32_t #include // for shared_ptr #include // for accumulate #include // for remove_cv_t #include // for vector #include "../common/type.h" // for EraseType #include "comm.h" // for Comm, Channel #include "xgboost/collective/result.h" // for Result #include "xgboost/span.h" // for Span namespace xgboost::collective { namespace cpu_impl { /** * @param worker_off Segment offset. For example, if the rank 2 worker specifis worker_off * = 1, then it owns the third segment. */ [[nodiscard]] Result RingAllgather(Comm const& comm, common::Span data, std::size_t segment_size, std::int32_t worker_off, std::shared_ptr prev_ch, std::shared_ptr next_ch); [[nodiscard]] Result RingAllgatherV(Comm const& comm, common::Span sizes, common::Span data, common::Span offset, common::Span erased_result); } // namespace cpu_impl template [[nodiscard]] Result RingAllgather(Comm const& comm, common::Span data, std::size_t size) { auto n_bytes = sizeof(T) * size; auto erased = common::EraseType(data); auto rank = comm.Rank(); auto prev = BootstrapPrev(rank, comm.World()); auto next = BootstrapNext(rank, comm.World()); auto prev_ch = comm.Chan(prev); auto next_ch = comm.Chan(next); auto rc = cpu_impl::RingAllgather(comm, erased, n_bytes, 0, prev_ch, next_ch); if (!rc.OK()) { return rc; } return comm.Block(); } template [[nodiscard]] Result RingAllgatherV(Comm const& comm, common::Span data, std::vector>* p_out) { auto world = comm.World(); auto rank = comm.Rank(); std::vector sizes(world, 0); sizes[rank] = data.size_bytes(); auto rc = RingAllgather(comm, common::Span{sizes.data(), sizes.size()}, 1); if (!rc.OK()) { return rc; } std::vector& result = *p_out; auto n_total_bytes = std::accumulate(sizes.cbegin(), sizes.cend(), 0); result.resize(n_total_bytes / sizeof(T)); auto h_result = common::Span{result.data(), result.size()}; auto erased_result = common::EraseType(h_result); auto erased_data = common::EraseType(data); std::vector offset(world + 1); return cpu_impl::RingAllgatherV(comm, sizes, erased_data, common::Span{offset.data(), offset.size()}, erased_result); } } // namespace xgboost::collective