[coll] allgather. (#9681)
This commit is contained in:
@@ -20,4 +20,24 @@ namespace cpu_impl {
|
||||
std::shared_ptr<Channel> prev_ch,
|
||||
std::shared_ptr<Channel> next_ch);
|
||||
} // namespace cpu_impl
|
||||
|
||||
template <typename T>
|
||||
[[nodiscard]] Result RingAllgather(Comm const& comm, common::Span<T> data, std::size_t size) {
|
||||
auto n_total_bytes = data.size_bytes();
|
||||
auto n_bytes = sizeof(T) * size;
|
||||
auto erased =
|
||||
common::Span<std::int8_t>{reinterpret_cast<std::int8_t*>(data.data()), n_total_bytes};
|
||||
|
||||
auto rank = comm.Rank();
|
||||
auto prev = BootstrapPrev(rank, comm.World());
|
||||
auto next = BootstrapNext(rank, comm.World());
|
||||
|
||||
auto prev_ch = comm.Chan(prev);
|
||||
auto next_ch = comm.Chan(next);
|
||||
auto rc = cpu_impl::RingAllgather(comm, erased, n_bytes, 0, prev_ch, next_ch);
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
}
|
||||
return comm.Block();
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
|
||||
Reference in New Issue
Block a user