Improve allgather functions (#9649)
This commit is contained in:
@@ -7,6 +7,7 @@
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "communicator-inl.h"
|
||||
#include "communicator.h"
|
||||
#include "xgboost/json.h"
|
||||
|
||||
@@ -55,10 +56,27 @@ class RabitCommunicator : public Communicator {
|
||||
|
||||
bool IsFederated() const override { return false; }
|
||||
|
||||
void AllGather(void *send_receive_buffer, std::size_t size) override {
|
||||
auto const per_rank = size / GetWorldSize();
|
||||
std::string AllGather(std::string_view input) override {
|
||||
auto const per_rank = input.size();
|
||||
auto const total_size = per_rank * GetWorldSize();
|
||||
auto const index = per_rank * GetRank();
|
||||
rabit::Allgather(static_cast<char *>(send_receive_buffer), size, index, per_rank, per_rank);
|
||||
std::string result(total_size, '\0');
|
||||
rabit::Allgather(result.data(), total_size, index, per_rank, per_rank);
|
||||
return result;
|
||||
}
|
||||
|
||||
std::string AllGatherV(std::string_view input) override {
|
||||
auto const size_node_slice = input.size();
|
||||
auto const all_sizes = collective::Allgather(size_node_slice);
|
||||
auto const total_size = std::accumulate(all_sizes.cbegin(), all_sizes.cend(), 0ul);
|
||||
auto const begin_index =
|
||||
std::accumulate(all_sizes.cbegin(), all_sizes.cbegin() + GetRank(), 0ul);
|
||||
auto const size_prev_slice = GetRank() == 0 ? 0 : all_sizes[GetRank() - 1];
|
||||
|
||||
std::string result(total_size, '\0');
|
||||
result.replace(begin_index, size_node_slice, input);
|
||||
rabit::Allgather(result.data(), total_size, begin_index, size_node_slice, size_prev_slice);
|
||||
return result;
|
||||
}
|
||||
|
||||
void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
|
||||
|
||||
Reference in New Issue
Block a user