Common interface for collective communication (#8057)
* implement broadcast for federated communicator * implement allreduce * add communicator factory * add device adapter * add device communicator to factory * add rabit communicator * add rabit communicator to the factory * add nccl device communicator * add synchronize to device communicator * add back print and getprocessorname * add python wrapper and c api * clean up types * fix non-gpu build * try to fix ci * fix std::size_t * portable string compare ignore case * c style size_t * fix lint errors * cross platform setenv * fix memory leak * fix lint errors * address review feedback * add python test for rabit communicator * fix failing gtest * use json to configure communicators * fix lint error * get rid of factories * fix cpu build * fix include * fix python import * don't export collective.py yet * skip collective communicator pytest on windows * add review feedback * update documentation * remove mpi communicator type * fix tests * shutdown the communicator separately Co-authored-by: Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
@@ -10,6 +10,8 @@
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
|
||||
#include "../../src/common/io.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace federated {
|
||||
|
||||
@@ -71,32 +73,35 @@ class AllreduceFunctor {
|
||||
void Accumulate(std::string& buffer, std::string const& input, DataType data_type,
|
||||
ReduceOperation reduce_operation) const {
|
||||
switch (data_type) {
|
||||
case DataType::CHAR:
|
||||
Accumulate(&buffer[0], reinterpret_cast<char const*>(input.data()), buffer.size(),
|
||||
case DataType::INT8:
|
||||
Accumulate(reinterpret_cast<std::int8_t*>(&buffer[0]),
|
||||
reinterpret_cast<std::int8_t const*>(input.data()), buffer.size(),
|
||||
reduce_operation);
|
||||
break;
|
||||
case DataType::UCHAR:
|
||||
Accumulate(reinterpret_cast<unsigned char*>(&buffer[0]),
|
||||
reinterpret_cast<unsigned char const*>(input.data()), buffer.size(),
|
||||
case DataType::UINT8:
|
||||
Accumulate(reinterpret_cast<std::uint8_t*>(&buffer[0]),
|
||||
reinterpret_cast<std::uint8_t const*>(input.data()), buffer.size(),
|
||||
reduce_operation);
|
||||
break;
|
||||
case DataType::INT:
|
||||
Accumulate(reinterpret_cast<int*>(&buffer[0]), reinterpret_cast<int const*>(input.data()),
|
||||
buffer.size() / sizeof(int), reduce_operation);
|
||||
case DataType::INT32:
|
||||
Accumulate(reinterpret_cast<std::int32_t*>(&buffer[0]),
|
||||
reinterpret_cast<std::int32_t const*>(input.data()),
|
||||
buffer.size() / sizeof(std::uint32_t), reduce_operation);
|
||||
break;
|
||||
case DataType::UINT:
|
||||
Accumulate(reinterpret_cast<unsigned int*>(&buffer[0]),
|
||||
reinterpret_cast<unsigned int const*>(input.data()),
|
||||
buffer.size() / sizeof(unsigned int), reduce_operation);
|
||||
case DataType::UINT32:
|
||||
Accumulate(reinterpret_cast<std::uint32_t*>(&buffer[0]),
|
||||
reinterpret_cast<std::uint32_t const*>(input.data()),
|
||||
buffer.size() / sizeof(std::uint32_t), reduce_operation);
|
||||
break;
|
||||
case DataType::LONG:
|
||||
Accumulate(reinterpret_cast<long*>(&buffer[0]), reinterpret_cast<long const*>(input.data()),
|
||||
buffer.size() / sizeof(long), reduce_operation);
|
||||
case DataType::INT64:
|
||||
Accumulate(reinterpret_cast<std::int64_t*>(&buffer[0]),
|
||||
reinterpret_cast<std::int64_t const*>(input.data()),
|
||||
buffer.size() / sizeof(std::int64_t), reduce_operation);
|
||||
break;
|
||||
case DataType::ULONG:
|
||||
Accumulate(reinterpret_cast<unsigned long*>(&buffer[0]),
|
||||
reinterpret_cast<unsigned long const*>(input.data()),
|
||||
buffer.size() / sizeof(unsigned long), reduce_operation);
|
||||
case DataType::UINT64:
|
||||
Accumulate(reinterpret_cast<std::uint64_t*>(&buffer[0]),
|
||||
reinterpret_cast<std::uint64_t const*>(input.data()),
|
||||
buffer.size() / sizeof(std::uint64_t), reduce_operation);
|
||||
break;
|
||||
case DataType::FLOAT:
|
||||
Accumulate(reinterpret_cast<float*>(&buffer[0]),
|
||||
@@ -108,16 +113,6 @@ class AllreduceFunctor {
|
||||
reinterpret_cast<double const*>(input.data()), buffer.size() / sizeof(double),
|
||||
reduce_operation);
|
||||
break;
|
||||
case DataType::LONGLONG:
|
||||
Accumulate(reinterpret_cast<long long*>(&buffer[0]),
|
||||
reinterpret_cast<long long const*>(input.data()),
|
||||
buffer.size() / sizeof(long long), reduce_operation);
|
||||
break;
|
||||
case DataType::ULONGLONG:
|
||||
Accumulate(reinterpret_cast<unsigned long long*>(&buffer[0]),
|
||||
reinterpret_cast<unsigned long long const*>(input.data()),
|
||||
buffer.size() / sizeof(unsigned long long), reduce_operation);
|
||||
break;
|
||||
default:
|
||||
throw std::invalid_argument("Invalid data type");
|
||||
}
|
||||
@@ -201,13 +196,6 @@ grpc::Status FederatedService::Handle(Request const* request, Reply* reply,
|
||||
return grpc::Status::OK;
|
||||
}
|
||||
|
||||
std::string ReadFile(char const* path) {
|
||||
auto stream = std::ifstream(path);
|
||||
std::ostringstream out;
|
||||
out << stream.rdbuf();
|
||||
return out.str();
|
||||
}
|
||||
|
||||
void RunServer(int port, int world_size, char const* server_key_file, char const* server_cert_file,
|
||||
char const* client_cert_file) {
|
||||
std::string const server_address = "0.0.0.0:" + std::to_string(port);
|
||||
@@ -216,10 +204,10 @@ void RunServer(int port, int world_size, char const* server_key_file, char const
|
||||
grpc::ServerBuilder builder;
|
||||
auto options =
|
||||
grpc::SslServerCredentialsOptions(GRPC_SSL_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY);
|
||||
options.pem_root_certs = ReadFile(client_cert_file);
|
||||
options.pem_root_certs = xgboost::common::ReadAll(client_cert_file);
|
||||
auto key = grpc::SslServerCredentialsOptions::PemKeyCertPair();
|
||||
key.private_key = ReadFile(server_key_file);
|
||||
key.cert_chain = ReadFile(server_cert_file);
|
||||
key.private_key = xgboost::common::ReadAll(server_key_file);
|
||||
key.cert_chain = xgboost::common::ReadAll(server_cert_file);
|
||||
options.pem_key_cert_pairs.push_back(key);
|
||||
builder.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
|
||||
builder.AddListeningPort(server_address, grpc::SslServerCredentials(options));
|
||||
|
||||
Reference in New Issue
Block a user