Revamp the rabit implementation. (#10112)
This PR replaces the original RABIT implementation with a new one, which has already been partially merged into XGBoost. The new one features: - Federated learning for both CPU and GPU. - NCCL. - More data types. - A unified interface for all the underlying implementations. - Improved timeout handling for both tracker and workers. - Exhausted tests with metrics (fixed a couple of bugs along the way). - A reusable tracker for Python and JVM packages.
This commit is contained in:
@@ -31,31 +31,13 @@ protobuf_generate(
|
||||
PLUGIN "protoc-gen-grpc=\$<TARGET_FILE:gRPC::grpc_cpp_plugin>"
|
||||
PROTOC_OUT_DIR "${PROTO_BINARY_DIR}")
|
||||
|
||||
add_library(federated_old_proto STATIC federated.old.proto)
|
||||
target_link_libraries(federated_old_proto PUBLIC protobuf::libprotobuf gRPC::grpc gRPC::grpc++)
|
||||
target_include_directories(federated_old_proto PUBLIC ${CMAKE_CURRENT_BINARY_DIR})
|
||||
xgboost_target_properties(federated_old_proto)
|
||||
|
||||
protobuf_generate(
|
||||
TARGET federated_old_proto
|
||||
LANGUAGE cpp
|
||||
PROTOC_OUT_DIR "${PROTO_BINARY_DIR}")
|
||||
protobuf_generate(
|
||||
TARGET federated_old_proto
|
||||
LANGUAGE grpc
|
||||
GENERATE_EXTENSIONS .grpc.pb.h .grpc.pb.cc
|
||||
PLUGIN "protoc-gen-grpc=\$<TARGET_FILE:gRPC::grpc_cpp_plugin>"
|
||||
PROTOC_OUT_DIR "${PROTO_BINARY_DIR}")
|
||||
|
||||
# Wrapper for the gRPC client.
|
||||
add_library(federated_client INTERFACE)
|
||||
target_sources(federated_client INTERFACE federated_client.h)
|
||||
target_link_libraries(federated_client INTERFACE federated_proto)
|
||||
target_link_libraries(federated_client INTERFACE federated_old_proto)
|
||||
|
||||
# Rabit engine for Federated Learning.
|
||||
target_sources(
|
||||
objxgboost PRIVATE federated_tracker.cc federated_server.cc federated_comm.cc federated_coll.cc
|
||||
objxgboost PRIVATE federated_tracker.cc federated_comm.cc federated_coll.cc
|
||||
)
|
||||
if(USE_CUDA)
|
||||
target_sources(objxgboost PRIVATE federated_comm.cu federated_coll.cu)
|
||||
|
||||
@@ -1,81 +0,0 @@
|
||||
/*!
|
||||
* Copyright 2022 XGBoost contributors
|
||||
*/
|
||||
syntax = "proto3";
|
||||
|
||||
package xgboost.federated;
|
||||
|
||||
service Federated {
|
||||
rpc Allgather(AllgatherRequest) returns (AllgatherReply) {}
|
||||
rpc AllgatherV(AllgatherVRequest) returns (AllgatherVReply) {}
|
||||
rpc Allreduce(AllreduceRequest) returns (AllreduceReply) {}
|
||||
rpc Broadcast(BroadcastRequest) returns (BroadcastReply) {}
|
||||
}
|
||||
|
||||
enum DataType {
|
||||
INT8 = 0;
|
||||
UINT8 = 1;
|
||||
INT32 = 2;
|
||||
UINT32 = 3;
|
||||
INT64 = 4;
|
||||
UINT64 = 5;
|
||||
FLOAT = 6;
|
||||
DOUBLE = 7;
|
||||
}
|
||||
|
||||
enum ReduceOperation {
|
||||
MAX = 0;
|
||||
MIN = 1;
|
||||
SUM = 2;
|
||||
BITWISE_AND = 3;
|
||||
BITWISE_OR = 4;
|
||||
BITWISE_XOR = 5;
|
||||
}
|
||||
|
||||
message AllgatherRequest {
|
||||
// An incrementing counter that is unique to each round to operations.
|
||||
uint64 sequence_number = 1;
|
||||
int32 rank = 2;
|
||||
bytes send_buffer = 3;
|
||||
}
|
||||
|
||||
message AllgatherReply {
|
||||
bytes receive_buffer = 1;
|
||||
}
|
||||
|
||||
message AllgatherVRequest {
|
||||
// An incrementing counter that is unique to each round to operations.
|
||||
uint64 sequence_number = 1;
|
||||
int32 rank = 2;
|
||||
bytes send_buffer = 3;
|
||||
}
|
||||
|
||||
message AllgatherVReply {
|
||||
bytes receive_buffer = 1;
|
||||
}
|
||||
|
||||
message AllreduceRequest {
|
||||
// An incrementing counter that is unique to each round to operations.
|
||||
uint64 sequence_number = 1;
|
||||
int32 rank = 2;
|
||||
bytes send_buffer = 3;
|
||||
DataType data_type = 4;
|
||||
ReduceOperation reduce_operation = 5;
|
||||
}
|
||||
|
||||
message AllreduceReply {
|
||||
bytes receive_buffer = 1;
|
||||
}
|
||||
|
||||
message BroadcastRequest {
|
||||
// An incrementing counter that is unique to each round to operations.
|
||||
uint64 sequence_number = 1;
|
||||
int32 rank = 2;
|
||||
bytes send_buffer = 3;
|
||||
// The root rank to broadcast from.
|
||||
int32 root = 4;
|
||||
}
|
||||
|
||||
message BroadcastReply {
|
||||
bytes receive_buffer = 1;
|
||||
}
|
||||
@@ -1,132 +0,0 @@
|
||||
/*!
|
||||
* Copyright 2022 XGBoost contributors
|
||||
*/
|
||||
#pragma once
|
||||
#include <federated.old.grpc.pb.h>
|
||||
#include <federated.old.pb.h>
|
||||
#include <grpcpp/grpcpp.h>
|
||||
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <limits>
|
||||
#include <string>
|
||||
|
||||
namespace xgboost::federated {
|
||||
/**
|
||||
* @brief A wrapper around the gRPC client.
|
||||
*/
|
||||
class FederatedClient {
|
||||
public:
|
||||
FederatedClient(std::string const &server_address, int rank, std::string const &server_cert,
|
||||
std::string const &client_key, std::string const &client_cert)
|
||||
: stub_{[&] {
|
||||
grpc::SslCredentialsOptions options;
|
||||
options.pem_root_certs = server_cert;
|
||||
options.pem_private_key = client_key;
|
||||
options.pem_cert_chain = client_cert;
|
||||
grpc::ChannelArguments args;
|
||||
args.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
|
||||
auto channel =
|
||||
grpc::CreateCustomChannel(server_address, grpc::SslCredentials(options), args);
|
||||
channel->WaitForConnected(
|
||||
gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), gpr_time_from_seconds(60, GPR_TIMESPAN)));
|
||||
return Federated::NewStub(channel);
|
||||
}()},
|
||||
rank_{rank} {}
|
||||
|
||||
/** @brief Insecure client for connecting to localhost only. */
|
||||
FederatedClient(std::string const &server_address, int rank)
|
||||
: stub_{[&] {
|
||||
grpc::ChannelArguments args;
|
||||
args.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
|
||||
return Federated::NewStub(
|
||||
grpc::CreateCustomChannel(server_address, grpc::InsecureChannelCredentials(), args));
|
||||
}()},
|
||||
rank_{rank} {}
|
||||
|
||||
std::string Allgather(std::string_view send_buffer) {
|
||||
AllgatherRequest request;
|
||||
request.set_sequence_number(sequence_number_++);
|
||||
request.set_rank(rank_);
|
||||
request.set_send_buffer(send_buffer.data(), send_buffer.size());
|
||||
|
||||
AllgatherReply reply;
|
||||
grpc::ClientContext context;
|
||||
context.set_wait_for_ready(true);
|
||||
grpc::Status status = stub_->Allgather(&context, request, &reply);
|
||||
|
||||
if (status.ok()) {
|
||||
return reply.receive_buffer();
|
||||
} else {
|
||||
std::cout << status.error_code() << ": " << status.error_message() << '\n';
|
||||
throw std::runtime_error("Allgather RPC failed");
|
||||
}
|
||||
}
|
||||
|
||||
std::string AllgatherV(std::string_view send_buffer) {
|
||||
AllgatherVRequest request;
|
||||
request.set_sequence_number(sequence_number_++);
|
||||
request.set_rank(rank_);
|
||||
request.set_send_buffer(send_buffer.data(), send_buffer.size());
|
||||
|
||||
AllgatherVReply reply;
|
||||
grpc::ClientContext context;
|
||||
context.set_wait_for_ready(true);
|
||||
grpc::Status status = stub_->AllgatherV(&context, request, &reply);
|
||||
|
||||
if (status.ok()) {
|
||||
return reply.receive_buffer();
|
||||
} else {
|
||||
std::cout << status.error_code() << ": " << status.error_message() << '\n';
|
||||
throw std::runtime_error("AllgatherV RPC failed");
|
||||
}
|
||||
}
|
||||
|
||||
std::string Allreduce(std::string const &send_buffer, DataType data_type,
|
||||
ReduceOperation reduce_operation) {
|
||||
AllreduceRequest request;
|
||||
request.set_sequence_number(sequence_number_++);
|
||||
request.set_rank(rank_);
|
||||
request.set_send_buffer(send_buffer);
|
||||
request.set_data_type(data_type);
|
||||
request.set_reduce_operation(reduce_operation);
|
||||
|
||||
AllreduceReply reply;
|
||||
grpc::ClientContext context;
|
||||
context.set_wait_for_ready(true);
|
||||
grpc::Status status = stub_->Allreduce(&context, request, &reply);
|
||||
|
||||
if (status.ok()) {
|
||||
return reply.receive_buffer();
|
||||
} else {
|
||||
std::cout << status.error_code() << ": " << status.error_message() << '\n';
|
||||
throw std::runtime_error("Allreduce RPC failed");
|
||||
}
|
||||
}
|
||||
|
||||
std::string Broadcast(std::string const &send_buffer, int root) {
|
||||
BroadcastRequest request;
|
||||
request.set_sequence_number(sequence_number_++);
|
||||
request.set_rank(rank_);
|
||||
request.set_send_buffer(send_buffer);
|
||||
request.set_root(root);
|
||||
|
||||
BroadcastReply reply;
|
||||
grpc::ClientContext context;
|
||||
context.set_wait_for_ready(true);
|
||||
grpc::Status status = stub_->Broadcast(&context, request, &reply);
|
||||
|
||||
if (status.ok()) {
|
||||
return reply.receive_buffer();
|
||||
} else {
|
||||
std::cout << status.error_code() << ": " << status.error_message() << '\n';
|
||||
throw std::runtime_error("Broadcast RPC failed");
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
std::unique_ptr<Federated::Stub> const stub_;
|
||||
int const rank_;
|
||||
uint64_t sequence_number_{};
|
||||
};
|
||||
} // namespace xgboost::federated
|
||||
@@ -1,195 +0,0 @@
|
||||
/*!
|
||||
* Copyright 2022 XGBoost contributors
|
||||
*/
|
||||
#pragma once
|
||||
#include <xgboost/json.h>
|
||||
|
||||
#include "../../src/c_api/c_api_utils.h"
|
||||
#include "../../src/collective/communicator.h"
|
||||
#include "../../src/common/io.h"
|
||||
#include "federated_client.h"
|
||||
|
||||
namespace xgboost::collective {
|
||||
/**
|
||||
* @brief A Federated Learning communicator class that handles collective communication.
|
||||
*/
|
||||
class FederatedCommunicator : public Communicator {
|
||||
public:
|
||||
/**
|
||||
* @brief Create a new communicator based on JSON configuration.
|
||||
* @param config JSON configuration.
|
||||
* @return Communicator as specified by the JSON configuration.
|
||||
*/
|
||||
static Communicator *Create(Json const &config) {
|
||||
std::string server_address{};
|
||||
int world_size{0};
|
||||
int rank{-1};
|
||||
std::string server_cert{};
|
||||
std::string client_key{};
|
||||
std::string client_cert{};
|
||||
|
||||
// Parse environment variables first.
|
||||
auto *value = getenv("FEDERATED_SERVER_ADDRESS");
|
||||
if (value != nullptr) {
|
||||
server_address = value;
|
||||
}
|
||||
value = getenv("FEDERATED_WORLD_SIZE");
|
||||
if (value != nullptr) {
|
||||
world_size = std::stoi(value);
|
||||
}
|
||||
value = getenv("FEDERATED_RANK");
|
||||
if (value != nullptr) {
|
||||
rank = std::stoi(value);
|
||||
}
|
||||
value = getenv("FEDERATED_SERVER_CERT");
|
||||
if (value != nullptr) {
|
||||
server_cert = value;
|
||||
}
|
||||
value = getenv("FEDERATED_CLIENT_KEY");
|
||||
if (value != nullptr) {
|
||||
client_key = value;
|
||||
}
|
||||
value = getenv("FEDERATED_CLIENT_CERT");
|
||||
if (value != nullptr) {
|
||||
client_cert = value;
|
||||
}
|
||||
|
||||
// Runtime configuration overrides, optional as users can specify them as env vars.
|
||||
server_address = OptionalArg<String>(config, "federated_server_address", server_address);
|
||||
world_size =
|
||||
OptionalArg<Integer>(config, "federated_world_size", static_cast<Integer::Int>(world_size));
|
||||
rank = OptionalArg<Integer>(config, "federated_rank", static_cast<Integer::Int>(rank));
|
||||
server_cert = OptionalArg<String>(config, "federated_server_cert", server_cert);
|
||||
client_key = OptionalArg<String>(config, "federated_client_key", client_key);
|
||||
client_cert = OptionalArg<String>(config, "federated_client_cert", client_cert);
|
||||
|
||||
if (server_address.empty()) {
|
||||
LOG(FATAL) << "Federated server address must be set.";
|
||||
}
|
||||
if (world_size == 0) {
|
||||
LOG(FATAL) << "Federated world size must be set.";
|
||||
}
|
||||
if (rank == -1) {
|
||||
LOG(FATAL) << "Federated rank must be set.";
|
||||
}
|
||||
return new FederatedCommunicator(world_size, rank, server_address, server_cert, client_key,
|
||||
client_cert);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Construct a new federated communicator.
|
||||
*
|
||||
* @param world_size Total number of processes.
|
||||
* @param rank Rank of the current process.
|
||||
* @param server_address Address of the federated server (host:port).
|
||||
* @param server_cert_path Path to the server cert file.
|
||||
* @param client_key_path Path to the client key file.
|
||||
* @param client_cert_path Path to the client cert file.
|
||||
*/
|
||||
FederatedCommunicator(int world_size, int rank, std::string const &server_address,
|
||||
std::string const &server_cert_path, std::string const &client_key_path,
|
||||
std::string const &client_cert_path)
|
||||
: Communicator{world_size, rank} {
|
||||
if (server_cert_path.empty() || client_key_path.empty() || client_cert_path.empty()) {
|
||||
client_.reset(new xgboost::federated::FederatedClient(server_address, rank));
|
||||
} else {
|
||||
client_.reset(new xgboost::federated::FederatedClient(
|
||||
server_address, rank, xgboost::common::ReadAll(server_cert_path),
|
||||
xgboost::common::ReadAll(client_key_path), xgboost::common::ReadAll(client_cert_path)));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Construct an insecure federated communicator without using SSL.
|
||||
* @param world_size Total number of processes.
|
||||
* @param rank Rank of the current process.
|
||||
* @param server_address Address of the federated server (host:port).
|
||||
*/
|
||||
FederatedCommunicator(int world_size, int rank, std::string const &server_address)
|
||||
: Communicator{world_size, rank} {
|
||||
client_.reset(new xgboost::federated::FederatedClient(server_address, rank));
|
||||
}
|
||||
|
||||
~FederatedCommunicator() override { client_.reset(); }
|
||||
|
||||
/**
|
||||
* \brief Get if the communicator is distributed.
|
||||
* \return True.
|
||||
*/
|
||||
[[nodiscard]] bool IsDistributed() const override { return true; }
|
||||
|
||||
/**
|
||||
* \brief Get if the communicator is federated.
|
||||
* \return True.
|
||||
*/
|
||||
[[nodiscard]] bool IsFederated() const override { return true; }
|
||||
|
||||
/**
|
||||
* \brief Perform allgather.
|
||||
* \param input Buffer for sending data.
|
||||
*/
|
||||
std::string AllGather(std::string_view input) override {
|
||||
return client_->Allgather(input);
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Perform variable-length allgather.
|
||||
* \param input Buffer for sending data.
|
||||
*/
|
||||
std::string AllGatherV(std::string_view input) override {
|
||||
return client_->AllgatherV(input);
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Perform in-place allreduce.
|
||||
* \param send_receive_buffer Buffer for both sending and receiving data.
|
||||
* \param count Number of elements to be reduced.
|
||||
* \param data_type Enumeration of data type.
|
||||
* \param op Enumeration of operation type.
|
||||
*/
|
||||
void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
|
||||
Operation op) override {
|
||||
std::string const send_buffer(reinterpret_cast<char const *>(send_receive_buffer),
|
||||
count * GetTypeSize(data_type));
|
||||
auto const received =
|
||||
client_->Allreduce(send_buffer, static_cast<xgboost::federated::DataType>(data_type),
|
||||
static_cast<xgboost::federated::ReduceOperation>(op));
|
||||
received.copy(reinterpret_cast<char *>(send_receive_buffer), count * GetTypeSize(data_type));
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Broadcast a memory region to all others from root.
|
||||
* \param send_receive_buffer Pointer to the send or receive buffer.
|
||||
* \param size Size of the data.
|
||||
* \param root The process rank to broadcast from.
|
||||
*/
|
||||
void Broadcast(void *send_receive_buffer, std::size_t size, int root) override {
|
||||
if (GetWorldSize() == 1) return;
|
||||
if (GetRank() == root) {
|
||||
std::string const send_buffer(reinterpret_cast<char const *>(send_receive_buffer), size);
|
||||
client_->Broadcast(send_buffer, root);
|
||||
} else {
|
||||
auto const received = client_->Broadcast("", root);
|
||||
received.copy(reinterpret_cast<char *>(send_receive_buffer), size);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Get the name of the processor.
|
||||
* \return Name of the processor.
|
||||
*/
|
||||
std::string GetProcessorName() override { return "rank" + std::to_string(GetRank()); }
|
||||
|
||||
/**
|
||||
* \brief Print the message to the communicator.
|
||||
* \param message The message to be printed.
|
||||
*/
|
||||
void Print(const std::string &message) override { LOG(CONSOLE) << message; }
|
||||
|
||||
protected:
|
||||
void Shutdown() override {}
|
||||
|
||||
private:
|
||||
std::unique_ptr<xgboost::federated::FederatedClient> client_{};
|
||||
};
|
||||
} // namespace xgboost::collective
|
||||
@@ -1,86 +0,0 @@
|
||||
/*!
|
||||
* Copyright 2022 XGBoost contributors
|
||||
*/
|
||||
#include "federated_server.h"
|
||||
|
||||
#include <grpcpp/grpcpp.h>
|
||||
#include <grpcpp/server.h> // for Server
|
||||
#include <grpcpp/server_builder.h>
|
||||
#include <xgboost/logging.h>
|
||||
|
||||
#include <sstream>
|
||||
|
||||
#include "../../src/collective/comm.h"
|
||||
#include "../../src/common/io.h"
|
||||
#include "../../src/common/json_utils.h"
|
||||
|
||||
namespace xgboost::federated {
|
||||
grpc::Status FederatedService::Allgather(grpc::ServerContext*, AllgatherRequest const* request,
|
||||
AllgatherReply* reply) {
|
||||
handler_.Allgather(request->send_buffer().data(), request->send_buffer().size(),
|
||||
reply->mutable_receive_buffer(), request->sequence_number(), request->rank());
|
||||
return grpc::Status::OK;
|
||||
}
|
||||
|
||||
grpc::Status FederatedService::AllgatherV(grpc::ServerContext*, AllgatherVRequest const* request,
|
||||
AllgatherVReply* reply) {
|
||||
handler_.AllgatherV(request->send_buffer().data(), request->send_buffer().size(),
|
||||
reply->mutable_receive_buffer(), request->sequence_number(), request->rank());
|
||||
return grpc::Status::OK;
|
||||
}
|
||||
|
||||
grpc::Status FederatedService::Allreduce(grpc::ServerContext*, AllreduceRequest const* request,
|
||||
AllreduceReply* reply) {
|
||||
handler_.Allreduce(request->send_buffer().data(), request->send_buffer().size(),
|
||||
reply->mutable_receive_buffer(), request->sequence_number(), request->rank(),
|
||||
static_cast<xgboost::collective::DataType>(request->data_type()),
|
||||
static_cast<xgboost::collective::Operation>(request->reduce_operation()));
|
||||
return grpc::Status::OK;
|
||||
}
|
||||
|
||||
grpc::Status FederatedService::Broadcast(grpc::ServerContext*, BroadcastRequest const* request,
|
||||
BroadcastReply* reply) {
|
||||
handler_.Broadcast(request->send_buffer().data(), request->send_buffer().size(),
|
||||
reply->mutable_receive_buffer(), request->sequence_number(), request->rank(),
|
||||
request->root());
|
||||
return grpc::Status::OK;
|
||||
}
|
||||
|
||||
void RunServer(int port, std::size_t world_size, char const* server_key_file,
|
||||
char const* server_cert_file, char const* client_cert_file) {
|
||||
std::string const server_address = "0.0.0.0:" + std::to_string(port);
|
||||
FederatedService service{static_cast<std::int32_t>(world_size)};
|
||||
|
||||
grpc::ServerBuilder builder;
|
||||
auto options =
|
||||
grpc::SslServerCredentialsOptions(GRPC_SSL_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY);
|
||||
options.pem_root_certs = xgboost::common::ReadAll(client_cert_file);
|
||||
auto key = grpc::SslServerCredentialsOptions::PemKeyCertPair();
|
||||
key.private_key = xgboost::common::ReadAll(server_key_file);
|
||||
key.cert_chain = xgboost::common::ReadAll(server_cert_file);
|
||||
options.pem_key_cert_pairs.push_back(key);
|
||||
builder.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
|
||||
builder.AddListeningPort(server_address, grpc::SslServerCredentials(options));
|
||||
builder.RegisterService(&service);
|
||||
std::unique_ptr<grpc::Server> server(builder.BuildAndStart());
|
||||
LOG(CONSOLE) << "Federated server listening on " << server_address << ", world size "
|
||||
<< world_size;
|
||||
|
||||
server->Wait();
|
||||
}
|
||||
|
||||
void RunInsecureServer(int port, std::size_t world_size) {
|
||||
std::string const server_address = "0.0.0.0:" + std::to_string(port);
|
||||
FederatedService service{static_cast<std::int32_t>(world_size)};
|
||||
|
||||
grpc::ServerBuilder builder;
|
||||
builder.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
|
||||
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials());
|
||||
builder.RegisterService(&service);
|
||||
std::unique_ptr<grpc::Server> server(builder.BuildAndStart());
|
||||
LOG(CONSOLE) << "Insecure federated server listening on " << server_address << ", world size "
|
||||
<< world_size;
|
||||
|
||||
server->Wait();
|
||||
}
|
||||
} // namespace xgboost::federated
|
||||
@@ -1,37 +0,0 @@
|
||||
/**
|
||||
* Copyright 2022-2024, XGBoost contributors
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <federated.old.grpc.pb.h>
|
||||
|
||||
#include <cstdint> // for int32_t
|
||||
|
||||
#include "../../src/collective/in_memory_handler.h"
|
||||
|
||||
namespace xgboost::federated {
|
||||
class FederatedService final : public Federated::Service {
|
||||
public:
|
||||
explicit FederatedService(std::int32_t world_size) : handler_{world_size} {}
|
||||
|
||||
grpc::Status Allgather(grpc::ServerContext* context, AllgatherRequest const* request,
|
||||
AllgatherReply* reply) override;
|
||||
|
||||
grpc::Status AllgatherV(grpc::ServerContext* context, AllgatherVRequest const* request,
|
||||
AllgatherVReply* reply) override;
|
||||
|
||||
grpc::Status Allreduce(grpc::ServerContext* context, AllreduceRequest const* request,
|
||||
AllreduceReply* reply) override;
|
||||
|
||||
grpc::Status Broadcast(grpc::ServerContext* context, BroadcastRequest const* request,
|
||||
BroadcastReply* reply) override;
|
||||
|
||||
private:
|
||||
xgboost::collective::InMemoryHandler handler_;
|
||||
};
|
||||
|
||||
void RunServer(int port, std::size_t world_size, char const* server_key_file,
|
||||
char const* server_cert_file, char const* client_cert_file);
|
||||
|
||||
void RunInsecureServer(int port, std::size_t world_size);
|
||||
} // namespace xgboost::federated
|
||||
@@ -1,5 +1,5 @@
|
||||
/**
|
||||
* Copyright 2022-2023, XGBoost contributors
|
||||
* Copyright 2022-2024, XGBoost contributors
|
||||
*/
|
||||
#include "federated_tracker.h"
|
||||
|
||||
@@ -8,13 +8,12 @@
|
||||
|
||||
#include <cstdint> // for int32_t
|
||||
#include <exception> // for exception
|
||||
#include <future> // for future, async
|
||||
#include <limits> // for numeric_limits
|
||||
#include <string> // for string
|
||||
#include <thread> // for sleep_for
|
||||
|
||||
#include "../../src/common/io.h" // for ReadAll
|
||||
#include "../../src/common/json_utils.h" // for RequiredArg
|
||||
#include "../../src/common/timer.h" // for Timer
|
||||
|
||||
namespace xgboost::collective {
|
||||
namespace federated {
|
||||
@@ -36,8 +35,8 @@ grpc::Status FederatedService::Allreduce(grpc::ServerContext*, AllreduceRequest
|
||||
AllreduceReply* reply) {
|
||||
handler_.Allreduce(request->send_buffer().data(), request->send_buffer().size(),
|
||||
reply->mutable_receive_buffer(), request->sequence_number(), request->rank(),
|
||||
static_cast<xgboost::collective::DataType>(request->data_type()),
|
||||
static_cast<xgboost::collective::Operation>(request->reduce_operation()));
|
||||
static_cast<xgboost::ArrayInterfaceHandler::Type>(request->data_type()),
|
||||
static_cast<xgboost::collective::Op>(request->reduce_operation()));
|
||||
return grpc::Status::OK;
|
||||
}
|
||||
|
||||
@@ -53,9 +52,13 @@ grpc::Status FederatedService::Broadcast(grpc::ServerContext*, BroadcastRequest
|
||||
FederatedTracker::FederatedTracker(Json const& config) : Tracker{config} {
|
||||
auto is_secure = RequiredArg<Boolean const>(config, "federated_secure", __func__);
|
||||
if (is_secure) {
|
||||
StringView msg{"Empty certificate path."};
|
||||
server_key_path_ = RequiredArg<String const>(config, "server_key_path", __func__);
|
||||
CHECK(!server_key_path_.empty()) << msg;
|
||||
server_cert_file_ = RequiredArg<String const>(config, "server_cert_path", __func__);
|
||||
CHECK(!server_cert_file_.empty()) << msg;
|
||||
client_cert_file_ = RequiredArg<String const>(config, "client_cert_path", __func__);
|
||||
CHECK(!client_cert_file_.empty()) << msg;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -5,11 +5,12 @@
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wtautological-constant-compare"
|
||||
#pragma GCC diagnostic ignored "-W#pragma-messages"
|
||||
#include <rabit/rabit.h>
|
||||
#pragma GCC diagnostic pop
|
||||
|
||||
#include "../sycl/device_manager.h"
|
||||
|
||||
#include "../../src/collective/communicator-inl.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace sycl {
|
||||
|
||||
@@ -21,22 +22,23 @@ namespace sycl {
|
||||
}
|
||||
|
||||
bool not_use_default_selector = (device_spec.ordinal != kDefaultOrdinal) ||
|
||||
(rabit::IsDistributed());
|
||||
(collective::IsDistributed());
|
||||
if (not_use_default_selector) {
|
||||
DeviceRegister& device_register = GetDevicesRegister();
|
||||
const int device_idx = rabit::IsDistributed() ? rabit::GetRank() : device_spec.ordinal;
|
||||
const int device_idx =
|
||||
collective::IsDistributed() ? collective::GetRank() : device_spec.ordinal;
|
||||
if (device_spec.IsSyclDefault()) {
|
||||
auto& devices = device_register.devices;
|
||||
CHECK_LT(device_idx, devices.size());
|
||||
return devices[device_idx];
|
||||
auto& devices = device_register.devices;
|
||||
CHECK_LT(device_idx, devices.size());
|
||||
return devices[device_idx];
|
||||
} else if (device_spec.IsSyclCPU()) {
|
||||
auto& cpu_devices = device_register.cpu_devices;
|
||||
CHECK_LT(device_idx, cpu_devices.size());
|
||||
return cpu_devices[device_idx];
|
||||
auto& cpu_devices = device_register.cpu_devices;
|
||||
CHECK_LT(device_idx, cpu_devices.size());
|
||||
return cpu_devices[device_idx];
|
||||
} else {
|
||||
auto& gpu_devices = device_register.gpu_devices;
|
||||
CHECK_LT(device_idx, gpu_devices.size());
|
||||
return gpu_devices[device_idx];
|
||||
auto& gpu_devices = device_register.gpu_devices;
|
||||
CHECK_LT(device_idx, gpu_devices.size());
|
||||
return gpu_devices[device_idx];
|
||||
}
|
||||
} else {
|
||||
if (device_spec.IsSyclCPU()) {
|
||||
@@ -62,24 +64,25 @@ namespace sycl {
|
||||
}
|
||||
|
||||
bool not_use_default_selector = (device_spec.ordinal != kDefaultOrdinal) ||
|
||||
(rabit::IsDistributed());
|
||||
(collective::IsDistributed());
|
||||
std::lock_guard<std::mutex> guard(queue_registering_mutex);
|
||||
if (not_use_default_selector) {
|
||||
DeviceRegister& device_register = GetDevicesRegister();
|
||||
const int device_idx = rabit::IsDistributed() ? rabit::GetRank() : device_spec.ordinal;
|
||||
if (device_spec.IsSyclDefault()) {
|
||||
auto& devices = device_register.devices;
|
||||
CHECK_LT(device_idx, devices.size());
|
||||
queue_register[device_spec.Name()] = ::sycl::queue(devices[device_idx]);
|
||||
} else if (device_spec.IsSyclCPU()) {
|
||||
auto& cpu_devices = device_register.cpu_devices;
|
||||
CHECK_LT(device_idx, cpu_devices.size());
|
||||
queue_register[device_spec.Name()] = ::sycl::queue(cpu_devices[device_idx]);;
|
||||
} else if (device_spec.IsSyclGPU()) {
|
||||
auto& gpu_devices = device_register.gpu_devices;
|
||||
CHECK_LT(device_idx, gpu_devices.size());
|
||||
queue_register[device_spec.Name()] = ::sycl::queue(gpu_devices[device_idx]);
|
||||
}
|
||||
DeviceRegister& device_register = GetDevicesRegister();
|
||||
const int device_idx =
|
||||
collective::IsDistributed() ? collective::GetRank() : device_spec.ordinal;
|
||||
if (device_spec.IsSyclDefault()) {
|
||||
auto& devices = device_register.devices;
|
||||
CHECK_LT(device_idx, devices.size());
|
||||
queue_register[device_spec.Name()] = ::sycl::queue(devices[device_idx]);
|
||||
} else if (device_spec.IsSyclCPU()) {
|
||||
auto& cpu_devices = device_register.cpu_devices;
|
||||
CHECK_LT(device_idx, cpu_devices.size());
|
||||
queue_register[device_spec.Name()] = ::sycl::queue(cpu_devices[device_idx]);
|
||||
} else if (device_spec.IsSyclGPU()) {
|
||||
auto& gpu_devices = device_register.gpu_devices;
|
||||
CHECK_LT(device_idx, gpu_devices.size());
|
||||
queue_register[device_spec.Name()] = ::sycl::queue(gpu_devices[device_idx]);
|
||||
}
|
||||
} else {
|
||||
if (device_spec.IsSyclCPU()) {
|
||||
queue_register[device_spec.Name()] = ::sycl::queue(::sycl::cpu_selector_v);
|
||||
|
||||
@@ -6,7 +6,6 @@
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wtautological-constant-compare"
|
||||
#pragma GCC diagnostic ignored "-W#pragma-messages"
|
||||
#include <rabit/rabit.h>
|
||||
#pragma GCC diagnostic pop
|
||||
|
||||
#include <vector>
|
||||
|
||||
@@ -9,7 +9,6 @@
|
||||
#include <xgboost/logging.h>
|
||||
#include <xgboost/objective.h>
|
||||
#pragma GCC diagnostic pop
|
||||
#include <rabit/rabit.h>
|
||||
|
||||
#include <cmath>
|
||||
#include <memory>
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wtautological-constant-compare"
|
||||
#pragma GCC diagnostic ignored "-W#pragma-messages"
|
||||
#include <rabit/rabit.h>
|
||||
#pragma GCC diagnostic pop
|
||||
|
||||
#include <cstddef>
|
||||
|
||||
Reference in New Issue
Block a user