[coll] Add federated coll. (#9738)
- Define a new data type, the proto file is copied for now. - Merge client and communicator into `FederatedColl`. - Define CUDA variant. - Migrate tests for CPU, add tests for CUDA.
This commit is contained in:
@@ -22,12 +22,35 @@ protobuf_generate(
|
||||
PLUGIN "protoc-gen-grpc=\$<TARGET_FILE:gRPC::grpc_cpp_plugin>"
|
||||
PROTOC_OUT_DIR "${PROTO_BINARY_DIR}")
|
||||
|
||||
add_library(federated_old_proto STATIC federated.old.proto)
|
||||
target_link_libraries(federated_old_proto PUBLIC protobuf::libprotobuf gRPC::grpc gRPC::grpc++)
|
||||
target_include_directories(federated_old_proto PUBLIC ${CMAKE_CURRENT_BINARY_DIR})
|
||||
xgboost_target_properties(federated_old_proto)
|
||||
|
||||
protobuf_generate(
|
||||
TARGET federated_old_proto
|
||||
LANGUAGE cpp
|
||||
PROTOC_OUT_DIR "${PROTO_BINARY_DIR}")
|
||||
protobuf_generate(
|
||||
TARGET federated_old_proto
|
||||
LANGUAGE grpc
|
||||
GENERATE_EXTENSIONS .grpc.pb.h .grpc.pb.cc
|
||||
PLUGIN "protoc-gen-grpc=\$<TARGET_FILE:gRPC::grpc_cpp_plugin>"
|
||||
PROTOC_OUT_DIR "${PROTO_BINARY_DIR}")
|
||||
|
||||
# Wrapper for the gRPC client.
|
||||
add_library(federated_client INTERFACE)
|
||||
target_sources(federated_client INTERFACE federated_client.h)
|
||||
target_link_libraries(federated_client INTERFACE federated_proto)
|
||||
target_link_libraries(federated_client INTERFACE federated_old_proto)
|
||||
|
||||
# Rabit engine for Federated Learning.
|
||||
target_sources(objxgboost PRIVATE federated_tracker.cc federated_server.cc federated_comm.cc)
|
||||
target_sources(
|
||||
objxgboost PRIVATE federated_tracker.cc federated_server.cc federated_comm.cc federated_coll.cc
|
||||
)
|
||||
if(USE_CUDA)
|
||||
target_sources(objxgboost PRIVATE federated_comm.cu federated_coll.cu)
|
||||
endif()
|
||||
|
||||
target_link_libraries(objxgboost PRIVATE federated_client "-Wl,--exclude-libs,ALL")
|
||||
target_compile_definitions(objxgboost PUBLIC -DXGBOOST_USE_FEDERATED=1)
|
||||
|
||||
81
plugin/federated/federated.old.proto
Normal file
81
plugin/federated/federated.old.proto
Normal file
@@ -0,0 +1,81 @@
|
||||
/*!
|
||||
* Copyright 2022 XGBoost contributors
|
||||
*/
|
||||
syntax = "proto3";
|
||||
|
||||
package xgboost.federated;
|
||||
|
||||
service Federated {
|
||||
rpc Allgather(AllgatherRequest) returns (AllgatherReply) {}
|
||||
rpc AllgatherV(AllgatherVRequest) returns (AllgatherVReply) {}
|
||||
rpc Allreduce(AllreduceRequest) returns (AllreduceReply) {}
|
||||
rpc Broadcast(BroadcastRequest) returns (BroadcastReply) {}
|
||||
}
|
||||
|
||||
enum DataType {
|
||||
INT8 = 0;
|
||||
UINT8 = 1;
|
||||
INT32 = 2;
|
||||
UINT32 = 3;
|
||||
INT64 = 4;
|
||||
UINT64 = 5;
|
||||
FLOAT = 6;
|
||||
DOUBLE = 7;
|
||||
}
|
||||
|
||||
enum ReduceOperation {
|
||||
MAX = 0;
|
||||
MIN = 1;
|
||||
SUM = 2;
|
||||
BITWISE_AND = 3;
|
||||
BITWISE_OR = 4;
|
||||
BITWISE_XOR = 5;
|
||||
}
|
||||
|
||||
message AllgatherRequest {
|
||||
// An incrementing counter that is unique to each round to operations.
|
||||
uint64 sequence_number = 1;
|
||||
int32 rank = 2;
|
||||
bytes send_buffer = 3;
|
||||
}
|
||||
|
||||
message AllgatherReply {
|
||||
bytes receive_buffer = 1;
|
||||
}
|
||||
|
||||
message AllgatherVRequest {
|
||||
// An incrementing counter that is unique to each round to operations.
|
||||
uint64 sequence_number = 1;
|
||||
int32 rank = 2;
|
||||
bytes send_buffer = 3;
|
||||
}
|
||||
|
||||
message AllgatherVReply {
|
||||
bytes receive_buffer = 1;
|
||||
}
|
||||
|
||||
message AllreduceRequest {
|
||||
// An incrementing counter that is unique to each round to operations.
|
||||
uint64 sequence_number = 1;
|
||||
int32 rank = 2;
|
||||
bytes send_buffer = 3;
|
||||
DataType data_type = 4;
|
||||
ReduceOperation reduce_operation = 5;
|
||||
}
|
||||
|
||||
message AllreduceReply {
|
||||
bytes receive_buffer = 1;
|
||||
}
|
||||
|
||||
message BroadcastRequest {
|
||||
// An incrementing counter that is unique to each round to operations.
|
||||
uint64 sequence_number = 1;
|
||||
int32 rank = 2;
|
||||
bytes send_buffer = 3;
|
||||
// The root rank to broadcast from.
|
||||
int32 root = 4;
|
||||
}
|
||||
|
||||
message BroadcastReply {
|
||||
bytes receive_buffer = 1;
|
||||
}
|
||||
@@ -1,9 +1,9 @@
|
||||
/*!
|
||||
* Copyright 2022 XGBoost contributors
|
||||
* Copyright 2022-2023 XGBoost contributors
|
||||
*/
|
||||
syntax = "proto3";
|
||||
|
||||
package xgboost.federated;
|
||||
package xgboost.collective.federated;
|
||||
|
||||
service Federated {
|
||||
rpc Allgather(AllgatherRequest) returns (AllgatherReply) {}
|
||||
@@ -13,14 +13,18 @@ service Federated {
|
||||
}
|
||||
|
||||
enum DataType {
|
||||
INT8 = 0;
|
||||
UINT8 = 1;
|
||||
INT32 = 2;
|
||||
UINT32 = 3;
|
||||
INT64 = 4;
|
||||
UINT64 = 5;
|
||||
FLOAT = 6;
|
||||
DOUBLE = 7;
|
||||
HALF = 0;
|
||||
FLOAT = 1;
|
||||
DOUBLE = 2;
|
||||
LONG_DOUBLE = 3;
|
||||
INT8 = 4;
|
||||
INT16 = 5;
|
||||
INT32 = 6;
|
||||
INT64 = 7;
|
||||
UINT8 = 8;
|
||||
UINT16 = 9;
|
||||
UINT32 = 10;
|
||||
UINT64 = 11;
|
||||
}
|
||||
|
||||
enum ReduceOperation {
|
||||
|
||||
@@ -2,8 +2,8 @@
|
||||
* Copyright 2022 XGBoost contributors
|
||||
*/
|
||||
#pragma once
|
||||
#include <federated.grpc.pb.h>
|
||||
#include <federated.pb.h>
|
||||
#include <federated.old.grpc.pb.h>
|
||||
#include <federated.old.pb.h>
|
||||
#include <grpcpp/grpcpp.h>
|
||||
|
||||
#include <cstdio>
|
||||
|
||||
155
plugin/federated/federated_coll.cc
Normal file
155
plugin/federated/federated_coll.cc
Normal file
@@ -0,0 +1,155 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost contributors
|
||||
*/
|
||||
#include "federated_coll.h"
|
||||
|
||||
#include <federated.grpc.pb.h>
|
||||
#include <federated.pb.h>
|
||||
|
||||
#include <algorithm> // for copy_n
|
||||
|
||||
#include "../../src/collective/allgather.h"
|
||||
#include "../../src/common/common.h" // for AssertGPUSupport
|
||||
#include "federated_comm.h" // for FederatedComm
|
||||
#include "xgboost/collective/result.h" // for Result
|
||||
|
||||
namespace xgboost::collective {
|
||||
namespace {
|
||||
[[nodiscard]] Result GetGRPCResult(std::string const &name, grpc::Status const &status) {
|
||||
return Fail(name + " RPC failed. " + std::to_string(status.error_code()) + ": " +
|
||||
status.error_message());
|
||||
}
|
||||
|
||||
[[nodiscard]] Result BroadcastImpl(Comm const &comm, std::uint64_t *sequence_number,
|
||||
common::Span<std::int8_t> data, std::int32_t root) {
|
||||
using namespace federated; // NOLINT
|
||||
|
||||
auto fed = dynamic_cast<FederatedComm const *>(&comm);
|
||||
CHECK(fed);
|
||||
auto stub = fed->Handle();
|
||||
|
||||
BroadcastRequest request;
|
||||
request.set_sequence_number(*sequence_number++);
|
||||
request.set_rank(comm.Rank());
|
||||
if (comm.Rank() != root) {
|
||||
request.set_send_buffer(nullptr, 0);
|
||||
} else {
|
||||
request.set_send_buffer(data.data(), data.size());
|
||||
}
|
||||
request.set_root(root);
|
||||
|
||||
BroadcastReply reply;
|
||||
grpc::ClientContext context;
|
||||
context.set_wait_for_ready(true);
|
||||
grpc::Status status = stub->Broadcast(&context, request, &reply);
|
||||
if (!status.ok()) {
|
||||
return GetGRPCResult("Broadcast", status);
|
||||
}
|
||||
if (comm.Rank() != root) {
|
||||
auto const &r = reply.receive_buffer();
|
||||
std::copy_n(r.cbegin(), r.size(), data.data());
|
||||
}
|
||||
|
||||
return Success();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
#if !defined(XGBOOST_USE_CUDA)
|
||||
Coll *FederatedColl::MakeCUDAVar() {
|
||||
common::AssertGPUSupport();
|
||||
return nullptr;
|
||||
}
|
||||
#endif
|
||||
|
||||
[[nodiscard]] Result FederatedColl::Allreduce(Comm const &comm, common::Span<std::int8_t> data,
|
||||
ArrayInterfaceHandler::Type type, Op op) {
|
||||
using namespace federated; // NOLINT
|
||||
auto fed = dynamic_cast<FederatedComm const *>(&comm);
|
||||
CHECK(fed);
|
||||
auto stub = fed->Handle();
|
||||
|
||||
AllreduceRequest request;
|
||||
request.set_sequence_number(sequence_number_++);
|
||||
request.set_rank(comm.Rank());
|
||||
request.set_send_buffer(data.data(), data.size());
|
||||
request.set_data_type(static_cast<::xgboost::collective::federated::DataType>(type));
|
||||
request.set_reduce_operation(static_cast<::xgboost::collective::federated::ReduceOperation>(op));
|
||||
|
||||
AllreduceReply reply;
|
||||
grpc::ClientContext context;
|
||||
context.set_wait_for_ready(true);
|
||||
grpc::Status status = stub->Allreduce(&context, request, &reply);
|
||||
if (!status.ok()) {
|
||||
return GetGRPCResult("Allreduce", status);
|
||||
}
|
||||
auto const &r = reply.receive_buffer();
|
||||
std::copy_n(r.cbegin(), r.size(), data.data());
|
||||
return Success();
|
||||
}
|
||||
|
||||
[[nodiscard]] Result FederatedColl::Broadcast(Comm const &comm, common::Span<std::int8_t> data,
|
||||
std::int32_t root) {
|
||||
if (comm.Rank() == root) {
|
||||
return BroadcastImpl(comm, &sequence_number_, data, root);
|
||||
} else {
|
||||
return BroadcastImpl(comm, &sequence_number_, data, root);
|
||||
}
|
||||
}
|
||||
|
||||
[[nodiscard]] Result FederatedColl::Allgather(Comm const &comm, common::Span<std::int8_t> data,
|
||||
std::int64_t size) {
|
||||
using namespace federated; // NOLINT
|
||||
auto fed = dynamic_cast<FederatedComm const *>(&comm);
|
||||
CHECK(fed);
|
||||
auto stub = fed->Handle();
|
||||
|
||||
auto offset = comm.Rank() * size;
|
||||
auto segment = data.subspan(offset, size);
|
||||
|
||||
AllgatherRequest request;
|
||||
request.set_sequence_number(sequence_number_++);
|
||||
request.set_rank(comm.Rank());
|
||||
request.set_send_buffer(segment.data(), segment.size());
|
||||
|
||||
AllgatherReply reply;
|
||||
grpc::ClientContext context;
|
||||
context.set_wait_for_ready(true);
|
||||
grpc::Status status = stub->Allgather(&context, request, &reply);
|
||||
|
||||
if (!status.ok()) {
|
||||
return GetGRPCResult("Allgather", status);
|
||||
}
|
||||
auto const &r = reply.receive_buffer();
|
||||
std::copy_n(r.cbegin(), r.size(), data.begin());
|
||||
return Success();
|
||||
}
|
||||
|
||||
[[nodiscard]] Result FederatedColl::AllgatherV(Comm const &comm,
|
||||
common::Span<std::int8_t const> data,
|
||||
common::Span<std::int64_t const>,
|
||||
common::Span<std::int64_t>,
|
||||
common::Span<std::int8_t> recv, AllgatherVAlgo) {
|
||||
using namespace federated; // NOLINT
|
||||
|
||||
auto fed = dynamic_cast<FederatedComm const *>(&comm);
|
||||
CHECK(fed);
|
||||
auto stub = fed->Handle();
|
||||
|
||||
AllgatherVRequest request;
|
||||
request.set_sequence_number(sequence_number_++);
|
||||
request.set_rank(comm.Rank());
|
||||
request.set_send_buffer(data.data(), data.size());
|
||||
|
||||
AllgatherVReply reply;
|
||||
grpc::ClientContext context;
|
||||
context.set_wait_for_ready(true);
|
||||
grpc::Status status = stub->AllgatherV(&context, request, &reply);
|
||||
if (!status.ok()) {
|
||||
return GetGRPCResult("AllgatherV", status);
|
||||
}
|
||||
std::string const &r = reply.receive_buffer();
|
||||
CHECK_EQ(r.size(), recv.size());
|
||||
std::copy_n(r.cbegin(), r.size(), recv.begin());
|
||||
return Success();
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
92
plugin/federated/federated_coll.cu
Normal file
92
plugin/federated/federated_coll.cu
Normal file
@@ -0,0 +1,92 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
*/
|
||||
#include <cstdint> // for int8_t, int32_t
|
||||
#include <memory> // for dynamic_pointer_cast
|
||||
#include <vector> // for vector
|
||||
|
||||
#include "../../src/collective/comm.cuh"
|
||||
#include "../../src/common/cuda_context.cuh" // for CUDAContext
|
||||
#include "../../src/data/array_interface.h" // for ArrayInterfaceHandler::Type
|
||||
#include "federated_coll.cuh"
|
||||
#include "federated_comm.cuh"
|
||||
#include "xgboost/collective/result.h" // for Result
|
||||
#include "xgboost/span.h" // for Span
|
||||
|
||||
namespace xgboost::collective {
|
||||
Coll *FederatedColl::MakeCUDAVar() {
|
||||
return new CUDAFederatedColl{std::dynamic_pointer_cast<FederatedColl>(this->shared_from_this())};
|
||||
}
|
||||
|
||||
[[nodiscard]] Result CUDAFederatedColl::Allreduce(Comm const &comm, common::Span<std::int8_t> data,
|
||||
ArrayInterfaceHandler::Type type, Op op) {
|
||||
auto cufed = dynamic_cast<CUDAFederatedComm const *>(&comm);
|
||||
CHECK(cufed);
|
||||
|
||||
std::vector<std::int8_t> h_data(data.size());
|
||||
|
||||
return Success() << [&] {
|
||||
return GetCUDAResult(
|
||||
cudaMemcpy(h_data.data(), data.data(), data.size(), cudaMemcpyDeviceToHost));
|
||||
} << [&] {
|
||||
return p_impl_->Allreduce(comm, common::Span{h_data.data(), h_data.size()}, type, op);
|
||||
} << [&] {
|
||||
return GetCUDAResult(cudaMemcpyAsync(data.data(), h_data.data(), data.size(),
|
||||
cudaMemcpyHostToDevice, cufed->Stream()));
|
||||
};
|
||||
}
|
||||
|
||||
[[nodiscard]] Result CUDAFederatedColl::Broadcast(Comm const &comm, common::Span<std::int8_t> data,
|
||||
std::int32_t root) {
|
||||
auto cufed = dynamic_cast<CUDAFederatedComm const *>(&comm);
|
||||
CHECK(cufed);
|
||||
std::vector<std::int8_t> h_data(data.size());
|
||||
|
||||
return Success() << [&] {
|
||||
return GetCUDAResult(
|
||||
cudaMemcpy(h_data.data(), data.data(), data.size(), cudaMemcpyDeviceToHost));
|
||||
} << [&] {
|
||||
return p_impl_->Broadcast(comm, common::Span{h_data.data(), h_data.size()}, root);
|
||||
} << [&] {
|
||||
return GetCUDAResult(cudaMemcpyAsync(data.data(), h_data.data(), data.size(),
|
||||
cudaMemcpyHostToDevice, cufed->Stream()));
|
||||
};
|
||||
}
|
||||
|
||||
[[nodiscard]] Result CUDAFederatedColl::Allgather(Comm const &comm, common::Span<std::int8_t> data,
|
||||
std::int64_t size) {
|
||||
auto cufed = dynamic_cast<CUDAFederatedComm const *>(&comm);
|
||||
CHECK(cufed);
|
||||
std::vector<std::int8_t> h_data(data.size());
|
||||
|
||||
return Success() << [&] {
|
||||
return GetCUDAResult(
|
||||
cudaMemcpy(h_data.data(), data.data(), data.size(), cudaMemcpyDeviceToHost));
|
||||
} << [&] {
|
||||
return p_impl_->Allgather(comm, common::Span{h_data.data(), h_data.size()}, size);
|
||||
} << [&] {
|
||||
return GetCUDAResult(cudaMemcpyAsync(data.data(), h_data.data(), data.size(),
|
||||
cudaMemcpyHostToDevice, cufed->Stream()));
|
||||
};
|
||||
}
|
||||
|
||||
[[nodiscard]] Result CUDAFederatedColl::AllgatherV(
|
||||
Comm const &comm, common::Span<std::int8_t const> data, common::Span<std::int64_t const> sizes,
|
||||
common::Span<std::int64_t> recv_segments, common::Span<std::int8_t> recv, AllgatherVAlgo algo) {
|
||||
auto cufed = dynamic_cast<CUDAFederatedComm const *>(&comm);
|
||||
CHECK(cufed);
|
||||
|
||||
std::vector<std::int8_t> h_data(data.size());
|
||||
std::vector<std::int8_t> h_recv(recv.size());
|
||||
|
||||
return Success() << [&] {
|
||||
return GetCUDAResult(
|
||||
cudaMemcpy(h_data.data(), data.data(), data.size(), cudaMemcpyDeviceToHost));
|
||||
} << [&] {
|
||||
return this->p_impl_->AllgatherV(comm, h_data, sizes, recv_segments, h_recv, algo);
|
||||
} << [&] {
|
||||
return GetCUDAResult(cudaMemcpyAsync(recv.data(), h_recv.data(), h_recv.size(),
|
||||
cudaMemcpyHostToDevice, cufed->Stream()));
|
||||
};
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
26
plugin/federated/federated_coll.cuh
Normal file
26
plugin/federated/federated_coll.cuh
Normal file
@@ -0,0 +1,26 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost contributors
|
||||
*/
|
||||
#include "../../src/collective/comm.h" // for Comm, Coll
|
||||
#include "federated_coll.h" // for FederatedColl
|
||||
#include "xgboost/collective/result.h" // for Result
|
||||
#include "xgboost/span.h" // for Span
|
||||
|
||||
namespace xgboost::collective {
|
||||
class CUDAFederatedColl : public Coll {
|
||||
std::shared_ptr<FederatedColl> p_impl_;
|
||||
|
||||
public:
|
||||
explicit CUDAFederatedColl(std::shared_ptr<FederatedColl> pimpl) : p_impl_{std::move(pimpl)} {}
|
||||
[[nodiscard]] Result Allreduce(Comm const &comm, common::Span<std::int8_t> data,
|
||||
ArrayInterfaceHandler::Type type, Op op) override;
|
||||
[[nodiscard]] Result Broadcast(Comm const &comm, common::Span<std::int8_t> data,
|
||||
std::int32_t root) override;
|
||||
[[nodiscard]] Result Allgather(Comm const &, common::Span<std::int8_t> data,
|
||||
std::int64_t size) override;
|
||||
[[nodiscard]] Result AllgatherV(Comm const &comm, common::Span<std::int8_t const> data,
|
||||
common::Span<std::int64_t const> sizes,
|
||||
common::Span<std::int64_t> recv_segments,
|
||||
common::Span<std::int8_t> recv, AllgatherVAlgo algo) override;
|
||||
};
|
||||
} // namespace xgboost::collective
|
||||
30
plugin/federated/federated_coll.h
Normal file
30
plugin/federated/federated_coll.h
Normal file
@@ -0,0 +1,30 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost contributors
|
||||
*/
|
||||
#pragma once
|
||||
#include "../../src/collective/coll.h" // for Coll
|
||||
#include "../../src/collective/comm.h" // for Comm
|
||||
#include "../../src/common/io.h" // for ReadAll
|
||||
#include "../../src/common/json_utils.h" // for OptionalArg
|
||||
#include "xgboost/json.h" // for Json
|
||||
|
||||
namespace xgboost::collective {
|
||||
class FederatedColl : public Coll {
|
||||
private:
|
||||
std::uint64_t sequence_number_{0};
|
||||
|
||||
public:
|
||||
Coll *MakeCUDAVar() override;
|
||||
|
||||
[[nodiscard]] Result Allreduce(Comm const &, common::Span<std::int8_t> data,
|
||||
ArrayInterfaceHandler::Type type, Op op) override;
|
||||
[[nodiscard]] Result Broadcast(Comm const &comm, common::Span<std::int8_t> data,
|
||||
std::int32_t root) override;
|
||||
[[nodiscard]] Result Allgather(Comm const &, common::Span<std::int8_t> data,
|
||||
std::int64_t) override;
|
||||
[[nodiscard]] Result AllgatherV(Comm const &comm, common::Span<std::int8_t const> data,
|
||||
common::Span<std::int64_t const> sizes,
|
||||
common::Span<std::int64_t> recv_segments,
|
||||
common::Span<std::int8_t> recv, AllgatherVAlgo algo) override;
|
||||
};
|
||||
} // namespace xgboost::collective
|
||||
@@ -7,6 +7,7 @@
|
||||
|
||||
#include <cstdint> // for int32_t
|
||||
#include <cstdlib> // for getenv
|
||||
#include <limits> // for numeric_limits
|
||||
#include <string> // for string, stoi
|
||||
|
||||
#include "../../src/common/common.h" // for Split
|
||||
@@ -29,12 +30,18 @@ void FederatedComm::Init(std::string const& host, std::int32_t port, std::int32_
|
||||
CHECK_GE(rank, 0) << "Invalid worker rank.";
|
||||
CHECK_LT(rank, world) << "Invalid worker rank.";
|
||||
|
||||
auto certs = {server_cert, client_cert, client_cert};
|
||||
auto is_empty = [](auto const& s) { return s.empty(); };
|
||||
bool valid = std::all_of(certs.begin(), certs.end(), is_empty) ||
|
||||
std::none_of(certs.begin(), certs.end(), is_empty);
|
||||
CHECK(valid) << "Invalid arguments for certificates.";
|
||||
|
||||
if (server_cert.empty()) {
|
||||
stub_ = [&] {
|
||||
grpc::ChannelArguments args;
|
||||
args.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
|
||||
return federated::Federated::NewStub(
|
||||
grpc::CreateCustomChannel(host, grpc::InsecureChannelCredentials(), args));
|
||||
args.SetMaxReceiveMessageSize(std::numeric_limits<std::int32_t>::max());
|
||||
return federated::Federated::NewStub(grpc::CreateCustomChannel(
|
||||
host + ":" + std::to_string(port), grpc::InsecureChannelCredentials(), args));
|
||||
}();
|
||||
} else {
|
||||
stub_ = [&] {
|
||||
@@ -43,8 +50,9 @@ void FederatedComm::Init(std::string const& host, std::int32_t port, std::int32_
|
||||
options.pem_private_key = client_key;
|
||||
options.pem_cert_chain = client_cert;
|
||||
grpc::ChannelArguments args;
|
||||
args.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
|
||||
auto channel = grpc::CreateCustomChannel(host, grpc::SslCredentials(options), args);
|
||||
args.SetMaxReceiveMessageSize(std::numeric_limits<std::int32_t>::max());
|
||||
auto channel = grpc::CreateCustomChannel(host + ":" + std::to_string(port),
|
||||
grpc::SslCredentials(options), args);
|
||||
channel->WaitForConnected(
|
||||
gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), gpr_time_from_seconds(60, GPR_TIMESPAN)));
|
||||
return federated::Federated::NewStub(channel);
|
||||
@@ -79,7 +87,7 @@ FederatedComm::FederatedComm(Json const& config) {
|
||||
rank = OptionalArg<Integer>(config, "federated_rank", static_cast<Integer::Int>(rank));
|
||||
|
||||
auto parsed = common::Split(server_address, ':');
|
||||
CHECK_EQ(parsed.size(), 2) << "invalid server address:" << server_address;
|
||||
CHECK_EQ(parsed.size(), 2) << "Invalid server address:" << server_address;
|
||||
|
||||
CHECK_NE(rank, -1) << "Parameter `federated_rank` is required";
|
||||
CHECK_NE(world_size, 0) << "Parameter `federated_world_size` is required.";
|
||||
@@ -111,4 +119,11 @@ FederatedComm::FederatedComm(Json const& config) {
|
||||
this->Init(parsed[0], std::stoi(parsed[1]), world_size, rank, server_cert, client_key,
|
||||
client_cert);
|
||||
}
|
||||
|
||||
#if !defined(XGBOOST_USE_CUDA)
|
||||
Comm* FederatedComm::MakeCUDAVar(Context const*, std::shared_ptr<Coll>) const {
|
||||
common::AssertGPUSupport();
|
||||
return nullptr;
|
||||
}
|
||||
#endif // !defined(XGBOOST_USE_CUDA)
|
||||
} // namespace xgboost::collective
|
||||
|
||||
20
plugin/federated/federated_comm.cu
Normal file
20
plugin/federated/federated_comm.cu
Normal file
@@ -0,0 +1,20 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
*/
|
||||
#include <memory> // for shared_ptr
|
||||
|
||||
#include "../../src/common/cuda_context.cuh"
|
||||
#include "federated_comm.cuh"
|
||||
#include "xgboost/context.h" // for Context
|
||||
|
||||
namespace xgboost::collective {
|
||||
CUDAFederatedComm::CUDAFederatedComm(Context const* ctx, std::shared_ptr<FederatedComm const> impl)
|
||||
: FederatedComm{impl}, stream_{ctx->CUDACtx()->Stream()} {
|
||||
CHECK(impl);
|
||||
}
|
||||
|
||||
Comm* FederatedComm::MakeCUDAVar(Context const* ctx, std::shared_ptr<Coll>) const {
|
||||
return new CUDAFederatedComm{
|
||||
ctx, std::dynamic_pointer_cast<FederatedComm const>(this->shared_from_this())};
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
20
plugin/federated/federated_comm.cuh
Normal file
20
plugin/federated/federated_comm.cuh
Normal file
@@ -0,0 +1,20 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <memory> // for shared_ptr
|
||||
|
||||
#include "../../src/common/device_helpers.cuh" // for CUDAStreamView
|
||||
#include "federated_comm.h" // for FederatedComm
|
||||
#include "xgboost/context.h" // for Context
|
||||
|
||||
namespace xgboost::collective {
|
||||
class CUDAFederatedComm : public FederatedComm {
|
||||
dh::CUDAStreamView stream_;
|
||||
|
||||
public:
|
||||
explicit CUDAFederatedComm(Context const* ctx, std::shared_ptr<FederatedComm const> impl);
|
||||
[[nodiscard]] auto Stream() const { return stream_; }
|
||||
};
|
||||
} // namespace xgboost::collective
|
||||
@@ -16,12 +16,20 @@
|
||||
|
||||
namespace xgboost::collective {
|
||||
class FederatedComm : public Comm {
|
||||
std::unique_ptr<federated::Federated::Stub> stub_;
|
||||
std::shared_ptr<federated::Federated::Stub> stub_;
|
||||
|
||||
void Init(std::string const& host, std::int32_t port, std::int32_t world, std::int32_t rank,
|
||||
std::string const& server_cert, std::string const& client_key,
|
||||
std::string const& client_cert);
|
||||
|
||||
protected:
|
||||
explicit FederatedComm(std::shared_ptr<FederatedComm const> that) : stub_{that->stub_} {
|
||||
this->rank_ = that->Rank();
|
||||
this->world_ = that->World();
|
||||
|
||||
this->tracker_ = that->TrackerInfo();
|
||||
}
|
||||
|
||||
public:
|
||||
/**
|
||||
* @param config
|
||||
@@ -49,5 +57,8 @@ class FederatedComm : public Comm {
|
||||
return Success();
|
||||
}
|
||||
[[nodiscard]] bool IsFederated() const override { return true; }
|
||||
[[nodiscard]] federated::Federated::Stub* Handle() const { return stub_.get(); }
|
||||
|
||||
Comm* MakeCUDAVar(Context const* ctx, std::shared_ptr<Coll> pimpl) const override;
|
||||
};
|
||||
} // namespace xgboost::collective
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <federated.grpc.pb.h>
|
||||
#include <federated.old.grpc.pb.h>
|
||||
|
||||
#include <cstdint> // for int32_t
|
||||
#include <future> // for future
|
||||
|
||||
@@ -16,9 +16,41 @@
|
||||
#include "../../src/common/io.h" // for ReadAll
|
||||
#include "../../src/common/json_utils.h" // for RequiredArg
|
||||
#include "../../src/common/timer.h" // for Timer
|
||||
#include "federated_server.h" // for FederatedService
|
||||
|
||||
namespace xgboost::collective {
|
||||
namespace federated {
|
||||
grpc::Status FederatedService::Allgather(grpc::ServerContext*, AllgatherRequest const* request,
|
||||
AllgatherReply* reply) {
|
||||
handler_.Allgather(request->send_buffer().data(), request->send_buffer().size(),
|
||||
reply->mutable_receive_buffer(), request->sequence_number(), request->rank());
|
||||
return grpc::Status::OK;
|
||||
}
|
||||
|
||||
grpc::Status FederatedService::AllgatherV(grpc::ServerContext*, AllgatherVRequest const* request,
|
||||
AllgatherVReply* reply) {
|
||||
handler_.AllgatherV(request->send_buffer().data(), request->send_buffer().size(),
|
||||
reply->mutable_receive_buffer(), request->sequence_number(), request->rank());
|
||||
return grpc::Status::OK;
|
||||
}
|
||||
|
||||
grpc::Status FederatedService::Allreduce(grpc::ServerContext*, AllreduceRequest const* request,
|
||||
AllreduceReply* reply) {
|
||||
handler_.Allreduce(request->send_buffer().data(), request->send_buffer().size(),
|
||||
reply->mutable_receive_buffer(), request->sequence_number(), request->rank(),
|
||||
static_cast<xgboost::collective::DataType>(request->data_type()),
|
||||
static_cast<xgboost::collective::Operation>(request->reduce_operation()));
|
||||
return grpc::Status::OK;
|
||||
}
|
||||
|
||||
grpc::Status FederatedService::Broadcast(grpc::ServerContext*, BroadcastRequest const* request,
|
||||
BroadcastReply* reply) {
|
||||
handler_.Broadcast(request->send_buffer().data(), request->send_buffer().size(),
|
||||
reply->mutable_receive_buffer(), request->sequence_number(), request->rank(),
|
||||
request->root());
|
||||
return grpc::Status::OK;
|
||||
}
|
||||
} // namespace federated
|
||||
|
||||
FederatedTracker::FederatedTracker(Json const& config) : Tracker{config} {
|
||||
auto is_secure = RequiredArg<Boolean const>(config, "federated_secure", __func__);
|
||||
if (is_secure) {
|
||||
@@ -31,7 +63,8 @@ FederatedTracker::FederatedTracker(Json const& config) : Tracker{config} {
|
||||
std::future<Result> FederatedTracker::Run() {
|
||||
return std::async([this]() {
|
||||
std::string const server_address = "0.0.0.0:" + std::to_string(this->port_);
|
||||
federated::FederatedService service{static_cast<std::int32_t>(this->n_workers_)};
|
||||
xgboost::collective::federated::FederatedService service{
|
||||
static_cast<std::int32_t>(this->n_workers_)};
|
||||
grpc::ServerBuilder builder;
|
||||
|
||||
if (this->server_cert_file_.empty()) {
|
||||
@@ -42,7 +75,6 @@ std::future<Result> FederatedTracker::Run() {
|
||||
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials());
|
||||
}
|
||||
builder.RegisterService(&service);
|
||||
server_ = builder.BuildAndStart();
|
||||
LOG(CONSOLE) << "Insecure federated server listening on " << server_address << ", world size "
|
||||
<< this->n_workers_;
|
||||
} else {
|
||||
@@ -60,12 +92,12 @@ std::future<Result> FederatedTracker::Run() {
|
||||
builder.AddListeningPort(server_address, grpc::SslServerCredentials(options));
|
||||
}
|
||||
builder.RegisterService(&service);
|
||||
server_ = builder.BuildAndStart();
|
||||
LOG(CONSOLE) << "Federated server listening on " << server_address << ", world size "
|
||||
<< n_workers_;
|
||||
}
|
||||
|
||||
try {
|
||||
server_ = builder.BuildAndStart();
|
||||
server_->Wait();
|
||||
} catch (std::exception const& e) {
|
||||
return collective::Fail(std::string{e.what()});
|
||||
|
||||
@@ -8,11 +8,35 @@
|
||||
#include <memory> // for unique_ptr
|
||||
#include <string> // for string
|
||||
|
||||
#include "../../src/collective/in_memory_handler.h"
|
||||
#include "../../src/collective/tracker.h" // for Tracker
|
||||
#include "xgboost/collective/result.h" // for Result
|
||||
#include "xgboost/json.h" // for Json
|
||||
|
||||
namespace xgboost::collective {
|
||||
namespace federated {
|
||||
class FederatedService final : public Federated::Service {
|
||||
public:
|
||||
explicit FederatedService(std::int32_t world_size)
|
||||
: handler_{static_cast<std::size_t>(world_size)} {}
|
||||
|
||||
grpc::Status Allgather(grpc::ServerContext* context, AllgatherRequest const* request,
|
||||
AllgatherReply* reply) override;
|
||||
|
||||
grpc::Status AllgatherV(grpc::ServerContext* context, AllgatherVRequest const* request,
|
||||
AllgatherVReply* reply) override;
|
||||
|
||||
grpc::Status Allreduce(grpc::ServerContext* context, AllreduceRequest const* request,
|
||||
AllreduceReply* reply) override;
|
||||
|
||||
grpc::Status Broadcast(grpc::ServerContext* context, BroadcastRequest const* request,
|
||||
BroadcastReply* reply) override;
|
||||
|
||||
private:
|
||||
xgboost::collective::InMemoryHandler handler_;
|
||||
};
|
||||
}; // namespace federated
|
||||
|
||||
class FederatedTracker : public collective::Tracker {
|
||||
std::unique_ptr<grpc::Server> server_;
|
||||
std::string server_key_path_;
|
||||
|
||||
Reference in New Issue
Block a user