[coll] allgatherv. (#9688)
This commit is contained in:
@@ -3,13 +3,16 @@
|
||||
*/
|
||||
#include "allgather.h"
|
||||
|
||||
#include <algorithm> // for min
|
||||
#include <algorithm> // for min, copy_n
|
||||
#include <cstddef> // for size_t
|
||||
#include <cstdint> // for int8_t
|
||||
#include <cstdint> // for int8_t, int32_t, int64_t
|
||||
#include <memory> // for shared_ptr
|
||||
#include <numeric> // for partial_sum
|
||||
#include <vector> // for vector
|
||||
|
||||
#include "comm.h" // for Comm, Channel
|
||||
#include "xgboost/span.h" // for Span
|
||||
#include "comm.h" // for Comm, Channel
|
||||
#include "xgboost/collective/result.h" // for Result
|
||||
#include "xgboost/span.h" // for Span
|
||||
|
||||
namespace xgboost::collective::cpu_impl {
|
||||
Result RingAllgather(Comm const& comm, common::Span<std::int8_t> data, std::size_t segment_size,
|
||||
@@ -39,4 +42,47 @@ Result RingAllgather(Comm const& comm, common::Span<std::int8_t> data, std::size
|
||||
|
||||
return Success();
|
||||
}
|
||||
|
||||
[[nodiscard]] Result RingAllgatherV(Comm const& comm, common::Span<std::int64_t const> sizes,
|
||||
common::Span<std::int8_t const> data,
|
||||
common::Span<std::int8_t> erased_result) {
|
||||
auto world = comm.World();
|
||||
auto rank = comm.Rank();
|
||||
|
||||
auto prev = BootstrapPrev(rank, comm.World());
|
||||
auto next = BootstrapNext(rank, comm.World());
|
||||
|
||||
auto prev_ch = comm.Chan(prev);
|
||||
auto next_ch = comm.Chan(next);
|
||||
|
||||
// get worker offset
|
||||
std::vector<std::int64_t> offset(world + 1, 0);
|
||||
std::partial_sum(sizes.cbegin(), sizes.cend(), offset.begin() + 1);
|
||||
CHECK_EQ(*offset.cbegin(), 0);
|
||||
|
||||
// copy data
|
||||
auto current = erased_result.subspan(offset[rank], data.size_bytes());
|
||||
auto erased_data = EraseType(data);
|
||||
std::copy_n(erased_data.data(), erased_data.size(), current.data());
|
||||
|
||||
for (std::int32_t r = 0; r < world; ++r) {
|
||||
auto send_rank = (rank + world - r) % world;
|
||||
auto send_off = offset[send_rank];
|
||||
auto send_size = sizes[send_rank];
|
||||
auto send_seg = erased_result.subspan(send_off, send_size);
|
||||
next_ch->SendAll(send_seg);
|
||||
|
||||
auto recv_rank = (rank + world - r - 1) % world;
|
||||
auto recv_off = offset[recv_rank];
|
||||
auto recv_size = sizes[recv_rank];
|
||||
auto recv_seg = erased_result.subspan(recv_off, recv_size);
|
||||
prev_ch->RecvAll(recv_seg.data(), recv_seg.size_bytes());
|
||||
|
||||
auto rc = prev_ch->Block();
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
}
|
||||
}
|
||||
return comm.Block();
|
||||
}
|
||||
} // namespace xgboost::collective::cpu_impl
|
||||
|
||||
@@ -2,12 +2,16 @@
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
*/
|
||||
#pragma once
|
||||
#include <cstddef> // for size_t
|
||||
#include <cstdint> // for int32_t
|
||||
#include <memory> // for shared_ptr
|
||||
#include <cstddef> // for size_t
|
||||
#include <cstdint> // for int32_t
|
||||
#include <memory> // for shared_ptr
|
||||
#include <numeric> // for accumulate
|
||||
#include <type_traits> // for remove_cv_t
|
||||
#include <vector> // for vector
|
||||
|
||||
#include "comm.h" // for Comm, Channel
|
||||
#include "xgboost/span.h" // for Span
|
||||
#include "comm.h" // for Comm, Channel, EraseType
|
||||
#include "xgboost/collective/result.h" // for Result
|
||||
#include "xgboost/span.h" // for Span
|
||||
|
||||
namespace xgboost::collective {
|
||||
namespace cpu_impl {
|
||||
@@ -19,14 +23,16 @@ namespace cpu_impl {
|
||||
std::size_t segment_size, std::int32_t worker_off,
|
||||
std::shared_ptr<Channel> prev_ch,
|
||||
std::shared_ptr<Channel> next_ch);
|
||||
|
||||
[[nodiscard]] Result RingAllgatherV(Comm const& comm, common::Span<std::int64_t const> sizes,
|
||||
common::Span<std::int8_t const> data,
|
||||
common::Span<std::int8_t> erased_result);
|
||||
} // namespace cpu_impl
|
||||
|
||||
template <typename T>
|
||||
[[nodiscard]] Result RingAllgather(Comm const& comm, common::Span<T> data, std::size_t size) {
|
||||
auto n_total_bytes = data.size_bytes();
|
||||
auto n_bytes = sizeof(T) * size;
|
||||
auto erased =
|
||||
common::Span<std::int8_t>{reinterpret_cast<std::int8_t*>(data.data()), n_total_bytes};
|
||||
auto erased = EraseType(data);
|
||||
|
||||
auto rank = comm.Rank();
|
||||
auto prev = BootstrapPrev(rank, comm.World());
|
||||
@@ -40,4 +46,27 @@ template <typename T>
|
||||
}
|
||||
return comm.Block();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
[[nodiscard]] Result RingAllgatherV(Comm const& comm, common::Span<T> data,
|
||||
std::vector<std::remove_cv_t<T>>* p_out) {
|
||||
auto world = comm.World();
|
||||
auto rank = comm.Rank();
|
||||
|
||||
std::vector<std::int64_t> sizes(world, 0);
|
||||
sizes[rank] = data.size_bytes();
|
||||
auto rc = RingAllgather(comm, common::Span{sizes.data(), sizes.size()}, 1);
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
}
|
||||
|
||||
std::vector<T>& result = *p_out;
|
||||
auto n_total_bytes = std::accumulate(sizes.cbegin(), sizes.cend(), 0);
|
||||
result.resize(n_total_bytes / sizeof(T));
|
||||
auto h_result = common::Span{result.data(), result.size()};
|
||||
auto erased_result = EraseType(h_result);
|
||||
auto erased_data = EraseType(data);
|
||||
|
||||
return cpu_impl::RingAllgatherV(comm, sizes, erased_data, erased_result);
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
*/
|
||||
#include "broadcast.h"
|
||||
|
||||
#include <cmath> // for ceil, log2
|
||||
#include <cstdint> // for int32_t, int8_t
|
||||
#include <utility> // for move
|
||||
|
||||
|
||||
@@ -11,8 +11,10 @@
|
||||
|
||||
#include "allgather.h"
|
||||
#include "protocol.h" // for kMagic
|
||||
#include "xgboost/base.h" // for XGBOOST_STRICT_R_MODE
|
||||
#include "xgboost/collective/socket.h" // for TCPSocket
|
||||
#include "xgboost/json.h" // for Json, Object
|
||||
#include "xgboost/string_view.h" // for StringView
|
||||
|
||||
namespace xgboost::collective {
|
||||
Comm::Comm(std::string const& host, std::int32_t port, std::chrono::seconds timeout,
|
||||
|
||||
@@ -2,20 +2,16 @@
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
*/
|
||||
#pragma once
|
||||
#include <chrono> // for seconds
|
||||
#include <condition_variable> // for condition_variable
|
||||
#include <cstddef> // for size_t
|
||||
#include <cstdint> // for int32_t
|
||||
#include <memory> // for shared_ptr
|
||||
#include <mutex> // for mutex
|
||||
#include <queue> // for queue
|
||||
#include <string> // for string
|
||||
#include <thread> // for thread
|
||||
#include <type_traits> // for remove_const_t
|
||||
#include <utility> // for move
|
||||
#include <vector> // for vector
|
||||
#include <chrono> // for seconds
|
||||
#include <cstddef> // for size_t
|
||||
#include <cstdint> // for int32_t
|
||||
#include <memory> // for shared_ptr
|
||||
#include <string> // for string
|
||||
#include <thread> // for thread
|
||||
#include <type_traits> // for remove_const_t
|
||||
#include <utility> // for move
|
||||
#include <vector> // for vector
|
||||
|
||||
#include "../common/timer.h"
|
||||
#include "loop.h" // for Loop
|
||||
#include "protocol.h" // for PeerInfo
|
||||
#include "xgboost/collective/result.h" // for Result
|
||||
|
||||
Reference in New Issue
Block a user