Common interface for collective communication (#8057)

* implement broadcast for federated communicator

* implement allreduce

* add communicator factory

* add device adapter

* add device communicator to factory

* add rabit communicator

* add rabit communicator to the factory

* add nccl device communicator

* add synchronize to device communicator

* add back print and getprocessorname

* add python wrapper and c api

* clean up types

* fix non-gpu build

* try to fix ci

* fix std::size_t

* portable string compare ignore case

* c style size_t

* fix lint errors

* cross platform setenv

* fix memory leak

* fix lint errors

* address review feedback

* add python test for rabit communicator

* fix failing gtest

* use json to configure communicators

* fix lint error

* get rid of factories

* fix cpu build

* fix include

* fix python import

* don't export collective.py yet

* skip collective communicator pytest on windows

* add review feedback

* update documentation

* remove mpi communicator type

* fix tests

* shutdown the communicator separately

Co-authored-by: Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
Rong Ou 2022-09-12 15:21:12 -07:00 committed by GitHub
parent bc818316f2
commit a2686543a9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 1771 additions and 95 deletions

View File

@ -71,6 +71,9 @@
#include "../src/logging.cc" #include "../src/logging.cc"
#include "../src/global_config.cc" #include "../src/global_config.cc"
// collective
#include "../src/collective/communicator.cc"
// common // common
#include "../src/common/common.cc" #include "../src/common/common.cc"
#include "../src/common/column_matrix.cc" #include "../src/common/column_matrix.cc"

View File

@ -9,10 +9,12 @@
#ifdef __cplusplus #ifdef __cplusplus
#define XGB_EXTERN_C extern "C" #define XGB_EXTERN_C extern "C"
#include <cstddef>
#include <cstdio> #include <cstdio>
#include <cstdint> #include <cstdint>
#else #else
#define XGB_EXTERN_C #define XGB_EXTERN_C
#include <stddef.h>
#include <stdio.h> #include <stdio.h>
#include <stdint.h> #include <stdint.h>
#endif // __cplusplus #endif // __cplusplus
@ -1386,4 +1388,135 @@ XGB_DLL int XGBoosterFeatureScore(BoosterHandle handle, const char *json_config,
bst_ulong *out_dim, bst_ulong *out_dim,
bst_ulong const **out_shape, bst_ulong const **out_shape,
float const **out_scores); float const **out_scores);
/*!
* \brief Initialize the collective communicator.
*
* Currently the communicator API is experimental, function signatures may change in the future
* without notice.
*
* Call this once before using anything.
*
* The additional configuration is not required. Usually the communicator will detect settings
* from environment variables.
*
* \param json_config JSON encoded configuration. Accepted JSON keys are:
* - xgboost_communicator: The type of the communicator. Can be set as an environment variable.
* * rabit: Use Rabit. This is the default if the type is unspecified.
* * mpi: Use MPI.
* * federated: Use the gRPC interface for Federated Learning.
* Only applicable to the Rabit communicator (these are case-sensitive):
* - rabit_tracker_uri: Hostname of the tracker.
* - rabit_tracker_port: Port number of the tracker.
* - rabit_task_id: ID of the current task, can be used to obtain deterministic rank assignment.
* - rabit_world_size: Total number of workers.
* - rabit_hadoop_mode: Enable Hadoop support.
* - rabit_tree_reduce_minsize: Minimal size for tree reduce.
* - rabit_reduce_ring_mincount: Minimal count to perform ring reduce.
* - rabit_reduce_buffer: Size of the reduce buffer.
* - rabit_bootstrap_cache: Size of the bootstrap cache.
* - rabit_debug: Enable debugging.
* - rabit_timeout: Enable timeout.
* - rabit_timeout_sec: Timeout in seconds.
* - rabit_enable_tcp_no_delay: Enable TCP no delay on Unix platforms.
* Only applicable to the Rabit communicator (these are case-sensitive, and can be set as
* environment variables):
* - DMLC_TRACKER_URI: Hostname of the tracker.
* - DMLC_TRACKER_PORT: Port number of the tracker.
* - DMLC_TASK_ID: ID of the current task, can be used to obtain deterministic rank assignment.
* - DMLC_ROLE: Role of the current task, "worker" or "server".
* - DMLC_NUM_ATTEMPT: Number of attempts after task failure.
* - DMLC_WORKER_CONNECT_RETRY: Number of retries to connect to the tracker.
* Only applicable to the Federated communicator (use upper case for environment variables, use
* lower case for runtime configuration):
* - federated_server_address: Address of the federated server.
* - federated_world_size: Number of federated workers.
* - federated_rank: Rank of the current worker.
* - federated_server_cert: Server certificate file path. Only needed for the SSL mode.
* - federated_client_key: Client key file path. Only needed for the SSL mode.
* - federated_client_cert: Client certificate file path. Only needed for the SSL mode.
* \return 0 for success, -1 for failure.
*/
XGB_DLL int XGCommunicatorInit(char const* json_config);
/*!
* \brief Finalize the collective communicator.
*
* Call this function after you finished all jobs.
*
* \return 0 for success, -1 for failure.
*/
XGB_DLL int XGCommunicatorFinalize(void);
/*!
* \brief Get rank of current process.
*
* \return Rank of the worker.
*/
XGB_DLL int XGCommunicatorGetRank(void);
/*!
* \brief Get total number of processes.
*
* \return Total world size.
*/
XGB_DLL int XGCommunicatorGetWorldSize(void);
/*!
* \brief Get if the communicator is distributed.
*
* \return True if the communicator is distributed.
*/
XGB_DLL int XGCommunicatorIsDistributed(void);
/*!
* \brief Print the message to the communicator.
*
* This function can be used to communicate the information of the progress to the user who monitors
* the communicator.
*
* \param message The message to be printed.
* \return 0 for success, -1 for failure.
*/
XGB_DLL int XGCommunicatorPrint(char const *message);
/*!
* \brief Get the name of the processor.
*
* \param name_str Pointer to received returned processor name.
* \return 0 for success, -1 for failure.
*/
XGB_DLL int XGCommunicatorGetProcessorName(const char** name_str);
/*!
* \brief Broadcast a memory region to all others from root. This function is NOT thread-safe.
*
* Example:
* int a = 1;
* Broadcast(&a, sizeof(a), root);
*
* \param send_receive_buffer Pointer to the send or receive buffer.
* \param size Size of the data.
* \param root The process rank to broadcast from.
* \return 0 for success, -1 for failure.
*/
XGB_DLL int XGCommunicatorBroadcast(void *send_receive_buffer, size_t size, int root);
/*!
* \brief Perform in-place allreduce. This function is NOT thread-safe.
*
* Example Usage: the following code gives sum of the result
* vector<int> data(10);
* ...
* Allreduce(&data[0], data.size(), DataType:kInt32, Op::kSum);
* ...
* \param send_receive_buffer Buffer for both sending and receiving data.
* \param count Number of elements to be reduced.
* \param data_type Enumeration of data type, see xgboost::collective::DataType in communicator.h.
* \param op Enumeration of operation type, see xgboost::collective::Operation in communicator.h.
* \return 0 for success, -1 for failure.
*/
XGB_DLL int XGCommunicatorAllreduce(void *send_receive_buffer, size_t count, int data_type, int op);
#endif // XGBOOST_C_API_H_ #endif // XGBOOST_C_API_H_

View File

