[coll] Add federated coll. (#9738)

- Define a new data type, the proto file is copied for now.
- Merge client and communicator into `FederatedColl`.
- Define CUDA variant.
- Migrate tests for CPU, add tests for CUDA.
This commit is contained in:
Jiaming Yuan
2023-11-01 04:06:46 +08:00
committed by GitHub
parent 6b98305db4
commit bc995a4865
24 changed files with 826 additions and 48 deletions

View File

@@ -22,12 +22,35 @@ protobuf_generate(
PLUGIN "protoc-gen-grpc=\$<TARGET_FILE:gRPC::grpc_cpp_plugin>"
PROTOC_OUT_DIR "${PROTO_BINARY_DIR}")
add_library(federated_old_proto STATIC federated.old.proto)
target_link_libraries(federated_old_proto PUBLIC protobuf::libprotobuf gRPC::grpc gRPC::grpc++)
target_include_directories(federated_old_proto PUBLIC ${CMAKE_CURRENT_BINARY_DIR})
xgboost_target_properties(federated_old_proto)
protobuf_generate(
TARGET federated_old_proto
LANGUAGE cpp
PROTOC_OUT_DIR "${PROTO_BINARY_DIR}")
protobuf_generate(
TARGET federated_old_proto
LANGUAGE grpc
GENERATE_EXTENSIONS .grpc.pb.h .grpc.pb.cc
PLUGIN "protoc-gen-grpc=\$<TARGET_FILE:gRPC::grpc_cpp_plugin>"
PROTOC_OUT_DIR "${PROTO_BINARY_DIR}")
# Wrapper for the gRPC client.
add_library(federated_client INTERFACE)
target_sources(federated_client INTERFACE federated_client.h)
target_link_libraries(federated_client INTERFACE federated_proto)
target_link_libraries(federated_client INTERFACE federated_old_proto)
# Rabit engine for Federated Learning.
target_sources(objxgboost PRIVATE federated_tracker.cc federated_server.cc federated_comm.cc)
target_sources(
objxgboost PRIVATE federated_tracker.cc federated_server.cc federated_comm.cc federated_coll.cc
)
if(USE_CUDA)
target_sources(objxgboost PRIVATE federated_comm.cu federated_coll.cu)
endif()
target_link_libraries(objxgboost PRIVATE federated_client "-Wl,--exclude-libs,ALL")
target_compile_definitions(objxgboost PUBLIC -DXGBOOST_USE_FEDERATED=1)

View File

@@ -0,0 +1,81 @@
/*!
* Copyright 2022 XGBoost contributors
*/
syntax = "proto3";
package xgboost.federated;
service Federated {
rpc Allgather(AllgatherRequest) returns (AllgatherReply) {}
rpc AllgatherV(AllgatherVRequest) returns (AllgatherVReply) {}
rpc Allreduce(AllreduceRequest) returns (AllreduceReply) {}
rpc Broadcast(BroadcastRequest) returns (BroadcastReply) {}
}
enum DataType {
INT8 = 0;
UINT8 = 1;
INT32 = 2;
UINT32 = 3;
INT64 = 4;
UINT64 = 5;
FLOAT = 6;
DOUBLE = 7;
}
enum ReduceOperation {
MAX = 0;
MIN = 1;
SUM = 2;
BITWISE_AND = 3;
BITWISE_OR = 4;
BITWISE_XOR = 5;
}
message AllgatherRequest {
// An incrementing counter that is unique to each round to operations.
uint64 sequence_number = 1;
int32 rank = 2;
bytes send_buffer = 3;
}
message AllgatherReply {
bytes receive_buffer = 1;
}
message AllgatherVRequest {
// An incrementing counter that is unique to each round to operations.
uint64 sequence_number = 1;
int32 rank = 2;
bytes send_buffer = 3;
}
message AllgatherVReply {
bytes receive_buffer = 1;
}
message AllreduceRequest {
// An incrementing counter that is unique to each round to operations.
uint64 sequence_number = 1;
int32 rank = 2;
bytes send_buffer = 3;
DataType data_type = 4;
ReduceOperation reduce_operation = 5;
}
message AllreduceReply {
bytes receive_buffer = 1;
}
message BroadcastRequest {
// An incrementing counter that is unique to each round to operations.
uint64 sequence_number = 1;
int32 rank = 2;
bytes send_buffer = 3;
// The root rank to broadcast from.
int32 root = 4;
}
message BroadcastReply {
bytes receive_buffer = 1;
}

View File

@@ -1,9 +1,9 @@
/*!
* Copyright 2022 XGBoost contributors
* Copyright 2022-2023 XGBoost contributors
*/
syntax = "proto3";
package xgboost.federated;
package xgboost.collective.federated;
service Federated {
rpc Allgather(AllgatherRequest) returns (AllgatherReply) {}
@@ -13,14 +13,18 @@ service Federated {
}
enum DataType {
INT8 = 0;
UINT8 = 1;
INT32 = 2;
UINT32 = 3;
INT64 = 4;
UINT64 = 5;
FLOAT = 6;
DOUBLE = 7;
HALF = 0;
FLOAT = 1;
DOUBLE = 2;
LONG_DOUBLE = 3;
INT8 = 4;
INT16 = 5;
INT32 = 6;
INT64 = 7;
UINT8 = 8;
UINT16 = 9;
UINT32 = 10;
UINT64 = 11;
}
enum ReduceOperation {

View File

@@ -2,8 +2,8 @@
* Copyright 2022 XGBoost contributors
*/
#pragma once
#include <federated.grpc.pb.h>
#include <federated.pb.h>
#include <federated.old.grpc.pb.h>
#include <federated.old.pb.h>
#include <grpcpp/grpcpp.h>
#include <cstdio>

View File

@@ -0,0 +1,155 @@
/**
* Copyright 2023, XGBoost contributors
*/
#include "federated_coll.h"
#include <federated.grpc.pb.h>
#include <federated.pb.h>
#include <algorithm> // for copy_n
#include "../../src/collective/allgather.h"
#include "../../src/common/common.h" // for AssertGPUSupport
#include "federated_comm.h" // for FederatedComm
#include "xgboost/collective/result.h" // for Result
namespace xgboost::collective {
namespace {
[[nodiscard]] Result GetGRPCResult(std::string const &name, grpc::Status const &status) {
return Fail(name + " RPC failed. " + std::to_string(status.error_code()) + ": " +
status.error_message());
}
[[nodiscard]] Result BroadcastImpl(Comm const &comm, std::uint64_t *sequence_number,
common::Span<std::int8_t> data, std::int32_t root) {
using namespace federated; // NOLINT
auto fed = dynamic_cast<FederatedComm const *>(&comm);
CHECK(fed);
auto stub = fed->Handle();
BroadcastRequest request;
request.set_sequence_number(*sequence_number++);
request.set_rank(comm.Rank());
if (comm.Rank() != root) {
request.set_send_buffer(nullptr, 0);
} else {
request.set_send_buffer(data.data(), data.size());
}
request.set_root(root);
BroadcastReply reply;
grpc::ClientContext context;
context.set_wait_for_ready(true);
grpc::Status status = stub->Broadcast(&context, request, &reply);
if (!status.ok()) {
return GetGRPCResult("Broadcast", status);
}
if (comm.Rank() != root) {
auto const &r = reply.receive_buffer();
std::copy_n(r.cbegin(), r.size(), data.data());
}
return Success();
}
} // namespace
#if !defined(XGBOOST_USE_CUDA)
Coll *FederatedColl::MakeCUDAVar() {
common::AssertGPUSupport();
return nullptr;
}
#endif
[[nodiscard]] Result FederatedColl::Allreduce(Comm const &comm, common::Span<std::int8_t> data,
ArrayInterfaceHandler::Type type, Op op) {
using namespace federated; // NOLINT
auto fed = dynamic_cast<FederatedComm const *>(&comm);
CHECK(fed);
auto stub = fed->Handle();
AllreduceRequest request;
request.set_sequence_number(sequence_number_++);
request.set_rank(comm.Rank());
request.set_send_buffer(data.data(), data.size());
request.set_data_type(static_cast<::xgboost::collective::federated::DataType>(type));
request.set_reduce_operation(static_cast<::xgboost::collective::federated::ReduceOperation>(op));
AllreduceReply reply;
grpc::ClientContext context;
context.set_wait_for_ready(true);
grpc::Status status = stub->Allreduce(&context, request, &reply);
if (!status.ok()) {
return GetGRPCResult("Allreduce", status);
}
auto const &r = reply.receive_buffer();
std::copy_n(r.cbegin(), r.size(), data.data());
return Success();
}
[[nodiscard]] Result FederatedColl::Broadcast(Comm const &comm, common::Span<std::int8_t> data,
std::int32_t root) {
if (comm.Rank() == root) {
return BroadcastImpl(comm, &sequence_number_, data, root);
} else {
return BroadcastImpl(comm, &sequence_number_, data, root);
}
}
[[nodiscard]] Result FederatedColl::Allgather(Comm const &comm, common::Span<std::int8_t> data,
std::int64_t size) {
using namespace federated; // NOLINT
auto fed = dynamic_cast<FederatedComm const *>(&comm);
CHECK(fed);
auto stub = fed->Handle();
auto offset = comm.Rank() * size;
auto segment = data.subspan(offset, size);
AllgatherRequest request;
request.set_sequence_number(sequence_number_++);
request.set_rank(comm.Rank());
request.set_send_buffer(segment.data(), segment.size());
AllgatherReply reply;
grpc::ClientContext context;
context.set_wait_for_ready(true);
grpc::Status status = stub->Allgather(&context, request, &reply);
if (!status.ok()) {
return GetGRPCResult("Allgather", status);
}
auto const &r = reply.receive_buffer();
std::copy_n(r.cbegin(), r.size(), data.begin());
return Success();
}
[[nodiscard]] Result FederatedColl::AllgatherV(Comm const &comm,
common::Span<std::int8_t const> data,
common::Span<std::int64_t const>,
common::Span<std::int64_t>,
common::Span<std::int8_t> recv, AllgatherVAlgo) {
using namespace federated; // NOLINT
auto fed = dynamic_cast<FederatedComm const *>(&comm);
CHECK(fed);
auto stub = fed->Handle();
AllgatherVRequest request;
request.set_sequence_number(sequence_number_++);
request.set_rank(comm.Rank());
request.set_send_buffer(data.data(), data.size());
AllgatherVReply reply;
grpc::ClientContext context;
context.set_wait_for_ready(true);
grpc::Status status = stub->AllgatherV(&context, request, &reply);
if (!status.ok()) {
return GetGRPCResult("AllgatherV", status);
}
std::string const &r = reply.receive_buffer();
CHECK_EQ(r.size(), recv.size());
std::copy_n(r.cbegin(), r.size(), recv.begin());
return Success();
}
} // namespace xgboost::collective

View File

@@ -0,0 +1,92 @@
/**
* Copyright 2023, XGBoost Contributors
*/
#include <cstdint> // for int8_t, int32_t
#include <memory> // for dynamic_pointer_cast
#include <vector> // for vector
#include "../../src/collective/comm.cuh"
#include "../../src/common/cuda_context.cuh" // for CUDAContext
#include "../../src/data/array_interface.h" // for ArrayInterfaceHandler::Type
#include "federated_coll.cuh"
#include "federated_comm.cuh"
#include "xgboost/collective/result.h" // for Result
#include "xgboost/span.h" // for Span
namespace xgboost::collective {
Coll *FederatedColl::MakeCUDAVar() {
return new CUDAFederatedColl{std::dynamic_pointer_cast<FederatedColl>(this->shared_from_this())};
}
[[nodiscard]] Result CUDAFederatedColl::Allreduce(Comm const &comm, common::Span<std::int8_t> data,
ArrayInterfaceHandler::Type type, Op op) {
auto cufed = dynamic_cast<CUDAFederatedComm const *>(&comm);
CHECK(cufed);
std::vector<std::int8_t> h_data(data.size());
return Success() << [&] {
return GetCUDAResult(
cudaMemcpy(h_data.data(), data.data(), data.size(), cudaMemcpyDeviceToHost));
} << [&] {
return p_impl_->Allreduce(comm, common::Span{h_data.data(), h_data.size()}, type, op);
} << [&] {
return GetCUDAResult(cudaMemcpyAsync(data.data(), h_data.data(), data.size(),
cudaMemcpyHostToDevice, cufed->Stream()));
};
}
[[nodiscard]] Result CUDAFederatedColl::Broadcast(Comm const &comm, common::Span<std::int8_t> data,
std::int32_t root) {
auto cufed = dynamic_cast<CUDAFederatedComm const *>(&comm);
CHECK(cufed);
std::vector<std::int8_t> h_data(data.size());
return Success() << [&] {
return GetCUDAResult(
cudaMemcpy(h_data.data(), data.data(), data.size(), cudaMemcpyDeviceToHost));
} << [&] {
return p_impl_->Broadcast(comm, common::Span{h_data.data(), h_data.size()}, root);
} << [&] {
return GetCUDAResult(cudaMemcpyAsync(data.data(), h_data.data(), data.size(),
cudaMemcpyHostToDevice, cufed->Stream()));
};
}
[[nodiscard]] Result CUDAFederatedColl::Allgather(Comm const &comm, common::Span<std::int8_t> data,
std::int64_t size) {
auto cufed = dynamic_cast<CUDAFederatedComm const *>(&comm);
CHECK(cufed);
std::vector<std::int8_t> h_data(data.size());
return Success() << [&] {
return GetCUDAResult(
cudaMemcpy(h_data.data(), data.data(), data.size(), cudaMemcpyDeviceToHost));
} << [&] {
return p_impl_->Allgather(comm, common::Span{h_data.data(), h_data.size()}, size);
} << [&] {
return GetCUDAResult(cudaMemcpyAsync(data.data(), h_data.data(), data.size(),
cudaMemcpyHostToDevice, cufed->Stream()));
};
}
[[nodiscard]] Result CUDAFederatedColl::AllgatherV(
Comm const &comm, common::Span<std::int8_t const> data, common::Span<std::int64_t const> sizes,
common::Span<std::int64_t> recv_segments, common::Span<std::int8_t> recv, AllgatherVAlgo algo) {
auto cufed = dynamic_cast<CUDAFederatedComm const *>(&comm);
CHECK(cufed);
std::vector<std::int8_t> h_data(data.size());
std::vector<std::int8_t> h_recv(recv.size());
return Success() << [&] {
return GetCUDAResult(
cudaMemcpy(h_data.data(), data.data(), data.size(), cudaMemcpyDeviceToHost));
} << [&] {
return this->p_impl_->AllgatherV(comm, h_data, sizes, recv_segments, h_recv, algo);
} << [&] {
return GetCUDAResult(cudaMemcpyAsync(recv.data(), h_recv.data(), h_recv.size(),
cudaMemcpyHostToDevice, cufed->Stream()));
};
}
} // namespace xgboost::collective

View File

@@ -0,0 +1,26 @@
/**
* Copyright 2023, XGBoost contributors
*/
#include "../../src/collective/comm.h" // for Comm, Coll
#include "federated_coll.h" // for FederatedColl
#include "xgboost/collective/result.h" // for Result
#include "xgboost/span.h" // for Span
namespace xgboost::collective {
class CUDAFederatedColl : public Coll {
std::shared_ptr<FederatedColl> p_impl_;
public:
explicit CUDAFederatedColl(std::shared_ptr<FederatedColl> pimpl) : p_impl_{std::move(pimpl)} {}
[[nodiscard]] Result Allreduce(Comm const &comm, common::Span<std::int8_t> data,
ArrayInterfaceHandler::Type type, Op op) override;
[[nodiscard]] Result Broadcast(Comm const &comm, common::Span<std::int8_t> data,
std::int32_t root) override;
[[nodiscard]] Result Allgather(Comm const &, common::Span<std::int8_t> data,
std::int64_t size) override;
[[nodiscard]] Result AllgatherV(Comm const &comm, common::Span<std::int8_t const> data,
common::Span<std::int64_t const> sizes,
common::Span<std::int64_t> recv_segments,
common::Span<std::int8_t> recv, AllgatherVAlgo algo) override;
};
} // namespace xgboost::collective

View File

@@ -0,0 +1,30 @@
/**
* Copyright 2023, XGBoost contributors
*/
#pragma once
#include "../../src/collective/coll.h" // for Coll
#include "../../src/collective/comm.h" // for Comm
#include "../../src/common/io.h" // for ReadAll
#include "../../src/common/json_utils.h" // for OptionalArg
#include "xgboost/json.h" // for Json
namespace xgboost::collective {
class FederatedColl : public Coll {
private:
std::uint64_t sequence_number_{0};
public:
Coll *MakeCUDAVar() override;
[[nodiscard]] Result Allreduce(Comm const &, common::Span<std::int8_t> data,
ArrayInterfaceHandler::Type type, Op op) override;
[[nodiscard]] Result Broadcast(Comm const &comm, common::Span<std::int8_t> data,
std::int32_t root) override;
[[nodiscard]] Result Allgather(Comm const &, common::Span<std::int8_t> data,
std::int64_t) override;
[[nodiscard]] Result AllgatherV(Comm const &comm, common::Span<std::int8_t const> data,
common::Span<std::int64_t const> sizes,
common::Span<std::int64_t> recv_segments,
common::Span<std::int8_t> recv, AllgatherVAlgo algo) override;
};
} // namespace xgboost::collective

View File

@@ -7,6 +7,7 @@
#include <cstdint> // for int32_t
#include <cstdlib> // for getenv
#include <limits> // for numeric_limits
#include <string> // for string, stoi
#include "../../src/common/common.h" // for Split
@@ -29,12 +30,18 @@ void FederatedComm::Init(std::string const& host, std::int32_t port, std::int32_
CHECK_GE(rank, 0) << "Invalid worker rank.";
CHECK_LT(rank, world) << "Invalid worker rank.";
auto certs = {server_cert, client_cert, client_cert};
auto is_empty = [](auto const& s) { return s.empty(); };
bool valid = std::all_of(certs.begin(), certs.end(), is_empty) ||
std::none_of(certs.begin(), certs.end(), is_empty);
CHECK(valid) << "Invalid arguments for certificates.";
if (server_cert.empty()) {
stub_ = [&] {
grpc::ChannelArguments args;
args.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
return federated::Federated::NewStub(
grpc::CreateCustomChannel(host, grpc::InsecureChannelCredentials(), args));
args.SetMaxReceiveMessageSize(std::numeric_limits<std::int32_t>::max());
return federated::Federated::NewStub(grpc::CreateCustomChannel(
host + ":" + std::to_string(port), grpc::InsecureChannelCredentials(), args));
}();
} else {
stub_ = [&] {
@@ -43,8 +50,9 @@ void FederatedComm::Init(std::string const& host, std::int32_t port, std::int32_
options.pem_private_key = client_key;
options.pem_cert_chain = client_cert;
grpc::ChannelArguments args;
args.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
auto channel = grpc::CreateCustomChannel(host, grpc::SslCredentials(options), args);
args.SetMaxReceiveMessageSize(std::numeric_limits<std::int32_t>::max());
auto channel = grpc::CreateCustomChannel(host + ":" + std::to_string(port),
grpc::SslCredentials(options), args);
channel->WaitForConnected(
gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), gpr_time_from_seconds(60, GPR_TIMESPAN)));
return federated::Federated::NewStub(channel);
@@ -79,7 +87,7 @@ FederatedComm::FederatedComm(Json const& config) {
rank = OptionalArg<Integer>(config, "federated_rank", static_cast<Integer::Int>(rank));
auto parsed = common::Split(server_address, ':');
CHECK_EQ(parsed.size(), 2) << "invalid server address:" << server_address;
CHECK_EQ(parsed.size(), 2) << "Invalid server address:" << server_address;
CHECK_NE(rank, -1) << "Parameter `federated_rank` is required";
CHECK_NE(world_size, 0) << "Parameter `federated_world_size` is required.";
@@ -111,4 +119,11 @@ FederatedComm::FederatedComm(Json const& config) {
this->Init(parsed[0], std::stoi(parsed[1]), world_size, rank, server_cert, client_key,
client_cert);
}
#if !defined(XGBOOST_USE_CUDA)
Comm* FederatedComm::MakeCUDAVar(Context const*, std::shared_ptr<Coll>) const {
common::AssertGPUSupport();
return nullptr;
}
#endif // !defined(XGBOOST_USE_CUDA)
} // namespace xgboost::collective

View File

@@ -0,0 +1,20 @@
/**
* Copyright 2023, XGBoost Contributors
*/
#include <memory> // for shared_ptr
#include "../../src/common/cuda_context.cuh"
#include "federated_comm.cuh"
#include "xgboost/context.h" // for Context
namespace xgboost::collective {
CUDAFederatedComm::CUDAFederatedComm(Context const* ctx, std::shared_ptr<FederatedComm const> impl)
: FederatedComm{impl}, stream_{ctx->CUDACtx()->Stream()} {
CHECK(impl);
}
Comm* FederatedComm::MakeCUDAVar(Context const* ctx, std::shared_ptr<Coll>) const {
return new CUDAFederatedComm{
ctx, std::dynamic_pointer_cast<FederatedComm const>(this->shared_from_this())};
}
} // namespace xgboost::collective

View File

@@ -0,0 +1,20 @@
/**
* Copyright 2023, XGBoost Contributors
*/
#pragma once
#include <memory> // for shared_ptr
#include "../../src/common/device_helpers.cuh" // for CUDAStreamView
#include "federated_comm.h" // for FederatedComm
#include "xgboost/context.h" // for Context
namespace xgboost::collective {
class CUDAFederatedComm : public FederatedComm {
dh::CUDAStreamView stream_;
public:
explicit CUDAFederatedComm(Context const* ctx, std::shared_ptr<FederatedComm const> impl);
[[nodiscard]] auto Stream() const { return stream_; }
};
} // namespace xgboost::collective

View File

@@ -16,12 +16,20 @@
namespace xgboost::collective {
class FederatedComm : public Comm {
std::unique_ptr<federated::Federated::Stub> stub_;
std::shared_ptr<federated::Federated::Stub> stub_;
void Init(std::string const& host, std::int32_t port, std::int32_t world, std::int32_t rank,
std::string const& server_cert, std::string const& client_key,
std::string const& client_cert);
protected:
explicit FederatedComm(std::shared_ptr<FederatedComm const> that) : stub_{that->stub_} {
this->rank_ = that->Rank();
this->world_ = that->World();
this->tracker_ = that->TrackerInfo();
}
public:
/**
* @param config
@@ -49,5 +57,8 @@ class FederatedComm : public Comm {
return Success();
}
[[nodiscard]] bool IsFederated() const override { return true; }
[[nodiscard]] federated::Federated::Stub* Handle() const { return stub_.get(); }
Comm* MakeCUDAVar(Context const* ctx, std::shared_ptr<Coll> pimpl) const override;
};
} // namespace xgboost::collective

View File

@@ -3,7 +3,7 @@
*/
#pragma once
#include <federated.grpc.pb.h>
#include <federated.old.grpc.pb.h>
#include <cstdint> // for int32_t
#include <future> // for future

View File

@@ -16,9 +16,41 @@
#include "../../src/common/io.h" // for ReadAll
#include "../../src/common/json_utils.h" // for RequiredArg
#include "../../src/common/timer.h" // for Timer
#include "federated_server.h" // for FederatedService
namespace xgboost::collective {
namespace federated {
grpc::Status FederatedService::Allgather(grpc::ServerContext*, AllgatherRequest const* request,
AllgatherReply* reply) {
handler_.Allgather(request->send_buffer().data(), request->send_buffer().size(),
reply->mutable_receive_buffer(), request->sequence_number(), request->rank());
return grpc::Status::OK;
}
grpc::Status FederatedService::AllgatherV(grpc::ServerContext*, AllgatherVRequest const* request,
AllgatherVReply* reply) {
handler_.AllgatherV(request->send_buffer().data(), request->send_buffer().size(),
reply->mutable_receive_buffer(), request->sequence_number(), request->rank());
return grpc::Status::OK;
}
grpc::Status FederatedService::Allreduce(grpc::ServerContext*, AllreduceRequest const* request,
AllreduceReply* reply) {
handler_.Allreduce(request->send_buffer().data(), request->send_buffer().size(),
reply->mutable_receive_buffer(), request->sequence_number(), request->rank(),
static_cast<xgboost::collective::DataType>(request->data_type()),
static_cast<xgboost::collective::Operation>(request->reduce_operation()));
return grpc::Status::OK;
}
grpc::Status FederatedService::Broadcast(grpc::ServerContext*, BroadcastRequest const* request,
BroadcastReply* reply) {
handler_.Broadcast(request->send_buffer().data(), request->send_buffer().size(),
reply->mutable_receive_buffer(), request->sequence_number(), request->rank(),
request->root());
return grpc::Status::OK;
}
} // namespace federated
FederatedTracker::FederatedTracker(Json const& config) : Tracker{config} {
auto is_secure = RequiredArg<Boolean const>(config, "federated_secure", __func__);
if (is_secure) {
@@ -31,7 +63,8 @@ FederatedTracker::FederatedTracker(Json const& config) : Tracker{config} {
std::future<Result> FederatedTracker::Run() {
return std::async([this]() {
std::string const server_address = "0.0.0.0:" + std::to_string(this->port_);
federated::FederatedService service{static_cast<std::int32_t>(this->n_workers_)};
xgboost::collective::federated::FederatedService service{
static_cast<std::int32_t>(this->n_workers_)};
grpc::ServerBuilder builder;
if (this->server_cert_file_.empty()) {
@@ -42,7 +75,6 @@ std::future<Result> FederatedTracker::Run() {
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials());
}
builder.RegisterService(&service);
server_ = builder.BuildAndStart();
LOG(CONSOLE) << "Insecure federated server listening on " << server_address << ", world size "
<< this->n_workers_;
} else {
@@ -60,12 +92,12 @@ std::future<Result> FederatedTracker::Run() {
builder.AddListeningPort(server_address, grpc::SslServerCredentials(options));
}
builder.RegisterService(&service);
server_ = builder.BuildAndStart();
LOG(CONSOLE) << "Federated server listening on " << server_address << ", world size "
<< n_workers_;
}
try {
server_ = builder.BuildAndStart();
server_->Wait();
} catch (std::exception const& e) {
return collective::Fail(std::string{e.what()});

View File

@@ -8,11 +8,35 @@
#include <memory> // for unique_ptr
#include <string> // for string
#include "../../src/collective/in_memory_handler.h"
#include "../../src/collective/tracker.h" // for Tracker
#include "xgboost/collective/result.h" // for Result
#include "xgboost/json.h" // for Json
namespace xgboost::collective {
namespace federated {
class FederatedService final : public Federated::Service {
public:
explicit FederatedService(std::int32_t world_size)
: handler_{static_cast<std::size_t>(world_size)} {}
grpc::Status Allgather(grpc::ServerContext* context, AllgatherRequest const* request,
AllgatherReply* reply) override;
grpc::Status AllgatherV(grpc::ServerContext* context, AllgatherVRequest const* request,
AllgatherVReply* reply) override;
grpc::Status Allreduce(grpc::ServerContext* context, AllreduceRequest const* request,
AllreduceReply* reply) override;
grpc::Status Broadcast(grpc::ServerContext* context, BroadcastRequest const* request,
BroadcastReply* reply) override;
private:
xgboost::collective::InMemoryHandler handler_;
};
}; // namespace federated
class FederatedTracker : public collective::Tracker {
std::unique_ptr<grpc::Server> server_;
std::string server_key_path_;