[coll] allgatherv. (#9688)

This commit is contained in:
Jiaming Yuan
2023-10-19 03:13:50 +08:00
committed by GitHub
parent ea9f09716b
commit 5d1bcde719
8 changed files with 157 additions and 35 deletions

View File

@@ -2,12 +2,16 @@
* Copyright 2023, XGBoost Contributors
*/
#pragma once
#include <cstddef> // for size_t
#include <cstdint> // for int32_t
#include <memory> // for shared_ptr
#include <cstddef> // for size_t
#include <cstdint> // for int32_t
#include <memory> // for shared_ptr
#include <numeric> // for accumulate
#include <type_traits> // for remove_cv_t
#include <vector> // for vector
#include "comm.h" // for Comm, Channel
#include "xgboost/span.h" // for Span
#include "comm.h" // for Comm, Channel, EraseType
#include "xgboost/collective/result.h" // for Result
#include "xgboost/span.h" // for Span
namespace xgboost::collective {
namespace cpu_impl {
@@ -19,14 +23,16 @@ namespace cpu_impl {
std::size_t segment_size, std::int32_t worker_off,
std::shared_ptr<Channel> prev_ch,
std::shared_ptr<Channel> next_ch);
[[nodiscard]] Result RingAllgatherV(Comm const& comm, common::Span<std::int64_t const> sizes,
common::Span<std::int8_t const> data,
common::Span<std::int8_t> erased_result);
} // namespace cpu_impl
template <typename T>
[[nodiscard]] Result RingAllgather(Comm const& comm, common::Span<T> data, std::size_t size) {
auto n_total_bytes = data.size_bytes();
auto n_bytes = sizeof(T) * size;
auto erased =
common::Span<std::int8_t>{reinterpret_cast<std::int8_t*>(data.data()), n_total_bytes};
auto erased = EraseType(data);
auto rank = comm.Rank();
auto prev = BootstrapPrev(rank, comm.World());
@@ -40,4 +46,27 @@ template <typename T>
}
return comm.Block();
}
template <typename T>
[[nodiscard]] Result RingAllgatherV(Comm const& comm, common::Span<T> data,
std::vector<std::remove_cv_t<T>>* p_out) {
auto world = comm.World();
auto rank = comm.Rank();
std::vector<std::int64_t> sizes(world, 0);
sizes[rank] = data.size_bytes();
auto rc = RingAllgather(comm, common::Span{sizes.data(), sizes.size()}, 1);
if (!rc.OK()) {
return rc;
}
std::vector<T>& result = *p_out;
auto n_total_bytes = std::accumulate(sizes.cbegin(), sizes.cend(), 0);
result.resize(n_total_bytes / sizeof(T));
auto h_result = common::Span{result.data(), result.size()};
auto erased_result = EraseType(h_result);
auto erased_data = EraseType(data);
return cpu_impl::RingAllgatherV(comm, sizes, erased_data, erased_result);
}
} // namespace xgboost::collective