[coll] allgather. (#9681)

This commit is contained in:
Jiaming Yuan
2023-10-18 10:22:18 +08:00
committed by GitHub
parent 48ac9b6cbe
commit 4c0e4422d0
2 changed files with 91 additions and 0 deletions

View File

@@ -20,4 +20,24 @@ namespace cpu_impl {
std::shared_ptr<Channel> prev_ch,
std::shared_ptr<Channel> next_ch);
} // namespace cpu_impl
template <typename T>
[[nodiscard]] Result RingAllgather(Comm const& comm, common::Span<T> data, std::size_t size) {
auto n_total_bytes = data.size_bytes();
auto n_bytes = sizeof(T) * size;
auto erased =
common::Span<std::int8_t>{reinterpret_cast<std::int8_t*>(data.data()), n_total_bytes};
auto rank = comm.Rank();
auto prev = BootstrapPrev(rank, comm.World());
auto next = BootstrapNext(rank, comm.World());
auto prev_ch = comm.Chan(prev);
auto next_ch = comm.Chan(next);
auto rc = cpu_impl::RingAllgather(comm, erased, n_bytes, 0, prev_ch, next_ch);
if (!rc.OK()) {
return rc;
}
return comm.Block();
}
} // namespace xgboost::collective