Add an in-memory collective communicator (#8494)
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
*/
|
||||
#include "communicator.h"
|
||||
|
||||
#include "in_memory_communicator.h"
|
||||
#include "noop_communicator.h"
|
||||
#include "rabit_communicator.h"
|
||||
|
||||
@@ -40,6 +41,10 @@ void Communicator::Init(Json const& config) {
|
||||
#endif
|
||||
break;
|
||||
}
|
||||
case CommunicatorType::kInMemory: {
|
||||
communicator_.reset(InMemoryCommunicator::Create(config));
|
||||
break;
|
||||
}
|
||||
case CommunicatorType::kUnknown:
|
||||
LOG(FATAL) << "Unknown communicator type.";
|
||||
}
|
||||
|
||||
@@ -23,12 +23,46 @@ enum class DataType {
|
||||
kDouble = 7
|
||||
};
|
||||
|
||||
/** @brief Get the size of the data type. */
|
||||
inline std::size_t GetTypeSize(DataType data_type) {
|
||||
std::size_t size{0};
|
||||
switch (data_type) {
|
||||
case DataType::kInt8:
|
||||
size = sizeof(std::int8_t);
|
||||
break;
|
||||
case DataType::kUInt8:
|
||||
size = sizeof(std::uint8_t);
|
||||
break;
|
||||
case DataType::kInt32:
|
||||
size = sizeof(std::int32_t);
|
||||
break;
|
||||
case DataType::kUInt32:
|
||||
size = sizeof(std::uint32_t);
|
||||
break;
|
||||
case DataType::kInt64:
|
||||
size = sizeof(std::int64_t);
|
||||
break;
|
||||
case DataType::kUInt64:
|
||||
size = sizeof(std::uint64_t);
|
||||
break;
|
||||
case DataType::kFloat:
|
||||
size = sizeof(float);
|
||||
break;
|
||||
case DataType::kDouble:
|
||||
size = sizeof(double);
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "Unknown data type.";
|
||||
}
|
||||
return size;
|
||||
}
|
||||
|
||||
/** @brief Defines the reduction operation. */
|
||||
enum class Operation { kMax = 0, kMin = 1, kSum = 2 };
|
||||
|
||||
class DeviceCommunicator;
|
||||
|
||||
enum class CommunicatorType { kUnknown, kRabit, kFederated };
|
||||
enum class CommunicatorType { kUnknown, kRabit, kFederated, kInMemory };
|
||||
|
||||
/** \brief Case-insensitive string comparison. */
|
||||
inline int CompareStringsCaseInsensitive(const char *s1, const char *s2) {
|
||||
|
||||
12
src/collective/in_memory_communicator.cc
Normal file
12
src/collective/in_memory_communicator.cc
Normal file
@@ -0,0 +1,12 @@
|
||||
/*!
|
||||
* Copyright 2022 XGBoost contributors
|
||||
*/
|
||||
#include "in_memory_communicator.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace collective {
|
||||
|
||||
InMemoryHandler InMemoryCommunicator::handler_{};
|
||||
|
||||
} // namespace collective
|
||||
} // namespace xgboost
|
||||
91
src/collective/in_memory_communicator.h
Normal file
91
src/collective/in_memory_communicator.h
Normal file
@@ -0,0 +1,91 @@
|
||||
/*!
|
||||
* Copyright 2022 XGBoost contributors
|
||||
*/
|
||||
#pragma once
|
||||
#include <xgboost/json.h>
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "../c_api/c_api_utils.h"
|
||||
#include "in_memory_handler.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace collective {
|
||||
|
||||
/**
|
||||
* An in-memory communicator, useful for testing.
|
||||
*/
|
||||
class InMemoryCommunicator : public Communicator {
|
||||
public:
|
||||
/**
|
||||
* @brief Create a new communicator based on JSON configuration.
|
||||
* @param config JSON configuration.
|
||||
* @return Communicator as specified by the JSON configuration.
|
||||
*/
|
||||
static Communicator* Create(Json const& config) {
|
||||
int world_size{0};
|
||||
int rank{-1};
|
||||
|
||||
// Parse environment variables first.
|
||||
auto* value = getenv("IN_MEMORY_WORLD_SIZE");
|
||||
if (value != nullptr) {
|
||||
world_size = std::stoi(value);
|
||||
}
|
||||
value = getenv("IN_MEMORY_RANK");
|
||||
if (value != nullptr) {
|
||||
rank = std::stoi(value);
|
||||
}
|
||||
|
||||
// Runtime configuration overrides, optional as users can specify them as env vars.
|
||||
world_size = static_cast<int>(OptionalArg<Integer>(config, "in_memory_world_size",
|
||||
static_cast<Integer::Int>(world_size)));
|
||||
rank = static_cast<int>(
|
||||
OptionalArg<Integer>(config, "in_memory_rank", static_cast<Integer::Int>(rank)));
|
||||
|
||||
if (world_size == 0) {
|
||||
LOG(FATAL) << "Federated world size must be set.";
|
||||
}
|
||||
if (rank == -1) {
|
||||
LOG(FATAL) << "Federated rank must be set.";
|
||||
}
|
||||
return new InMemoryCommunicator(world_size, rank);
|
||||
}
|
||||
|
||||
InMemoryCommunicator(int world_size, int rank) : Communicator(world_size, rank) {
|
||||
handler_.Init(world_size, rank);
|
||||
}
|
||||
|
||||
~InMemoryCommunicator() override { handler_.Shutdown(sequence_number_++, GetRank()); }
|
||||
|
||||
bool IsDistributed() const override { return true; }
|
||||
bool IsFederated() const override { return false; }
|
||||
|
||||
void AllReduce(void* in_out, std::size_t size, DataType data_type, Operation operation) override {
|
||||
auto const bytes = size * GetTypeSize(data_type);
|
||||
std::string output;
|
||||
handler_.Allreduce(static_cast<const char*>(in_out), bytes, &output, sequence_number_++,
|
||||
GetRank(), data_type, operation);
|
||||
output.copy(static_cast<char*>(in_out), bytes);
|
||||
}
|
||||
|
||||
void Broadcast(void* in_out, std::size_t size, int root) override {
|
||||
std::string output;
|
||||
handler_.Broadcast(static_cast<const char*>(in_out), size, &output, sequence_number_++,
|
||||
GetRank(), root);
|
||||
output.copy(static_cast<char*>(in_out), size);
|
||||
}
|
||||
|
||||
std::string GetProcessorName() override { return "rank" + std::to_string(GetRank()); }
|
||||
|
||||
void Print(const std::string& message) override { LOG(CONSOLE) << message; }
|
||||
|
||||
protected:
|
||||
void Shutdown() override {}
|
||||
|
||||
private:
|
||||
static InMemoryHandler handler_;
|
||||
uint64_t sequence_number_{};
|
||||
};
|
||||
|
||||
} // namespace collective
|
||||
} // namespace xgboost
|
||||
200
src/collective/in_memory_handler.cc
Normal file
200
src/collective/in_memory_handler.cc
Normal file
@@ -0,0 +1,200 @@
|
||||
/*!
|
||||
* Copyright 2022 XGBoost contributors
|
||||
*/
|
||||
#include "in_memory_handler.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
|
||||
namespace xgboost {
|
||||
namespace collective {
|
||||
|
||||
/**
|
||||
* @brief Functor for allreduce.
|
||||
*/
|
||||
class AllreduceFunctor {
|
||||
public:
|
||||
std::string const name{"Allreduce"};
|
||||
|
||||
AllreduceFunctor(DataType dataType, Operation operation)
|
||||
: data_type_(dataType), operation_(operation) {}
|
||||
|
||||
void operator()(char const* input, std::size_t bytes, std::string* buffer) const {
|
||||
if (buffer->empty()) {
|
||||
// Copy the input if this is the first request.
|
||||
buffer->assign(input, bytes);
|
||||
} else {
|
||||
// Apply the reduce_operation to the input and the buffer.
|
||||
Accumulate(input, bytes / GetTypeSize(data_type_), &buffer->front());
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
template <class T>
|
||||
void Accumulate(T* buffer, T const* input, std::size_t size, Operation reduce_operation) const {
|
||||
switch (reduce_operation) {
|
||||
case Operation::kMax:
|
||||
std::transform(buffer, buffer + size, input, buffer,
|
||||
[](T a, T b) { return std::max(a, b); });
|
||||
break;
|
||||
case Operation::kMin:
|
||||
std::transform(buffer, buffer + size, input, buffer,
|
||||
[](T a, T b) { return std::min(a, b); });
|
||||
break;
|
||||
case Operation::kSum:
|
||||
std::transform(buffer, buffer + size, input, buffer, std::plus<T>());
|
||||
break;
|
||||
default:
|
||||
throw std::invalid_argument("Invalid reduce operation");
|
||||
}
|
||||
}
|
||||
|
||||
void Accumulate(char const* input, std::size_t size, char* buffer) const {
|
||||
switch (data_type_) {
|
||||
case DataType::kInt8:
|
||||
Accumulate(reinterpret_cast<std::int8_t*>(buffer),
|
||||
reinterpret_cast<std::int8_t const*>(input), size, operation_);
|
||||
break;
|
||||
case DataType::kUInt8:
|
||||
Accumulate(reinterpret_cast<std::uint8_t*>(buffer),
|
||||
reinterpret_cast<std::uint8_t const*>(input), size, operation_);
|
||||
break;
|
||||
case DataType::kInt32:
|
||||
Accumulate(reinterpret_cast<std::int32_t*>(buffer),
|
||||
reinterpret_cast<std::int32_t const*>(input), size, operation_);
|
||||
break;
|
||||
case DataType::kUInt32:
|
||||
Accumulate(reinterpret_cast<std::uint32_t*>(buffer),
|
||||
reinterpret_cast<std::uint32_t const*>(input), size, operation_);
|
||||
break;
|
||||
case DataType::kInt64:
|
||||
Accumulate(reinterpret_cast<std::int64_t*>(buffer),
|
||||
reinterpret_cast<std::int64_t const*>(input), size, operation_);
|
||||
break;
|
||||
case DataType::kUInt64:
|
||||
Accumulate(reinterpret_cast<std::uint64_t*>(buffer),
|
||||
reinterpret_cast<std::uint64_t const*>(input), size, operation_);
|
||||
break;
|
||||
case DataType::kFloat:
|
||||
Accumulate(reinterpret_cast<float*>(buffer), reinterpret_cast<float const*>(input), size,
|
||||
operation_);
|
||||
break;
|
||||
case DataType::kDouble:
|
||||
Accumulate(reinterpret_cast<double*>(buffer), reinterpret_cast<double const*>(input), size,
|
||||
operation_);
|
||||
break;
|
||||
default:
|
||||
throw std::invalid_argument("Invalid data type");
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
DataType data_type_;
|
||||
Operation operation_;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Functor for broadcast.
|
||||
*/
|
||||
class BroadcastFunctor {
|
||||
public:
|
||||
std::string const name{"Broadcast"};
|
||||
|
||||
BroadcastFunctor(int rank, int root) : rank_(rank), root_(root) {}
|
||||
|
||||
void operator()(char const* input, std::size_t bytes, std::string* buffer) const {
|
||||
if (rank_ == root_) {
|
||||
// Copy the input if this is the root.
|
||||
buffer->assign(input, bytes);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
int rank_;
|
||||
int root_;
|
||||
};
|
||||
|
||||
void InMemoryHandler::Init(int world_size, int rank) {
|
||||
CHECK(world_size_ < world_size) << "In memory handler already initialized.";
|
||||
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
world_size_++;
|
||||
cv_.wait(lock, [this, world_size] { return world_size_ == world_size; });
|
||||
lock.unlock();
|
||||
cv_.notify_all();
|
||||
}
|
||||
|
||||
void InMemoryHandler::Shutdown(uint64_t sequence_number, int rank) {
|
||||
CHECK(world_size_ > 0) << "In memory handler already shutdown.";
|
||||
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
cv_.wait(lock, [this, sequence_number] { return sequence_number_ == sequence_number; });
|
||||
received_++;
|
||||
cv_.wait(lock, [this] { return received_ == world_size_; });
|
||||
|
||||
received_ = 0;
|
||||
world_size_ = 0;
|
||||
sequence_number_ = 0;
|
||||
lock.unlock();
|
||||
cv_.notify_all();
|
||||
}
|
||||
|
||||
void InMemoryHandler::Allreduce(char const* input, std::size_t bytes, std::string* output,
|
||||
std::size_t sequence_number, int rank, DataType data_type,
|
||||
Operation op) {
|
||||
Handle(input, bytes, output, sequence_number, rank, AllreduceFunctor{data_type, op});
|
||||
}
|
||||
|
||||
void InMemoryHandler::Broadcast(char const* input, std::size_t bytes, std::string* output,
|
||||
std::size_t sequence_number, int rank, int root) {
|
||||
Handle(input, bytes, output, sequence_number, rank, BroadcastFunctor{rank, root});
|
||||
}
|
||||
|
||||
template <class HandlerFunctor>
|
||||
void InMemoryHandler::Handle(char const* input, std::size_t bytes, std::string* output,
|
||||
std::size_t sequence_number, int rank, HandlerFunctor const& functor) {
|
||||
// Pass through if there is only 1 client.
|
||||
if (world_size_ == 1) {
|
||||
if (input != output->data()) {
|
||||
output->assign(input, bytes);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
|
||||
LOG(INFO) << functor.name << " rank " << rank << ": waiting for current sequence number";
|
||||
cv_.wait(lock, [this, sequence_number] { return sequence_number_ == sequence_number; });
|
||||
|
||||
LOG(INFO) << functor.name << " rank " << rank << ": handling request";
|
||||
functor(input, bytes, &buffer_);
|
||||
received_++;
|
||||
|
||||
if (received_ == world_size_) {
|
||||
LOG(INFO) << functor.name << " rank " << rank << ": all requests received";
|
||||
output->assign(buffer_);
|
||||
sent_++;
|
||||
lock.unlock();
|
||||
cv_.notify_all();
|
||||
return;
|
||||
}
|
||||
|
||||
LOG(INFO) << functor.name << " rank " << rank << ": waiting for all clients";
|
||||
cv_.wait(lock, [this] { return received_ == world_size_; });
|
||||
|
||||
LOG(INFO) << functor.name << " rank " << rank << ": sending reply";
|
||||
output->assign(buffer_);
|
||||
sent_++;
|
||||
|
||||
if (sent_ == world_size_) {
|
||||
LOG(INFO) << functor.name << " rank " << rank << ": all replies sent";
|
||||
sent_ = 0;
|
||||
received_ = 0;
|
||||
buffer_.clear();
|
||||
sequence_number_++;
|
||||
lock.unlock();
|
||||
cv_.notify_all();
|
||||
}
|
||||
}
|
||||
} // namespace collective
|
||||
} // namespace xgboost
|
||||
106
src/collective/in_memory_handler.h
Normal file
106
src/collective/in_memory_handler.h
Normal file
@@ -0,0 +1,106 @@
|
||||
/*!
|
||||
* Copyright 2022 XGBoost contributors
|
||||
*/
|
||||
#pragma once
|
||||
#include <condition_variable>
|
||||
#include <string>
|
||||
|
||||
#include "communicator.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace collective {
|
||||
|
||||
/**
|
||||
* @brief Handles collective communication primitives in memory.
|
||||
*
|
||||
* This class is thread safe.
|
||||
*/
|
||||
class InMemoryHandler {
|
||||
public:
|
||||
/**
|
||||
* @brief Default constructor.
|
||||
*
|
||||
* This is used when multiple objects/threads are accessing the same handler and need to
|
||||
* initialize it collectively.
|
||||
*/
|
||||
InMemoryHandler() = default;
|
||||
|
||||
/**
|
||||
* @brief Construct a handler with the given world size.
|
||||
* @param world_size Number of workers.
|
||||
*
|
||||
* This is used when the handler only needs to be initialized once with a known world size.
|
||||
*/
|
||||
explicit InMemoryHandler(int worldSize) : world_size_{worldSize} {}
|
||||
|
||||
/**
|
||||
* @brief Initialize the handler with the world size and rank.
|
||||
* @param world_size Number of workers.
|
||||
* @param rank Index of the worker.
|
||||
*
|
||||
* This is used when multiple objects/threads are accessing the same handler and need to
|
||||
* initialize it collectively.
|
||||
*/
|
||||
void Init(int world_size, int rank);
|
||||
|
||||
/**
|
||||
* @brief Shut down the handler.
|
||||
* @param sequence_number Call sequence number.
|
||||
* @param rank Index of the worker.
|
||||
*
|
||||
* This is used when multiple objects/threads are accessing the same handler and need to
|
||||
* shut it down collectively.
|
||||
*/
|
||||
void Shutdown(uint64_t sequence_number, int rank);
|
||||
|
||||
/**
|
||||
* @brief Perform allreduce.
|
||||
* @param input The input buffer.
|
||||
* @param bytes Number of bytes in the input buffer.
|
||||
* @param output The output buffer.
|
||||
* @param sequence_number Call sequence number.
|
||||
* @param rank Index of the worker.
|
||||
* @param data_type Type of the data.
|
||||
* @param op The reduce operation.
|
||||
*/
|
||||
void Allreduce(char const* input, std::size_t bytes, std::string* output,
|
||||
std::size_t sequence_number, int rank, DataType data_type, Operation op);
|
||||
|
||||
/**
|
||||
* @brief Perform broadcast.
|
||||
* @param input The input buffer.
|
||||
* @param bytes Number of bytes in the input buffer.
|
||||
* @param output The output buffer.
|
||||
* @param sequence_number Call sequence number.
|
||||
* @param rank Index of the worker.
|
||||
* @param root Index of the worker to broadcast from.
|
||||
*/
|
||||
void Broadcast(char const* input, std::size_t bytes, std::string* output,
|
||||
std::size_t sequence_number, int rank, int root);
|
||||
|
||||
private:
|
||||
/**
|
||||
* @brief Handle a collective communication primitive.
|
||||
* @tparam HandlerFunctor The functor used to perform the specific primitive.
|
||||
* @param input The input buffer.
|
||||
* @param size Size of the input in terms of the data type.
|
||||
* @param output The output buffer.
|
||||
* @param sequence_number Call sequence number.
|
||||
* @param rank Index of the worker.
|
||||
* @param functor The functor instance used to perform the specific primitive.
|
||||
*/
|
||||
template <class HandlerFunctor>
|
||||
void Handle(char const* input, std::size_t size, std::string* output, std::size_t sequence_number,
|
||||
int rank, HandlerFunctor const& functor);
|
||||
|
||||
int world_size_{}; /// Number of workers.
|
||||
int received_{}; /// Number of calls received with the current sequence.
|
||||
int sent_{}; /// Number of calls completed with the current sequence.
|
||||
std::string buffer_{}; /// A shared common buffer.
|
||||
uint64_t sequence_number_{}; /// Call sequence number.
|
||||
mutable std::mutex mutex_; /// Lock.
|
||||
mutable std::condition_variable cv_; /// Conditional variable to wait on.
|
||||
};
|
||||
|
||||
} // namespace collective
|
||||
} // namespace xgboost
|
||||
Reference in New Issue
Block a user