From a8255ea678269045007fe3d38a665283f3d4d1cc Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Wed, 30 Nov 2022 08:24:12 -0800 Subject: [PATCH] Add an in-memory collective communicator (#8494) --- R-package/src/Makevars.in | 2 + R-package/src/Makevars.win | 2 + plugin/federated/federated.proto | 12 -- plugin/federated/federated_client.h | 19 -- plugin/federated/federated_communicator.h | 34 --- plugin/federated/federated_server.cc | 181 +--------------- plugin/federated/federated_server.h | 19 +- src/collective/communicator.cc | 5 + src/collective/communicator.h | 36 +++- src/collective/in_memory_communicator.cc | 12 ++ src/collective/in_memory_communicator.h | 91 ++++++++ src/collective/in_memory_handler.cc | 200 ++++++++++++++++++ src/collective/in_memory_handler.h | 106 ++++++++++ .../collective/test_in_memory_communicator.cc | 112 ++++++++++ tests/cpp/plugin/test_federated_server.cc | 23 +- 15 files changed, 577 insertions(+), 277 deletions(-) create mode 100644 src/collective/in_memory_communicator.cc create mode 100644 src/collective/in_memory_communicator.h create mode 100644 src/collective/in_memory_handler.cc create mode 100644 src/collective/in_memory_handler.h create mode 100644 tests/cpp/collective/test_in_memory_communicator.cc diff --git a/R-package/src/Makevars.in b/R-package/src/Makevars.in index 54f3acaa5..5251bb2a5 100644 --- a/R-package/src/Makevars.in +++ b/R-package/src/Makevars.in @@ -72,6 +72,8 @@ OBJECTS= \ $(PKGROOT)/src/logging.o \ $(PKGROOT)/src/global_config.o \ $(PKGROOT)/src/collective/communicator.o \ + $(PKGROOT)/src/collective/in_memory_communicator.o \ + $(PKGROOT)/src/collective/in_memory_handler.o \ $(PKGROOT)/src/collective/socket.o \ $(PKGROOT)/src/common/charconv.o \ $(PKGROOT)/src/common/column_matrix.o \ diff --git a/R-package/src/Makevars.win b/R-package/src/Makevars.win index c08153532..65000ef20 100644 --- a/R-package/src/Makevars.win +++ b/R-package/src/Makevars.win @@ -72,6 +72,8 @@ OBJECTS= \ $(PKGROOT)/src/logging.o \ $(PKGROOT)/src/global_config.o \ $(PKGROOT)/src/collective/communicator.o \ + $(PKGROOT)/src/collective/in_memory_communicator.o \ + $(PKGROOT)/src/collective/in_memory_handler.o \ $(PKGROOT)/src/collective/socket.o \ $(PKGROOT)/src/common/charconv.o \ $(PKGROOT)/src/common/column_matrix.o \ diff --git a/plugin/federated/federated.proto b/plugin/federated/federated.proto index 5a338ba0d..751861d9b 100644 --- a/plugin/federated/federated.proto +++ b/plugin/federated/federated.proto @@ -6,7 +6,6 @@ syntax = "proto3"; package xgboost.federated; service Federated { - rpc Allgather(AllgatherRequest) returns (AllgatherReply) {} rpc Allreduce(AllreduceRequest) returns (AllreduceReply) {} rpc Broadcast(BroadcastRequest) returns (BroadcastReply) {} } @@ -28,17 +27,6 @@ enum ReduceOperation { SUM = 2; } -message AllgatherRequest { - // An incrementing counter that is unique to each round to operations. - uint64 sequence_number = 1; - int32 rank = 2; - bytes send_buffer = 3; -} - -message AllgatherReply { - bytes receive_buffer = 1; -} - message AllreduceRequest { // An incrementing counter that is unique to each round to operations. uint64 sequence_number = 1; diff --git a/plugin/federated/federated_client.h b/plugin/federated/federated_client.h index 2b4637339..3d0cdb729 100644 --- a/plugin/federated/federated_client.h +++ b/plugin/federated/federated_client.h @@ -46,25 +46,6 @@ class FederatedClient { }()}, rank_{rank} {} - std::string Allgather(std::string const &send_buffer) { - AllgatherRequest request; - request.set_sequence_number(sequence_number_++); - request.set_rank(rank_); - request.set_send_buffer(send_buffer); - - AllgatherReply reply; - grpc::ClientContext context; - context.set_wait_for_ready(true); - grpc::Status status = stub_->Allgather(&context, request, &reply); - - if (status.ok()) { - return reply.receive_buffer(); - } else { - std::cout << status.error_code() << ": " << status.error_message() << '\n'; - throw std::runtime_error("Allgather RPC failed"); - } - } - std::string Allreduce(std::string const &send_buffer, DataType data_type, ReduceOperation reduce_operation) { AllreduceRequest request; diff --git a/plugin/federated/federated_communicator.h b/plugin/federated/federated_communicator.h index 2dc50962a..856ed6aac 100644 --- a/plugin/federated/federated_communicator.h +++ b/plugin/federated/federated_communicator.h @@ -12,40 +12,6 @@ namespace xgboost { namespace collective { -/** @brief Get the size of the data type. */ -inline std::size_t GetTypeSize(DataType data_type) { - std::size_t size{0}; - switch (data_type) { - case DataType::kInt8: - size = sizeof(std::int8_t); - break; - case DataType::kUInt8: - size = sizeof(std::uint8_t); - break; - case DataType::kInt32: - size = sizeof(std::int32_t); - break; - case DataType::kUInt32: - size = sizeof(std::uint32_t); - break; - case DataType::kInt64: - size = sizeof(std::int64_t); - break; - case DataType::kUInt64: - size = sizeof(std::uint64_t); - break; - case DataType::kFloat: - size = sizeof(float); - break; - case DataType::kDouble: - size = sizeof(double); - break; - default: - LOG(FATAL) << "Unknown data type."; - } - return size; -} - /** * @brief A Federated Learning communicator class that handles collective communication. */ diff --git a/plugin/federated/federated_server.cc b/plugin/federated/federated_server.cc index 0738f776b..ec0b451a9 100644 --- a/plugin/federated/federated_server.cc +++ b/plugin/federated/federated_server.cc @@ -7,7 +7,6 @@ #include #include -#include #include #include "../../src/common/io.h" @@ -15,184 +14,20 @@ namespace xgboost { namespace federated { -class AllgatherFunctor { - public: - std::string const name{"Allgather"}; - - explicit AllgatherFunctor(int const world_size) : world_size_{world_size} {} - - void operator()(AllgatherRequest const* request, std::string& buffer) const { - auto const rank = request->rank(); - auto const& send_buffer = request->send_buffer(); - auto const send_size = send_buffer.size(); - // Resize the buffer if this is the first request. - if (buffer.size() != send_size * world_size_) { - buffer.resize(send_size * world_size_); - } - // Splice the send_buffer into the common buffer. - buffer.replace(rank * send_size, send_size, send_buffer); - } - - private: - int const world_size_; -}; - -class AllreduceFunctor { - public: - std::string const name{"Allreduce"}; - - void operator()(AllreduceRequest const* request, std::string& buffer) const { - if (buffer.empty()) { - // Copy the send_buffer if this is the first request. - buffer = request->send_buffer(); - } else { - // Apply the reduce_operation to the send_buffer and the common buffer. - Accumulate(buffer, request->send_buffer(), request->data_type(), request->reduce_operation()); - } - } - - private: - template - void Accumulate(T* buffer, T const* input, std::size_t n, - ReduceOperation reduce_operation) const { - switch (reduce_operation) { - case ReduceOperation::MAX: - std::transform(buffer, buffer + n, input, buffer, [](T a, T b) { return std::max(a, b); }); - break; - case ReduceOperation::MIN: - std::transform(buffer, buffer + n, input, buffer, [](T a, T b) { return std::min(a, b); }); - break; - case ReduceOperation::SUM: - std::transform(buffer, buffer + n, input, buffer, std::plus()); - break; - default: - throw std::invalid_argument("Invalid reduce operation"); - } - } - - void Accumulate(std::string& buffer, std::string const& input, DataType data_type, - ReduceOperation reduce_operation) const { - switch (data_type) { - case DataType::INT8: - Accumulate(reinterpret_cast(&buffer[0]), - reinterpret_cast(input.data()), buffer.size(), - reduce_operation); - break; - case DataType::UINT8: - Accumulate(reinterpret_cast(&buffer[0]), - reinterpret_cast(input.data()), buffer.size(), - reduce_operation); - break; - case DataType::INT32: - Accumulate(reinterpret_cast(&buffer[0]), - reinterpret_cast(input.data()), - buffer.size() / sizeof(std::uint32_t), reduce_operation); - break; - case DataType::UINT32: - Accumulate(reinterpret_cast(&buffer[0]), - reinterpret_cast(input.data()), - buffer.size() / sizeof(std::uint32_t), reduce_operation); - break; - case DataType::INT64: - Accumulate(reinterpret_cast(&buffer[0]), - reinterpret_cast(input.data()), - buffer.size() / sizeof(std::int64_t), reduce_operation); - break; - case DataType::UINT64: - Accumulate(reinterpret_cast(&buffer[0]), - reinterpret_cast(input.data()), - buffer.size() / sizeof(std::uint64_t), reduce_operation); - break; - case DataType::FLOAT: - Accumulate(reinterpret_cast(&buffer[0]), - reinterpret_cast(input.data()), buffer.size() / sizeof(float), - reduce_operation); - break; - case DataType::DOUBLE: - Accumulate(reinterpret_cast(&buffer[0]), - reinterpret_cast(input.data()), buffer.size() / sizeof(double), - reduce_operation); - break; - default: - throw std::invalid_argument("Invalid data type"); - } - } -}; - -class BroadcastFunctor { - public: - std::string const name{"Broadcast"}; - - void operator()(BroadcastRequest const* request, std::string& buffer) const { - if (request->rank() == request->root()) { - // Copy the send_buffer if this is the root. - buffer = request->send_buffer(); - } - } -}; - -grpc::Status FederatedService::Allgather(grpc::ServerContext* context, - AllgatherRequest const* request, AllgatherReply* reply) { - return Handle(request, reply, AllgatherFunctor{world_size_}); -} - grpc::Status FederatedService::Allreduce(grpc::ServerContext* context, AllreduceRequest const* request, AllreduceReply* reply) { - return Handle(request, reply, AllreduceFunctor{}); + handler_.Allreduce(request->send_buffer().data(), request->send_buffer().size(), + reply->mutable_receive_buffer(), request->sequence_number(), request->rank(), + static_cast(request->data_type()), + static_cast(request->reduce_operation())); + return grpc::Status::OK; } grpc::Status FederatedService::Broadcast(grpc::ServerContext* context, BroadcastRequest const* request, BroadcastReply* reply) { - return Handle(request, reply, BroadcastFunctor{}); -} - -template -grpc::Status FederatedService::Handle(Request const* request, Reply* reply, - RequestFunctor const& functor) { - // Pass through if there is only 1 client. - if (world_size_ == 1) { - reply->set_receive_buffer(request->send_buffer()); - return grpc::Status::OK; - } - - std::unique_lock lock(mutex_); - - auto const sequence_number = request->sequence_number(); - auto const rank = request->rank(); - - LOG(INFO) << functor.name << " rank " << rank << ": waiting for current sequence number"; - cv_.wait(lock, [this, sequence_number] { return sequence_number_ == sequence_number; }); - - LOG(INFO) << functor.name << " rank " << rank << ": handling request"; - functor(request, buffer_); - received_++; - - if (received_ == world_size_) { - LOG(INFO) << functor.name << " rank " << rank << ": all requests received"; - reply->set_receive_buffer(buffer_); - sent_++; - lock.unlock(); - cv_.notify_all(); - return grpc::Status::OK; - } - - LOG(INFO) << functor.name << " rank " << rank << ": waiting for all clients"; - cv_.wait(lock, [this] { return received_ == world_size_; }); - - LOG(INFO) << functor.name << " rank " << rank << ": sending reply"; - reply->set_receive_buffer(buffer_); - sent_++; - - if (sent_ == world_size_) { - LOG(INFO) << functor.name << " rank " << rank << ": all replies sent"; - sent_ = 0; - received_ = 0; - buffer_.clear(); - sequence_number_++; - lock.unlock(); - cv_.notify_all(); - } - + handler_.Broadcast(request->send_buffer().data(), request->send_buffer().size(), + reply->mutable_receive_buffer(), request->sequence_number(), request->rank(), + request->root()); return grpc::Status::OK; } diff --git a/plugin/federated/federated_server.h b/plugin/federated/federated_server.h index 122499d0d..3a5abc4c9 100644 --- a/plugin/federated/federated_server.h +++ b/plugin/federated/federated_server.h @@ -5,18 +5,14 @@ #include -#include -#include +#include "../../src/collective/in_memory_handler.h" namespace xgboost { namespace federated { class FederatedService final : public Federated::Service { public: - explicit FederatedService(int const world_size) : world_size_{world_size} {} - - grpc::Status Allgather(grpc::ServerContext* context, AllgatherRequest const* request, - AllgatherReply* reply) override; + explicit FederatedService(int const world_size) : handler_{world_size} {} grpc::Status Allreduce(grpc::ServerContext* context, AllreduceRequest const* request, AllreduceReply* reply) override; @@ -25,16 +21,7 @@ class FederatedService final : public Federated::Service { BroadcastReply* reply) override; private: - template - grpc::Status Handle(Request const* request, Reply* reply, RequestFunctor const& functor); - - int const world_size_; - int received_{}; - int sent_{}; - std::string buffer_{}; - uint64_t sequence_number_{}; - mutable std::mutex mutex_; - mutable std::condition_variable cv_; + xgboost::collective::InMemoryHandler handler_; }; void RunServer(int port, int world_size, char const* server_key_file, char const* server_cert_file, diff --git a/src/collective/communicator.cc b/src/collective/communicator.cc index 4b45f1e31..22c85f3ad 100644 --- a/src/collective/communicator.cc +++ b/src/collective/communicator.cc @@ -3,6 +3,7 @@ */ #include "communicator.h" +#include "in_memory_communicator.h" #include "noop_communicator.h" #include "rabit_communicator.h" @@ -40,6 +41,10 @@ void Communicator::Init(Json const& config) { #endif break; } + case CommunicatorType::kInMemory: { + communicator_.reset(InMemoryCommunicator::Create(config)); + break; + } case CommunicatorType::kUnknown: LOG(FATAL) << "Unknown communicator type."; } diff --git a/src/collective/communicator.h b/src/collective/communicator.h index ac9346c64..65da9320f 100644 --- a/src/collective/communicator.h +++ b/src/collective/communicator.h @@ -23,12 +23,46 @@ enum class DataType { kDouble = 7 }; +/** @brief Get the size of the data type. */ +inline std::size_t GetTypeSize(DataType data_type) { + std::size_t size{0}; + switch (data_type) { + case DataType::kInt8: + size = sizeof(std::int8_t); + break; + case DataType::kUInt8: + size = sizeof(std::uint8_t); + break; + case DataType::kInt32: + size = sizeof(std::int32_t); + break; + case DataType::kUInt32: + size = sizeof(std::uint32_t); + break; + case DataType::kInt64: + size = sizeof(std::int64_t); + break; + case DataType::kUInt64: + size = sizeof(std::uint64_t); + break; + case DataType::kFloat: + size = sizeof(float); + break; + case DataType::kDouble: + size = sizeof(double); + break; + default: + LOG(FATAL) << "Unknown data type."; + } + return size; +} + /** @brief Defines the reduction operation. */ enum class Operation { kMax = 0, kMin = 1, kSum = 2 }; class DeviceCommunicator; -enum class CommunicatorType { kUnknown, kRabit, kFederated }; +enum class CommunicatorType { kUnknown, kRabit, kFederated, kInMemory }; /** \brief Case-insensitive string comparison. */ inline int CompareStringsCaseInsensitive(const char *s1, const char *s2) { diff --git a/src/collective/in_memory_communicator.cc b/src/collective/in_memory_communicator.cc new file mode 100644 index 000000000..535a15bc9 --- /dev/null +++ b/src/collective/in_memory_communicator.cc @@ -0,0 +1,12 @@ +/*! + * Copyright 2022 XGBoost contributors + */ +#include "in_memory_communicator.h" + +namespace xgboost { +namespace collective { + +InMemoryHandler InMemoryCommunicator::handler_{}; + +} // namespace collective +} // namespace xgboost diff --git a/src/collective/in_memory_communicator.h b/src/collective/in_memory_communicator.h new file mode 100644 index 000000000..c1c5d4493 --- /dev/null +++ b/src/collective/in_memory_communicator.h @@ -0,0 +1,91 @@ +/*! + * Copyright 2022 XGBoost contributors + */ +#pragma once +#include + +#include + +#include "../c_api/c_api_utils.h" +#include "in_memory_handler.h" + +namespace xgboost { +namespace collective { + +/** + * An in-memory communicator, useful for testing. + */ +class InMemoryCommunicator : public Communicator { + public: + /** + * @brief Create a new communicator based on JSON configuration. + * @param config JSON configuration. + * @return Communicator as specified by the JSON configuration. + */ + static Communicator* Create(Json const& config) { + int world_size{0}; + int rank{-1}; + + // Parse environment variables first. + auto* value = getenv("IN_MEMORY_WORLD_SIZE"); + if (value != nullptr) { + world_size = std::stoi(value); + } + value = getenv("IN_MEMORY_RANK"); + if (value != nullptr) { + rank = std::stoi(value); + } + + // Runtime configuration overrides, optional as users can specify them as env vars. + world_size = static_cast(OptionalArg(config, "in_memory_world_size", + static_cast(world_size))); + rank = static_cast( + OptionalArg(config, "in_memory_rank", static_cast(rank))); + + if (world_size == 0) { + LOG(FATAL) << "Federated world size must be set."; + } + if (rank == -1) { + LOG(FATAL) << "Federated rank must be set."; + } + return new InMemoryCommunicator(world_size, rank); + } + + InMemoryCommunicator(int world_size, int rank) : Communicator(world_size, rank) { + handler_.Init(world_size, rank); + } + + ~InMemoryCommunicator() override { handler_.Shutdown(sequence_number_++, GetRank()); } + + bool IsDistributed() const override { return true; } + bool IsFederated() const override { return false; } + + void AllReduce(void* in_out, std::size_t size, DataType data_type, Operation operation) override { + auto const bytes = size * GetTypeSize(data_type); + std::string output; + handler_.Allreduce(static_cast(in_out), bytes, &output, sequence_number_++, + GetRank(), data_type, operation); + output.copy(static_cast(in_out), bytes); + } + + void Broadcast(void* in_out, std::size_t size, int root) override { + std::string output; + handler_.Broadcast(static_cast(in_out), size, &output, sequence_number_++, + GetRank(), root); + output.copy(static_cast(in_out), size); + } + + std::string GetProcessorName() override { return "rank" + std::to_string(GetRank()); } + + void Print(const std::string& message) override { LOG(CONSOLE) << message; } + + protected: + void Shutdown() override {} + + private: + static InMemoryHandler handler_; + uint64_t sequence_number_{}; +}; + +} // namespace collective +} // namespace xgboost diff --git a/src/collective/in_memory_handler.cc b/src/collective/in_memory_handler.cc new file mode 100644 index 000000000..da425b708 --- /dev/null +++ b/src/collective/in_memory_handler.cc @@ -0,0 +1,200 @@ +/*! + * Copyright 2022 XGBoost contributors + */ +#include "in_memory_handler.h" + +#include +#include + +namespace xgboost { +namespace collective { + +/** + * @brief Functor for allreduce. + */ +class AllreduceFunctor { + public: + std::string const name{"Allreduce"}; + + AllreduceFunctor(DataType dataType, Operation operation) + : data_type_(dataType), operation_(operation) {} + + void operator()(char const* input, std::size_t bytes, std::string* buffer) const { + if (buffer->empty()) { + // Copy the input if this is the first request. + buffer->assign(input, bytes); + } else { + // Apply the reduce_operation to the input and the buffer. + Accumulate(input, bytes / GetTypeSize(data_type_), &buffer->front()); + } + } + + private: + template + void Accumulate(T* buffer, T const* input, std::size_t size, Operation reduce_operation) const { + switch (reduce_operation) { + case Operation::kMax: + std::transform(buffer, buffer + size, input, buffer, + [](T a, T b) { return std::max(a, b); }); + break; + case Operation::kMin: + std::transform(buffer, buffer + size, input, buffer, + [](T a, T b) { return std::min(a, b); }); + break; + case Operation::kSum: + std::transform(buffer, buffer + size, input, buffer, std::plus()); + break; + default: + throw std::invalid_argument("Invalid reduce operation"); + } + } + + void Accumulate(char const* input, std::size_t size, char* buffer) const { + switch (data_type_) { + case DataType::kInt8: + Accumulate(reinterpret_cast(buffer), + reinterpret_cast(input), size, operation_); + break; + case DataType::kUInt8: + Accumulate(reinterpret_cast(buffer), + reinterpret_cast(input), size, operation_); + break; + case DataType::kInt32: + Accumulate(reinterpret_cast(buffer), + reinterpret_cast(input), size, operation_); + break; + case DataType::kUInt32: + Accumulate(reinterpret_cast(buffer), + reinterpret_cast(input), size, operation_); + break; + case DataType::kInt64: + Accumulate(reinterpret_cast(buffer), + reinterpret_cast(input), size, operation_); + break; + case DataType::kUInt64: + Accumulate(reinterpret_cast(buffer), + reinterpret_cast(input), size, operation_); + break; + case DataType::kFloat: + Accumulate(reinterpret_cast(buffer), reinterpret_cast(input), size, + operation_); + break; + case DataType::kDouble: + Accumulate(reinterpret_cast(buffer), reinterpret_cast(input), size, + operation_); + break; + default: + throw std::invalid_argument("Invalid data type"); + } + } + + private: + DataType data_type_; + Operation operation_; +}; + +/** + * @brief Functor for broadcast. + */ +class BroadcastFunctor { + public: + std::string const name{"Broadcast"}; + + BroadcastFunctor(int rank, int root) : rank_(rank), root_(root) {} + + void operator()(char const* input, std::size_t bytes, std::string* buffer) const { + if (rank_ == root_) { + // Copy the input if this is the root. + buffer->assign(input, bytes); + } + } + + private: + int rank_; + int root_; +}; + +void InMemoryHandler::Init(int world_size, int rank) { + CHECK(world_size_ < world_size) << "In memory handler already initialized."; + + std::unique_lock lock(mutex_); + world_size_++; + cv_.wait(lock, [this, world_size] { return world_size_ == world_size; }); + lock.unlock(); + cv_.notify_all(); +} + +void InMemoryHandler::Shutdown(uint64_t sequence_number, int rank) { + CHECK(world_size_ > 0) << "In memory handler already shutdown."; + + std::unique_lock lock(mutex_); + cv_.wait(lock, [this, sequence_number] { return sequence_number_ == sequence_number; }); + received_++; + cv_.wait(lock, [this] { return received_ == world_size_; }); + + received_ = 0; + world_size_ = 0; + sequence_number_ = 0; + lock.unlock(); + cv_.notify_all(); +} + +void InMemoryHandler::Allreduce(char const* input, std::size_t bytes, std::string* output, + std::size_t sequence_number, int rank, DataType data_type, + Operation op) { + Handle(input, bytes, output, sequence_number, rank, AllreduceFunctor{data_type, op}); +} + +void InMemoryHandler::Broadcast(char const* input, std::size_t bytes, std::string* output, + std::size_t sequence_number, int rank, int root) { + Handle(input, bytes, output, sequence_number, rank, BroadcastFunctor{rank, root}); +} + +template +void InMemoryHandler::Handle(char const* input, std::size_t bytes, std::string* output, + std::size_t sequence_number, int rank, HandlerFunctor const& functor) { + // Pass through if there is only 1 client. + if (world_size_ == 1) { + if (input != output->data()) { + output->assign(input, bytes); + } + return; + } + + std::unique_lock lock(mutex_); + + LOG(INFO) << functor.name << " rank " << rank << ": waiting for current sequence number"; + cv_.wait(lock, [this, sequence_number] { return sequence_number_ == sequence_number; }); + + LOG(INFO) << functor.name << " rank " << rank << ": handling request"; + functor(input, bytes, &buffer_); + received_++; + + if (received_ == world_size_) { + LOG(INFO) << functor.name << " rank " << rank << ": all requests received"; + output->assign(buffer_); + sent_++; + lock.unlock(); + cv_.notify_all(); + return; + } + + LOG(INFO) << functor.name << " rank " << rank << ": waiting for all clients"; + cv_.wait(lock, [this] { return received_ == world_size_; }); + + LOG(INFO) << functor.name << " rank " << rank << ": sending reply"; + output->assign(buffer_); + sent_++; + + if (sent_ == world_size_) { + LOG(INFO) << functor.name << " rank " << rank << ": all replies sent"; + sent_ = 0; + received_ = 0; + buffer_.clear(); + sequence_number_++; + lock.unlock(); + cv_.notify_all(); + } +} +} // namespace collective +} // namespace xgboost diff --git a/src/collective/in_memory_handler.h b/src/collective/in_memory_handler.h new file mode 100644 index 000000000..3ab2d9a0b --- /dev/null +++ b/src/collective/in_memory_handler.h @@ -0,0 +1,106 @@ +/*! + * Copyright 2022 XGBoost contributors + */ +#pragma once +#include +#include + +#include "communicator.h" + +namespace xgboost { +namespace collective { + +/** + * @brief Handles collective communication primitives in memory. + * + * This class is thread safe. + */ +class InMemoryHandler { + public: + /** + * @brief Default constructor. + * + * This is used when multiple objects/threads are accessing the same handler and need to + * initialize it collectively. + */ + InMemoryHandler() = default; + + /** + * @brief Construct a handler with the given world size. + * @param world_size Number of workers. + * + * This is used when the handler only needs to be initialized once with a known world size. + */ + explicit InMemoryHandler(int worldSize) : world_size_{worldSize} {} + + /** + * @brief Initialize the handler with the world size and rank. + * @param world_size Number of workers. + * @param rank Index of the worker. + * + * This is used when multiple objects/threads are accessing the same handler and need to + * initialize it collectively. + */ + void Init(int world_size, int rank); + + /** + * @brief Shut down the handler. + * @param sequence_number Call sequence number. + * @param rank Index of the worker. + * + * This is used when multiple objects/threads are accessing the same handler and need to + * shut it down collectively. + */ + void Shutdown(uint64_t sequence_number, int rank); + + /** + * @brief Perform allreduce. + * @param input The input buffer. + * @param bytes Number of bytes in the input buffer. + * @param output The output buffer. + * @param sequence_number Call sequence number. + * @param rank Index of the worker. + * @param data_type Type of the data. + * @param op The reduce operation. + */ + void Allreduce(char const* input, std::size_t bytes, std::string* output, + std::size_t sequence_number, int rank, DataType data_type, Operation op); + + /** + * @brief Perform broadcast. + * @param input The input buffer. + * @param bytes Number of bytes in the input buffer. + * @param output The output buffer. + * @param sequence_number Call sequence number. + * @param rank Index of the worker. + * @param root Index of the worker to broadcast from. + */ + void Broadcast(char const* input, std::size_t bytes, std::string* output, + std::size_t sequence_number, int rank, int root); + + private: + /** + * @brief Handle a collective communication primitive. + * @tparam HandlerFunctor The functor used to perform the specific primitive. + * @param input The input buffer. + * @param size Size of the input in terms of the data type. + * @param output The output buffer. + * @param sequence_number Call sequence number. + * @param rank Index of the worker. + * @param functor The functor instance used to perform the specific primitive. + */ + template + void Handle(char const* input, std::size_t size, std::string* output, std::size_t sequence_number, + int rank, HandlerFunctor const& functor); + + int world_size_{}; /// Number of workers. + int received_{}; /// Number of calls received with the current sequence. + int sent_{}; /// Number of calls completed with the current sequence. + std::string buffer_{}; /// A shared common buffer. + uint64_t sequence_number_{}; /// Call sequence number. + mutable std::mutex mutex_; /// Lock. + mutable std::condition_variable cv_; /// Conditional variable to wait on. +}; + +} // namespace collective +} // namespace xgboost diff --git a/tests/cpp/collective/test_in_memory_communicator.cc b/tests/cpp/collective/test_in_memory_communicator.cc new file mode 100644 index 000000000..ef70e292e --- /dev/null +++ b/tests/cpp/collective/test_in_memory_communicator.cc @@ -0,0 +1,112 @@ +/*! + * Copyright 2022 XGBoost contributors + */ +#include +#include + +#include + +#include "../../../src/collective/in_memory_communicator.h" + +namespace xgboost { +namespace collective { + +class InMemoryCommunicatorTest : public ::testing::Test { + public: + static void VerifyAllreduce(int rank) { + InMemoryCommunicator comm{kWorldSize, rank}; + int buffer[] = {1, 2, 3, 4, 5}; + comm.AllReduce(buffer, sizeof(buffer) / sizeof(buffer[0]), DataType::kInt32, Operation::kSum); + int expected[] = {3, 6, 9, 12, 15}; + for (auto i = 0; i < 5; i++) { + EXPECT_EQ(buffer[i], expected[i]); + } + } + + static void VerifyBroadcast(int rank) { + InMemoryCommunicator comm{kWorldSize, rank}; + if (rank == 0) { + std::string buffer{"hello"}; + comm.Broadcast(&buffer[0], buffer.size(), 0); + EXPECT_EQ(buffer, "hello"); + } else { + std::string buffer{" "}; + comm.Broadcast(&buffer[0], buffer.size(), 0); + EXPECT_EQ(buffer, "hello"); + } + } + + protected: + static int const kWorldSize{3}; +}; + +TEST(InMemoryCommunicatorSimpleTest, ThrowOnWorldSizeTooSmall) { + auto construct = []() { InMemoryCommunicator comm{0, 0}; }; + EXPECT_THROW(construct(), dmlc::Error); +} + +TEST(InMemoryCommunicatorSimpleTest, ThrowOnRankTooSmall) { + auto construct = []() { InMemoryCommunicator comm{1, -1}; }; + EXPECT_THROW(construct(), dmlc::Error); +} + +TEST(InMemoryCommunicatorSimpleTest, ThrowOnRankTooBig) { + auto construct = []() { InMemoryCommunicator comm{1, 1}; }; + EXPECT_THROW(construct(), dmlc::Error); +} + +TEST(InMemoryCommunicatorSimpleTest, ThrowOnWorldSizeNotInteger) { + auto construct = []() { + Json config{JsonObject()}; + config["in_memory_world_size"] = std::string("1"); + config["in_memory_rank"] = Integer(0); + auto *comm = InMemoryCommunicator::Create(config); + delete comm; + }; + EXPECT_THROW(construct(), dmlc::Error); +} + +TEST(InMemoryCommunicatorSimpleTest, ThrowOnRankNotInteger) { + auto construct = []() { + Json config{JsonObject()}; + config["in_memory_world_size"] = 1; + config["in_memory_rank"] = std::string("0"); + auto *comm = InMemoryCommunicator::Create(config); + delete comm; + }; + EXPECT_THROW(construct(), dmlc::Error); +} + +TEST(InMemoryCommunicatorSimpleTest, GetWorldSizeAndRank) { + InMemoryCommunicator comm{1, 0}; + EXPECT_EQ(comm.GetWorldSize(), 1); + EXPECT_EQ(comm.GetRank(), 0); +} + +TEST(InMemoryCommunicatorSimpleTest, IsDistributed) { + InMemoryCommunicator comm{1, 0}; + EXPECT_TRUE(comm.IsDistributed()); +} + +TEST_F(InMemoryCommunicatorTest, Allreduce) { + std::vector threads; + for (auto rank = 0; rank < kWorldSize; rank++) { + threads.emplace_back(std::thread(&InMemoryCommunicatorTest::VerifyAllreduce, rank)); + } + for (auto &thread : threads) { + thread.join(); + } +} + +TEST_F(InMemoryCommunicatorTest, Broadcast) { + std::vector threads; + for (auto rank = 0; rank < kWorldSize; rank++) { + threads.emplace_back(std::thread(&InMemoryCommunicatorTest::VerifyBroadcast, rank)); + } + for (auto &thread : threads) { + thread.join(); + } +} + +} // namespace collective +} // namespace xgboost diff --git a/tests/cpp/plugin/test_federated_server.cc b/tests/cpp/plugin/test_federated_server.cc index 2e7afe5a2..61828975b 100644 --- a/tests/cpp/plugin/test_federated_server.cc +++ b/tests/cpp/plugin/test_federated_server.cc @@ -26,11 +26,6 @@ namespace xgboost { class FederatedServerTest : public ::testing::Test { public: - static void VerifyAllgather(int rank, const std::string& server_address) { - federated::FederatedClient client{server_address, rank}; - CheckAllgather(client, rank); - } - static void VerifyAllreduce(int rank, const std::string& server_address) { federated::FederatedClient client{server_address, rank}; CheckAllreduce(client); @@ -44,7 +39,6 @@ class FederatedServerTest : public ::testing::Test { static void VerifyMixture(int rank, const std::string& server_address) { federated::FederatedClient client{server_address, rank}; for (auto i = 0; i < 10; i++) { - CheckAllgather(client, rank); CheckAllreduce(client); CheckBroadcast(client, rank); } @@ -68,11 +62,6 @@ class FederatedServerTest : public ::testing::Test { server_thread_->join(); } - static void CheckAllgather(federated::FederatedClient& client, int rank) { - auto reply = client.Allgather("hello " + std::to_string(rank) + " "); - EXPECT_EQ(reply, "hello 0 hello 1 hello 2 "); - } - static void CheckAllreduce(federated::FederatedClient& client) { int data[] = {1, 2, 3, 4, 5}; std::string send_buffer(reinterpret_cast(data), sizeof(data)); @@ -90,7 +79,7 @@ class FederatedServerTest : public ::testing::Test { send_buffer = "hello broadcast"; } auto reply = client.Broadcast(send_buffer, 0); - EXPECT_EQ(reply, "hello broadcast"); + EXPECT_EQ(reply, "hello broadcast") << "rank " << rank; } static int const kWorldSize{3}; @@ -99,16 +88,6 @@ class FederatedServerTest : public ::testing::Test { std::unique_ptr server_; }; -TEST_F(FederatedServerTest, Allgather) { - std::vector threads; - for (auto rank = 0; rank < kWorldSize; rank++) { - threads.emplace_back(std::thread(&FederatedServerTest::VerifyAllgather, rank, server_address_)); - } - for (auto& thread : threads) { - thread.join(); - } -} - TEST_F(FederatedServerTest, Allreduce) { std::vector threads; for (auto rank = 0; rank < kWorldSize; rank++) {