@ -71,7 +71,9 @@ class FederatedEngine : public IEngine {
void Allreduce(void *sendrecvbuf, size_t size, mpi::DataType dtype, mpi::OpType op) { void Allreduce(void *sendrecvbuf, size_t size, mpi::DataType dtype, mpi::OpType op) {
auto *buffer = reinterpret_cast<char *>(sendrecvbuf); auto *buffer = reinterpret_cast<char *>(sendrecvbuf);
std::string const send_buffer(buffer, size); std::string const send_buffer(buffer, size);
auto const receive_buffer = client_->Allreduce(send_buffer, GetDataType(dtype), GetOp(op)); auto const receive_buffer =
client_->Allreduce(send_buffer, static_cast<xgboost::federated::DataType>(dtype),
static_cast<xgboost::federated::ReduceOperation>(op));
receive_buffer.copy(buffer, size); receive_buffer.copy(buffer, size);
} }
@ -113,51 +115,6 @@ class FederatedEngine : public IEngine {
} }
private: private:
/** @brief Transform mpi::DataType to xgboost::federated::DataType. */
static xgboost::federated::DataType GetDataType(mpi::DataType data_type) {
switch (data_type) {
case mpi::kChar:
return xgboost::federated::CHAR;
case mpi::kUChar:
return xgboost::federated::UCHAR;
case mpi::kInt:
return xgboost::federated::INT;
case mpi::kUInt:
return xgboost::federated::UINT;
case mpi::kLong:
return xgboost::federated::LONG;
case mpi::kULong:
return xgboost::federated::ULONG;
case mpi::kFloat:
return xgboost::federated::FLOAT;
case mpi::kDouble:
return xgboost::federated::DOUBLE;
case mpi::kLongLong:
return xgboost::federated::LONGLONG;
case mpi::kULongLong:
return xgboost::federated::ULONGLONG;
}
utils::Error("unknown mpi::DataType");
return xgboost::federated::CHAR;
}
/** @brief Transform mpi::OpType to enum to MPI OP */
static xgboost::federated::ReduceOperation GetOp(mpi::OpType op_type) {
switch (op_type) {
case mpi::kMax:
return xgboost::federated::MAX;
case mpi::kMin:
return xgboost::federated::MIN;
case mpi::kSum:
return xgboost::federated::SUM;
case mpi::kBitwiseOR:
utils::Error("Bitwise OR is not supported");
return xgboost::federated::MAX;
}
utils::Error("unknown mpi::OpType");
return xgboost::federated::MAX;
}
void SetParam(std::string const &name, std::string const &val) { void SetParam(std::string const &name, std::string const &val) {
if (!strcasecmp(name.c_str(), "FEDERATED_SERVER_ADDRESS")) { if (!strcasecmp(name.c_str(), "FEDERATED_SERVER_ADDRESS")) {
server_address_ = val; server_address_ = val;

View File

@ -12,16 +12,14 @@ service Federated {
} }
enum DataType { enum DataType {
CHAR = 0; INT8 = 0;
UCHAR = 1; UINT8 = 1;
INT = 2; INT32 = 2;
UINT = 3; UINT32 = 3;
LONG = 4; INT64 = 4;
ULONG = 5; UINT64 = 5;
FLOAT = 6; FLOAT = 6;
DOUBLE = 7; DOUBLE = 7;
LONGLONG = 8;
ULONGLONG = 9;
} }
enum ReduceOperation { enum ReduceOperation {

View File

@ -0,0 +1,192 @@
/*!
* Copyright 2022 XGBoost contributors
*/
#pragma once
#include <xgboost/json.h>
#include "../../src/collective/communicator.h"
#include "../../src/common/io.h"
#include "federated_client.h"
namespace xgboost {
namespace collective {
/**
* @brief A Federated Learning communicator class that handles collective communication.
*/
class FederatedCommunicator : public Communicator {
public:
/**
* @brief Create a new communicator based on JSON configuration.
* @param config JSON configuration.
* @return Communicator as specified by the JSON configuration.
*/
static Communicator *Create(Json const &config) {
std::string server_address{};
int world_size{0};
int rank{-1};
std::string server_cert{};
std::string client_key{};
std::string client_cert{};
// Parse environment variables first.
auto *value = getenv("FEDERATED_SERVER_ADDRESS");
if (value != nullptr) {
server_address = value;
}
value = getenv("FEDERATED_WORLD_SIZE");
if (value != nullptr) {
world_size = std::stoi(value);
}
value = getenv("FEDERATED_RANK");
if (value != nullptr) {
rank = std::stoi(value);
}
value = getenv("FEDERATED_SERVER_CERT");
if (value != nullptr) {
server_cert = value;
}
value = getenv("FEDERATED_CLIENT_KEY");
if (value != nullptr) {
client_key = value;
}
value = getenv("FEDERATED_CLIENT_CERT");
if (value != nullptr) {
client_cert = value;
}
// Runtime configuration overrides.
auto const &j_server_address = config["federated_server_address"];
if (IsA<String const>(j_server_address)) {
server_address = get<String const>(j_server_address);
}
auto const &j_world_size = config["federated_world_size"];
if (IsA<Integer const>(j_world_size)) {
world_size = static_cast<int>(get<Integer const>(j_world_size));
}
auto const &j_rank = config["federated_rank"];
if (IsA<Integer const>(j_rank)) {
rank = static_cast<int>(get<Integer const>(j_rank));
}
auto const &j_server_cert = config["federated_server_cert"];
if (IsA<String const>(j_server_cert)) {
server_cert = get<String const>(j_server_cert);
}
auto const &j_client_key = config["federated_client_key"];
if (IsA<String const>(j_client_key)) {
client_key = get<String const>(j_client_key);
}
auto const &j_client_cert = config["federated_client_cert"];
if (IsA<String const>(j_client_cert)) {
client_cert = get<String const>(j_client_cert);
}
if (server_address.empty()) {
LOG(FATAL) << "Federated server address must be set.";
}
if (world_size == 0) {
LOG(FATAL) << "Federated world size must be set.";
}
if (rank == -1) {
LOG(FATAL) << "Federated rank must be set.";
}
return new FederatedCommunicator(world_size, rank, server_address, server_cert, client_key,
client_cert);
}
/**
* @brief Construct a new federated communicator.
*
* @param world_size Total number of processes.
* @param rank Rank of the current process.
* @param server_address Address of the federated server (host:port).
* @param server_cert_path Path to the server cert file.
* @param client_key_path Path to the client key file.
* @param client_cert_path Path to the client cert file.
*/
FederatedCommunicator(int world_size, int rank, std::string const &server_address,
std::string const &server_cert_path, std::string const &client_key_path,
std::string const &client_cert_path)
: Communicator{world_size, rank} {
if (server_cert_path.empty() || client_key_path.empty() || client_cert_path.empty()) {
client_.reset(new xgboost::federated::FederatedClient(server_address, rank));
} else {
client_.reset(new xgboost::federated::FederatedClient(
server_address, rank, xgboost::common::ReadAll(server_cert_path),
xgboost::common::ReadAll(client_key_path), xgboost::common::ReadAll(client_cert_path)));
}
}
/**
* @brief Construct an insecure federated communicator without using SSL.
* @param world_size Total number of processes.
* @param rank Rank of the current process.
* @param server_address Address of the federated server (host:port).
*/
FederatedCommunicator(int world_size, int rank, std::string const &server_address)
: Communicator{world_size, rank} {
client_.reset(new xgboost::federated::FederatedClient(server_address, rank));
}
~FederatedCommunicator() override { client_.reset(); }
/**
* \brief Get if the communicator is distributed.
* \return True.
*/
bool IsDistributed() const override { return true; }
/**
* \brief Perform in-place allreduce.
* \param send_receive_buffer Buffer for both sending and receiving data.
* \param count Number of elements to be reduced.
* \param data_type Enumeration of data type.
* \param op Enumeration of operation type.
*/
void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
Operation op) override {
std::string const send_buffer(reinterpret_cast<char const *>(send_receive_buffer),
count * GetTypeSize(data_type));
auto const received =
client_->Allreduce(send_buffer, static_cast<xgboost::federated::DataType>(data_type),
static_cast<xgboost::federated::ReduceOperation>(op));
received.copy(reinterpret_cast<char *>(send_receive_buffer), count * GetTypeSize(data_type));
}
/**
* \brief Broadcast a memory region to all others from root.
* \param send_receive_buffer Pointer to the send or receive buffer.
* \param size Size of the data.
* \param root The process rank to broadcast from.
*/
void Broadcast(void *send_receive_buffer, std::size_t size, int root) override {
if (GetWorldSize() == 1) return;
if (GetRank() == root) {
std::string const send_buffer(reinterpret_cast<char const *>(send_receive_buffer), size);
client_->Broadcast(send_buffer, root);
} else {
auto const received = client_->Broadcast("", root);
received.copy(reinterpret_cast<char *>(send_receive_buffer), size);
}
}
/**
* \brief Get the name of the processor.
* \return Name of the processor.
*/
std::string GetProcessorName() override { return "rank" + std::to_string(GetRank()); }
/**
* \brief Print the message to the communicator.
* \param message The message to be printed.
*/
void Print(const std::string &message) override { LOG(CONSOLE) << message; }
protected:
void Shutdown() override {}
private:
std::unique_ptr<xgboost::federated::FederatedClient> client_{};
};
} // namespace collective
} // namespace xgboost

View File

@ -10,6 +10,8 @@
#include <fstream> #include <fstream>
#include <sstream> #include <sstream>
#include "../../src/common/io.h"
namespace xgboost { namespace xgboost {
namespace federated { namespace federated {
@ -71,32 +73,35 @@ class AllreduceFunctor {
void Accumulate(std::string& buffer, std::string const& input, DataType data_type, void Accumulate(std::string& buffer, std::string const& input, DataType data_type,
ReduceOperation reduce_operation) const { ReduceOperation reduce_operation) const {
switch (data_type) { switch (data_type) {
case DataType::CHAR: case DataType::INT8:
Accumulate(&buffer[0], reinterpret_cast<char const*>(input.data()), buffer.size(), Accumulate(reinterpret_cast<std::int8_t*>(&buffer[0]),
reinterpret_cast<std::int8_t const*>(input.data()), buffer.size(),
reduce_operation); reduce_operation);
break; break;
case DataType::UCHAR: case DataType::UINT8:
Accumulate(reinterpret_cast<unsigned char*>(&buffer[0]), Accumulate(reinterpret_cast<std::uint8_t*>(&buffer[0]),
reinterpret_cast<unsigned char const*>(input.data()), buffer.size(), reinterpret_cast<std::uint8_t const*>(input.data()), buffer.size(),
reduce_operation); reduce_operation);
break; break;
case DataType::INT: case DataType::INT32:
Accumulate(reinterpret_cast<int*>(&buffer[0]), reinterpret_cast<int const*>(input.data()), Accumulate(reinterpret_cast<std::int32_t*>(&buffer[0]),
buffer.size() / sizeof(int), reduce_operation); reinterpret_cast<std::int32_t const*>(input.data()),
buffer.size() / sizeof(std::uint32_t), reduce_operation);
break; break;
case DataType::UINT: case DataType::UINT32:
Accumulate(reinterpret_cast<unsigned int*>(&buffer[0]), Accumulate(reinterpret_cast<std::uint32_t*>(&buffer[0]),
reinterpret_cast<unsigned int const*>(input.data()), reinterpret_cast<std::uint32_t const*>(input.data()),
buffer.size() / sizeof(unsigned int), reduce_operation); buffer.size() / sizeof(std::uint32_t), reduce_operation);
break; break;
case DataType::LONG: case DataType::INT64:
Accumulate(reinterpret_cast<long*>(&buffer[0]), reinterpret_cast<long const*>(input.data()), Accumulate(reinterpret_cast<std::int64_t*>(&buffer[0]),
buffer.size() / sizeof(long), reduce_operation); reinterpret_cast<std::int64_t const*>(input.data()),
buffer.size() / sizeof(std::int64_t), reduce_operation);
break; break;
case DataType::ULONG: case DataType::UINT64:
Accumulate(reinterpret_cast<unsigned long*>(&buffer[0]), Accumulate(reinterpret_cast<std::uint64_t*>(&buffer[0]),
reinterpret_cast<unsigned long const*>(input.data()), reinterpret_cast<std::uint64_t const*>(input.data()),
buffer.size() / sizeof(unsigned long), reduce_operation); buffer.size() / sizeof(std::uint64_t), reduce_operation);
break; break;
case DataType::FLOAT: case DataType::FLOAT:
Accumulate(reinterpret_cast<float*>(&buffer[0]), Accumulate(reinterpret_cast<float*>(&buffer[0]),
@ -108,16 +113,6 @@ class AllreduceFunctor {
reinterpret_cast<double const*>(input.data()), buffer.size() / sizeof(double), reinterpret_cast<double const*>(input.data()), buffer.size() / sizeof(double),
reduce_operation); reduce_operation);
break; break;
case DataType::LONGLONG:
Accumulate(reinterpret_cast<long long*>(&buffer[0]),
reinterpret_cast<long long const*>(input.data()),
buffer.size() / sizeof(long long), reduce_operation);
break;
case DataType::ULONGLONG:
Accumulate(reinterpret_cast<unsigned long long*>(&buffer[0]),
reinterpret_cast<unsigned long long const*>(input.data()),
buffer.size() / sizeof(unsigned long long), reduce_operation);
break;
default: default:
throw std::invalid_argument("Invalid data type"); throw std::invalid_argument("Invalid data type");
} }
@ -201,13 +196,6 @@ grpc::Status FederatedService::Handle(Request const* request, Reply* reply,
return grpc::Status::OK; return grpc::Status::OK;
} }
std::string ReadFile(char const* path) {
auto stream = std::ifstream(path);
std::ostringstream out;
out << stream.rdbuf();
return out.str();
}
void RunServer(int port, int world_size, char const* server_key_file, char const* server_cert_file, void RunServer(int port, int world_size, char const* server_key_file, char const* server_cert_file,
char const* client_cert_file) { char const* client_cert_file) {
std::string const server_address = "0.0.0.0:" + std::to_string(port); std::string const server_address = "0.0.0.0:" + std::to_string(port);
@ -216,10 +204,10 @@ void RunServer(int port, int world_size, char const* server_key_file, char const
grpc::ServerBuilder builder; grpc::ServerBuilder builder;
auto options = auto options =
grpc::SslServerCredentialsOptions(GRPC_SSL_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY); grpc::SslServerCredentialsOptions(GRPC_SSL_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY);
options.pem_root_certs = ReadFile(client_cert_file); options.pem_root_certs = xgboost::common::ReadAll(client_cert_file);
auto key = grpc::SslServerCredentialsOptions::PemKeyCertPair(); auto key = grpc::SslServerCredentialsOptions::PemKeyCertPair();
key.private_key = ReadFile(server_key_file); key.private_key = xgboost::common::ReadAll(server_key_file);
key.cert_chain = ReadFile(server_cert_file); key.cert_chain = xgboost::common::ReadAll(server_cert_file);
options.pem_key_cert_pairs.push_back(key); options.pem_key_cert_pairs.push_back(key);
builder.SetMaxReceiveMessageSize(std::numeric_limits<int>::max()); builder.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
builder.AddListeningPort(server_address, grpc::SslServerCredentials(options)); builder.AddListeningPort(server_address, grpc::SslServerCredentials(options));

View File

@ -0,0 +1,243 @@
"""XGBoost collective communication related API."""
import ctypes
import json
import logging
import pickle
from enum import IntEnum, unique
from typing import Any, List
import numpy as np
from ._typing import _T
from .core import _LIB, _check_call, c_str, py_str, from_pystr_to_cstr
LOGGER = logging.getLogger("[xgboost.collective]")
def init(**args: Any) -> None:
"""Initialize the collective library with arguments.
Parameters
----------
args: Dict[str, Any]
Keyword arguments representing the parameters and their values.
Accepted parameters:
- xgboost_communicator: The type of the communicator. Can be set as an environment
variable.
* rabit: Use Rabit. This is the default if the type is unspecified.
* federated: Use the gRPC interface for Federated Learning.
Only applicable to the Rabit communicator (these are case sensitive):
-- rabit_tracker_uri: Hostname of the tracker.
-- rabit_tracker_port: Port number of the tracker.
-- rabit_task_id: ID of the current task, can be used to obtain deterministic rank
assignment.
-- rabit_world_size: Total number of workers.
-- rabit_hadoop_mode: Enable Hadoop support.
-- rabit_tree_reduce_minsize: Minimal size for tree reduce.
-- rabit_reduce_ring_mincount: Minimal count to perform ring reduce.
-- rabit_reduce_buffer: Size of the reduce buffer.
-- rabit_bootstrap_cache: Size of the bootstrap cache.
-- rabit_debug: Enable debugging.
-- rabit_timeout: Enable timeout.
-- rabit_timeout_sec: Timeout in seconds.
-- rabit_enable_tcp_no_delay: Enable TCP no delay on Unix platforms.
Only applicable to the Rabit communicator (these are case-sensitive, and can be set as
environment variables):
-- DMLC_TRACKER_URI: Hostname of the tracker.
-- DMLC_TRACKER_PORT: Port number of the tracker.
-- DMLC_TASK_ID: ID of the current task, can be used to obtain deterministic rank
assignment.
-- DMLC_ROLE: Role of the current task, "worker" or "server".
-- DMLC_NUM_ATTEMPT: Number of attempts after task failure.
-- DMLC_WORKER_CONNECT_RETRY: Number of retries to connect to the tracker.
Only applicable to the Federated communicator (use upper case for environment variables, use
lower case for runtime configuration):
-- federated_server_address: Address of the federated server.
-- federated_world_size: Number of federated workers.
-- federated_rank: Rank of the current worker.
-- federated_server_cert: Server certificate file path. Only needed for the SSL mode.
-- federated_client_key: Client key file path. Only needed for the SSL mode.
-- federated_client_cert: Client certificate file path. Only needed for the SSL mode.
"""
config = from_pystr_to_cstr(json.dumps(args))
_check_call(_LIB.XGCommunicatorInit(config))
def finalize() -> None:
"""Finalize the communicator."""
_check_call(_LIB.XGCommunicatorFinalize())
def get_rank() -> int:
"""Get rank of current process.
Returns
-------
rank : int
Rank of current process.
"""
ret = _LIB.XGCommunicatorGetRank()
return ret
def get_world_size() -> int:
"""Get total number workers.
Returns
-------
n : int
Total number of process.
"""
ret = _LIB.XGCommunicatorGetWorldSize()
return ret
def is_distributed() -> int:
"""If the collective communicator is distributed."""
is_dist = _LIB.XGCommunicatorIsDistributed()
return is_dist
def communicator_print(msg: Any) -> None:
"""Print message to the communicator.
This function can be used to communicate the information of
the progress to the communicator.
Parameters
----------
msg : str
The message to be printed to the communicator.
"""
if not isinstance(msg, str):
msg = str(msg)
is_dist = _LIB.XGCommunicatorIsDistributed()
if is_dist != 0:
_check_call(_LIB.XGCommunicatorPrint(c_str(msg)))
else:
print(msg.strip(), flush=True)
def get_processor_name() -> str:
"""Get the processor name.
Returns
-------
name : str
the name of processor(host)
"""
name_str = ctypes.c_char_p()
_check_call(_LIB.XGCommunicatorGetProcessorName(ctypes.byref(name_str)))
value = name_str.value
assert value
return py_str(value)
def broadcast(data: _T, root: int) -> _T:
"""Broadcast object from one node to all other nodes.
Parameters
----------
data : any type that can be pickled
Input data, if current rank does not equal root, this can be None
root : int
Rank of the node to broadcast data from.
Returns
-------
object : int
the result of broadcast.
"""
rank = get_rank()
length = ctypes.c_ulong()
if root == rank:
assert data is not None, 'need to pass in data when broadcasting'
s = pickle.dumps(data, protocol=pickle.HIGHEST_PROTOCOL)
length.value = len(s)
# run first broadcast
_check_call(_LIB.XGCommunicatorBroadcast(ctypes.byref(length),
ctypes.sizeof(ctypes.c_ulong), root))
if root != rank:
dptr = (ctypes.c_char * length.value)()
# run second
_check_call(_LIB.XGCommunicatorBroadcast(ctypes.cast(dptr, ctypes.c_void_p),
length.value, root))
data = pickle.loads(dptr.raw)
del dptr
else:
_check_call(_LIB.XGCommunicatorBroadcast(ctypes.cast(ctypes.c_char_p(s), ctypes.c_void_p),
length.value, root))
del s
return data
# enumeration of dtypes
DTYPE_ENUM__ = {
np.dtype('int8'): 0,
np.dtype('uint8'): 1,
np.dtype('int32'): 2,
np.dtype('uint32'): 3,
np.dtype('int64'): 4,
np.dtype('uint64'): 5,
np.dtype('float32'): 6,
np.dtype('float64'): 7
}
@unique
class Op(IntEnum):
"""Supported operations for allreduce."""
MAX = 0
MIN = 1
SUM = 2
def allreduce( # pylint:disable=invalid-name
data: np.ndarray, op: Op
) -> np.ndarray:
"""Perform allreduce, return the result.
Parameters
----------
data :
Input data.
op :
Reduction operator.
Returns
-------
result :
The result of allreduce, have same shape as data
Notes
-----
This function is not thread-safe.
"""
if not isinstance(data, np.ndarray):
raise TypeError('allreduce only takes in numpy.ndarray')
buf = data.ravel()
if buf.base is data.base:
buf = buf.copy()
if buf.dtype not in DTYPE_ENUM__:
raise Exception(f"data type {buf.dtype} not supported")
_check_call(_LIB.XGCommunicatorAllreduce(buf.ctypes.data_as(ctypes.c_void_p),
buf.size, DTYPE_ENUM__[buf.dtype],
int(op), None, None))
return buf
class CommunicatorContext:
"""A context controlling collective communicator initialization and finalization."""
def __init__(self, **args: Any) -> None:
self.args = args
def __enter__(self) -> None:
init(**self.args)
assert is_distributed()
LOGGER.debug("-------------- communicator say hello ------------------")
def __exit__(self, *args: List) -> None:
finalize()
LOGGER.debug("--------------- communicator say bye ------------------")

View File

@ -22,6 +22,7 @@
#include "c_api_error.h" #include "c_api_error.h"
#include "c_api_utils.h" #include "c_api_utils.h"
#include "../collective/communicator.h"
#include "../common/io.h" #include "../common/io.h"
#include "../common/charconv.h" #include "../common/charconv.h"
#include "../data/adapter.h" #include "../data/adapter.h"
@ -1370,6 +1371,62 @@ XGB_DLL int XGBoosterFeatureScore(BoosterHandle handle, char const *json_config,
API_END(); API_END();
} }
using xgboost::collective::Communicator;
XGB_DLL int XGCommunicatorInit(char const* json_config) {
API_BEGIN();
Json config { Json::Load(StringView{json_config}) };
Communicator::Init(config);
API_END();
}
XGB_DLL int XGCommunicatorFinalize(void) {
API_BEGIN();
Communicator::Finalize();
API_END();
}
XGB_DLL int XGCommunicatorGetRank(void) {
return Communicator::Get()->GetRank();
}
XGB_DLL int XGCommunicatorGetWorldSize(void) {
return Communicator::Get()->GetWorldSize();
}
XGB_DLL int XGCommunicatorIsDistributed(void) {
return Communicator::Get()->IsDistributed();
}
XGB_DLL int XGCommunicatorPrint(char const *message) {
API_BEGIN();
Communicator::Get()->Print(message);
API_END();
}
XGB_DLL int XGCommunicatorGetProcessorName(char const **name_str) {
API_BEGIN();
auto& local = *GlobalConfigAPIThreadLocalStore::Get();
local.ret_str = Communicator::Get()->GetProcessorName();
*name_str = local.ret_str.c_str();
API_END();
}
XGB_DLL int XGCommunicatorBroadcast(void *send_receive_buffer, size_t size, int root) {
API_BEGIN();
Communicator::Get()->Broadcast(send_receive_buffer, size, root);
API_END();
}
XGB_DLL int XGCommunicatorAllreduce(void *send_receive_buffer, size_t count, int enum_dtype,
int enum_op) {
API_BEGIN();
Communicator::Get()->AllReduce(
send_receive_buffer, count, static_cast<xgboost::collective::DataType>(enum_dtype),
static_cast<xgboost::collective::Operation>(enum_op));
API_END();
}
#if defined(XGBOOST_USE_FEDERATED) #if defined(XGBOOST_USE_FEDERATED)
XGB_DLL int XGBRunFederatedServer(int port, int world_size, char const *server_key_path, XGB_DLL int XGBRunFederatedServer(int port, int world_size, char const *server_key_path,
char const *server_cert_path, char const *client_cert_path) { char const *server_cert_path, char const *client_cert_path) {

View File

@ -0,0 +1,59 @@
/*!
* Copyright 2022 XGBoost contributors
*/
#include "communicator.h"
#include "rabit_communicator.h"
#if defined(XGBOOST_USE_FEDERATED)
#include "../../plugin/federated/federated_communicator.h"
#endif
namespace xgboost {
namespace collective {
thread_local std::unique_ptr<Communicator> Communicator::communicator_{};
thread_local CommunicatorType Communicator::type_{};
void Communicator::Init(Json const& config) {
if (communicator_) {
LOG(FATAL) << "Communicator can only be initialized once.";
}
auto type = GetTypeFromEnv();
auto const arg = GetTypeFromConfig(config);
if (arg != CommunicatorType::kUnknown) {
type = arg;
}
if (type == CommunicatorType::kUnknown) {
// Default to Rabit if unspecified.
type = CommunicatorType::kRabit;
}
type_ = type;
switch (type) {
case CommunicatorType::kRabit: {
communicator_.reset(RabitCommunicator::Create(config));
break;
}
case CommunicatorType::kFederated: {
#if defined(XGBOOST_USE_FEDERATED)
communicator_.reset(FederatedCommunicator::Create(config));
#else
LOG(FATAL) << "XGBoost is not compiled with Federated Learning support.";
#endif
break;
}
case CommunicatorType::kUnknown:
LOG(FATAL) << "Unknown communicator type.";
}
}
#ifndef XGBOOST_USE_CUDA
void Communicator::Finalize() {
communicator_->Shutdown();
communicator_.reset(nullptr);
}
#endif
} // namespace collective
} // namespace xgboost

View File

@ -0,0 +1,41 @@
/*!
* Copyright 2022 XGBoost contributors
*/
#include "communicator.h"
#include "device_communicator.cuh"
#include "device_communicator_adapter.cuh"
#ifdef XGBOOST_USE_NCCL
#include "nccl_device_communicator.cuh"
#endif
namespace xgboost {
namespace collective {
thread_local int Communicator::device_ordinal_{-1};
thread_local std::unique_ptr<DeviceCommunicator> Communicator::device_communicator_{};
void Communicator::Finalize() {
communicator_->Shutdown();
communicator_.reset(nullptr);
device_ordinal_ = -1;
device_communicator_.reset(nullptr);
}
DeviceCommunicator* Communicator::GetDevice(int device_ordinal) {
if (!device_communicator_ || device_ordinal_ != device_ordinal) {
device_ordinal_ = device_ordinal;
#ifdef XGBOOST_USE_NCCL
if (type_ != CommunicatorType::kFederated) {
device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal, Get()));
} else {
device_communicator_.reset(new DeviceCommunicatorAdapter(device_ordinal, Get()));
}
#else
device_communicator_.reset(new DeviceCommunicatorAdapter(device_ordinal, Get()));
#endif
}
return device_communicator_.get();
}
} // namespace collective
} // namespace xgboost

View File

@ -0,0 +1,218 @@
/*!
* Copyright 2022 XGBoost contributors
*/
#pragma once
#include <xgboost/json.h>
#include <xgboost/logging.h>
#include <memory>
#include <string>
namespace xgboost {
namespace collective {
/** @brief Defines the integral and floating data types. */
enum class DataType {
kInt8 = 0,
kUInt8 = 1,
kInt32 = 2,
kUInt32 = 3,
kInt64 = 4,
kUInt64 = 5,
kFloat = 6,
kDouble = 7
};
/** @brief Get the size of the data type. */
inline std::size_t GetTypeSize(DataType data_type) {
std::size_t size{0};
switch (data_type) {
case DataType::kInt8:
size = sizeof(std::int8_t);
break;
case DataType::kUInt8:
size = sizeof(std::uint8_t);
break;
case DataType::kInt32:
size = sizeof(std::int32_t);
break;
case DataType::kUInt32:
size = sizeof(std::uint32_t);
break;
case DataType::kInt64:
size = sizeof(std::int64_t);
break;
case DataType::kUInt64:
size = sizeof(std::uint64_t);
break;
case DataType::kFloat:
size = sizeof(float);
break;
case DataType::kDouble:
size = sizeof(double);
break;
default:
LOG(FATAL) << "Unknown data type.";
}
return size;
}
/** @brief Defines the reduction operation. */
enum class Operation { kMax = 0, kMin = 1, kSum = 2 };
class DeviceCommunicator;
enum class CommunicatorType { kUnknown, kRabit, kFederated };
/** \brief Case-insensitive string comparison. */
inline int CompareStringsCaseInsensitive(const char *s1, const char *s2) {
#ifdef _MSC_VER
return _stricmp(s1, s2);
#else // _MSC_VER
return strcasecmp(s1, s2);
#endif // _MSC_VER
}
/**
* @brief A communicator class that handles collective communication.
*/
class Communicator {
public:
/**
* @brief Initialize the communicator. This can only be done once.
*
* @param config JSON configuration for the communicator.
*/
static void Init(Json const &config);
/** @brief Finalize the communicator. */
static void Finalize();
/** @brief Get the communicator instance. */
static Communicator *Get() { return communicator_.get(); }
#if defined(XGBOOST_USE_CUDA)
/**
* @brief Get the device communicator.
*
* @param device_ordinal ID of the device.
* @return An instance of device communicator.
*/
static DeviceCommunicator *GetDevice(int device_ordinal);
#endif
virtual ~Communicator() = default;
/** @brief Get the total number of processes. */
int GetWorldSize() const { return world_size_; }
/** @brief Get the rank of the current processes. */
int GetRank() const { return rank_; }
/** @brief Whether the communicator is running in distributed mode. */
virtual bool IsDistributed() const = 0;
/**
* @brief Combines values from all processes and distributes the result back to all processes.
*
* @param send_receive_buffer Buffer storing the data.
* @param count Number of elements in the buffer.
* @param data_type Data type stored in the buffer.
* @param op The operation to perform.
*/
virtual void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
Operation op) = 0;
/**
* @brief Broadcasts a message from the process with rank `root` to all other processes of the
* group.
*
* @param send_receive_buffer Buffer storing the data.
* @param size Size of the data in bytes.
* @param root Rank of broadcast root.
*/
virtual void Broadcast(void *send_receive_buffer, std::size_t size, int root) = 0;
/**
* @brief Gets the name of the processor.
*/
virtual std::string GetProcessorName() = 0;
/**
* @brief Prints the message.
*/
virtual void Print(std::string const &message) = 0;
/** @brief Get the communicator type from environment variables. Visible for testing. */
static CommunicatorType GetTypeFromEnv() {
auto *env = std::getenv("XGBOOST_COMMUNICATOR");
if (env != nullptr) {
return StringToType(env);
} else {
return CommunicatorType::kUnknown;
}
}
/** @brief Get the communicator type from runtime configuration. Visible for testing. */
static CommunicatorType GetTypeFromConfig(Json const &config) {
auto const &j_upper = config["XGBOOST_COMMUNICATOR"];
if (IsA<String const>(j_upper)) {
return StringToType(get<String const>(j_upper).c_str());
}
auto const &j_lower = config["xgboost_communicator"];
if (IsA<String const>(j_lower)) {
return StringToType(get<String const>(j_lower).c_str());
}
return CommunicatorType::kUnknown;
}
protected:
/**
* @brief Construct a new communicator.
*
* @param world_size Total number of processes.
* @param rank Rank of the current process.
*/
Communicator(int world_size, int rank) : world_size_(world_size), rank_(rank) {
if (world_size < 1) {
LOG(FATAL) << "World size " << world_size << " is less than 1.";
}
if (rank < 0) {
LOG(FATAL) << "Rank " << rank << " is less than 0.";
}
if (rank >= world_size) {
LOG(FATAL) << "Rank " << rank << " is greater than world_size - 1: " << world_size - 1 << ".";
}
}
/**
* @brief Shuts down the communicator.
*/
virtual void Shutdown() = 0;
private:
static CommunicatorType StringToType(char const *str) {
CommunicatorType result = CommunicatorType::kUnknown;
if (!CompareStringsCaseInsensitive("rabit", str)) {
result = CommunicatorType::kRabit;
} else if (!CompareStringsCaseInsensitive("federated", str)) {
result = CommunicatorType::kFederated;
} else {
LOG(FATAL) << "Unknown communicator type " << str;
}
return result;
}
static thread_local std::unique_ptr<Communicator> communicator_;
static thread_local CommunicatorType type_;
#if defined(XGBOOST_USE_CUDA)
static thread_local int device_ordinal_;
static thread_local std::unique_ptr<DeviceCommunicator> device_communicator_;
#endif
int const world_size_;
int const rank_;
};
} // namespace collective
} // namespace xgboost

View File

@ -0,0 +1,42 @@
/*!
* Copyright 2022 XGBoost contributors
*/
#pragma once
#include <vector>
#include "../common/device_helpers.cuh"
namespace xgboost {
namespace collective {
/**
* @brief Collective communicator for device buffers.
*/
class DeviceCommunicator {
public:
virtual ~DeviceCommunicator() = default;
/**
* @brief Sum values from all processes and distribute the result back to all processes.
* @param send_receive_buffer Buffer storing the data.
* @param count Number of elements in the buffer.
*/
virtual void AllReduceSum(double *send_receive_buffer, int count) = 0;
/**
* @brief Gather variable-length values from all processes.
* @param send_buffer Buffer storing the input data.
* @param length_bytes Length in bytes of the input data.
* @param segments Size of each segment.
* @param receive_buffer Buffer storing the output data.
*/
virtual void AllGatherV(void const *send_buffer, size_t length_bytes,
std::vector<size_t> *segments,
dh::caching_device_vector<char> *receive_buffer) = 0;
/** @brief Synchronize device operations. */
virtual void Synchronize() = 0;
};
} // namespace collective
} // namespace xgboost

View File

@ -0,0 +1,76 @@
/*!
* Copyright 2022 XGBoost contributors
*/
#pragma once
#include "communicator.h"
#include "device_communicator.cuh"
namespace xgboost {
namespace collective {
class DeviceCommunicatorAdapter : public DeviceCommunicator {
public:
DeviceCommunicatorAdapter(int device_ordinal, Communicator *communicator)
: device_ordinal_{device_ordinal}, communicator_{communicator} {
if (device_ordinal_ < 0) {
LOG(FATAL) << "Invalid device ordinal: " << device_ordinal_;
}
if (communicator_ == nullptr) {
LOG(FATAL) << "Communicator cannot be null.";
}
}
~DeviceCommunicatorAdapter() override = default;
void AllReduceSum(double *send_receive_buffer, int count) override {
dh::safe_cuda(cudaSetDevice(device_ordinal_));
auto size = count * sizeof(double);
host_buffer_.reserve(size);
dh::safe_cuda(cudaMemcpy(host_buffer_.data(), send_receive_buffer, size, cudaMemcpyDefault));
communicator_->AllReduce(host_buffer_.data(), count, DataType::kDouble, Operation::kSum);
dh::safe_cuda(cudaMemcpy(send_receive_buffer, host_buffer_.data(), size, cudaMemcpyDefault));
}
void AllGatherV(void const *send_buffer, size_t length_bytes, std::vector<std::size_t> *segments,
dh::caching_device_vector<char> *receive_buffer) override {
dh::safe_cuda(cudaSetDevice(device_ordinal_));
int const world_size = communicator_->GetWorldSize();
int const rank = communicator_->GetRank();
segments->clear();
segments->resize(world_size, 0);
segments->at(rank) = length_bytes;
communicator_->AllReduce(segments->data(), segments->size(), DataType::kUInt64,
Operation::kMax);
auto total_bytes = std::accumulate(segments->cbegin(), segments->cend(), 0UL);
receive_buffer->resize(total_bytes);
host_buffer_.reserve(total_bytes);
size_t offset = 0;
for (int32_t i = 0; i < world_size; ++i) {
size_t as_bytes = segments->at(i);
if (i == rank) {
dh::safe_cuda(cudaMemcpy(host_buffer_.data() + offset, send_buffer, segments->at(rank),
cudaMemcpyDefault));
}
communicator_->Broadcast(host_buffer_.data() + offset, as_bytes, i);
offset += as_bytes;
}
dh::safe_cuda(cudaMemcpy(receive_buffer->data().get(), host_buffer_.data(), total_bytes,
cudaMemcpyDefault));
}
void Synchronize() override {
// Noop.
}
private:
int const device_ordinal_;
Communicator *communicator_;
/// Host buffer used to call communicator functions.
std::vector<char> host_buffer_{};
};
} // namespace collective
} // namespace xgboost

View File

@ -0,0 +1,149 @@
/*!
* Copyright 2022 XGBoost contributors
*/
#pragma once
#include "../common/device_helpers.cuh"
#include "communicator.h"
#include "device_communicator.cuh"
namespace xgboost {
namespace collective {
class NcclDeviceCommunicator : public DeviceCommunicator {
public:
NcclDeviceCommunicator(int device_ordinal, Communicator *communicator)
: device_ordinal_{device_ordinal}, communicator_{communicator} {
if (device_ordinal_ < 0) {
LOG(FATAL) << "Invalid device ordinal: " << device_ordinal_;
}
if (communicator_ == nullptr) {
LOG(FATAL) << "Communicator cannot be null.";
}
int32_t const rank = communicator_->GetRank();
int32_t const world = communicator_->GetWorldSize();
std::vector<uint64_t> uuids(world * kUuidLength, 0);
auto s_uuid = xgboost::common::Span<uint64_t>{uuids.data(), uuids.size()};
auto s_this_uuid = s_uuid.subspan(rank * kUuidLength, kUuidLength);
GetCudaUUID(s_this_uuid);
// TODO(rongou): replace this with allgather.
communicator_->AllReduce(uuids.data(), uuids.size(), DataType::kUInt64, Operation::kSum);
std::vector<xgboost::common::Span<uint64_t, kUuidLength>> converted(world);
size_t j = 0;
for (size_t i = 0; i < uuids.size(); i += kUuidLength) {
converted[j] = xgboost::common::Span<uint64_t, kUuidLength>{uuids.data() + i, kUuidLength};
j++;
}
auto iter = std::unique(converted.begin(), converted.end());
auto n_uniques = std::distance(converted.begin(), iter);
CHECK_EQ(n_uniques, world)
<< "Multiple processes within communication group running on same CUDA "
<< "device is not supported. " << PrintUUID(s_this_uuid) << "\n";
nccl_unique_id_ = GetUniqueId();
dh::safe_nccl(ncclCommInitRank(&nccl_comm_, world, nccl_unique_id_, rank));
dh::safe_cuda(cudaStreamCreate(&cuda_stream_));
}
~NcclDeviceCommunicator() override {
dh::safe_cuda(cudaStreamDestroy(cuda_stream_));
ncclCommDestroy(nccl_comm_);
if (xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug)) {
LOG(CONSOLE) << "======== NCCL Statistics========";
LOG(CONSOLE) << "AllReduce calls: " << allreduce_calls_;
LOG(CONSOLE) << "AllReduce total MiB communicated: " << allreduce_bytes_ / 1048576;
}
}
void AllReduceSum(double *send_receive_buffer, int count) override {
dh::safe_cuda(cudaSetDevice(device_ordinal_));
dh::safe_nccl(ncclAllReduce(send_receive_buffer, send_receive_buffer, count, ncclDouble,
ncclSum, nccl_comm_, cuda_stream_));
allreduce_bytes_ += count * sizeof(double);
allreduce_calls_ += 1;
}
void AllGatherV(void const *send_buffer, size_t length_bytes, std::vector<std::size_t> *segments,
dh::caching_device_vector<char> *receive_buffer) override {
dh::safe_cuda(cudaSetDevice(device_ordinal_));
int const world_size = communicator_->GetWorldSize();
int const rank = communicator_->GetRank();
segments->clear();
segments->resize(world_size, 0);
segments->at(rank) = length_bytes;
communicator_->AllReduce(segments->data(), segments->size(), DataType::kUInt64,
Operation::kMax);
auto total_bytes = std::accumulate(segments->cbegin(), segments->cend(), 0UL);
receive_buffer->resize(total_bytes);
size_t offset = 0;
dh::safe_nccl(ncclGroupStart());
for (int32_t i = 0; i < world_size; ++i) {
size_t as_bytes = segments->at(i);
dh::safe_nccl(ncclBroadcast(send_buffer, receive_buffer->data().get() + offset, as_bytes,
ncclChar, i, nccl_comm_, cuda_stream_));
offset += as_bytes;
}
dh::safe_nccl(ncclGroupEnd());
}
void Synchronize() override {
dh::safe_cuda(cudaSetDevice(device_ordinal_));
dh::safe_cuda(cudaStreamSynchronize(cuda_stream_));
}
private:
static constexpr std::size_t kUuidLength =
sizeof(std::declval<cudaDeviceProp>().uuid) / sizeof(uint64_t);
void GetCudaUUID(xgboost::common::Span<uint64_t, kUuidLength> const &uuid) const {
cudaDeviceProp prob{};
dh::safe_cuda(cudaGetDeviceProperties(&prob, device_ordinal_));
std::memcpy(uuid.data(), static_cast<void *>(&(prob.uuid)), sizeof(prob.uuid));
}
static std::string PrintUUID(xgboost::common::Span<uint64_t, kUuidLength> const &uuid) {
std::stringstream ss;
for (auto v : uuid) {
ss << std::hex << v;
}
return ss.str();
}
/**
* \fn ncclUniqueId GetUniqueId()
*
* \brief Gets the Unique ID from NCCL to be used in setting up interprocess
* communication
*
* \return the Unique ID
*/
ncclUniqueId GetUniqueId() {
static const int kRootRank = 0;
ncclUniqueId id;
if (communicator_->GetRank() == kRootRank) {
dh::safe_nccl(ncclGetUniqueId(&id));
}
communicator_->Broadcast(static_cast<void *>(&id), sizeof(ncclUniqueId),
static_cast<int>(kRootRank));
return id;
}
int const device_ordinal_;
Communicator *communicator_;
ncclComm_t nccl_comm_{};
cudaStream_t cuda_stream_{};
ncclUniqueId nccl_unique_id_{};
size_t allreduce_bytes_{0}; // Keep statistics of the number of bytes communicated.
size_t allreduce_calls_{0}; // Keep statistics of the number of reduce calls.
};
} // namespace collective
} // namespace xgboost

View File

@ -0,0 +1,120 @@
/*!
* Copyright 2022 XGBoost contributors
*/
#pragma once
#include <rabit/rabit.h>
#include <string>
#include <vector>
#include "communicator.h"
#include "xgboost/json.h"
namespace xgboost {
namespace collective {
class RabitCommunicator : public Communicator {
public:
static Communicator *Create(Json const &config) {
std::vector<std::string> args_str;
for (auto &items : get<Object const>(config)) {
switch (items.second.GetValue().Type()) {
case xgboost::Value::ValueKind::kString: {
args_str.push_back(items.first + "=" + get<String const>(items.second));
break;
}
case xgboost::Value::ValueKind::kInteger: {
args_str.push_back(items.first + "=" + std::to_string(get<Integer const>(items.second)));
break;
}
case xgboost::Value::ValueKind::kBoolean: {
if (get<Boolean const>(items.second)) {
args_str.push_back(items.first + "=1");
} else {
args_str.push_back(items.first + "=0");
}
break;
}
default:
break;
}
}
std::vector<char *> args;
for (auto &key_value : args_str) {
args.push_back(&key_value[0]);
}
if (!rabit::Init(static_cast<int>(args.size()), &args[0])) {
LOG(FATAL) << "Failed to initialize Rabit";
}
return new RabitCommunicator(rabit::GetWorldSize(), rabit::GetRank());
}
RabitCommunicator(int world_size, int rank) : Communicator(world_size, rank) {}
bool IsDistributed() const override { return rabit::IsDistributed(); }
void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
Operation op) override {
switch (data_type) {
case DataType::kInt8:
DoAllReduce<char>(send_receive_buffer, count, op);
break;
case DataType::kUInt8:
DoAllReduce<unsigned char>(send_receive_buffer, count, op);
break;
case DataType::kInt32:
DoAllReduce<std::int32_t>(send_receive_buffer, count, op);
break;
case DataType::kUInt32:
DoAllReduce<std::uint32_t>(send_receive_buffer, count, op);
break;
case DataType::kInt64:
DoAllReduce<std::int64_t>(send_receive_buffer, count, op);
break;
case DataType::kUInt64:
DoAllReduce<std::uint64_t>(send_receive_buffer, count, op);
break;
case DataType::kFloat:
DoAllReduce<float>(send_receive_buffer, count, op);
break;
case DataType::kDouble:
DoAllReduce<double>(send_receive_buffer, count, op);
break;
default:
LOG(FATAL) << "Unknown data type";
}
}
void Broadcast(void *send_receive_buffer, std::size_t size, int root) override {
rabit::Broadcast(send_receive_buffer, size, root);
}
std::string GetProcessorName() override { return rabit::GetProcessorName(); }
void Print(const std::string &message) override { rabit::TrackerPrint(message); }
protected:
void Shutdown() override {
rabit::Finalize();
}
private:
template <typename DType>
void DoAllReduce(void *send_receive_buffer, std::size_t count, Operation op) {
switch (op) {
case Operation::kMax:
rabit::Allreduce<rabit::op::Max, DType>(static_cast<DType *>(send_receive_buffer), count);
break;
case Operation::kMin:
rabit::Allreduce<rabit::op::Min, DType>(static_cast<DType *>(send_receive_buffer), count);
break;
case Operation::kSum:
rabit::Allreduce<rabit::op::Sum, DType>(static_cast<DType *>(send_receive_buffer), count);
break;
default:
LOG(FATAL) << "Unknown allreduce operation";
}
}
};
} // namespace collective
} // namespace xgboost

View File

@ -12,6 +12,7 @@
#include <rabit/rabit.h> #include <rabit/rabit.h>
#include <string> #include <string>
#include <cstring> #include <cstring>
#include <fstream>
#include "common.h" #include "common.h"
@ -111,6 +112,22 @@ inline std::string ReadAll(dmlc::Stream* fi, PeekableInStream* fp) {
} }
return buffer; return buffer;
} }
/**
* \brief Read the whole file content into a string.
*/
inline std::string ReadAll(std::string const &path) {
std::ifstream stream(path);
if (!stream.is_open()) {
LOG(FATAL) << "Could not open file " << path;
}
std::string content{std::istreambuf_iterator<char>(stream), std::istreambuf_iterator<char>()};
if (content.empty()) {
LOG(FATAL) << "Empty file " << path;
}
return content;
}
} // namespace common } // namespace common
} // namespace xgboost } // namespace xgboost
#endif // XGBOOST_COMMON_IO_H_ #endif // XGBOOST_COMMON_IO_H_

View File

@ -22,7 +22,7 @@ if (PLUGIN_FEDERATED)
target_include_directories(testxgboost PRIVATE ${xgboost_SOURCE_DIR}/plugin/federated) target_include_directories(testxgboost PRIVATE ${xgboost_SOURCE_DIR}/plugin/federated)
target_link_libraries(testxgboost PRIVATE federated_client) target_link_libraries(testxgboost PRIVATE federated_client)
else (PLUGIN_FEDERATED) else (PLUGIN_FEDERATED)
file(GLOB_RECURSE FEDERATED_TEST_SOURCES "plugin/*_federated_*.cc") file(GLOB_RECURSE FEDERATED_TEST_SOURCES "plugin/*_federated_*.*")
list(REMOVE_ITEM TEST_SOURCES ${FEDERATED_TEST_SOURCES}) list(REMOVE_ITEM TEST_SOURCES ${FEDERATED_TEST_SOURCES})
endif (PLUGIN_FEDERATED) endif (PLUGIN_FEDERATED)

View File

@ -0,0 +1,54 @@
/*!
* Copyright 2022 XGBoost contributors
*/
#include <dmlc/parameter.h>
#include <gtest/gtest.h>
#include "../../../src/collective/communicator.h"
namespace xgboost {
namespace collective {
TEST(CommunicatorFactory, TypeFromEnv) {
EXPECT_EQ(CommunicatorType::kUnknown, Communicator::GetTypeFromEnv());
dmlc::SetEnv<std::string>("XGBOOST_COMMUNICATOR", "rabit");
EXPECT_EQ(CommunicatorType::kRabit, Communicator::GetTypeFromEnv());
dmlc::SetEnv<std::string>("XGBOOST_COMMUNICATOR", "Federated");
EXPECT_EQ(CommunicatorType::kFederated, Communicator::GetTypeFromEnv());
dmlc::SetEnv<std::string>("XGBOOST_COMMUNICATOR", "foo");
EXPECT_THROW(Communicator::GetTypeFromEnv(), dmlc::Error);
}
TEST(CommunicatorFactory, TypeFromArgs) {
Json config{JsonObject()};
EXPECT_EQ(CommunicatorType::kUnknown, Communicator::GetTypeFromConfig(config));
config["xgboost_communicator"] = String("rabit");
EXPECT_EQ(CommunicatorType::kRabit, Communicator::GetTypeFromConfig(config));
config["xgboost_communicator"] = String("federated");
EXPECT_EQ(CommunicatorType::kFederated, Communicator::GetTypeFromConfig(config));
config["xgboost_communicator"] = String("foo");
EXPECT_THROW(Communicator::GetTypeFromConfig(config), dmlc::Error);
}
TEST(CommunicatorFactory, TypeFromArgsUpperCase) {
Json config{JsonObject()};
EXPECT_EQ(CommunicatorType::kUnknown, Communicator::GetTypeFromConfig(config));
config["XGBOOST_COMMUNICATOR"] = String("rabit");
EXPECT_EQ(CommunicatorType::kRabit, Communicator::GetTypeFromConfig(config));
config["XGBOOST_COMMUNICATOR"] = String("federated");
EXPECT_EQ(CommunicatorType::kFederated, Communicator::GetTypeFromConfig(config));
config["XGBOOST_COMMUNICATOR"] = String("foo");
EXPECT_THROW(Communicator::GetTypeFromConfig(config), dmlc::Error);
}
} // namespace collective
} // namespace xgboost

View File

@ -0,0 +1,26 @@
/*!
* Copyright 2022 XGBoost contributors
*/
#ifdef XGBOOST_USE_NCCL
#include <gtest/gtest.h>
#include "../../../src/collective/nccl_device_communicator.cuh"
namespace xgboost {
namespace collective {
TEST(NcclDeviceCommunicatorSimpleTest, ThrowOnInvalidDeviceOrdinal) {
auto construct = []() { NcclDeviceCommunicator comm{-1, nullptr}; };
EXPECT_THROW(construct(), dmlc::Error);
}
TEST(NcclDeviceCommunicatorSimpleTest, ThrowOnInvalidCommunicator) {
auto construct = []() { NcclDeviceCommunicator comm{0, nullptr}; };
EXPECT_THROW(construct(), dmlc::Error);
}
} // namespace collective
} // namespace xgboost
#endif

View File

@ -0,0 +1,39 @@
/*!
* Copyright 2022 XGBoost contributors
*/
#include <gtest/gtest.h>
#include "../../../src/collective/rabit_communicator.h"
namespace xgboost {
namespace collective {
TEST(RabitCommunicatorSimpleTest, ThrowOnWorldSizeTooSmall) {
auto construct = []() { RabitCommunicator comm{0, 0}; };
EXPECT_THROW(construct(), dmlc::Error);
}
TEST(RabitCommunicatorSimpleTest, ThrowOnRankTooSmall) {
auto construct = []() { RabitCommunicator comm{1, -1}; };
EXPECT_THROW(construct(), dmlc::Error);
}
TEST(RabitCommunicatorSimpleTest, ThrowOnRankTooBig) {
auto construct = []() { RabitCommunicator comm{1, 1}; };
EXPECT_THROW(construct(), dmlc::Error);
}
TEST(RabitCommunicatorSimpleTest, GetWorldSizeAndRank) {
RabitCommunicator comm{6, 3};
EXPECT_EQ(comm.GetWorldSize(), 6);
EXPECT_EQ(comm.GetRank(), 3);
}
TEST(RabitCommunicatorSimpleTest, IsNotDistributed) {
RabitCommunicator comm{2, 1};
// Rabit is only distributed with a tracker.
EXPECT_FALSE(comm.IsDistributed());
}
} // namespace collective
} // namespace xgboost

View File

@ -0,0 +1,105 @@
/*!
* Copyright 2022 XGBoost contributors
*/
#include <grpcpp/server_builder.h>
#include <gtest/gtest.h>
#include <thrust/host_vector.h>
#include <thread>
#include "../../../plugin/federated/federated_communicator.h"
#include "../../../plugin/federated/federated_server.h"
#include "../../../src/collective/device_communicator_adapter.cuh"
namespace xgboost {
namespace collective {
std::string const kServerAddress{"localhost:56789"}; // NOLINT(cert-err58-cpp)
class FederatedAdapterTest : public ::testing::Test {
protected:
void SetUp() override {
server_thread_.reset(new std::thread([this] {
grpc::ServerBuilder builder;
federated::FederatedService service{kWorldSize};
builder.AddListeningPort(kServerAddress, grpc::InsecureServerCredentials());
builder.RegisterService(&service);
server_ = builder.BuildAndStart();
server_->Wait();
}));
}
void TearDown() override {
server_->Shutdown();
server_thread_->join();
}
static int const kWorldSize{2};
std::unique_ptr<std::thread> server_thread_;
std::unique_ptr<grpc::Server> server_;
};
TEST(FederatedAdapterSimpleTest, ThrowOnInvalidDeviceOrdinal) {
auto construct = []() { DeviceCommunicatorAdapter adapter{-1, nullptr}; };
EXPECT_THROW(construct(), dmlc::Error);
}
TEST(FederatedAdapterSimpleTest, ThrowOnInvalidCommunicator) {
auto construct = []() { DeviceCommunicatorAdapter adapter{0, nullptr}; };
EXPECT_THROW(construct(), dmlc::Error);
}
TEST_F(FederatedAdapterTest, DeviceAllReduceSum) {
std::vector<std::thread> threads;
for (auto rank = 0; rank < kWorldSize; rank++) {
threads.emplace_back(std::thread([rank] {
FederatedCommunicator comm{kWorldSize, rank, kServerAddress};
DeviceCommunicatorAdapter adapter{rank, &comm};
int const count = 3;
thrust::device_vector<double> buffer(count, 0);
thrust::sequence(buffer.begin(), buffer.end());
adapter.AllReduceSum(buffer.data().get(), count);
thrust::host_vector<double> host_buffer = buffer;
EXPECT_EQ(host_buffer.size(), count);
for (auto i = 0; i < count; i++) {
EXPECT_EQ(host_buffer[i], i * 2);
}
}));
}
for (auto& thread : threads) {
thread.join();
}
}
TEST_F(FederatedAdapterTest, DeviceAllGatherV) {
std::vector<std::thread> threads;
for (auto rank = 0; rank < kWorldSize; rank++) {
threads.emplace_back(std::thread([rank] {
FederatedCommunicator comm{kWorldSize, rank, kServerAddress};
DeviceCommunicatorAdapter adapter{rank, &comm};
int const count = rank + 2;
thrust::device_vector<char> buffer(count, 0);
thrust::sequence(buffer.begin(), buffer.end());
std::vector<std::size_t> segments(kWorldSize);
dh::caching_device_vector<char> receive_buffer{};
adapter.AllGatherV(buffer.data().get(), count, &segments, &receive_buffer);
EXPECT_EQ(segments[0], 2);
EXPECT_EQ(segments[1], 3);
thrust::host_vector<char> host_buffer = receive_buffer;
EXPECT_EQ(host_buffer.size(), 5);
int expected[] = {0, 1, 0, 1, 2};
for (auto i = 0; i < 5; i++) {
EXPECT_EQ(host_buffer[i], expected[i]);
}
}));
}
for (auto& thread : threads) {
thread.join();
}
}
} // namespace collective
} // namespace xgboost

View File

@ -0,0 +1,119 @@
/*!
* Copyright 2022 XGBoost contributors
*/
#include <dmlc/parameter.h>
#include <grpcpp/server_builder.h>
#include <gtest/gtest.h>
#include <thread>
#include "../../../plugin/federated/federated_communicator.h"
#include "../../../plugin/federated/federated_server.h"
namespace xgboost {
namespace collective {
std::string const kServerAddress{"localhost:56789"}; // NOLINT(cert-err58-cpp)
class FederatedCommunicatorTest : public ::testing::Test {
public:
static void VerifyAllreduce(int rank) {
FederatedCommunicator comm{kWorldSize, rank, kServerAddress};
CheckAllreduce(comm);
}
static void VerifyBroadcast(int rank) {
FederatedCommunicator comm{kWorldSize, rank, kServerAddress};
CheckBroadcast(comm, rank);
}
protected:
void SetUp() override {
server_thread_.reset(new std::thread([this] {
grpc::ServerBuilder builder;
federated::FederatedService service{kWorldSize};
builder.AddListeningPort(kServerAddress, grpc::InsecureServerCredentials());
builder.RegisterService(&service);
server_ = builder.BuildAndStart();
server_->Wait();
}));
}
void TearDown() override {
server_->Shutdown();
server_thread_->join();
}
static void CheckAllreduce(FederatedCommunicator &comm) {
int buffer[] = {1, 2, 3, 4, 5};
comm.AllReduce(buffer, sizeof(buffer) / sizeof(buffer[0]), DataType::kInt32, Operation::kSum);
int expected[] = {3, 6, 9, 12, 15};
for (auto i = 0; i < 5; i++) {
EXPECT_EQ(buffer[i], expected[i]);
}
}
static void CheckBroadcast(FederatedCommunicator &comm, int rank) {
if (rank == 0) {
std::string buffer{"hello"};
comm.Broadcast(&buffer[0], buffer.size(), 0);
EXPECT_EQ(buffer, "hello");
} else {
std::string buffer{" "};
comm.Broadcast(&buffer[0], buffer.size(), 0);
EXPECT_EQ(buffer, "hello");
}
}
static int const kWorldSize{3};
std::unique_ptr<std::thread> server_thread_;
std::unique_ptr<grpc::Server> server_;
};
TEST(FederatedCommunicatorSimpleTest, ThrowOnWorldSizeTooSmall) {
auto construct = []() { FederatedCommunicator comm{0, 0, kServerAddress, "", "", ""}; };
EXPECT_THROW(construct(), dmlc::Error);
}
TEST(FederatedCommunicatorSimpleTest, ThrowOnRankTooSmall) {
auto construct = []() { FederatedCommunicator comm{1, -1, kServerAddress, "", "", ""}; };
EXPECT_THROW(construct(), dmlc::Error);
}
TEST(FederatedCommunicatorSimpleTest, ThrowOnRankTooBig) {
auto construct = []() { FederatedCommunicator comm{1, 1, kServerAddress, "", "", ""}; };
EXPECT_THROW(construct(), dmlc::Error);
}
TEST(FederatedCommunicatorSimpleTest, GetWorldSizeAndRank) {
FederatedCommunicator comm{6, 3, kServerAddress};
EXPECT_EQ(comm.GetWorldSize(), 6);
EXPECT_EQ(comm.GetRank(), 3);
}
TEST(FederatedCommunicatorSimpleTest, IsDistributed) {
FederatedCommunicator comm{2, 1, kServerAddress};
EXPECT_TRUE(comm.IsDistributed());
}
TEST_F(FederatedCommunicatorTest, Allreduce) {
std::vector<std::thread> threads;
for (auto rank = 0; rank < kWorldSize; rank++) {
threads.emplace_back(std::thread(&FederatedCommunicatorTest::VerifyAllreduce, rank));
}
for (auto &thread : threads) {
thread.join();
}
}
TEST_F(FederatedCommunicatorTest, Broadcast) {
std::vector<std::thread> threads;
for (auto rank = 0; rank < kWorldSize; rank++) {
threads.emplace_back(std::thread(&FederatedCommunicatorTest::VerifyBroadcast, rank));
}
for (auto &thread : threads) {
thread.join();
}
}
} // namespace collective
} // namespace xgboost

View File

@ -62,7 +62,7 @@ class FederatedServerTest : public ::testing::Test {
static void CheckAllreduce(federated::FederatedClient& client) { static void CheckAllreduce(federated::FederatedClient& client) {
int data[] = {1, 2, 3, 4, 5}; int data[] = {1, 2, 3, 4, 5};
std::string send_buffer(reinterpret_cast<char const*>(data), sizeof(data)); std::string send_buffer(reinterpret_cast<char const*>(data), sizeof(data));
auto reply = client.Allreduce(send_buffer, federated::INT, federated::SUM); auto reply = client.Allreduce(send_buffer, federated::INT32, federated::SUM);
auto const* result = reinterpret_cast<int const*>(reply.data()); auto const* result = reinterpret_cast<int const*>(reply.data());
int expected[] = {3, 6, 9, 12, 15}; int expected[] = {3, 6, 9, 12, 15};
for (auto i = 0; i < 5; i++) { for (auto i = 0; i < 5; i++) {

View File

@ -22,6 +22,7 @@ def run_server(port: int, world_size: int, with_ssl: bool) -> None:
def run_worker(port: int, world_size: int, rank: int, with_ssl: bool, with_gpu: bool) -> None: def run_worker(port: int, world_size: int, rank: int, with_ssl: bool, with_gpu: bool) -> None:
rabit_env = [ rabit_env = [
'xgboost_communicator=federated',
f'federated_server_address=localhost:{port}', f'federated_server_address=localhost:{port}',
f'federated_world_size={world_size}', f'federated_world_size={world_size}',
f'federated_rank={rank}' f'federated_rank={rank}'

View File

@ -0,0 +1,39 @@
import multiprocessing
import socket
import sys
import numpy as np
import pytest
import xgboost as xgb
from xgboost import RabitTracker
from xgboost import collective
if sys.platform.startswith("win"):
pytest.skip("Skipping collective tests on Windows", allow_module_level=True)
def run_rabit_worker(rabit_env, world_size):
with xgb.collective.CommunicatorContext(**rabit_env):
assert xgb.collective.get_world_size() == world_size
assert xgb.collective.is_distributed()
assert xgb.collective.get_processor_name() == socket.gethostname()
ret = xgb.collective.broadcast('test1234', 0)
assert str(ret) == 'test1234'
ret = xgb.collective.allreduce(np.asarray([1, 2, 3]), xgb.collective.Op.SUM)
assert np.array_equal(ret, np.asarray([2, 4, 6]))
def test_rabit_communicator():
world_size = 2
tracker = RabitTracker(host_ip='127.0.0.1', n_workers=world_size)
tracker.start(world_size)
workers = []
for _ in range(world_size):
worker = multiprocessing.Process(target=run_rabit_worker,
args=(tracker.worker_envs(), world_size))
workers.append(worker)
worker.start()
for worker in workers:
worker.join()
assert worker.exitcode == 0