Common interface for collective communication (#8057)

* implement broadcast for federated communicator

* implement allreduce

* add communicator factory

* add device adapter

* add device communicator to factory

* add rabit communicator

* add rabit communicator to the factory

* add nccl device communicator

* add synchronize to device communicator

* add back print and getprocessorname

* add python wrapper and c api

* clean up types

* fix non-gpu build

* try to fix ci

* fix std::size_t

* portable string compare ignore case

* c style size_t

* fix lint errors

* cross platform setenv

* fix memory leak

* fix lint errors

* address review feedback

* add python test for rabit communicator

* fix failing gtest

* use json to configure communicators

* fix lint error

* get rid of factories

* fix cpu build

* fix include

* fix python import

* don't export collective.py yet

* skip collective communicator pytest on windows

* add review feedback

* update documentation

* remove mpi communicator type

* fix tests

* shutdown the communicator separately

Co-authored-by: Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
Rong Ou
2022-09-12 15:21:12 -07:00
committed by GitHub
parent bc818316f2
commit a2686543a9
25 changed files with 1771 additions and 95 deletions

View File

@@ -71,7 +71,9 @@ class FederatedEngine : public IEngine {
void Allreduce(void *sendrecvbuf, size_t size, mpi::DataType dtype, mpi::OpType op) {
auto *buffer = reinterpret_cast<char *>(sendrecvbuf);
std::string const send_buffer(buffer, size);
auto const receive_buffer = client_->Allreduce(send_buffer, GetDataType(dtype), GetOp(op));
auto const receive_buffer =
client_->Allreduce(send_buffer, static_cast<xgboost::federated::DataType>(dtype),
static_cast<xgboost::federated::ReduceOperation>(op));
receive_buffer.copy(buffer, size);
}
@@ -113,51 +115,6 @@ class FederatedEngine : public IEngine {
}
private:
/** @brief Transform mpi::DataType to xgboost::federated::DataType. */
static xgboost::federated::DataType GetDataType(mpi::DataType data_type) {
switch (data_type) {
case mpi::kChar:
return xgboost::federated::CHAR;
case mpi::kUChar:
return xgboost::federated::UCHAR;
case mpi::kInt:
return xgboost::federated::INT;
case mpi::kUInt:
return xgboost::federated::UINT;
case mpi::kLong:
return xgboost::federated::LONG;
case mpi::kULong:
return xgboost::federated::ULONG;
case mpi::kFloat:
return xgboost::federated::FLOAT;
case mpi::kDouble:
return xgboost::federated::DOUBLE;
case mpi::kLongLong:
return xgboost::federated::LONGLONG;
case mpi::kULongLong:
return xgboost::federated::ULONGLONG;
}
utils::Error("unknown mpi::DataType");
return xgboost::federated::CHAR;
}
/** @brief Transform mpi::OpType to enum to MPI OP */
static xgboost::federated::ReduceOperation GetOp(mpi::OpType op_type) {
switch (op_type) {
case mpi::kMax:
return xgboost::federated::MAX;
case mpi::kMin:
return xgboost::federated::MIN;
case mpi::kSum:
return xgboost::federated::SUM;
case mpi::kBitwiseOR:
utils::Error("Bitwise OR is not supported");
return xgboost::federated::MAX;
}
utils::Error("unknown mpi::OpType");
return xgboost::federated::MAX;
}
void SetParam(std::string const &name, std::string const &val) {
if (!strcasecmp(name.c_str(), "FEDERATED_SERVER_ADDRESS")) {
server_address_ = val;

View File

@@ -12,16 +12,14 @@ service Federated {
}
enum DataType {
CHAR = 0;
UCHAR = 1;
INT = 2;
UINT = 3;
LONG = 4;
ULONG = 5;
INT8 = 0;
UINT8 = 1;
INT32 = 2;
UINT32 = 3;
INT64 = 4;
UINT64 = 5;
FLOAT = 6;
DOUBLE = 7;
LONGLONG = 8;
ULONGLONG = 9;
}
enum ReduceOperation {

View File

@@ -0,0 +1,192 @@
/*!
* Copyright 2022 XGBoost contributors
*/
#pragma once
#include <xgboost/json.h>
#include "../../src/collective/communicator.h"
#include "../../src/common/io.h"
#include "federated_client.h"
namespace xgboost {
namespace collective {
/**
* @brief A Federated Learning communicator class that handles collective communication.
*/
class FederatedCommunicator : public Communicator {
public:
/**
* @brief Create a new communicator based on JSON configuration.
* @param config JSON configuration.
* @return Communicator as specified by the JSON configuration.
*/
static Communicator *Create(Json const &config) {
std::string server_address{};
int world_size{0};
int rank{-1};
std::string server_cert{};
std::string client_key{};
std::string client_cert{};
// Parse environment variables first.
auto *value = getenv("FEDERATED_SERVER_ADDRESS");
if (value != nullptr) {
server_address = value;
}
value = getenv("FEDERATED_WORLD_SIZE");
if (value != nullptr) {
world_size = std::stoi(value);
}
value = getenv("FEDERATED_RANK");
if (value != nullptr) {
rank = std::stoi(value);
}
value = getenv("FEDERATED_SERVER_CERT");
if (value != nullptr) {
server_cert = value;
}
value = getenv("FEDERATED_CLIENT_KEY");
if (value != nullptr) {
client_key = value;
}
value = getenv("FEDERATED_CLIENT_CERT");
if (value != nullptr) {
client_cert = value;
}
// Runtime configuration overrides.
auto const &j_server_address = config["federated_server_address"];
if (IsA<String const>(j_server_address)) {
server_address = get<String const>(j_server_address);
}
auto const &j_world_size = config["federated_world_size"];
if (IsA<Integer const>(j_world_size)) {
world_size = static_cast<int>(get<Integer const>(j_world_size));
}
auto const &j_rank = config["federated_rank"];
if (IsA<Integer const>(j_rank)) {
rank = static_cast<int>(get<Integer const>(j_rank));
}
auto const &j_server_cert = config["federated_server_cert"];
if (IsA<String const>(j_server_cert)) {
server_cert = get<String const>(j_server_cert);
}
auto const &j_client_key = config["federated_client_key"];
if (IsA<String const>(j_client_key)) {
client_key = get<String const>(j_client_key);
}
auto const &j_client_cert = config["federated_client_cert"];
if (IsA<String const>(j_client_cert)) {
client_cert = get<String const>(j_client_cert);
}
if (server_address.empty()) {
LOG(FATAL) << "Federated server address must be set.";
}
if (world_size == 0) {
LOG(FATAL) << "Federated world size must be set.";
}
if (rank == -1) {
LOG(FATAL) << "Federated rank must be set.";
}
return new FederatedCommunicator(world_size, rank, server_address, server_cert, client_key,
client_cert);
}
/**
* @brief Construct a new federated communicator.
*
* @param world_size Total number of processes.
* @param rank Rank of the current process.
* @param server_address Address of the federated server (host:port).
* @param server_cert_path Path to the server cert file.
* @param client_key_path Path to the client key file.
* @param client_cert_path Path to the client cert file.
*/
FederatedCommunicator(int world_size, int rank, std::string const &server_address,
std::string const &server_cert_path, std::string const &client_key_path,
std::string const &client_cert_path)
: Communicator{world_size, rank} {
if (server_cert_path.empty() || client_key_path.empty() || client_cert_path.empty()) {
client_.reset(new xgboost::federated::FederatedClient(server_address, rank));
} else {
client_.reset(new xgboost::federated::FederatedClient(
server_address, rank, xgboost::common::ReadAll(server_cert_path),
xgboost::common::ReadAll(client_key_path), xgboost::common::ReadAll(client_cert_path)));
}
}
/**
* @brief Construct an insecure federated communicator without using SSL.
* @param world_size Total number of processes.
* @param rank Rank of the current process.
* @param server_address Address of the federated server (host:port).
*/
FederatedCommunicator(int world_size, int rank, std::string const &server_address)
: Communicator{world_size, rank} {
client_.reset(new xgboost::federated::FederatedClient(server_address, rank));
}
~FederatedCommunicator() override { client_.reset(); }
/**
* \brief Get if the communicator is distributed.
* \return True.
*/
bool IsDistributed() const override { return true; }
/**
* \brief Perform in-place allreduce.
* \param send_receive_buffer Buffer for both sending and receiving data.
* \param count Number of elements to be reduced.
* \param data_type Enumeration of data type.
* \param op Enumeration of operation type.
*/
void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
Operation op) override {
std::string const send_buffer(reinterpret_cast<char const *>(send_receive_buffer),
count * GetTypeSize(data_type));
auto const received =
client_->Allreduce(send_buffer, static_cast<xgboost::federated::DataType>(data_type),
static_cast<xgboost::federated::ReduceOperation>(op));
received.copy(reinterpret_cast<char *>(send_receive_buffer), count * GetTypeSize(data_type));
}
/**
* \brief Broadcast a memory region to all others from root.
* \param send_receive_buffer Pointer to the send or receive buffer.
* \param size Size of the data.
* \param root The process rank to broadcast from.
*/
void Broadcast(void *send_receive_buffer, std::size_t size, int root) override {
if (GetWorldSize() == 1) return;
if (GetRank() == root) {
std::string const send_buffer(reinterpret_cast<char const *>(send_receive_buffer), size);
client_->Broadcast(send_buffer, root);
} else {
auto const received = client_->Broadcast("", root);
received.copy(reinterpret_cast<char *>(send_receive_buffer), size);
}
}
/**
* \brief Get the name of the processor.
* \return Name of the processor.
*/
std::string GetProcessorName() override { return "rank" + std::to_string(GetRank()); }
/**
* \brief Print the message to the communicator.
* \param message The message to be printed.
*/
void Print(const std::string &message) override { LOG(CONSOLE) << message; }
protected:
void Shutdown() override {}
private:
std::unique_ptr<xgboost::federated::FederatedClient> client_{};
};
} // namespace collective
} // namespace xgboost

View File

@@ -10,6 +10,8 @@
#include <fstream>
#include <sstream>
#include "../../src/common/io.h"
namespace xgboost {
namespace federated {
@@ -71,32 +73,35 @@ class AllreduceFunctor {
void Accumulate(std::string& buffer, std::string const& input, DataType data_type,
ReduceOperation reduce_operation) const {
switch (data_type) {
case DataType::CHAR:
Accumulate(&buffer[0], reinterpret_cast<char const*>(input.data()), buffer.size(),
case DataType::INT8:
Accumulate(reinterpret_cast<std::int8_t*>(&buffer[0]),
reinterpret_cast<std::int8_t const*>(input.data()), buffer.size(),
reduce_operation);
break;
case DataType::UCHAR:
Accumulate(reinterpret_cast<unsigned char*>(&buffer[0]),
reinterpret_cast<unsigned char const*>(input.data()), buffer.size(),
case DataType::UINT8:
Accumulate(reinterpret_cast<std::uint8_t*>(&buffer[0]),
reinterpret_cast<std::uint8_t const*>(input.data()), buffer.size(),
reduce_operation);
break;
case DataType::INT:
Accumulate(reinterpret_cast<int*>(&buffer[0]), reinterpret_cast<int const*>(input.data()),
buffer.size() / sizeof(int), reduce_operation);
case DataType::INT32:
Accumulate(reinterpret_cast<std::int32_t*>(&buffer[0]),
reinterpret_cast<std::int32_t const*>(input.data()),
buffer.size() / sizeof(std::uint32_t), reduce_operation);
break;
case DataType::UINT:
Accumulate(reinterpret_cast<unsigned int*>(&buffer[0]),
reinterpret_cast<unsigned int const*>(input.data()),
buffer.size() / sizeof(unsigned int), reduce_operation);
case DataType::UINT32:
Accumulate(reinterpret_cast<std::uint32_t*>(&buffer[0]),
reinterpret_cast<std::uint32_t const*>(input.data()),
buffer.size() / sizeof(std::uint32_t), reduce_operation);
break;
case DataType::LONG:
Accumulate(reinterpret_cast<long*>(&buffer[0]), reinterpret_cast<long const*>(input.data()),
buffer.size() / sizeof(long), reduce_operation);
case DataType::INT64:
Accumulate(reinterpret_cast<std::int64_t*>(&buffer[0]),
reinterpret_cast<std::int64_t const*>(input.data()),
buffer.size() / sizeof(std::int64_t), reduce_operation);
break;
case DataType::ULONG:
Accumulate(reinterpret_cast<unsigned long*>(&buffer[0]),
reinterpret_cast<unsigned long const*>(input.data()),
buffer.size() / sizeof(unsigned long), reduce_operation);
case DataType::UINT64:
Accumulate(reinterpret_cast<std::uint64_t*>(&buffer[0]),
reinterpret_cast<std::uint64_t const*>(input.data()),
buffer.size() / sizeof(std::uint64_t), reduce_operation);
break;
case DataType::FLOAT:
Accumulate(reinterpret_cast<float*>(&buffer[0]),
@@ -108,16 +113,6 @@ class AllreduceFunctor {
reinterpret_cast<double const*>(input.data()), buffer.size() / sizeof(double),
reduce_operation);
break;
case DataType::LONGLONG:
Accumulate(reinterpret_cast<long long*>(&buffer[0]),
reinterpret_cast<long long const*>(input.data()),
buffer.size() / sizeof(long long), reduce_operation);
break;
case DataType::ULONGLONG:
Accumulate(reinterpret_cast<unsigned long long*>(&buffer[0]),
reinterpret_cast<unsigned long long const*>(input.data()),
buffer.size() / sizeof(unsigned long long), reduce_operation);
break;
default:
throw std::invalid_argument("Invalid data type");
}
@@ -201,13 +196,6 @@ grpc::Status FederatedService::Handle(Request const* request, Reply* reply,
return grpc::Status::OK;
}
std::string ReadFile(char const* path) {
auto stream = std::ifstream(path);
std::ostringstream out;
out << stream.rdbuf();
return out.str();
}
void RunServer(int port, int world_size, char const* server_key_file, char const* server_cert_file,
char const* client_cert_file) {
std::string const server_address = "0.0.0.0:" + std::to_string(port);
@@ -216,10 +204,10 @@ void RunServer(int port, int world_size, char const* server_key_file, char const
grpc::ServerBuilder builder;
auto options =
grpc::SslServerCredentialsOptions(GRPC_SSL_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY);
options.pem_root_certs = ReadFile(client_cert_file);
options.pem_root_certs = xgboost::common::ReadAll(client_cert_file);
auto key = grpc::SslServerCredentialsOptions::PemKeyCertPair();
key.private_key = ReadFile(server_key_file);
key.cert_chain = ReadFile(server_cert_file);
key.private_key = xgboost::common::ReadAll(server_key_file);
key.cert_chain = xgboost::common::ReadAll(server_cert_file);
options.pem_key_cert_pairs.push_back(key);
builder.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
builder.AddListeningPort(server_address, grpc::SslServerCredentials(options));