Common interface for collective communication (#8057)
* implement broadcast for federated communicator * implement allreduce * add communicator factory * add device adapter * add device communicator to factory * add rabit communicator * add rabit communicator to the factory * add nccl device communicator * add synchronize to device communicator * add back print and getprocessorname * add python wrapper and c api * clean up types * fix non-gpu build * try to fix ci * fix std::size_t * portable string compare ignore case * c style size_t * fix lint errors * cross platform setenv * fix memory leak * fix lint errors * address review feedback * add python test for rabit communicator * fix failing gtest * use json to configure communicators * fix lint error * get rid of factories * fix cpu build * fix include * fix python import * don't export collective.py yet * skip collective communicator pytest on windows * add review feedback * update documentation * remove mpi communicator type * fix tests * shutdown the communicator separately Co-authored-by: Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
parent
bc818316f2
commit
a2686543a9
@ -71,6 +71,9 @@
|
|||||||
#include "../src/logging.cc"
|
#include "../src/logging.cc"
|
||||||
#include "../src/global_config.cc"
|
#include "../src/global_config.cc"
|
||||||
|
|
||||||
|
// collective
|
||||||
|
#include "../src/collective/communicator.cc"
|
||||||
|
|
||||||
// common
|
// common
|
||||||
#include "../src/common/common.cc"
|
#include "../src/common/common.cc"
|
||||||
#include "../src/common/column_matrix.cc"
|
#include "../src/common/column_matrix.cc"
|
||||||
|
|||||||
@ -9,10 +9,12 @@
|
|||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
#define XGB_EXTERN_C extern "C"
|
#define XGB_EXTERN_C extern "C"
|
||||||
|
#include <cstddef>
|
||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#else
|
#else
|
||||||
#define XGB_EXTERN_C
|
#define XGB_EXTERN_C
|
||||||
|
#include <stddef.h>
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
#include <stdint.h>
|
#include <stdint.h>
|
||||||
#endif // __cplusplus
|
#endif // __cplusplus
|
||||||
@ -1386,4 +1388,135 @@ XGB_DLL int XGBoosterFeatureScore(BoosterHandle handle, const char *json_config,
|
|||||||
bst_ulong *out_dim,
|
bst_ulong *out_dim,
|
||||||
bst_ulong const **out_shape,
|
bst_ulong const **out_shape,
|
||||||
float const **out_scores);
|
float const **out_scores);
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \brief Initialize the collective communicator.
|
||||||
|
*
|
||||||
|
* Currently the communicator API is experimental, function signatures may change in the future
|
||||||
|
* without notice.
|
||||||
|
*
|
||||||
|
* Call this once before using anything.
|
||||||
|
*
|
||||||
|
* The additional configuration is not required. Usually the communicator will detect settings
|
||||||
|
* from environment variables.
|
||||||
|
*
|
||||||
|
* \param json_config JSON encoded configuration. Accepted JSON keys are:
|
||||||
|
* - xgboost_communicator: The type of the communicator. Can be set as an environment variable.
|
||||||
|
* * rabit: Use Rabit. This is the default if the type is unspecified.
|
||||||
|
* * mpi: Use MPI.
|
||||||
|
* * federated: Use the gRPC interface for Federated Learning.
|
||||||
|
* Only applicable to the Rabit communicator (these are case-sensitive):
|
||||||
|
* - rabit_tracker_uri: Hostname of the tracker.
|
||||||
|
* - rabit_tracker_port: Port number of the tracker.
|
||||||
|
* - rabit_task_id: ID of the current task, can be used to obtain deterministic rank assignment.
|
||||||
|
* - rabit_world_size: Total number of workers.
|
||||||
|
* - rabit_hadoop_mode: Enable Hadoop support.
|
||||||
|
* - rabit_tree_reduce_minsize: Minimal size for tree reduce.
|
||||||
|
* - rabit_reduce_ring_mincount: Minimal count to perform ring reduce.
|
||||||
|
* - rabit_reduce_buffer: Size of the reduce buffer.
|
||||||
|
* - rabit_bootstrap_cache: Size of the bootstrap cache.
|
||||||
|
* - rabit_debug: Enable debugging.
|
||||||
|
* - rabit_timeout: Enable timeout.
|
||||||
|
* - rabit_timeout_sec: Timeout in seconds.
|
||||||
|
* - rabit_enable_tcp_no_delay: Enable TCP no delay on Unix platforms.
|
||||||
|
* Only applicable to the Rabit communicator (these are case-sensitive, and can be set as
|
||||||
|
* environment variables):
|
||||||
|
* - DMLC_TRACKER_URI: Hostname of the tracker.
|
||||||
|
* - DMLC_TRACKER_PORT: Port number of the tracker.
|
||||||
|
* - DMLC_TASK_ID: ID of the current task, can be used to obtain deterministic rank assignment.
|
||||||
|
* - DMLC_ROLE: Role of the current task, "worker" or "server".
|
||||||
|
* - DMLC_NUM_ATTEMPT: Number of attempts after task failure.
|
||||||
|
* - DMLC_WORKER_CONNECT_RETRY: Number of retries to connect to the tracker.
|
||||||
|
* Only applicable to the Federated communicator (use upper case for environment variables, use
|
||||||
|
* lower case for runtime configuration):
|
||||||
|
* - federated_server_address: Address of the federated server.
|
||||||
|
* - federated_world_size: Number of federated workers.
|
||||||
|
* - federated_rank: Rank of the current worker.
|
||||||
|
* - federated_server_cert: Server certificate file path. Only needed for the SSL mode.
|
||||||
|
* - federated_client_key: Client key file path. Only needed for the SSL mode.
|
||||||
|
* - federated_client_cert: Client certificate file path. Only needed for the SSL mode.
|
||||||
|
* \return 0 for success, -1 for failure.
|
||||||
|
*/
|
||||||
|
XGB_DLL int XGCommunicatorInit(char const* json_config);
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \brief Finalize the collective communicator.
|
||||||
|
*
|
||||||
|
* Call this function after you finished all jobs.
|
||||||
|
*
|
||||||
|
* \return 0 for success, -1 for failure.
|
||||||
|
*/
|
||||||
|
XGB_DLL int XGCommunicatorFinalize(void);
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \brief Get rank of current process.
|
||||||
|
*
|
||||||
|
* \return Rank of the worker.
|
||||||
|
*/
|
||||||
|
XGB_DLL int XGCommunicatorGetRank(void);
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \brief Get total number of processes.
|
||||||
|
*
|
||||||
|
* \return Total world size.
|
||||||
|
*/
|
||||||
|
XGB_DLL int XGCommunicatorGetWorldSize(void);
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \brief Get if the communicator is distributed.
|
||||||
|
*
|
||||||
|
* \return True if the communicator is distributed.
|
||||||
|
*/
|
||||||
|
XGB_DLL int XGCommunicatorIsDistributed(void);
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \brief Print the message to the communicator.
|
||||||
|
*
|
||||||
|
* This function can be used to communicate the information of the progress to the user who monitors
|
||||||
|
* the communicator.
|
||||||
|
*
|
||||||
|
* \param message The message to be printed.
|
||||||
|
* \return 0 for success, -1 for failure.
|
||||||
|
*/
|
||||||
|
XGB_DLL int XGCommunicatorPrint(char const *message);
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \brief Get the name of the processor.
|
||||||
|
*
|
||||||
|
* \param name_str Pointer to received returned processor name.
|
||||||
|
* \return 0 for success, -1 for failure.
|
||||||
|
*/
|
||||||
|
XGB_DLL int XGCommunicatorGetProcessorName(const char** name_str);
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \brief Broadcast a memory region to all others from root. This function is NOT thread-safe.
|
||||||
|
*
|
||||||
|
* Example:
|
||||||
|
* int a = 1;
|
||||||
|
* Broadcast(&a, sizeof(a), root);
|
||||||
|
*
|
||||||
|
* \param send_receive_buffer Pointer to the send or receive buffer.
|
||||||
|
* \param size Size of the data.
|
||||||
|
* \param root The process rank to broadcast from.
|
||||||
|
* \return 0 for success, -1 for failure.
|
||||||
|
*/
|
||||||
|
XGB_DLL int XGCommunicatorBroadcast(void *send_receive_buffer, size_t size, int root);
|
||||||
|
|
||||||
|
/*!
|
||||||
|
* \brief Perform in-place allreduce. This function is NOT thread-safe.
|
||||||
|
*
|
||||||
|
* Example Usage: the following code gives sum of the result
|
||||||
|
* vector<int> data(10);
|
||||||
|
* ...
|
||||||
|
* Allreduce(&data[0], data.size(), DataType:kInt32, Op::kSum);
|
||||||
|
* ...
|
||||||
|
* \param send_receive_buffer Buffer for both sending and receiving data.
|
||||||
|
* \param count Number of elements to be reduced.
|
||||||
|
* \param data_type Enumeration of data type, see xgboost::collective::DataType in communicator.h.
|
||||||
|
* \param op Enumeration of operation type, see xgboost::collective::Operation in communicator.h.
|
||||||
|
* \return 0 for success, -1 for failure.
|
||||||
|
*/
|
||||||
|
XGB_DLL int XGCommunicatorAllreduce(void *send_receive_buffer, size_t count, int data_type, int op);
|
||||||
|
|
||||||
|
|
||||||
#endif // XGBOOST_C_API_H_
|
#endif // XGBOOST_C_API_H_
|
||||||
|
|||||||
@ -71,7 +71,9 @@ class FederatedEngine : public IEngine {
|
|||||||
void Allreduce(void *sendrecvbuf, size_t size, mpi::DataType dtype, mpi::OpType op) {
|
void Allreduce(void *sendrecvbuf, size_t size, mpi::DataType dtype, mpi::OpType op) {
|
||||||
auto *buffer = reinterpret_cast<char *>(sendrecvbuf);
|
auto *buffer = reinterpret_cast<char *>(sendrecvbuf);
|
||||||
std::string const send_buffer(buffer, size);
|
std::string const send_buffer(buffer, size);
|
||||||
auto const receive_buffer = client_->Allreduce(send_buffer, GetDataType(dtype), GetOp(op));
|
auto const receive_buffer =
|
||||||
|
client_->Allreduce(send_buffer, static_cast<xgboost::federated::DataType>(dtype),
|
||||||
|
static_cast<xgboost::federated::ReduceOperation>(op));
|
||||||
receive_buffer.copy(buffer, size);
|
receive_buffer.copy(buffer, size);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -113,51 +115,6 @@ class FederatedEngine : public IEngine {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
/** @brief Transform mpi::DataType to xgboost::federated::DataType. */
|
|
||||||
static xgboost::federated::DataType GetDataType(mpi::DataType data_type) {
|
|
||||||
switch (data_type) {
|
|
||||||
case mpi::kChar:
|
|
||||||
return xgboost::federated::CHAR;
|
|
||||||
case mpi::kUChar:
|
|
||||||
return xgboost::federated::UCHAR;
|
|
||||||
case mpi::kInt:
|
|
||||||
return xgboost::federated::INT;
|
|
||||||
case mpi::kUInt:
|
|
||||||
return xgboost::federated::UINT;
|
|
||||||
case mpi::kLong:
|
|
||||||
return xgboost::federated::LONG;
|
|
||||||
case mpi::kULong:
|
|
||||||
return xgboost::federated::ULONG;
|
|
||||||
case mpi::kFloat:
|
|
||||||
return xgboost::federated::FLOAT;
|
|
||||||
case mpi::kDouble:
|
|
||||||
return xgboost::federated::DOUBLE;
|
|
||||||
case mpi::kLongLong:
|
|
||||||
return xgboost::federated::LONGLONG;
|
|
||||||
case mpi::kULongLong:
|
|
||||||
return xgboost::federated::ULONGLONG;
|
|
||||||
}
|
|
||||||
utils::Error("unknown mpi::DataType");
|
|
||||||
return xgboost::federated::CHAR;
|
|
||||||
}
|
|
||||||
|
|
||||||
/** @brief Transform mpi::OpType to enum to MPI OP */
|
|
||||||
static xgboost::federated::ReduceOperation GetOp(mpi::OpType op_type) {
|
|
||||||
switch (op_type) {
|
|
||||||
case mpi::kMax:
|
|
||||||
return xgboost::federated::MAX;
|
|
||||||
case mpi::kMin:
|
|
||||||
return xgboost::federated::MIN;
|
|
||||||
case mpi::kSum:
|
|
||||||
return xgboost::federated::SUM;
|
|
||||||
case mpi::kBitwiseOR:
|
|
||||||
utils::Error("Bitwise OR is not supported");
|
|
||||||
return xgboost::federated::MAX;
|
|
||||||
}
|
|
||||||
utils::Error("unknown mpi::OpType");
|
|
||||||
return xgboost::federated::MAX;
|
|
||||||
}
|
|
||||||
|
|
||||||
void SetParam(std::string const &name, std::string const &val) {
|
void SetParam(std::string const &name, std::string const &val) {
|
||||||
if (!strcasecmp(name.c_str(), "FEDERATED_SERVER_ADDRESS")) {
|
if (!strcasecmp(name.c_str(), "FEDERATED_SERVER_ADDRESS")) {
|
||||||
server_address_ = val;
|
server_address_ = val;
|
||||||
|
|||||||
@ -12,16 +12,14 @@ service Federated {
|
|||||||
}
|
}
|
||||||
|
|
||||||
enum DataType {
|
enum DataType {
|
||||||
CHAR = 0;
|
INT8 = 0;
|
||||||
UCHAR = 1;
|
UINT8 = 1;
|
||||||
INT = 2;
|
INT32 = 2;
|
||||||
UINT = 3;
|
UINT32 = 3;
|
||||||
LONG = 4;
|
INT64 = 4;
|
||||||
ULONG = 5;
|
UINT64 = 5;
|
||||||
FLOAT = 6;
|
FLOAT = 6;
|
||||||
DOUBLE = 7;
|
DOUBLE = 7;
|
||||||
LONGLONG = 8;
|
|
||||||
ULONGLONG = 9;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
enum ReduceOperation {
|
enum ReduceOperation {
|
||||||
|
|||||||
192
plugin/federated/federated_communicator.h
Normal file
192
plugin/federated/federated_communicator.h
Normal file
@ -0,0 +1,192 @@
|
|||||||
|
/*!
|
||||||
|
* Copyright 2022 XGBoost contributors
|
||||||
|
*/
|
||||||
|
#pragma once
|
||||||
|
#include <xgboost/json.h>
|
||||||
|
|
||||||
|
#include "../../src/collective/communicator.h"
|
||||||
|
#include "../../src/common/io.h"
|
||||||
|
#include "federated_client.h"
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
namespace collective {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief A Federated Learning communicator class that handles collective communication.
|
||||||
|
*/
|
||||||
|
class FederatedCommunicator : public Communicator {
|
||||||
|
public:
|
||||||
|
/**
|
||||||
|
* @brief Create a new communicator based on JSON configuration.
|
||||||
|
* @param config JSON configuration.
|
||||||
|
* @return Communicator as specified by the JSON configuration.
|
||||||
|
*/
|
||||||
|
static Communicator *Create(Json const &config) {
|
||||||
|
std::string server_address{};
|
||||||
|
int world_size{0};
|
||||||
|
int rank{-1};
|
||||||
|
std::string server_cert{};
|
||||||
|
std::string client_key{};
|
||||||
|
std::string client_cert{};
|
||||||
|
|
||||||
|
// Parse environment variables first.
|
||||||
|
auto *value = getenv("FEDERATED_SERVER_ADDRESS");
|
||||||
|
if (value != nullptr) {
|
||||||
|
server_address = value;
|
||||||
|
}
|
||||||
|
value = getenv("FEDERATED_WORLD_SIZE");
|
||||||
|
if (value != nullptr) {
|
||||||
|
world_size = std::stoi(value);
|
||||||
|
}
|
||||||
|
value = getenv("FEDERATED_RANK");
|
||||||
|
if (value != nullptr) {
|
||||||
|
rank = std::stoi(value);
|
||||||
|
}
|
||||||
|
value = getenv("FEDERATED_SERVER_CERT");
|
||||||
|
if (value != nullptr) {
|
||||||
|
server_cert = value;
|
||||||
|
}
|
||||||
|
value = getenv("FEDERATED_CLIENT_KEY");
|
||||||
|
if (value != nullptr) {
|
||||||
|
client_key = value;
|
||||||
|
}
|
||||||
|
value = getenv("FEDERATED_CLIENT_CERT");
|
||||||
|
if (value != nullptr) {
|
||||||
|
client_cert = value;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Runtime configuration overrides.
|
||||||
|
auto const &j_server_address = config["federated_server_address"];
|
||||||
|
if (IsA<String const>(j_server_address)) {
|
||||||
|
server_address = get<String const>(j_server_address);
|
||||||
|
}
|
||||||
|
auto const &j_world_size = config["federated_world_size"];
|
||||||
|
if (IsA<Integer const>(j_world_size)) {
|
||||||
|
world_size = static_cast<int>(get<Integer const>(j_world_size));
|
||||||
|
}
|
||||||
|
auto const &j_rank = config["federated_rank"];
|
||||||
|
if (IsA<Integer const>(j_rank)) {
|
||||||
|
rank = static_cast<int>(get<Integer const>(j_rank));
|
||||||
|
}
|
||||||
|
auto const &j_server_cert = config["federated_server_cert"];
|
||||||
|
if (IsA<String const>(j_server_cert)) {
|
||||||
|
server_cert = get<String const>(j_server_cert);
|
||||||
|
}
|
||||||
|
auto const &j_client_key = config["federated_client_key"];
|
||||||
|
if (IsA<String const>(j_client_key)) {
|
||||||
|
client_key = get<String const>(j_client_key);
|
||||||
|
}
|
||||||
|
auto const &j_client_cert = config["federated_client_cert"];
|
||||||
|
if (IsA<String const>(j_client_cert)) {
|
||||||
|
client_cert = get<String const>(j_client_cert);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (server_address.empty()) {
|
||||||
|
LOG(FATAL) << "Federated server address must be set.";
|
||||||
|
}
|
||||||
|
if (world_size == 0) {
|
||||||
|
LOG(FATAL) << "Federated world size must be set.";
|
||||||
|
}
|
||||||
|
if (rank == -1) {
|
||||||
|
LOG(FATAL) << "Federated rank must be set.";
|
||||||
|
}
|
||||||
|
return new FederatedCommunicator(world_size, rank, server_address, server_cert, client_key,
|
||||||
|
client_cert);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Construct a new federated communicator.
|
||||||
|
*
|
||||||
|
* @param world_size Total number of processes.
|
||||||
|
* @param rank Rank of the current process.
|
||||||
|
* @param server_address Address of the federated server (host:port).
|
||||||
|
* @param server_cert_path Path to the server cert file.
|
||||||
|
* @param client_key_path Path to the client key file.
|
||||||
|
* @param client_cert_path Path to the client cert file.
|
||||||
|
*/
|
||||||
|
FederatedCommunicator(int world_size, int rank, std::string const &server_address,
|
||||||
|
std::string const &server_cert_path, std::string const &client_key_path,
|
||||||
|
std::string const &client_cert_path)
|
||||||
|
: Communicator{world_size, rank} {
|
||||||
|
if (server_cert_path.empty() || client_key_path.empty() || client_cert_path.empty()) {
|
||||||
|
client_.reset(new xgboost::federated::FederatedClient(server_address, rank));
|
||||||
|
} else {
|
||||||
|
client_.reset(new xgboost::federated::FederatedClient(
|
||||||
|
server_address, rank, xgboost::common::ReadAll(server_cert_path),
|
||||||
|
xgboost::common::ReadAll(client_key_path), xgboost::common::ReadAll(client_cert_path)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Construct an insecure federated communicator without using SSL.
|
||||||
|
* @param world_size Total number of processes.
|
||||||
|
* @param rank Rank of the current process.
|
||||||
|
* @param server_address Address of the federated server (host:port).
|
||||||
|
*/
|
||||||
|
FederatedCommunicator(int world_size, int rank, std::string const &server_address)
|
||||||
|
: Communicator{world_size, rank} {
|
||||||
|
client_.reset(new xgboost::federated::FederatedClient(server_address, rank));
|
||||||
|
}
|
||||||
|
|
||||||
|
~FederatedCommunicator() override { client_.reset(); }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* \brief Get if the communicator is distributed.
|
||||||
|
* \return True.
|
||||||
|
*/
|
||||||
|
bool IsDistributed() const override { return true; }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* \brief Perform in-place allreduce.
|
||||||
|
* \param send_receive_buffer Buffer for both sending and receiving data.
|
||||||
|
* \param count Number of elements to be reduced.
|
||||||
|
* \param data_type Enumeration of data type.
|
||||||
|
* \param op Enumeration of operation type.
|
||||||
|
*/
|
||||||
|
void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
|
||||||
|
Operation op) override {
|
||||||
|
std::string const send_buffer(reinterpret_cast<char const *>(send_receive_buffer),
|
||||||
|
count * GetTypeSize(data_type));
|
||||||
|
auto const received =
|
||||||
|
client_->Allreduce(send_buffer, static_cast<xgboost::federated::DataType>(data_type),
|
||||||
|
static_cast<xgboost::federated::ReduceOperation>(op));
|
||||||
|
received.copy(reinterpret_cast<char *>(send_receive_buffer), count * GetTypeSize(data_type));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* \brief Broadcast a memory region to all others from root.
|
||||||
|
* \param send_receive_buffer Pointer to the send or receive buffer.
|
||||||
|
* \param size Size of the data.
|
||||||
|
* \param root The process rank to broadcast from.
|
||||||
|
*/
|
||||||
|
void Broadcast(void *send_receive_buffer, std::size_t size, int root) override {
|
||||||
|
if (GetWorldSize() == 1) return;
|
||||||
|
if (GetRank() == root) {
|
||||||
|
std::string const send_buffer(reinterpret_cast<char const *>(send_receive_buffer), size);
|
||||||
|
client_->Broadcast(send_buffer, root);
|
||||||
|
} else {
|
||||||
|
auto const received = client_->Broadcast("", root);
|
||||||
|
received.copy(reinterpret_cast<char *>(send_receive_buffer), size);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* \brief Get the name of the processor.
|
||||||
|
* \return Name of the processor.
|
||||||
|
*/
|
||||||
|
std::string GetProcessorName() override { return "rank" + std::to_string(GetRank()); }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* \brief Print the message to the communicator.
|
||||||
|
* \param message The message to be printed.
|
||||||
|
*/
|
||||||
|
void Print(const std::string &message) override { LOG(CONSOLE) << message; }
|
||||||
|
|
||||||
|
protected:
|
||||||
|
void Shutdown() override {}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::unique_ptr<xgboost::federated::FederatedClient> client_{};
|
||||||
|
};
|
||||||
|
} // namespace collective
|
||||||
|
} // namespace xgboost
|
||||||
@ -10,6 +10,8 @@
|
|||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
|
||||||
|
#include "../../src/common/io.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace federated {
|
namespace federated {
|
||||||
|
|
||||||
@ -71,32 +73,35 @@ class AllreduceFunctor {
|
|||||||
void Accumulate(std::string& buffer, std::string const& input, DataType data_type,
|
void Accumulate(std::string& buffer, std::string const& input, DataType data_type,
|
||||||
ReduceOperation reduce_operation) const {
|
ReduceOperation reduce_operation) const {
|
||||||
switch (data_type) {
|
switch (data_type) {
|
||||||
case DataType::CHAR:
|
case DataType::INT8:
|
||||||
Accumulate(&buffer[0], reinterpret_cast<char const*>(input.data()), buffer.size(),
|
Accumulate(reinterpret_cast<std::int8_t*>(&buffer[0]),
|
||||||
|
reinterpret_cast<std::int8_t const*>(input.data()), buffer.size(),
|
||||||
reduce_operation);
|
reduce_operation);
|
||||||
break;
|
break;
|
||||||
case DataType::UCHAR:
|
case DataType::UINT8:
|
||||||
Accumulate(reinterpret_cast<unsigned char*>(&buffer[0]),
|
Accumulate(reinterpret_cast<std::uint8_t*>(&buffer[0]),
|
||||||
reinterpret_cast<unsigned char const*>(input.data()), buffer.size(),
|
reinterpret_cast<std::uint8_t const*>(input.data()), buffer.size(),
|
||||||
reduce_operation);
|
reduce_operation);
|
||||||
break;
|
break;
|
||||||
case DataType::INT:
|
case DataType::INT32:
|
||||||
Accumulate(reinterpret_cast<int*>(&buffer[0]), reinterpret_cast<int const*>(input.data()),
|
Accumulate(reinterpret_cast<std::int32_t*>(&buffer[0]),
|
||||||
buffer.size() / sizeof(int), reduce_operation);
|
reinterpret_cast<std::int32_t const*>(input.data()),
|
||||||
|
buffer.size() / sizeof(std::uint32_t), reduce_operation);
|
||||||
break;
|
break;
|
||||||
case DataType::UINT:
|
case DataType::UINT32:
|
||||||
Accumulate(reinterpret_cast<unsigned int*>(&buffer[0]),
|
Accumulate(reinterpret_cast<std::uint32_t*>(&buffer[0]),
|
||||||
reinterpret_cast<unsigned int const*>(input.data()),
|
reinterpret_cast<std::uint32_t const*>(input.data()),
|
||||||
buffer.size() / sizeof(unsigned int), reduce_operation);
|
buffer.size() / sizeof(std::uint32_t), reduce_operation);
|
||||||
break;
|
break;
|
||||||
case DataType::LONG:
|
case DataType::INT64:
|
||||||
Accumulate(reinterpret_cast<long*>(&buffer[0]), reinterpret_cast<long const*>(input.data()),
|
Accumulate(reinterpret_cast<std::int64_t*>(&buffer[0]),
|
||||||
buffer.size() / sizeof(long), reduce_operation);
|
reinterpret_cast<std::int64_t const*>(input.data()),
|
||||||
|
buffer.size() / sizeof(std::int64_t), reduce_operation);
|
||||||
break;
|
break;
|
||||||
case DataType::ULONG:
|
case DataType::UINT64:
|
||||||
Accumulate(reinterpret_cast<unsigned long*>(&buffer[0]),
|
Accumulate(reinterpret_cast<std::uint64_t*>(&buffer[0]),
|
||||||
reinterpret_cast<unsigned long const*>(input.data()),
|
reinterpret_cast<std::uint64_t const*>(input.data()),
|
||||||
buffer.size() / sizeof(unsigned long), reduce_operation);
|
buffer.size() / sizeof(std::uint64_t), reduce_operation);
|
||||||
break;
|
break;
|
||||||
case DataType::FLOAT:
|
case DataType::FLOAT:
|
||||||
Accumulate(reinterpret_cast<float*>(&buffer[0]),
|
Accumulate(reinterpret_cast<float*>(&buffer[0]),
|
||||||
@ -108,16 +113,6 @@ class AllreduceFunctor {
|
|||||||
reinterpret_cast<double const*>(input.data()), buffer.size() / sizeof(double),
|
reinterpret_cast<double const*>(input.data()), buffer.size() / sizeof(double),
|
||||||
reduce_operation);
|
reduce_operation);
|
||||||
break;
|
break;
|
||||||
case DataType::LONGLONG:
|
|
||||||
Accumulate(reinterpret_cast<long long*>(&buffer[0]),
|
|
||||||
reinterpret_cast<long long const*>(input.data()),
|
|
||||||
buffer.size() / sizeof(long long), reduce_operation);
|
|
||||||
break;
|
|
||||||
case DataType::ULONGLONG:
|
|
||||||
Accumulate(reinterpret_cast<unsigned long long*>(&buffer[0]),
|
|
||||||
reinterpret_cast<unsigned long long const*>(input.data()),
|
|
||||||
buffer.size() / sizeof(unsigned long long), reduce_operation);
|
|
||||||
break;
|
|
||||||
default:
|
default:
|
||||||
throw std::invalid_argument("Invalid data type");
|
throw std::invalid_argument("Invalid data type");
|
||||||
}
|
}
|
||||||
@ -201,13 +196,6 @@ grpc::Status FederatedService::Handle(Request const* request, Reply* reply,
|
|||||||
return grpc::Status::OK;
|
return grpc::Status::OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string ReadFile(char const* path) {
|
|
||||||
auto stream = std::ifstream(path);
|
|
||||||
std::ostringstream out;
|
|
||||||
out << stream.rdbuf();
|
|
||||||
return out.str();
|
|
||||||
}
|
|
||||||
|
|
||||||
void RunServer(int port, int world_size, char const* server_key_file, char const* server_cert_file,
|
void RunServer(int port, int world_size, char const* server_key_file, char const* server_cert_file,
|
||||||
char const* client_cert_file) {
|
char const* client_cert_file) {
|
||||||
std::string const server_address = "0.0.0.0:" + std::to_string(port);
|
std::string const server_address = "0.0.0.0:" + std::to_string(port);
|
||||||
@ -216,10 +204,10 @@ void RunServer(int port, int world_size, char const* server_key_file, char const
|
|||||||
grpc::ServerBuilder builder;
|
grpc::ServerBuilder builder;
|
||||||
auto options =
|
auto options =
|
||||||
grpc::SslServerCredentialsOptions(GRPC_SSL_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY);
|
grpc::SslServerCredentialsOptions(GRPC_SSL_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY);
|
||||||
options.pem_root_certs = ReadFile(client_cert_file);
|
options.pem_root_certs = xgboost::common::ReadAll(client_cert_file);
|
||||||
auto key = grpc::SslServerCredentialsOptions::PemKeyCertPair();
|
auto key = grpc::SslServerCredentialsOptions::PemKeyCertPair();
|
||||||
key.private_key = ReadFile(server_key_file);
|
key.private_key = xgboost::common::ReadAll(server_key_file);
|
||||||
key.cert_chain = ReadFile(server_cert_file);
|
key.cert_chain = xgboost::common::ReadAll(server_cert_file);
|
||||||
options.pem_key_cert_pairs.push_back(key);
|
options.pem_key_cert_pairs.push_back(key);
|
||||||
builder.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
|
builder.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
|
||||||
builder.AddListeningPort(server_address, grpc::SslServerCredentials(options));
|
builder.AddListeningPort(server_address, grpc::SslServerCredentials(options));
|
||||||
|
|||||||
243
python-package/xgboost/collective.py
Normal file
243
python-package/xgboost/collective.py
Normal file
@ -0,0 +1,243 @@
|
|||||||
|
"""XGBoost collective communication related API."""
|
||||||
|
import ctypes
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import pickle
|
||||||
|
from enum import IntEnum, unique
|
||||||
|
from typing import Any, List
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from ._typing import _T
|
||||||
|
from .core import _LIB, _check_call, c_str, py_str, from_pystr_to_cstr
|
||||||
|
|
||||||
|
LOGGER = logging.getLogger("[xgboost.collective]")
|
||||||
|
|
||||||
|
|
||||||
|
def init(**args: Any) -> None:
|
||||||
|
"""Initialize the collective library with arguments.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
args: Dict[str, Any]
|
||||||
|
Keyword arguments representing the parameters and their values.
|
||||||
|
|
||||||
|
Accepted parameters:
|
||||||
|
- xgboost_communicator: The type of the communicator. Can be set as an environment
|
||||||
|
variable.
|
||||||
|
* rabit: Use Rabit. This is the default if the type is unspecified.
|
||||||
|
* federated: Use the gRPC interface for Federated Learning.
|
||||||
|
Only applicable to the Rabit communicator (these are case sensitive):
|
||||||
|
-- rabit_tracker_uri: Hostname of the tracker.
|
||||||
|
-- rabit_tracker_port: Port number of the tracker.
|
||||||
|
-- rabit_task_id: ID of the current task, can be used to obtain deterministic rank
|
||||||
|
assignment.
|
||||||
|
-- rabit_world_size: Total number of workers.
|
||||||
|
-- rabit_hadoop_mode: Enable Hadoop support.
|
||||||
|
-- rabit_tree_reduce_minsize: Minimal size for tree reduce.
|
||||||
|
-- rabit_reduce_ring_mincount: Minimal count to perform ring reduce.
|
||||||
|
-- rabit_reduce_buffer: Size of the reduce buffer.
|
||||||
|
-- rabit_bootstrap_cache: Size of the bootstrap cache.
|
||||||
|
-- rabit_debug: Enable debugging.
|
||||||
|
-- rabit_timeout: Enable timeout.
|
||||||
|
-- rabit_timeout_sec: Timeout in seconds.
|
||||||
|
-- rabit_enable_tcp_no_delay: Enable TCP no delay on Unix platforms.
|
||||||
|
Only applicable to the Rabit communicator (these are case-sensitive, and can be set as
|
||||||
|
environment variables):
|
||||||
|
-- DMLC_TRACKER_URI: Hostname of the tracker.
|
||||||
|
-- DMLC_TRACKER_PORT: Port number of the tracker.
|
||||||
|
-- DMLC_TASK_ID: ID of the current task, can be used to obtain deterministic rank
|
||||||
|
assignment.
|
||||||
|
-- DMLC_ROLE: Role of the current task, "worker" or "server".
|
||||||
|
-- DMLC_NUM_ATTEMPT: Number of attempts after task failure.
|
||||||
|
-- DMLC_WORKER_CONNECT_RETRY: Number of retries to connect to the tracker.
|
||||||
|
Only applicable to the Federated communicator (use upper case for environment variables, use
|
||||||
|
lower case for runtime configuration):
|
||||||
|
-- federated_server_address: Address of the federated server.
|
||||||
|
-- federated_world_size: Number of federated workers.
|
||||||
|
-- federated_rank: Rank of the current worker.
|
||||||
|
-- federated_server_cert: Server certificate file path. Only needed for the SSL mode.
|
||||||
|
-- federated_client_key: Client key file path. Only needed for the SSL mode.
|
||||||
|
-- federated_client_cert: Client certificate file path. Only needed for the SSL mode.
|
||||||
|
"""
|
||||||
|
config = from_pystr_to_cstr(json.dumps(args))
|
||||||
|
_check_call(_LIB.XGCommunicatorInit(config))
|
||||||
|
|
||||||
|
|
||||||
|
def finalize() -> None:
|
||||||
|
"""Finalize the communicator."""
|
||||||
|
_check_call(_LIB.XGCommunicatorFinalize())
|
||||||
|
|
||||||
|
|
||||||
|
def get_rank() -> int:
|
||||||
|
"""Get rank of current process.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
rank : int
|
||||||
|
Rank of current process.
|
||||||
|
"""
|
||||||
|
ret = _LIB.XGCommunicatorGetRank()
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
def get_world_size() -> int:
|
||||||
|
"""Get total number workers.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
n : int
|
||||||
|
Total number of process.
|
||||||
|
"""
|
||||||
|
ret = _LIB.XGCommunicatorGetWorldSize()
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
def is_distributed() -> int:
|
||||||
|
"""If the collective communicator is distributed."""
|
||||||
|
is_dist = _LIB.XGCommunicatorIsDistributed()
|
||||||
|
return is_dist
|
||||||
|
|
||||||
|
|
||||||
|
def communicator_print(msg: Any) -> None:
|
||||||
|
"""Print message to the communicator.
|
||||||
|
|
||||||
|
This function can be used to communicate the information of
|
||||||
|
the progress to the communicator.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
msg : str
|
||||||
|
The message to be printed to the communicator.
|
||||||
|
"""
|
||||||
|
if not isinstance(msg, str):
|
||||||
|
msg = str(msg)
|
||||||
|
is_dist = _LIB.XGCommunicatorIsDistributed()
|
||||||
|
if is_dist != 0:
|
||||||
|
_check_call(_LIB.XGCommunicatorPrint(c_str(msg)))
|
||||||
|
else:
|
||||||
|
print(msg.strip(), flush=True)
|
||||||
|
|
||||||
|
|
||||||
|
def get_processor_name() -> str:
|
||||||
|
"""Get the processor name.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
name : str
|
||||||
|
the name of processor(host)
|
||||||
|
"""
|
||||||
|
name_str = ctypes.c_char_p()
|
||||||
|
_check_call(_LIB.XGCommunicatorGetProcessorName(ctypes.byref(name_str)))
|
||||||
|
value = name_str.value
|
||||||
|
assert value
|
||||||
|
return py_str(value)
|
||||||
|
|
||||||
|
|
||||||
|
def broadcast(data: _T, root: int) -> _T:
|
||||||
|
"""Broadcast object from one node to all other nodes.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
data : any type that can be pickled
|
||||||
|
Input data, if current rank does not equal root, this can be None
|
||||||
|
root : int
|
||||||
|
Rank of the node to broadcast data from.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
object : int
|
||||||
|
the result of broadcast.
|
||||||
|
"""
|
||||||
|
rank = get_rank()
|
||||||
|
length = ctypes.c_ulong()
|
||||||
|
if root == rank:
|
||||||
|
assert data is not None, 'need to pass in data when broadcasting'
|
||||||
|
s = pickle.dumps(data, protocol=pickle.HIGHEST_PROTOCOL)
|
||||||
|
length.value = len(s)
|
||||||
|
# run first broadcast
|
||||||
|
_check_call(_LIB.XGCommunicatorBroadcast(ctypes.byref(length),
|
||||||
|
ctypes.sizeof(ctypes.c_ulong), root))
|
||||||
|
if root != rank:
|
||||||
|
dptr = (ctypes.c_char * length.value)()
|
||||||
|
# run second
|
||||||
|
_check_call(_LIB.XGCommunicatorBroadcast(ctypes.cast(dptr, ctypes.c_void_p),
|
||||||
|
length.value, root))
|
||||||
|
data = pickle.loads(dptr.raw)
|
||||||
|
del dptr
|
||||||
|
else:
|
||||||
|
_check_call(_LIB.XGCommunicatorBroadcast(ctypes.cast(ctypes.c_char_p(s), ctypes.c_void_p),
|
||||||
|
length.value, root))
|
||||||
|
del s
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
# enumeration of dtypes
|
||||||
|
DTYPE_ENUM__ = {
|
||||||
|
np.dtype('int8'): 0,
|
||||||
|
np.dtype('uint8'): 1,
|
||||||
|
np.dtype('int32'): 2,
|
||||||
|
np.dtype('uint32'): 3,
|
||||||
|
np.dtype('int64'): 4,
|
||||||
|
np.dtype('uint64'): 5,
|
||||||
|
np.dtype('float32'): 6,
|
||||||
|
np.dtype('float64'): 7
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@unique
|
||||||
|
class Op(IntEnum):
|
||||||
|
"""Supported operations for allreduce."""
|
||||||
|
MAX = 0
|
||||||
|
MIN = 1
|
||||||
|
SUM = 2
|
||||||
|
|
||||||
|
|
||||||
|
def allreduce( # pylint:disable=invalid-name
|
||||||
|
data: np.ndarray, op: Op
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""Perform allreduce, return the result.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
data :
|
||||||
|
Input data.
|
||||||
|
op :
|
||||||
|
Reduction operator.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
result :
|
||||||
|
The result of allreduce, have same shape as data
|
||||||
|
|
||||||
|
Notes
|
||||||
|
-----
|
||||||
|
This function is not thread-safe.
|
||||||
|
"""
|
||||||
|
if not isinstance(data, np.ndarray):
|
||||||
|
raise TypeError('allreduce only takes in numpy.ndarray')
|
||||||
|
buf = data.ravel()
|
||||||
|
if buf.base is data.base:
|
||||||
|
buf = buf.copy()
|
||||||
|
if buf.dtype not in DTYPE_ENUM__:
|
||||||
|
raise Exception(f"data type {buf.dtype} not supported")
|
||||||
|
_check_call(_LIB.XGCommunicatorAllreduce(buf.ctypes.data_as(ctypes.c_void_p),
|
||||||
|
buf.size, DTYPE_ENUM__[buf.dtype],
|
||||||
|
int(op), None, None))
|
||||||
|
return buf
|
||||||
|
|
||||||
|
|
||||||
|
class CommunicatorContext:
|
||||||
|
"""A context controlling collective communicator initialization and finalization."""
|
||||||
|
|
||||||
|
def __init__(self, **args: Any) -> None:
|
||||||
|
self.args = args
|
||||||
|
|
||||||
|
def __enter__(self) -> None:
|
||||||
|
init(**self.args)
|
||||||
|
assert is_distributed()
|
||||||
|
LOGGER.debug("-------------- communicator say hello ------------------")
|
||||||
|
|
||||||
|
def __exit__(self, *args: List) -> None:
|
||||||
|
finalize()
|
||||||
|
LOGGER.debug("--------------- communicator say bye ------------------")
|
||||||
@ -22,6 +22,7 @@
|
|||||||
|
|
||||||
#include "c_api_error.h"
|
#include "c_api_error.h"
|
||||||
#include "c_api_utils.h"
|
#include "c_api_utils.h"
|
||||||
|
#include "../collective/communicator.h"
|
||||||
#include "../common/io.h"
|
#include "../common/io.h"
|
||||||
#include "../common/charconv.h"
|
#include "../common/charconv.h"
|
||||||
#include "../data/adapter.h"
|
#include "../data/adapter.h"
|
||||||
@ -1370,6 +1371,62 @@ XGB_DLL int XGBoosterFeatureScore(BoosterHandle handle, char const *json_config,
|
|||||||
API_END();
|
API_END();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
using xgboost::collective::Communicator;
|
||||||
|
|
||||||
|
XGB_DLL int XGCommunicatorInit(char const* json_config) {
|
||||||
|
API_BEGIN();
|
||||||
|
Json config { Json::Load(StringView{json_config}) };
|
||||||
|
Communicator::Init(config);
|
||||||
|
API_END();
|
||||||
|
}
|
||||||
|
|
||||||
|
XGB_DLL int XGCommunicatorFinalize(void) {
|
||||||
|
API_BEGIN();
|
||||||
|
Communicator::Finalize();
|
||||||
|
API_END();
|
||||||
|
}
|
||||||
|
|
||||||
|
XGB_DLL int XGCommunicatorGetRank(void) {
|
||||||
|
return Communicator::Get()->GetRank();
|
||||||
|
}
|
||||||
|
|
||||||
|
XGB_DLL int XGCommunicatorGetWorldSize(void) {
|
||||||
|
return Communicator::Get()->GetWorldSize();
|
||||||
|
}
|
||||||
|
|
||||||
|
XGB_DLL int XGCommunicatorIsDistributed(void) {
|
||||||
|
return Communicator::Get()->IsDistributed();
|
||||||
|
}
|
||||||
|
|
||||||
|
XGB_DLL int XGCommunicatorPrint(char const *message) {
|
||||||
|
API_BEGIN();
|
||||||
|
Communicator::Get()->Print(message);
|
||||||
|
API_END();
|
||||||
|
}
|
||||||
|
|
||||||
|
XGB_DLL int XGCommunicatorGetProcessorName(char const **name_str) {
|
||||||
|
API_BEGIN();
|
||||||
|
auto& local = *GlobalConfigAPIThreadLocalStore::Get();
|
||||||
|
local.ret_str = Communicator::Get()->GetProcessorName();
|
||||||
|
*name_str = local.ret_str.c_str();
|
||||||
|
API_END();
|
||||||
|
}
|
||||||
|
|
||||||
|
XGB_DLL int XGCommunicatorBroadcast(void *send_receive_buffer, size_t size, int root) {
|
||||||
|
API_BEGIN();
|
||||||
|
Communicator::Get()->Broadcast(send_receive_buffer, size, root);
|
||||||
|
API_END();
|
||||||
|
}
|
||||||
|
|
||||||
|
XGB_DLL int XGCommunicatorAllreduce(void *send_receive_buffer, size_t count, int enum_dtype,
|
||||||
|
int enum_op) {
|
||||||
|
API_BEGIN();
|
||||||
|
Communicator::Get()->AllReduce(
|
||||||
|
send_receive_buffer, count, static_cast<xgboost::collective::DataType>(enum_dtype),
|
||||||
|
static_cast<xgboost::collective::Operation>(enum_op));
|
||||||
|
API_END();
|
||||||
|
}
|
||||||
|
|
||||||
#if defined(XGBOOST_USE_FEDERATED)
|
#if defined(XGBOOST_USE_FEDERATED)
|
||||||
XGB_DLL int XGBRunFederatedServer(int port, int world_size, char const *server_key_path,
|
XGB_DLL int XGBRunFederatedServer(int port, int world_size, char const *server_key_path,
|
||||||
char const *server_cert_path, char const *client_cert_path) {
|
char const *server_cert_path, char const *client_cert_path) {
|
||||||
|
|||||||
59
src/collective/communicator.cc
Normal file
59
src/collective/communicator.cc
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
/*!
|
||||||
|
* Copyright 2022 XGBoost contributors
|
||||||
|
*/
|
||||||
|
#include "communicator.h"
|
||||||
|
|
||||||
|
#include "rabit_communicator.h"
|
||||||
|
|
||||||
|
#if defined(XGBOOST_USE_FEDERATED)
|
||||||
|
#include "../../plugin/federated/federated_communicator.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
namespace collective {
|
||||||
|
|
||||||
|
thread_local std::unique_ptr<Communicator> Communicator::communicator_{};
|
||||||
|
thread_local CommunicatorType Communicator::type_{};
|
||||||
|
|
||||||
|
void Communicator::Init(Json const& config) {
|
||||||
|
if (communicator_) {
|
||||||
|
LOG(FATAL) << "Communicator can only be initialized once.";
|
||||||
|
}
|
||||||
|
|
||||||
|
auto type = GetTypeFromEnv();
|
||||||
|
auto const arg = GetTypeFromConfig(config);
|
||||||
|
if (arg != CommunicatorType::kUnknown) {
|
||||||
|
type = arg;
|
||||||
|
}
|
||||||
|
if (type == CommunicatorType::kUnknown) {
|
||||||
|
// Default to Rabit if unspecified.
|
||||||
|
type = CommunicatorType::kRabit;
|
||||||
|
}
|
||||||
|
type_ = type;
|
||||||
|
switch (type) {
|
||||||
|
case CommunicatorType::kRabit: {
|
||||||
|
communicator_.reset(RabitCommunicator::Create(config));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case CommunicatorType::kFederated: {
|
||||||
|
#if defined(XGBOOST_USE_FEDERATED)
|
||||||
|
communicator_.reset(FederatedCommunicator::Create(config));
|
||||||
|
#else
|
||||||
|
LOG(FATAL) << "XGBoost is not compiled with Federated Learning support.";
|
||||||
|
#endif
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case CommunicatorType::kUnknown:
|
||||||
|
LOG(FATAL) << "Unknown communicator type.";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#ifndef XGBOOST_USE_CUDA
|
||||||
|
void Communicator::Finalize() {
|
||||||
|
communicator_->Shutdown();
|
||||||
|
communicator_.reset(nullptr);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
} // namespace collective
|
||||||
|
} // namespace xgboost
|
||||||
41
src/collective/communicator.cu
Normal file
41
src/collective/communicator.cu
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
/*!
|
||||||
|
* Copyright 2022 XGBoost contributors
|
||||||
|
*/
|
||||||
|
#include "communicator.h"
|
||||||
|
#include "device_communicator.cuh"
|
||||||
|
#include "device_communicator_adapter.cuh"
|
||||||
|
#ifdef XGBOOST_USE_NCCL
|
||||||
|
#include "nccl_device_communicator.cuh"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
namespace collective {
|
||||||
|
|
||||||
|
thread_local int Communicator::device_ordinal_{-1};
|
||||||
|
thread_local std::unique_ptr<DeviceCommunicator> Communicator::device_communicator_{};
|
||||||
|
|
||||||
|
void Communicator::Finalize() {
|
||||||
|
communicator_->Shutdown();
|
||||||
|
communicator_.reset(nullptr);
|
||||||
|
device_ordinal_ = -1;
|
||||||
|
device_communicator_.reset(nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
DeviceCommunicator* Communicator::GetDevice(int device_ordinal) {
|
||||||
|
if (!device_communicator_ || device_ordinal_ != device_ordinal) {
|
||||||
|
device_ordinal_ = device_ordinal;
|
||||||
|
#ifdef XGBOOST_USE_NCCL
|
||||||
|
if (type_ != CommunicatorType::kFederated) {
|
||||||
|
device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal, Get()));
|
||||||
|
} else {
|
||||||
|
device_communicator_.reset(new DeviceCommunicatorAdapter(device_ordinal, Get()));
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
device_communicator_.reset(new DeviceCommunicatorAdapter(device_ordinal, Get()));
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
return device_communicator_.get();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace collective
|
||||||
|
} // namespace xgboost
|
||||||
218
src/collective/communicator.h
Normal file
218
src/collective/communicator.h
Normal file
@ -0,0 +1,218 @@
|
|||||||
|
/*!
|
||||||
|
* Copyright 2022 XGBoost contributors
|
||||||
|
*/
|
||||||
|
#pragma once
|
||||||
|
#include <xgboost/json.h>
|
||||||
|
#include <xgboost/logging.h>
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
namespace collective {
|
||||||
|
|
||||||
|
/** @brief Defines the integral and floating data types. */
|
||||||
|
enum class DataType {
|
||||||
|
kInt8 = 0,
|
||||||
|
kUInt8 = 1,
|
||||||
|
kInt32 = 2,
|
||||||
|
kUInt32 = 3,
|
||||||
|
kInt64 = 4,
|
||||||
|
kUInt64 = 5,
|
||||||
|
kFloat = 6,
|
||||||
|
kDouble = 7
|
||||||
|
};
|
||||||
|
|
||||||
|
/** @brief Get the size of the data type. */
|
||||||
|
inline std::size_t GetTypeSize(DataType data_type) {
|
||||||
|
std::size_t size{0};
|
||||||
|
switch (data_type) {
|
||||||
|
case DataType::kInt8:
|
||||||
|
size = sizeof(std::int8_t);
|
||||||
|
break;
|
||||||
|
case DataType::kUInt8:
|
||||||
|
size = sizeof(std::uint8_t);
|
||||||
|
break;
|
||||||
|
case DataType::kInt32:
|
||||||
|
size = sizeof(std::int32_t);
|
||||||
|
break;
|
||||||
|
case DataType::kUInt32:
|
||||||
|
size = sizeof(std::uint32_t);
|
||||||
|
break;
|
||||||
|
case DataType::kInt64:
|
||||||
|
size = sizeof(std::int64_t);
|
||||||
|
break;
|
||||||
|
case DataType::kUInt64:
|
||||||
|
size = sizeof(std::uint64_t);
|
||||||
|
break;
|
||||||
|
case DataType::kFloat:
|
||||||
|
size = sizeof(float);
|
||||||
|
break;
|
||||||
|
case DataType::kDouble:
|
||||||
|
size = sizeof(double);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
LOG(FATAL) << "Unknown data type.";
|
||||||
|
}
|
||||||
|
return size;
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief Defines the reduction operation. */
|
||||||
|
enum class Operation { kMax = 0, kMin = 1, kSum = 2 };
|
||||||
|
|
||||||
|
class DeviceCommunicator;
|
||||||
|
|
||||||
|
enum class CommunicatorType { kUnknown, kRabit, kFederated };
|
||||||
|
|
||||||
|
/** \brief Case-insensitive string comparison. */
|
||||||
|
inline int CompareStringsCaseInsensitive(const char *s1, const char *s2) {
|
||||||
|
#ifdef _MSC_VER
|
||||||
|
return _stricmp(s1, s2);
|
||||||
|
#else // _MSC_VER
|
||||||
|
return strcasecmp(s1, s2);
|
||||||
|
#endif // _MSC_VER
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief A communicator class that handles collective communication.
|
||||||
|
*/
|
||||||
|
class Communicator {
|
||||||
|
public:
|
||||||
|
/**
|
||||||
|
* @brief Initialize the communicator. This can only be done once.
|
||||||
|
*
|
||||||
|
* @param config JSON configuration for the communicator.
|
||||||
|
*/
|
||||||
|
static void Init(Json const &config);
|
||||||
|
|
||||||
|
/** @brief Finalize the communicator. */
|
||||||
|
static void Finalize();
|
||||||
|
|
||||||
|
/** @brief Get the communicator instance. */
|
||||||
|
static Communicator *Get() { return communicator_.get(); }
|
||||||
|
|
||||||
|
#if defined(XGBOOST_USE_CUDA)
|
||||||
|
/**
|
||||||
|
* @brief Get the device communicator.
|
||||||
|
*
|
||||||
|
* @param device_ordinal ID of the device.
|
||||||
|
* @return An instance of device communicator.
|
||||||
|
*/
|
||||||
|
static DeviceCommunicator *GetDevice(int device_ordinal);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
virtual ~Communicator() = default;
|
||||||
|
|
||||||
|
/** @brief Get the total number of processes. */
|
||||||
|
int GetWorldSize() const { return world_size_; }
|
||||||
|
|
||||||
|
/** @brief Get the rank of the current processes. */
|
||||||
|
int GetRank() const { return rank_; }
|
||||||
|
|
||||||
|
/** @brief Whether the communicator is running in distributed mode. */
|
||||||
|
virtual bool IsDistributed() const = 0;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Combines values from all processes and distributes the result back to all processes.
|
||||||
|
*
|
||||||
|
* @param send_receive_buffer Buffer storing the data.
|
||||||
|
* @param count Number of elements in the buffer.
|
||||||
|
* @param data_type Data type stored in the buffer.
|
||||||
|
* @param op The operation to perform.
|
||||||
|
*/
|
||||||
|
virtual void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
|
||||||
|
Operation op) = 0;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Broadcasts a message from the process with rank `root` to all other processes of the
|
||||||
|
* group.
|
||||||
|
*
|
||||||
|
* @param send_receive_buffer Buffer storing the data.
|
||||||
|
* @param size Size of the data in bytes.
|
||||||
|
* @param root Rank of broadcast root.
|
||||||
|
*/
|
||||||
|
virtual void Broadcast(void *send_receive_buffer, std::size_t size, int root) = 0;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Gets the name of the processor.
|
||||||
|
*/
|
||||||
|
virtual std::string GetProcessorName() = 0;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Prints the message.
|
||||||
|
*/
|
||||||
|
virtual void Print(std::string const &message) = 0;
|
||||||
|
|
||||||
|
/** @brief Get the communicator type from environment variables. Visible for testing. */
|
||||||
|
static CommunicatorType GetTypeFromEnv() {
|
||||||
|
auto *env = std::getenv("XGBOOST_COMMUNICATOR");
|
||||||
|
if (env != nullptr) {
|
||||||
|
return StringToType(env);
|
||||||
|
} else {
|
||||||
|
return CommunicatorType::kUnknown;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief Get the communicator type from runtime configuration. Visible for testing. */
|
||||||
|
static CommunicatorType GetTypeFromConfig(Json const &config) {
|
||||||
|
auto const &j_upper = config["XGBOOST_COMMUNICATOR"];
|
||||||
|
if (IsA<String const>(j_upper)) {
|
||||||
|
return StringToType(get<String const>(j_upper).c_str());
|
||||||
|
}
|
||||||
|
auto const &j_lower = config["xgboost_communicator"];
|
||||||
|
if (IsA<String const>(j_lower)) {
|
||||||
|
return StringToType(get<String const>(j_lower).c_str());
|
||||||
|
}
|
||||||
|
return CommunicatorType::kUnknown;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
/**
|
||||||
|
* @brief Construct a new communicator.
|
||||||
|
*
|
||||||
|
* @param world_size Total number of processes.
|
||||||
|
* @param rank Rank of the current process.
|
||||||
|
*/
|
||||||
|
Communicator(int world_size, int rank) : world_size_(world_size), rank_(rank) {
|
||||||
|
if (world_size < 1) {
|
||||||
|
LOG(FATAL) << "World size " << world_size << " is less than 1.";
|
||||||
|
}
|
||||||
|
if (rank < 0) {
|
||||||
|
LOG(FATAL) << "Rank " << rank << " is less than 0.";
|
||||||
|
}
|
||||||
|
if (rank >= world_size) {
|
||||||
|
LOG(FATAL) << "Rank " << rank << " is greater than world_size - 1: " << world_size - 1 << ".";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Shuts down the communicator.
|
||||||
|
*/
|
||||||
|
virtual void Shutdown() = 0;
|
||||||
|
|
||||||
|
private:
|
||||||
|
static CommunicatorType StringToType(char const *str) {
|
||||||
|
CommunicatorType result = CommunicatorType::kUnknown;
|
||||||
|
if (!CompareStringsCaseInsensitive("rabit", str)) {
|
||||||
|
result = CommunicatorType::kRabit;
|
||||||
|
} else if (!CompareStringsCaseInsensitive("federated", str)) {
|
||||||
|
result = CommunicatorType::kFederated;
|
||||||
|
} else {
|
||||||
|
LOG(FATAL) << "Unknown communicator type " << str;
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
static thread_local std::unique_ptr<Communicator> communicator_;
|
||||||
|
static thread_local CommunicatorType type_;
|
||||||
|
#if defined(XGBOOST_USE_CUDA)
|
||||||
|
static thread_local int device_ordinal_;
|
||||||
|
static thread_local std::unique_ptr<DeviceCommunicator> device_communicator_;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
int const world_size_;
|
||||||
|
int const rank_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace collective
|
||||||
|
} // namespace xgboost
|
||||||
42
src/collective/device_communicator.cuh
Normal file
42
src/collective/device_communicator.cuh
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
/*!
|
||||||
|
* Copyright 2022 XGBoost contributors
|
||||||
|
*/
|
||||||
|
#pragma once
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "../common/device_helpers.cuh"
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
namespace collective {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Collective communicator for device buffers.
|
||||||
|
*/
|
||||||
|
class DeviceCommunicator {
|
||||||
|
public:
|
||||||
|
virtual ~DeviceCommunicator() = default;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Sum values from all processes and distribute the result back to all processes.
|
||||||
|
* @param send_receive_buffer Buffer storing the data.
|
||||||
|
* @param count Number of elements in the buffer.
|
||||||
|
*/
|
||||||
|
virtual void AllReduceSum(double *send_receive_buffer, int count) = 0;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Gather variable-length values from all processes.
|
||||||
|
* @param send_buffer Buffer storing the input data.
|
||||||
|
* @param length_bytes Length in bytes of the input data.
|
||||||
|
* @param segments Size of each segment.
|
||||||
|
* @param receive_buffer Buffer storing the output data.
|
||||||
|
*/
|
||||||
|
virtual void AllGatherV(void const *send_buffer, size_t length_bytes,
|
||||||
|
std::vector<size_t> *segments,
|
||||||
|
dh::caching_device_vector<char> *receive_buffer) = 0;
|
||||||
|
|
||||||
|
/** @brief Synchronize device operations. */
|
||||||
|
virtual void Synchronize() = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace collective
|
||||||
|
} // namespace xgboost
|
||||||
76
src/collective/device_communicator_adapter.cuh
Normal file
76
src/collective/device_communicator_adapter.cuh
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
/*!
|
||||||
|
* Copyright 2022 XGBoost contributors
|
||||||
|
*/
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "communicator.h"
|
||||||
|
#include "device_communicator.cuh"
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
namespace collective {
|
||||||
|
|
||||||
|
class DeviceCommunicatorAdapter : public DeviceCommunicator {
|
||||||
|
public:
|
||||||
|
DeviceCommunicatorAdapter(int device_ordinal, Communicator *communicator)
|
||||||
|
: device_ordinal_{device_ordinal}, communicator_{communicator} {
|
||||||
|
if (device_ordinal_ < 0) {
|
||||||
|
LOG(FATAL) << "Invalid device ordinal: " << device_ordinal_;
|
||||||
|
}
|
||||||
|
if (communicator_ == nullptr) {
|
||||||
|
LOG(FATAL) << "Communicator cannot be null.";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
~DeviceCommunicatorAdapter() override = default;
|
||||||
|
|
||||||
|
void AllReduceSum(double *send_receive_buffer, int count) override {
|
||||||
|
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||||
|
auto size = count * sizeof(double);
|
||||||
|
host_buffer_.reserve(size);
|
||||||
|
dh::safe_cuda(cudaMemcpy(host_buffer_.data(), send_receive_buffer, size, cudaMemcpyDefault));
|
||||||
|
communicator_->AllReduce(host_buffer_.data(), count, DataType::kDouble, Operation::kSum);
|
||||||
|
dh::safe_cuda(cudaMemcpy(send_receive_buffer, host_buffer_.data(), size, cudaMemcpyDefault));
|
||||||
|
}
|
||||||
|
|
||||||
|
void AllGatherV(void const *send_buffer, size_t length_bytes, std::vector<std::size_t> *segments,
|
||||||
|
dh::caching_device_vector<char> *receive_buffer) override {
|
||||||
|
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||||
|
int const world_size = communicator_->GetWorldSize();
|
||||||
|
int const rank = communicator_->GetRank();
|
||||||
|
|
||||||
|
segments->clear();
|
||||||
|
segments->resize(world_size, 0);
|
||||||
|
segments->at(rank) = length_bytes;
|
||||||
|
communicator_->AllReduce(segments->data(), segments->size(), DataType::kUInt64,
|
||||||
|
Operation::kMax);
|
||||||
|
auto total_bytes = std::accumulate(segments->cbegin(), segments->cend(), 0UL);
|
||||||
|
receive_buffer->resize(total_bytes);
|
||||||
|
|
||||||
|
host_buffer_.reserve(total_bytes);
|
||||||
|
size_t offset = 0;
|
||||||
|
for (int32_t i = 0; i < world_size; ++i) {
|
||||||
|
size_t as_bytes = segments->at(i);
|
||||||
|
if (i == rank) {
|
||||||
|
dh::safe_cuda(cudaMemcpy(host_buffer_.data() + offset, send_buffer, segments->at(rank),
|
||||||
|
cudaMemcpyDefault));
|
||||||
|
}
|
||||||
|
communicator_->Broadcast(host_buffer_.data() + offset, as_bytes, i);
|
||||||
|
offset += as_bytes;
|
||||||
|
}
|
||||||
|
dh::safe_cuda(cudaMemcpy(receive_buffer->data().get(), host_buffer_.data(), total_bytes,
|
||||||
|
cudaMemcpyDefault));
|
||||||
|
}
|
||||||
|
|
||||||
|
void Synchronize() override {
|
||||||
|
// Noop.
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
int const device_ordinal_;
|
||||||
|
Communicator *communicator_;
|
||||||
|
/// Host buffer used to call communicator functions.
|
||||||
|
std::vector<char> host_buffer_{};
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace collective
|
||||||
|
} // namespace xgboost
|
||||||
149
src/collective/nccl_device_communicator.cuh
Normal file
149
src/collective/nccl_device_communicator.cuh
Normal file
@ -0,0 +1,149 @@
|
|||||||
|
/*!
|
||||||
|
* Copyright 2022 XGBoost contributors
|
||||||
|
*/
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "../common/device_helpers.cuh"
|
||||||
|
#include "communicator.h"
|
||||||
|
#include "device_communicator.cuh"
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
namespace collective {
|
||||||
|
|
||||||
|
class NcclDeviceCommunicator : public DeviceCommunicator {
|
||||||
|
public:
|
||||||
|
NcclDeviceCommunicator(int device_ordinal, Communicator *communicator)
|
||||||
|
: device_ordinal_{device_ordinal}, communicator_{communicator} {
|
||||||
|
if (device_ordinal_ < 0) {
|
||||||
|
LOG(FATAL) << "Invalid device ordinal: " << device_ordinal_;
|
||||||
|
}
|
||||||
|
if (communicator_ == nullptr) {
|
||||||
|
LOG(FATAL) << "Communicator cannot be null.";
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t const rank = communicator_->GetRank();
|
||||||
|
int32_t const world = communicator_->GetWorldSize();
|
||||||
|
|
||||||
|
std::vector<uint64_t> uuids(world * kUuidLength, 0);
|
||||||
|
auto s_uuid = xgboost::common::Span<uint64_t>{uuids.data(), uuids.size()};
|
||||||
|
auto s_this_uuid = s_uuid.subspan(rank * kUuidLength, kUuidLength);
|
||||||
|
GetCudaUUID(s_this_uuid);
|
||||||
|
|
||||||
|
// TODO(rongou): replace this with allgather.
|
||||||
|
communicator_->AllReduce(uuids.data(), uuids.size(), DataType::kUInt64, Operation::kSum);
|
||||||
|
|
||||||
|
std::vector<xgboost::common::Span<uint64_t, kUuidLength>> converted(world);
|
||||||
|
size_t j = 0;
|
||||||
|
for (size_t i = 0; i < uuids.size(); i += kUuidLength) {
|
||||||
|
converted[j] = xgboost::common::Span<uint64_t, kUuidLength>{uuids.data() + i, kUuidLength};
|
||||||
|
j++;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto iter = std::unique(converted.begin(), converted.end());
|
||||||
|
auto n_uniques = std::distance(converted.begin(), iter);
|
||||||
|
|
||||||
|
CHECK_EQ(n_uniques, world)
|
||||||
|
<< "Multiple processes within communication group running on same CUDA "
|
||||||
|
<< "device is not supported. " << PrintUUID(s_this_uuid) << "\n";
|
||||||
|
|
||||||
|
nccl_unique_id_ = GetUniqueId();
|
||||||
|
dh::safe_nccl(ncclCommInitRank(&nccl_comm_, world, nccl_unique_id_, rank));
|
||||||
|
dh::safe_cuda(cudaStreamCreate(&cuda_stream_));
|
||||||
|
}
|
||||||
|
|
||||||
|
~NcclDeviceCommunicator() override {
|
||||||
|
dh::safe_cuda(cudaStreamDestroy(cuda_stream_));
|
||||||
|
ncclCommDestroy(nccl_comm_);
|
||||||
|
if (xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug)) {
|
||||||
|
LOG(CONSOLE) << "======== NCCL Statistics========";
|
||||||
|
LOG(CONSOLE) << "AllReduce calls: " << allreduce_calls_;
|
||||||
|
LOG(CONSOLE) << "AllReduce total MiB communicated: " << allreduce_bytes_ / 1048576;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void AllReduceSum(double *send_receive_buffer, int count) override {
|
||||||
|
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||||
|
dh::safe_nccl(ncclAllReduce(send_receive_buffer, send_receive_buffer, count, ncclDouble,
|
||||||
|
ncclSum, nccl_comm_, cuda_stream_));
|
||||||
|
allreduce_bytes_ += count * sizeof(double);
|
||||||
|
allreduce_calls_ += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
void AllGatherV(void const *send_buffer, size_t length_bytes, std::vector<std::size_t> *segments,
|
||||||
|
dh::caching_device_vector<char> *receive_buffer) override {
|
||||||
|
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||||
|
int const world_size = communicator_->GetWorldSize();
|
||||||
|
int const rank = communicator_->GetRank();
|
||||||
|
|
||||||
|
segments->clear();
|
||||||
|
segments->resize(world_size, 0);
|
||||||
|
segments->at(rank) = length_bytes;
|
||||||
|
communicator_->AllReduce(segments->data(), segments->size(), DataType::kUInt64,
|
||||||
|
Operation::kMax);
|
||||||
|
auto total_bytes = std::accumulate(segments->cbegin(), segments->cend(), 0UL);
|
||||||
|
receive_buffer->resize(total_bytes);
|
||||||
|
|
||||||
|
size_t offset = 0;
|
||||||
|
dh::safe_nccl(ncclGroupStart());
|
||||||
|
for (int32_t i = 0; i < world_size; ++i) {
|
||||||
|
size_t as_bytes = segments->at(i);
|
||||||
|
dh::safe_nccl(ncclBroadcast(send_buffer, receive_buffer->data().get() + offset, as_bytes,
|
||||||
|
ncclChar, i, nccl_comm_, cuda_stream_));
|
||||||
|
offset += as_bytes;
|
||||||
|
}
|
||||||
|
dh::safe_nccl(ncclGroupEnd());
|
||||||
|
}
|
||||||
|
|
||||||
|
void Synchronize() override {
|
||||||
|
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||||
|
dh::safe_cuda(cudaStreamSynchronize(cuda_stream_));
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
static constexpr std::size_t kUuidLength =
|
||||||
|
sizeof(std::declval<cudaDeviceProp>().uuid) / sizeof(uint64_t);
|
||||||
|
|
||||||
|
void GetCudaUUID(xgboost::common::Span<uint64_t, kUuidLength> const &uuid) const {
|
||||||
|
cudaDeviceProp prob{};
|
||||||
|
dh::safe_cuda(cudaGetDeviceProperties(&prob, device_ordinal_));
|
||||||
|
std::memcpy(uuid.data(), static_cast<void *>(&(prob.uuid)), sizeof(prob.uuid));
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::string PrintUUID(xgboost::common::Span<uint64_t, kUuidLength> const &uuid) {
|
||||||
|
std::stringstream ss;
|
||||||
|
for (auto v : uuid) {
|
||||||
|
ss << std::hex << v;
|
||||||
|
}
|
||||||
|
return ss.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* \fn ncclUniqueId GetUniqueId()
|
||||||
|
*
|
||||||
|
* \brief Gets the Unique ID from NCCL to be used in setting up interprocess
|
||||||
|
* communication
|
||||||
|
*
|
||||||
|
* \return the Unique ID
|
||||||
|
*/
|
||||||
|
ncclUniqueId GetUniqueId() {
|
||||||
|
static const int kRootRank = 0;
|
||||||
|
ncclUniqueId id;
|
||||||
|
if (communicator_->GetRank() == kRootRank) {
|
||||||
|
dh::safe_nccl(ncclGetUniqueId(&id));
|
||||||
|
}
|
||||||
|
communicator_->Broadcast(static_cast<void *>(&id), sizeof(ncclUniqueId),
|
||||||
|
static_cast<int>(kRootRank));
|
||||||
|
return id;
|
||||||
|
}
|
||||||
|
|
||||||
|
int const device_ordinal_;
|
||||||
|
Communicator *communicator_;
|
||||||
|
ncclComm_t nccl_comm_{};
|
||||||
|
cudaStream_t cuda_stream_{};
|
||||||
|
ncclUniqueId nccl_unique_id_{};
|
||||||
|
size_t allreduce_bytes_{0}; // Keep statistics of the number of bytes communicated.
|
||||||
|
size_t allreduce_calls_{0}; // Keep statistics of the number of reduce calls.
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace collective
|
||||||
|
} // namespace xgboost
|
||||||
120
src/collective/rabit_communicator.h
Normal file
120
src/collective/rabit_communicator.h
Normal file
@ -0,0 +1,120 @@
|
|||||||
|
/*!
|
||||||
|
* Copyright 2022 XGBoost contributors
|
||||||
|
*/
|
||||||
|
#pragma once
|
||||||
|
#include <rabit/rabit.h>
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "communicator.h"
|
||||||
|
#include "xgboost/json.h"
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
namespace collective {
|
||||||
|
|
||||||
|
class RabitCommunicator : public Communicator {
|
||||||
|
public:
|
||||||
|
static Communicator *Create(Json const &config) {
|
||||||
|
std::vector<std::string> args_str;
|
||||||
|
for (auto &items : get<Object const>(config)) {
|
||||||
|
switch (items.second.GetValue().Type()) {
|
||||||
|
case xgboost::Value::ValueKind::kString: {
|
||||||
|
args_str.push_back(items.first + "=" + get<String const>(items.second));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case xgboost::Value::ValueKind::kInteger: {
|
||||||
|
args_str.push_back(items.first + "=" + std::to_string(get<Integer const>(items.second)));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case xgboost::Value::ValueKind::kBoolean: {
|
||||||
|
if (get<Boolean const>(items.second)) {
|
||||||
|
args_str.push_back(items.first + "=1");
|
||||||
|
} else {
|
||||||
|
args_str.push_back(items.first + "=0");
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::vector<char *> args;
|
||||||
|
for (auto &key_value : args_str) {
|
||||||
|
args.push_back(&key_value[0]);
|
||||||
|
}
|
||||||
|
if (!rabit::Init(static_cast<int>(args.size()), &args[0])) {
|
||||||
|
LOG(FATAL) << "Failed to initialize Rabit";
|
||||||
|
}
|
||||||
|
return new RabitCommunicator(rabit::GetWorldSize(), rabit::GetRank());
|
||||||
|
}
|
||||||
|
|
||||||
|
RabitCommunicator(int world_size, int rank) : Communicator(world_size, rank) {}
|
||||||
|
|
||||||
|
bool IsDistributed() const override { return rabit::IsDistributed(); }
|
||||||
|
|
||||||
|
void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
|
||||||
|
Operation op) override {
|
||||||
|
switch (data_type) {
|
||||||
|
case DataType::kInt8:
|
||||||
|
DoAllReduce<char>(send_receive_buffer, count, op);
|
||||||
|
break;
|
||||||
|
case DataType::kUInt8:
|
||||||
|
DoAllReduce<unsigned char>(send_receive_buffer, count, op);
|
||||||
|
break;
|
||||||
|
case DataType::kInt32:
|
||||||
|
DoAllReduce<std::int32_t>(send_receive_buffer, count, op);
|
||||||
|
break;
|
||||||
|
case DataType::kUInt32:
|
||||||
|
DoAllReduce<std::uint32_t>(send_receive_buffer, count, op);
|
||||||
|
break;
|
||||||
|
case DataType::kInt64:
|
||||||
|
DoAllReduce<std::int64_t>(send_receive_buffer, count, op);
|
||||||
|
break;
|
||||||
|
case DataType::kUInt64:
|
||||||
|
DoAllReduce<std::uint64_t>(send_receive_buffer, count, op);
|
||||||
|
break;
|
||||||
|
case DataType::kFloat:
|
||||||
|
DoAllReduce<float>(send_receive_buffer, count, op);
|
||||||
|
break;
|
||||||
|
case DataType::kDouble:
|
||||||
|
DoAllReduce<double>(send_receive_buffer, count, op);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
LOG(FATAL) << "Unknown data type";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Broadcast(void *send_receive_buffer, std::size_t size, int root) override {
|
||||||
|
rabit::Broadcast(send_receive_buffer, size, root);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string GetProcessorName() override { return rabit::GetProcessorName(); }
|
||||||
|
|
||||||
|
void Print(const std::string &message) override { rabit::TrackerPrint(message); }
|
||||||
|
|
||||||
|
protected:
|
||||||
|
void Shutdown() override {
|
||||||
|
rabit::Finalize();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
template <typename DType>
|
||||||
|
void DoAllReduce(void *send_receive_buffer, std::size_t count, Operation op) {
|
||||||
|
switch (op) {
|
||||||
|
case Operation::kMax:
|
||||||
|
rabit::Allreduce<rabit::op::Max, DType>(static_cast<DType *>(send_receive_buffer), count);
|
||||||
|
break;
|
||||||
|
case Operation::kMin:
|
||||||
|
rabit::Allreduce<rabit::op::Min, DType>(static_cast<DType *>(send_receive_buffer), count);
|
||||||
|
break;
|
||||||
|
case Operation::kSum:
|
||||||
|
rabit::Allreduce<rabit::op::Sum, DType>(static_cast<DType *>(send_receive_buffer), count);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
LOG(FATAL) << "Unknown allreduce operation";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace collective
|
||||||
|
} // namespace xgboost
|
||||||
@ -12,6 +12,7 @@
|
|||||||
#include <rabit/rabit.h>
|
#include <rabit/rabit.h>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
|
#include <fstream>
|
||||||
|
|
||||||
#include "common.h"
|
#include "common.h"
|
||||||
|
|
||||||
@ -111,6 +112,22 @@ inline std::string ReadAll(dmlc::Stream* fi, PeekableInStream* fp) {
|
|||||||
}
|
}
|
||||||
return buffer;
|
return buffer;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* \brief Read the whole file content into a string.
|
||||||
|
*/
|
||||||
|
inline std::string ReadAll(std::string const &path) {
|
||||||
|
std::ifstream stream(path);
|
||||||
|
if (!stream.is_open()) {
|
||||||
|
LOG(FATAL) << "Could not open file " << path;
|
||||||
|
}
|
||||||
|
std::string content{std::istreambuf_iterator<char>(stream), std::istreambuf_iterator<char>()};
|
||||||
|
if (content.empty()) {
|
||||||
|
LOG(FATAL) << "Empty file " << path;
|
||||||
|
}
|
||||||
|
return content;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace common
|
} // namespace common
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
#endif // XGBOOST_COMMON_IO_H_
|
#endif // XGBOOST_COMMON_IO_H_
|
||||||
|
|||||||
@ -22,7 +22,7 @@ if (PLUGIN_FEDERATED)
|
|||||||
target_include_directories(testxgboost PRIVATE ${xgboost_SOURCE_DIR}/plugin/federated)
|
target_include_directories(testxgboost PRIVATE ${xgboost_SOURCE_DIR}/plugin/federated)
|
||||||
target_link_libraries(testxgboost PRIVATE federated_client)
|
target_link_libraries(testxgboost PRIVATE federated_client)
|
||||||
else (PLUGIN_FEDERATED)
|
else (PLUGIN_FEDERATED)
|
||||||
file(GLOB_RECURSE FEDERATED_TEST_SOURCES "plugin/*_federated_*.cc")
|
file(GLOB_RECURSE FEDERATED_TEST_SOURCES "plugin/*_federated_*.*")
|
||||||
list(REMOVE_ITEM TEST_SOURCES ${FEDERATED_TEST_SOURCES})
|
list(REMOVE_ITEM TEST_SOURCES ${FEDERATED_TEST_SOURCES})
|
||||||
endif (PLUGIN_FEDERATED)
|
endif (PLUGIN_FEDERATED)
|
||||||
|
|
||||||
|
|||||||
54
tests/cpp/collective/test_communicator.cc
Normal file
54
tests/cpp/collective/test_communicator.cc
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
/*!
|
||||||
|
* Copyright 2022 XGBoost contributors
|
||||||
|
*/
|
||||||
|
#include <dmlc/parameter.h>
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
#include "../../../src/collective/communicator.h"
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
namespace collective {
|
||||||
|
|
||||||
|
TEST(CommunicatorFactory, TypeFromEnv) {
|
||||||
|
EXPECT_EQ(CommunicatorType::kUnknown, Communicator::GetTypeFromEnv());
|
||||||
|
|
||||||
|
dmlc::SetEnv<std::string>("XGBOOST_COMMUNICATOR", "rabit");
|
||||||
|
EXPECT_EQ(CommunicatorType::kRabit, Communicator::GetTypeFromEnv());
|
||||||
|
|
||||||
|
dmlc::SetEnv<std::string>("XGBOOST_COMMUNICATOR", "Federated");
|
||||||
|
EXPECT_EQ(CommunicatorType::kFederated, Communicator::GetTypeFromEnv());
|
||||||
|
|
||||||
|
dmlc::SetEnv<std::string>("XGBOOST_COMMUNICATOR", "foo");
|
||||||
|
EXPECT_THROW(Communicator::GetTypeFromEnv(), dmlc::Error);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CommunicatorFactory, TypeFromArgs) {
|
||||||
|
Json config{JsonObject()};
|
||||||
|
EXPECT_EQ(CommunicatorType::kUnknown, Communicator::GetTypeFromConfig(config));
|
||||||
|
|
||||||
|
config["xgboost_communicator"] = String("rabit");
|
||||||
|
EXPECT_EQ(CommunicatorType::kRabit, Communicator::GetTypeFromConfig(config));
|
||||||
|
|
||||||
|
config["xgboost_communicator"] = String("federated");
|
||||||
|
EXPECT_EQ(CommunicatorType::kFederated, Communicator::GetTypeFromConfig(config));
|
||||||
|
|
||||||
|
config["xgboost_communicator"] = String("foo");
|
||||||
|
EXPECT_THROW(Communicator::GetTypeFromConfig(config), dmlc::Error);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CommunicatorFactory, TypeFromArgsUpperCase) {
|
||||||
|
Json config{JsonObject()};
|
||||||
|
EXPECT_EQ(CommunicatorType::kUnknown, Communicator::GetTypeFromConfig(config));
|
||||||
|
|
||||||
|
config["XGBOOST_COMMUNICATOR"] = String("rabit");
|
||||||
|
EXPECT_EQ(CommunicatorType::kRabit, Communicator::GetTypeFromConfig(config));
|
||||||
|
|
||||||
|
config["XGBOOST_COMMUNICATOR"] = String("federated");
|
||||||
|
EXPECT_EQ(CommunicatorType::kFederated, Communicator::GetTypeFromConfig(config));
|
||||||
|
|
||||||
|
config["XGBOOST_COMMUNICATOR"] = String("foo");
|
||||||
|
EXPECT_THROW(Communicator::GetTypeFromConfig(config), dmlc::Error);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace collective
|
||||||
|
} // namespace xgboost
|
||||||
26
tests/cpp/collective/test_nccl_device_communicator.cu
Normal file
26
tests/cpp/collective/test_nccl_device_communicator.cu
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
/*!
|
||||||
|
* Copyright 2022 XGBoost contributors
|
||||||
|
*/
|
||||||
|
#ifdef XGBOOST_USE_NCCL
|
||||||
|
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
#include "../../../src/collective/nccl_device_communicator.cuh"
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
namespace collective {
|
||||||
|
|
||||||
|
TEST(NcclDeviceCommunicatorSimpleTest, ThrowOnInvalidDeviceOrdinal) {
|
||||||
|
auto construct = []() { NcclDeviceCommunicator comm{-1, nullptr}; };
|
||||||
|
EXPECT_THROW(construct(), dmlc::Error);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(NcclDeviceCommunicatorSimpleTest, ThrowOnInvalidCommunicator) {
|
||||||
|
auto construct = []() { NcclDeviceCommunicator comm{0, nullptr}; };
|
||||||
|
EXPECT_THROW(construct(), dmlc::Error);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace collective
|
||||||
|
} // namespace xgboost
|
||||||
|
|
||||||
|
#endif
|
||||||
39
tests/cpp/collective/test_rabit_communicator.cc
Normal file
39
tests/cpp/collective/test_rabit_communicator.cc
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
/*!
|
||||||
|
* Copyright 2022 XGBoost contributors
|
||||||
|
*/
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
#include "../../../src/collective/rabit_communicator.h"
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
namespace collective {
|
||||||
|
|
||||||
|
TEST(RabitCommunicatorSimpleTest, ThrowOnWorldSizeTooSmall) {
|
||||||
|
auto construct = []() { RabitCommunicator comm{0, 0}; };
|
||||||
|
EXPECT_THROW(construct(), dmlc::Error);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(RabitCommunicatorSimpleTest, ThrowOnRankTooSmall) {
|
||||||
|
auto construct = []() { RabitCommunicator comm{1, -1}; };
|
||||||
|
EXPECT_THROW(construct(), dmlc::Error);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(RabitCommunicatorSimpleTest, ThrowOnRankTooBig) {
|
||||||
|
auto construct = []() { RabitCommunicator comm{1, 1}; };
|
||||||
|
EXPECT_THROW(construct(), dmlc::Error);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(RabitCommunicatorSimpleTest, GetWorldSizeAndRank) {
|
||||||
|
RabitCommunicator comm{6, 3};
|
||||||
|
EXPECT_EQ(comm.GetWorldSize(), 6);
|
||||||
|
EXPECT_EQ(comm.GetRank(), 3);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(RabitCommunicatorSimpleTest, IsNotDistributed) {
|
||||||
|
RabitCommunicator comm{2, 1};
|
||||||
|
// Rabit is only distributed with a tracker.
|
||||||
|
EXPECT_FALSE(comm.IsDistributed());
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace collective
|
||||||
|
} // namespace xgboost
|
||||||
105
tests/cpp/plugin/test_federated_adapter.cu
Normal file
105
tests/cpp/plugin/test_federated_adapter.cu
Normal file
@ -0,0 +1,105 @@
|
|||||||
|
/*!
|
||||||
|
* Copyright 2022 XGBoost contributors
|
||||||
|
*/
|
||||||
|
#include <grpcpp/server_builder.h>
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
#include <thrust/host_vector.h>
|
||||||
|
|
||||||
|
#include <thread>
|
||||||
|
|
||||||
|
#include "../../../plugin/federated/federated_communicator.h"
|
||||||
|
#include "../../../plugin/federated/federated_server.h"
|
||||||
|
#include "../../../src/collective/device_communicator_adapter.cuh"
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
namespace collective {
|
||||||
|
|
||||||
|
std::string const kServerAddress{"localhost:56789"}; // NOLINT(cert-err58-cpp)
|
||||||
|
|
||||||
|
class FederatedAdapterTest : public ::testing::Test {
|
||||||
|
protected:
|
||||||
|
void SetUp() override {
|
||||||
|
server_thread_.reset(new std::thread([this] {
|
||||||
|
grpc::ServerBuilder builder;
|
||||||
|
federated::FederatedService service{kWorldSize};
|
||||||
|
builder.AddListeningPort(kServerAddress, grpc::InsecureServerCredentials());
|
||||||
|
builder.RegisterService(&service);
|
||||||
|
server_ = builder.BuildAndStart();
|
||||||
|
server_->Wait();
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
void TearDown() override {
|
||||||
|
server_->Shutdown();
|
||||||
|
server_thread_->join();
|
||||||
|
}
|
||||||
|
|
||||||
|
static int const kWorldSize{2};
|
||||||
|
std::unique_ptr<std::thread> server_thread_;
|
||||||
|
std::unique_ptr<grpc::Server> server_;
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST(FederatedAdapterSimpleTest, ThrowOnInvalidDeviceOrdinal) {
|
||||||
|
auto construct = []() { DeviceCommunicatorAdapter adapter{-1, nullptr}; };
|
||||||
|
EXPECT_THROW(construct(), dmlc::Error);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(FederatedAdapterSimpleTest, ThrowOnInvalidCommunicator) {
|
||||||
|
auto construct = []() { DeviceCommunicatorAdapter adapter{0, nullptr}; };
|
||||||
|
EXPECT_THROW(construct(), dmlc::Error);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(FederatedAdapterTest, DeviceAllReduceSum) {
|
||||||
|
std::vector<std::thread> threads;
|
||||||
|
for (auto rank = 0; rank < kWorldSize; rank++) {
|
||||||
|
threads.emplace_back(std::thread([rank] {
|
||||||
|
FederatedCommunicator comm{kWorldSize, rank, kServerAddress};
|
||||||
|
DeviceCommunicatorAdapter adapter{rank, &comm};
|
||||||
|
int const count = 3;
|
||||||
|
thrust::device_vector<double> buffer(count, 0);
|
||||||
|
thrust::sequence(buffer.begin(), buffer.end());
|
||||||
|
adapter.AllReduceSum(buffer.data().get(), count);
|
||||||
|
thrust::host_vector<double> host_buffer = buffer;
|
||||||
|
EXPECT_EQ(host_buffer.size(), count);
|
||||||
|
for (auto i = 0; i < count; i++) {
|
||||||
|
EXPECT_EQ(host_buffer[i], i * 2);
|
||||||
|
}
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
for (auto& thread : threads) {
|
||||||
|
thread.join();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(FederatedAdapterTest, DeviceAllGatherV) {
|
||||||
|
std::vector<std::thread> threads;
|
||||||
|
for (auto rank = 0; rank < kWorldSize; rank++) {
|
||||||
|
threads.emplace_back(std::thread([rank] {
|
||||||
|
FederatedCommunicator comm{kWorldSize, rank, kServerAddress};
|
||||||
|
DeviceCommunicatorAdapter adapter{rank, &comm};
|
||||||
|
|
||||||
|
int const count = rank + 2;
|
||||||
|
thrust::device_vector<char> buffer(count, 0);
|
||||||
|
thrust::sequence(buffer.begin(), buffer.end());
|
||||||
|
std::vector<std::size_t> segments(kWorldSize);
|
||||||
|
dh::caching_device_vector<char> receive_buffer{};
|
||||||
|
|
||||||
|
adapter.AllGatherV(buffer.data().get(), count, &segments, &receive_buffer);
|
||||||
|
|
||||||
|
EXPECT_EQ(segments[0], 2);
|
||||||
|
EXPECT_EQ(segments[1], 3);
|
||||||
|
thrust::host_vector<char> host_buffer = receive_buffer;
|
||||||
|
EXPECT_EQ(host_buffer.size(), 5);
|
||||||
|
int expected[] = {0, 1, 0, 1, 2};
|
||||||
|
for (auto i = 0; i < 5; i++) {
|
||||||
|
EXPECT_EQ(host_buffer[i], expected[i]);
|
||||||
|
}
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
for (auto& thread : threads) {
|
||||||
|
thread.join();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace collective
|
||||||
|
} // namespace xgboost
|
||||||
119
tests/cpp/plugin/test_federated_communicator.cc
Normal file
119
tests/cpp/plugin/test_federated_communicator.cc
Normal file
@ -0,0 +1,119 @@
|
|||||||
|
/*!
|
||||||
|
* Copyright 2022 XGBoost contributors
|
||||||
|
*/
|
||||||
|
#include <dmlc/parameter.h>
|
||||||
|
#include <grpcpp/server_builder.h>
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
#include <thread>
|
||||||
|
|
||||||
|
#include "../../../plugin/federated/federated_communicator.h"
|
||||||
|
#include "../../../plugin/federated/federated_server.h"
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
namespace collective {
|
||||||
|
|
||||||
|
std::string const kServerAddress{"localhost:56789"}; // NOLINT(cert-err58-cpp)
|
||||||
|
|
||||||
|
class FederatedCommunicatorTest : public ::testing::Test {
|
||||||
|
public:
|
||||||
|
static void VerifyAllreduce(int rank) {
|
||||||
|
FederatedCommunicator comm{kWorldSize, rank, kServerAddress};
|
||||||
|
CheckAllreduce(comm);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void VerifyBroadcast(int rank) {
|
||||||
|
FederatedCommunicator comm{kWorldSize, rank, kServerAddress};
|
||||||
|
CheckBroadcast(comm, rank);
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
void SetUp() override {
|
||||||
|
server_thread_.reset(new std::thread([this] {
|
||||||
|
grpc::ServerBuilder builder;
|
||||||
|
federated::FederatedService service{kWorldSize};
|
||||||
|
builder.AddListeningPort(kServerAddress, grpc::InsecureServerCredentials());
|
||||||
|
builder.RegisterService(&service);
|
||||||
|
server_ = builder.BuildAndStart();
|
||||||
|
server_->Wait();
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
void TearDown() override {
|
||||||
|
server_->Shutdown();
|
||||||
|
server_thread_->join();
|
||||||
|
}
|
||||||
|
|
||||||
|
static void CheckAllreduce(FederatedCommunicator &comm) {
|
||||||
|
int buffer[] = {1, 2, 3, 4, 5};
|
||||||
|
comm.AllReduce(buffer, sizeof(buffer) / sizeof(buffer[0]), DataType::kInt32, Operation::kSum);
|
||||||
|
int expected[] = {3, 6, 9, 12, 15};
|
||||||
|
for (auto i = 0; i < 5; i++) {
|
||||||
|
EXPECT_EQ(buffer[i], expected[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void CheckBroadcast(FederatedCommunicator &comm, int rank) {
|
||||||
|
if (rank == 0) {
|
||||||
|
std::string buffer{"hello"};
|
||||||
|
comm.Broadcast(&buffer[0], buffer.size(), 0);
|
||||||
|
EXPECT_EQ(buffer, "hello");
|
||||||
|
} else {
|
||||||
|
std::string buffer{" "};
|
||||||
|
comm.Broadcast(&buffer[0], buffer.size(), 0);
|
||||||
|
EXPECT_EQ(buffer, "hello");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static int const kWorldSize{3};
|
||||||
|
std::unique_ptr<std::thread> server_thread_;
|
||||||
|
std::unique_ptr<grpc::Server> server_;
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST(FederatedCommunicatorSimpleTest, ThrowOnWorldSizeTooSmall) {
|
||||||
|
auto construct = []() { FederatedCommunicator comm{0, 0, kServerAddress, "", "", ""}; };
|
||||||
|
EXPECT_THROW(construct(), dmlc::Error);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(FederatedCommunicatorSimpleTest, ThrowOnRankTooSmall) {
|
||||||
|
auto construct = []() { FederatedCommunicator comm{1, -1, kServerAddress, "", "", ""}; };
|
||||||
|
EXPECT_THROW(construct(), dmlc::Error);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(FederatedCommunicatorSimpleTest, ThrowOnRankTooBig) {
|
||||||
|
auto construct = []() { FederatedCommunicator comm{1, 1, kServerAddress, "", "", ""}; };
|
||||||
|
EXPECT_THROW(construct(), dmlc::Error);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(FederatedCommunicatorSimpleTest, GetWorldSizeAndRank) {
|
||||||
|
FederatedCommunicator comm{6, 3, kServerAddress};
|
||||||
|
EXPECT_EQ(comm.GetWorldSize(), 6);
|
||||||
|
EXPECT_EQ(comm.GetRank(), 3);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(FederatedCommunicatorSimpleTest, IsDistributed) {
|
||||||
|
FederatedCommunicator comm{2, 1, kServerAddress};
|
||||||
|
EXPECT_TRUE(comm.IsDistributed());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(FederatedCommunicatorTest, Allreduce) {
|
||||||
|
std::vector<std::thread> threads;
|
||||||
|
for (auto rank = 0; rank < kWorldSize; rank++) {
|
||||||
|
threads.emplace_back(std::thread(&FederatedCommunicatorTest::VerifyAllreduce, rank));
|
||||||
|
}
|
||||||
|
for (auto &thread : threads) {
|
||||||
|
thread.join();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(FederatedCommunicatorTest, Broadcast) {
|
||||||
|
std::vector<std::thread> threads;
|
||||||
|
for (auto rank = 0; rank < kWorldSize; rank++) {
|
||||||
|
threads.emplace_back(std::thread(&FederatedCommunicatorTest::VerifyBroadcast, rank));
|
||||||
|
}
|
||||||
|
for (auto &thread : threads) {
|
||||||
|
thread.join();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace collective
|
||||||
|
} // namespace xgboost
|
||||||
@ -62,7 +62,7 @@ class FederatedServerTest : public ::testing::Test {
|
|||||||
static void CheckAllreduce(federated::FederatedClient& client) {
|
static void CheckAllreduce(federated::FederatedClient& client) {
|
||||||
int data[] = {1, 2, 3, 4, 5};
|
int data[] = {1, 2, 3, 4, 5};
|
||||||
std::string send_buffer(reinterpret_cast<char const*>(data), sizeof(data));
|
std::string send_buffer(reinterpret_cast<char const*>(data), sizeof(data));
|
||||||
auto reply = client.Allreduce(send_buffer, federated::INT, federated::SUM);
|
auto reply = client.Allreduce(send_buffer, federated::INT32, federated::SUM);
|
||||||
auto const* result = reinterpret_cast<int const*>(reply.data());
|
auto const* result = reinterpret_cast<int const*>(reply.data());
|
||||||
int expected[] = {3, 6, 9, 12, 15};
|
int expected[] = {3, 6, 9, 12, 15};
|
||||||
for (auto i = 0; i < 5; i++) {
|
for (auto i = 0; i < 5; i++) {
|
||||||
|
|||||||
@ -22,6 +22,7 @@ def run_server(port: int, world_size: int, with_ssl: bool) -> None:
|
|||||||
|
|
||||||
def run_worker(port: int, world_size: int, rank: int, with_ssl: bool, with_gpu: bool) -> None:
|
def run_worker(port: int, world_size: int, rank: int, with_ssl: bool, with_gpu: bool) -> None:
|
||||||
rabit_env = [
|
rabit_env = [
|
||||||
|
'xgboost_communicator=federated',
|
||||||
f'federated_server_address=localhost:{port}',
|
f'federated_server_address=localhost:{port}',
|
||||||
f'federated_world_size={world_size}',
|
f'federated_world_size={world_size}',
|
||||||
f'federated_rank={rank}'
|
f'federated_rank={rank}'
|
||||||
|
|||||||
39
tests/python/test_collective.py
Normal file
39
tests/python/test_collective.py
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
import multiprocessing
|
||||||
|
import socket
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import xgboost as xgb
|
||||||
|
from xgboost import RabitTracker
|
||||||
|
from xgboost import collective
|
||||||
|
|
||||||
|
if sys.platform.startswith("win"):
|
||||||
|
pytest.skip("Skipping collective tests on Windows", allow_module_level=True)
|
||||||
|
|
||||||
|
|
||||||
|
def run_rabit_worker(rabit_env, world_size):
|
||||||
|
with xgb.collective.CommunicatorContext(**rabit_env):
|
||||||
|
assert xgb.collective.get_world_size() == world_size
|
||||||
|
assert xgb.collective.is_distributed()
|
||||||
|
assert xgb.collective.get_processor_name() == socket.gethostname()
|
||||||
|
ret = xgb.collective.broadcast('test1234', 0)
|
||||||
|
assert str(ret) == 'test1234'
|
||||||
|
ret = xgb.collective.allreduce(np.asarray([1, 2, 3]), xgb.collective.Op.SUM)
|
||||||
|
assert np.array_equal(ret, np.asarray([2, 4, 6]))
|
||||||
|
|
||||||
|
|
||||||
|
def test_rabit_communicator():
|
||||||
|
world_size = 2
|
||||||
|
tracker = RabitTracker(host_ip='127.0.0.1', n_workers=world_size)
|
||||||
|
tracker.start(world_size)
|
||||||
|
workers = []
|
||||||
|
for _ in range(world_size):
|
||||||
|
worker = multiprocessing.Process(target=run_rabit_worker,
|
||||||
|
args=(tracker.worker_envs(), world_size))
|
||||||
|
workers.append(worker)
|
||||||
|
worker.start()
|
||||||
|
for worker in workers:
|
||||||
|
worker.join()
|
||||||
|
assert worker.exitcode == 0
|
||||||
Loading…
x
Reference in New Issue
Block a user