Add Allgather to collective communicator (#8765)
* Add Allgather to collective communicator
This commit is contained in:
@@ -122,6 +122,17 @@ class Communicator {
|
||||
/** @brief Whether the communicator is running in federated mode. */
|
||||
virtual bool IsFederated() const = 0;
|
||||
|
||||
/**
|
||||
* @brief Gathers data from all processes and distributes it to all processes.
|
||||
*
|
||||
* This assumes all ranks have the same size, and input data has been sliced into the
|
||||
* corresponding position.
|
||||
*
|
||||
* @param send_receive_buffer Buffer storing the data.
|
||||
* @param size Size of the data in bytes.
|
||||
*/
|
||||
virtual void AllGather(void *send_receive_buffer, std::size_t size) = 0;
|
||||
|
||||
/**
|
||||
* @brief Combines values from all processes and distributes the result back to all processes.
|
||||
*
|
||||
|
||||
@@ -60,6 +60,13 @@ class InMemoryCommunicator : public Communicator {
|
||||
bool IsDistributed() const override { return true; }
|
||||
bool IsFederated() const override { return false; }
|
||||
|
||||
void AllGather(void* in_out, std::size_t size) override {
|
||||
std::string output;
|
||||
handler_.Allgather(static_cast<const char*>(in_out), size, &output, sequence_number_++,
|
||||
GetRank());
|
||||
output.copy(static_cast<char*>(in_out), size);
|
||||
}
|
||||
|
||||
void AllReduce(void* in_out, std::size_t size, DataType data_type, Operation operation) override {
|
||||
auto const bytes = size * GetTypeSize(data_type);
|
||||
std::string output;
|
||||
|
||||
@@ -9,6 +9,32 @@
|
||||
namespace xgboost {
|
||||
namespace collective {
|
||||
|
||||
/**
|
||||
* @brief Functor for allgather.
|
||||
*/
|
||||
class AllgatherFunctor {
|
||||
public:
|
||||
std::string const name{"Allgather"};
|
||||
|
||||
AllgatherFunctor(int world_size, int rank) : world_size_{world_size}, rank_{rank} {}
|
||||
|
||||
void operator()(char const* input, std::size_t bytes, std::string* buffer) const {
|
||||
if (buffer->empty()) {
|
||||
// Copy the input if this is the first request.
|
||||
buffer->assign(input, bytes);
|
||||
} else {
|
||||
// Splice the input into the common buffer.
|
||||
auto const per_rank = bytes / world_size_;
|
||||
auto const index = rank_ * per_rank;
|
||||
buffer->replace(index, per_rank, input + index, per_rank);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
int world_size_;
|
||||
int rank_;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Functor for allreduce.
|
||||
*/
|
||||
@@ -17,7 +43,7 @@ class AllreduceFunctor {
|
||||
std::string const name{"Allreduce"};
|
||||
|
||||
AllreduceFunctor(DataType dataType, Operation operation)
|
||||
: data_type_(dataType), operation_(operation) {}
|
||||
: data_type_{dataType}, operation_{operation} {}
|
||||
|
||||
void operator()(char const* input, std::size_t bytes, std::string* buffer) const {
|
||||
if (buffer->empty()) {
|
||||
@@ -128,7 +154,7 @@ class BroadcastFunctor {
|
||||
public:
|
||||
std::string const name{"Broadcast"};
|
||||
|
||||
BroadcastFunctor(int rank, int root) : rank_(rank), root_(root) {}
|
||||
BroadcastFunctor(int rank, int root) : rank_{rank}, root_{root} {}
|
||||
|
||||
void operator()(char const* input, std::size_t bytes, std::string* buffer) const {
|
||||
if (rank_ == root_) {
|
||||
@@ -167,6 +193,11 @@ void InMemoryHandler::Shutdown(uint64_t sequence_number, int) {
|
||||
cv_.notify_all();
|
||||
}
|
||||
|
||||
void InMemoryHandler::Allgather(char const* input, std::size_t bytes, std::string* output,
|
||||
std::size_t sequence_number, int rank) {
|
||||
Handle(input, bytes, output, sequence_number, rank, AllgatherFunctor{world_size_, rank});
|
||||
}
|
||||
|
||||
void InMemoryHandler::Allreduce(char const* input, std::size_t bytes, std::string* output,
|
||||
std::size_t sequence_number, int rank, DataType data_type,
|
||||
Operation op) {
|
||||
|
||||
@@ -53,6 +53,17 @@ class InMemoryHandler {
|
||||
*/
|
||||
void Shutdown(uint64_t sequence_number, int rank);
|
||||
|
||||
/**
|
||||
* @brief Perform allgather.
|
||||
* @param input The input buffer.
|
||||
* @param bytes Number of bytes in the input buffer.
|
||||
* @param output The output buffer.
|
||||
* @param sequence_number Call sequence number.
|
||||
* @param rank Index of the worker.
|
||||
*/
|
||||
void Allgather(char const* input, std::size_t bytes, std::string* output,
|
||||
std::size_t sequence_number, int rank);
|
||||
|
||||
/**
|
||||
* @brief Perform allreduce.
|
||||
* @param input The input buffer.
|
||||
|
||||
@@ -17,6 +17,7 @@ class NoOpCommunicator : public Communicator {
|
||||
NoOpCommunicator() : Communicator(1, 0) {}
|
||||
bool IsDistributed() const override { return false; }
|
||||
bool IsFederated() const override { return false; }
|
||||
void AllGather(void *, std::size_t) override {}
|
||||
void AllReduce(void *, std::size_t, DataType, Operation) override {}
|
||||
void Broadcast(void *, std::size_t, int) override {}
|
||||
std::string GetProcessorName() override { return ""; }
|
||||
|
||||
@@ -55,6 +55,12 @@ class RabitCommunicator : public Communicator {
|
||||
|
||||
bool IsFederated() const override { return false; }
|
||||
|
||||
void AllGather(void *send_receive_buffer, std::size_t size) override {
|
||||
auto const per_rank = size / GetWorldSize();
|
||||
auto const index = per_rank * GetRank();
|
||||
rabit::Allgather(static_cast<char *>(send_receive_buffer), size, index, per_rank, per_rank);
|
||||
}
|
||||
|
||||
void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
|
||||
Operation op) override {
|
||||
switch (data_type) {
|
||||
|
||||
Reference in New Issue
Block a user