Merge branch 'master'

This commit is contained in:
Hui Liu 2023-10-31 15:59:31 -07:00
commit 123af45327
26 changed files with 846 additions and 49 deletions

View File

@ -87,7 +87,13 @@ xgb.plot.importance <- function(importance_matrix = NULL, top_n = NULL, measure
}
# also aggregate, just in case when the values were not yet summed up by feature
importance_matrix <- importance_matrix[, Importance := sum(get(measure)), by = Feature]
importance_matrix <- importance_matrix[
, lapply(.SD, sum)
, .SDcols = setdiff(names(importance_matrix), "Feature")
, by = Feature
][
, Importance := get(measure)
]
# make sure it's ordered
importance_matrix <- importance_matrix[order(-abs(Importance))]

View File

@ -382,6 +382,9 @@ test_that("xgb.importance works with GLM model", {
expect_equal(colnames(imp2plot), c("Feature", "Weight", "Importance"))
xgb.ggplot.importance(importance.GLM)
# check that the input is not modified in-place
expect_false("Importance" %in% names(importance.GLM))
# for multiclass
imp.GLM <- xgb.importance(model = mbst.GLM)
expect_equal(dim(imp.GLM), c(12, 3))
@ -400,6 +403,16 @@ test_that("xgb.model.dt.tree and xgb.importance work with a single split model",
expect_equal(imp$Gain, 1)
})
test_that("xgb.plot.importance de-duplicates features", {
importances <- data.table(
Feature = c("col1", "col2", "col2"),
Gain = c(0.4, 0.3, 0.3)
)
imp2plot <- xgb.plot.importance(importances)
expect_equal(nrow(imp2plot), 2L)
expect_equal(imp2plot$Feature, c("col2", "col1"))
})
test_that("xgb.plot.tree works with and without feature names", {
.skip_if_vcd_not_available()
expect_silent(xgb.plot.tree(feature_names = feature.names, model = bst.Tree))

View File

@ -22,12 +22,35 @@ protobuf_generate(
PLUGIN "protoc-gen-grpc=\$<TARGET_FILE:gRPC::grpc_cpp_plugin>"
PROTOC_OUT_DIR "${PROTO_BINARY_DIR}")
add_library(federated_old_proto STATIC federated.old.proto)
target_link_libraries(federated_old_proto PUBLIC protobuf::libprotobuf gRPC::grpc gRPC::grpc++)
target_include_directories(federated_old_proto PUBLIC ${CMAKE_CURRENT_BINARY_DIR})
xgboost_target_properties(federated_old_proto)
protobuf_generate(
TARGET federated_old_proto
LANGUAGE cpp
PROTOC_OUT_DIR "${PROTO_BINARY_DIR}")
protobuf_generate(
TARGET federated_old_proto
LANGUAGE grpc
GENERATE_EXTENSIONS .grpc.pb.h .grpc.pb.cc
PLUGIN "protoc-gen-grpc=\$<TARGET_FILE:gRPC::grpc_cpp_plugin>"
PROTOC_OUT_DIR "${PROTO_BINARY_DIR}")
# Wrapper for the gRPC client.
add_library(federated_client INTERFACE)
target_sources(federated_client INTERFACE federated_client.h)
target_link_libraries(federated_client INTERFACE federated_proto)
target_link_libraries(federated_client INTERFACE federated_old_proto)
# Rabit engine for Federated Learning.
target_sources(objxgboost PRIVATE federated_tracker.cc federated_server.cc federated_comm.cc)
target_sources(
objxgboost PRIVATE federated_tracker.cc federated_server.cc federated_comm.cc federated_coll.cc
)
if(USE_CUDA)
target_sources(objxgboost PRIVATE federated_comm.cu federated_coll.cu)
endif()
target_link_libraries(objxgboost PRIVATE federated_client "-Wl,--exclude-libs,ALL")
target_compile_definitions(objxgboost PUBLIC -DXGBOOST_USE_FEDERATED=1)

View File

@ -0,0 +1,81 @@
/*!
* Copyright 2022 XGBoost contributors
*/
syntax = "proto3";
package xgboost.federated;
service Federated {
rpc Allgather(AllgatherRequest) returns (AllgatherReply) {}
rpc AllgatherV(AllgatherVRequest) returns (AllgatherVReply) {}
rpc Allreduce(AllreduceRequest) returns (AllreduceReply) {}
rpc Broadcast(BroadcastRequest) returns (BroadcastReply) {}
}
enum DataType {
INT8 = 0;
UINT8 = 1;
INT32 = 2;
UINT32 = 3;
INT64 = 4;
UINT64 = 5;
FLOAT = 6;
DOUBLE = 7;
}
enum ReduceOperation {
MAX = 0;
MIN = 1;
SUM = 2;
BITWISE_AND = 3;
BITWISE_OR = 4;
BITWISE_XOR = 5;
}
message AllgatherRequest {
// An incrementing counter that is unique to each round to operations.
uint64 sequence_number = 1;
int32 rank = 2;
bytes send_buffer = 3;
}
message AllgatherReply {
bytes receive_buffer = 1;
}
message AllgatherVRequest {
// An incrementing counter that is unique to each round to operations.
uint64 sequence_number = 1;
int32 rank = 2;
bytes send_buffer = 3;
}
message AllgatherVReply {
bytes receive_buffer = 1;
}
message AllreduceRequest {
// An incrementing counter that is unique to each round to operations.
uint64 sequence_number = 1;
int32 rank = 2;
bytes send_buffer = 3;
DataType data_type = 4;
ReduceOperation reduce_operation = 5;
}
message AllreduceReply {
bytes receive_buffer = 1;
}
message BroadcastRequest {
// An incrementing counter that is unique to each round to operations.
uint64 sequence_number = 1;
int32 rank = 2;
bytes send_buffer = 3;
// The root rank to broadcast from.
int32 root = 4;
}
message BroadcastReply {
bytes receive_buffer = 1;
}

View File

@ -1,9 +1,9 @@
/*!
* Copyright 2022 XGBoost contributors
* Copyright 2022-2023 XGBoost contributors
*/
syntax = "proto3";
package xgboost.federated;
package xgboost.collective.federated;
service Federated {
rpc Allgather(AllgatherRequest) returns (AllgatherReply) {}
@ -13,14 +13,18 @@ service Federated {
}
enum DataType {
INT8 = 0;
UINT8 = 1;
INT32 = 2;
UINT32 = 3;
INT64 = 4;
UINT64 = 5;
FLOAT = 6;
DOUBLE = 7;
HALF = 0;
FLOAT = 1;
DOUBLE = 2;
LONG_DOUBLE = 3;
INT8 = 4;
INT16 = 5;
INT32 = 6;
INT64 = 7;
UINT8 = 8;
UINT16 = 9;
UINT32 = 10;
UINT64 = 11;
}
enum ReduceOperation {

View File

@ -2,8 +2,8 @@
* Copyright 2022 XGBoost contributors
*/
#pragma once
#include <federated.grpc.pb.h>
#include <federated.pb.h>
#include <federated.old.grpc.pb.h>
#include <federated.old.pb.h>
#include <grpcpp/grpcpp.h>
#include <cstdio>

View File

@ -0,0 +1,155 @@
/**
* Copyright 2023, XGBoost contributors
*/
#include "federated_coll.h"
#include <federated.grpc.pb.h>
#include <federated.pb.h>
#include <algorithm> // for copy_n
#include "../../src/collective/allgather.h"
#include "../../src/common/common.h" // for AssertGPUSupport
#include "federated_comm.h" // for FederatedComm
#include "xgboost/collective/result.h" // for Result
namespace xgboost::collective {
namespace {
[[nodiscard]] Result GetGRPCResult(std::string const &name, grpc::Status const &status) {
return Fail(name + " RPC failed. " + std::to_string(status.error_code()) + ": " +
status.error_message());
}
[[nodiscard]] Result BroadcastImpl(Comm const &comm, std::uint64_t *sequence_number,
common::Span<std::int8_t> data, std::int32_t root) {
using namespace federated; // NOLINT
auto fed = dynamic_cast<FederatedComm const *>(&comm);
CHECK(fed);
auto stub = fed->Handle();
BroadcastRequest request;
request.set_sequence_number(*sequence_number++);
request.set_rank(comm.Rank());
if (comm.Rank() != root) {
request.set_send_buffer(nullptr, 0);
} else {
request.set_send_buffer(data.data(), data.size());
}
request.set_root(root);
BroadcastReply reply;
grpc::ClientContext context;
context.set_wait_for_ready(true);
grpc::Status status = stub->Broadcast(&context, request, &reply);
if (!status.ok()) {
return GetGRPCResult("Broadcast", status);
}
if (comm.Rank() != root) {
auto const &r = reply.receive_buffer();
std::copy_n(r.cbegin(), r.size(), data.data());
}
return Success();
}
} // namespace
#if !defined(XGBOOST_USE_CUDA)
Coll *FederatedColl::MakeCUDAVar() {
common::AssertGPUSupport();
return nullptr;
}
#endif
[[nodiscard]] Result FederatedColl::Allreduce(Comm const &comm, common::Span<std::int8_t> data,
ArrayInterfaceHandler::Type type, Op op) {
using namespace federated; // NOLINT
auto fed = dynamic_cast<FederatedComm const *>(&comm);
CHECK(fed);
auto stub = fed->Handle();
AllreduceRequest request;
request.set_sequence_number(sequence_number_++);
request.set_rank(comm.Rank());
request.set_send_buffer(data.data(), data.size());
request.set_data_type(static_cast<::xgboost::collective::federated::DataType>(type));
request.set_reduce_operation(static_cast<::xgboost::collective::federated::ReduceOperation>(op));
AllreduceReply reply;
grpc::ClientContext context;
context.set_wait_for_ready(true);
grpc::Status status = stub->Allreduce(&context, request, &reply);
if (!status.ok()) {
return GetGRPCResult("Allreduce", status);
}
auto const &r = reply.receive_buffer();
std::copy_n(r.cbegin(), r.size(), data.data());
return Success();
}
[[nodiscard]] Result FederatedColl::Broadcast(Comm const &comm, common::Span<std::int8_t> data,
std::int32_t root) {
if (comm.Rank() == root) {
return BroadcastImpl(comm, &sequence_number_, data, root);
} else {
return BroadcastImpl(comm, &sequence_number_, data, root);
}
}
[[nodiscard]] Result FederatedColl::Allgather(Comm const &comm, common::Span<std::int8_t> data,
std::int64_t size) {
using namespace federated; // NOLINT
auto fed = dynamic_cast<FederatedComm const *>(&comm);
CHECK(fed);
auto stub = fed->Handle();
auto offset = comm.Rank() * size;
auto segment = data.subspan(offset, size);
AllgatherRequest request;
request.set_sequence_number(sequence_number_++);
request.set_rank(comm.Rank());
request.set_send_buffer(segment.data(), segment.size());
AllgatherReply reply;
grpc::ClientContext context;
context.set_wait_for_ready(true);
grpc::Status status = stub->Allgather(&context, request, &reply);
if (!status.ok()) {
return GetGRPCResult("Allgather", status);
}
auto const &r = reply.receive_buffer();
std::copy_n(r.cbegin(), r.size(), data.begin());
return Success();
}
[[nodiscard]] Result FederatedColl::AllgatherV(Comm const &comm,
common::Span<std::int8_t const> data,
common::Span<std::int64_t const>,
common::Span<std::int64_t>,
common::Span<std::int8_t> recv, AllgatherVAlgo) {
using namespace federated; // NOLINT
auto fed = dynamic_cast<FederatedComm const *>(&comm);
CHECK(fed);
auto stub = fed->Handle();
AllgatherVRequest request;
request.set_sequence_number(sequence_number_++);
request.set_rank(comm.Rank());
request.set_send_buffer(data.data(), data.size());
AllgatherVReply reply;
grpc::ClientContext context;
context.set_wait_for_ready(true);
grpc::Status status = stub->AllgatherV(&context, request, &reply);
if (!status.ok()) {
return GetGRPCResult("AllgatherV", status);
}
std::string const &r = reply.receive_buffer();
CHECK_EQ(r.size(), recv.size());
std::copy_n(r.cbegin(), r.size(), recv.begin());
return Success();
}
} // namespace xgboost::collective

View File

@ -0,0 +1,92 @@
/**
* Copyright 2023, XGBoost Contributors
*/
#include <cstdint> // for int8_t, int32_t
#include <memory> // for dynamic_pointer_cast
#include <vector> // for vector
#include "../../src/collective/comm.cuh"
#include "../../src/common/cuda_context.cuh" // for CUDAContext
#include "../../src/data/array_interface.h" // for ArrayInterfaceHandler::Type
#include "federated_coll.cuh"
#include "federated_comm.cuh"
#include "xgboost/collective/result.h" // for Result
#include "xgboost/span.h" // for Span
namespace xgboost::collective {
Coll *FederatedColl::MakeCUDAVar() {
return new CUDAFederatedColl{std::dynamic_pointer_cast<FederatedColl>(this->shared_from_this())};
}
[[nodiscard]] Result CUDAFederatedColl::Allreduce(Comm const &comm, common::Span<std::int8_t> data,
ArrayInterfaceHandler::Type type, Op op) {
auto cufed = dynamic_cast<CUDAFederatedComm const *>(&comm);
CHECK(cufed);
std::vector<std::int8_t> h_data(data.size());
return Success() << [&] {
return GetCUDAResult(
cudaMemcpy(h_data.data(), data.data(), data.size(), cudaMemcpyDeviceToHost));
} << [&] {
return p_impl_->Allreduce(comm, common::Span{h_data.data(), h_data.size()}, type, op);
} << [&] {
return GetCUDAResult(cudaMemcpyAsync(data.data(), h_data.data(), data.size(),
cudaMemcpyHostToDevice, cufed->Stream()));
};
}
[[nodiscard]] Result CUDAFederatedColl::Broadcast(Comm const &comm, common::Span<std::int8_t> data,
std::int32_t root) {
auto cufed = dynamic_cast<CUDAFederatedComm const *>(&comm);
CHECK(cufed);
std::vector<std::int8_t> h_data(data.size());
return Success() << [&] {
return GetCUDAResult(
cudaMemcpy(h_data.data(), data.data(), data.size(), cudaMemcpyDeviceToHost));
} << [&] {
return p_impl_->Broadcast(comm, common::Span{h_data.data(), h_data.size()}, root);
} << [&] {
return GetCUDAResult(cudaMemcpyAsync(data.data(), h_data.data(), data.size(),
cudaMemcpyHostToDevice, cufed->Stream()));
};
}
[[nodiscard]] Result CUDAFederatedColl::Allgather(Comm const &comm, common::Span<std::int8_t> data,
std::int64_t size) {
auto cufed = dynamic_cast<CUDAFederatedComm const *>(&comm);
CHECK(cufed);
std::vector<std::int8_t> h_data(data.size());
return Success() << [&] {
return GetCUDAResult(
cudaMemcpy(h_data.data(), data.data(), data.size(), cudaMemcpyDeviceToHost));
} << [&] {
return p_impl_->Allgather(comm, common::Span{h_data.data(), h_data.size()}, size);
} << [&] {
return GetCUDAResult(cudaMemcpyAsync(data.data(), h_data.data(), data.size(),
cudaMemcpyHostToDevice, cufed->Stream()));
};
}
[[nodiscard]] Result CUDAFederatedColl::AllgatherV(
Comm const &comm, common::Span<std::int8_t const> data, common::Span<std::int64_t const> sizes,
common::Span<std::int64_t> recv_segments, common::Span<std::int8_t> recv, AllgatherVAlgo algo) {
auto cufed = dynamic_cast<CUDAFederatedComm const *>(&comm);
CHECK(cufed);
std::vector<std::int8_t> h_data(data.size());
std::vector<std::int8_t> h_recv(recv.size());
return Success() << [&] {
return GetCUDAResult(
cudaMemcpy(h_data.data(), data.data(), data.size(), cudaMemcpyDeviceToHost));
} << [&] {
return this->p_impl_->AllgatherV(comm, h_data, sizes, recv_segments, h_recv, algo);
} << [&] {
return GetCUDAResult(cudaMemcpyAsync(recv.data(), h_recv.data(), h_recv.size(),
cudaMemcpyHostToDevice, cufed->Stream()));
};
}
} // namespace xgboost::collective

View File

@ -0,0 +1,26 @@
/**
* Copyright 2023, XGBoost contributors
*/
#include "../../src/collective/comm.h" // for Comm, Coll
#include "federated_coll.h" // for FederatedColl
#include "xgboost/collective/result.h" // for Result
#include "xgboost/span.h" // for Span
namespace xgboost::collective {
class CUDAFederatedColl : public Coll {
std::shared_ptr<FederatedColl> p_impl_;
public:
explicit CUDAFederatedColl(std::shared_ptr<FederatedColl> pimpl) : p_impl_{std::move(pimpl)} {}
[[nodiscard]] Result Allreduce(Comm const &comm, common::Span<std::int8_t> data,
ArrayInterfaceHandler::Type type, Op op) override;
[[nodiscard]] Result Broadcast(Comm const &comm, common::Span<std::int8_t> data,
std::int32_t root) override;
[[nodiscard]] Result Allgather(Comm const &, common::Span<std::int8_t> data,
std::int64_t size) override;
[[nodiscard]] Result AllgatherV(Comm const &comm, common::Span<std::int8_t const> data,
common::Span<std::int64_t const> sizes,
common::Span<std::int64_t> recv_segments,
common::Span<std::int8_t> recv, AllgatherVAlgo algo) override;
};
} // namespace xgboost::collective

View File

@ -0,0 +1,30 @@
/**
* Copyright 2023, XGBoost contributors
*/
#pragma once
#include "../../src/collective/coll.h" // for Coll
#include "../../src/collective/comm.h" // for Comm
#include "../../src/common/io.h" // for ReadAll
#include "../../src/common/json_utils.h" // for OptionalArg
#include "xgboost/json.h" // for Json
namespace xgboost::collective {
class FederatedColl : public Coll {
private:
std::uint64_t sequence_number_{0};
public:
Coll *MakeCUDAVar() override;
[[nodiscard]] Result Allreduce(Comm const &, common::Span<std::int8_t> data,
ArrayInterfaceHandler::Type type, Op op) override;
[[nodiscard]] Result Broadcast(Comm const &comm, common::Span<std::int8_t> data,
std::int32_t root) override;
[[nodiscard]] Result Allgather(Comm const &, common::Span<std::int8_t> data,
std::int64_t) override;
[[nodiscard]] Result AllgatherV(Comm const &comm, common::Span<std::int8_t const> data,
common::Span<std::int64_t const> sizes,
common::Span<std::int64_t> recv_segments,
common::Span<std::int8_t> recv, AllgatherVAlgo algo) override;
};
} // namespace xgboost::collective

View File

@ -7,6 +7,7 @@
#include <cstdint> // for int32_t
#include <cstdlib> // for getenv
#include <limits> // for numeric_limits
#include <string> // for string, stoi
#include "../../src/common/common.h" // for Split
@ -29,12 +30,18 @@ void FederatedComm::Init(std::string const& host, std::int32_t port, std::int32_
CHECK_GE(rank, 0) << "Invalid worker rank.";
CHECK_LT(rank, world) << "Invalid worker rank.";
auto certs = {server_cert, client_cert, client_cert};
auto is_empty = [](auto const& s) { return s.empty(); };
bool valid = std::all_of(certs.begin(), certs.end(), is_empty) ||
std::none_of(certs.begin(), certs.end(), is_empty);
CHECK(valid) << "Invalid arguments for certificates.";
if (server_cert.empty()) {
stub_ = [&] {
grpc::ChannelArguments args;
args.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
return federated::Federated::NewStub(
grpc::CreateCustomChannel(host, grpc::InsecureChannelCredentials(), args));
args.SetMaxReceiveMessageSize(std::numeric_limits<std::int32_t>::max());
return federated::Federated::NewStub(grpc::CreateCustomChannel(
host + ":" + std::to_string(port), grpc::InsecureChannelCredentials(), args));
}();
} else {
stub_ = [&] {
@ -43,8 +50,9 @@ void FederatedComm::Init(std::string const& host, std::int32_t port, std::int32_
options.pem_private_key = client_key;
options.pem_cert_chain = client_cert;
grpc::ChannelArguments args;
args.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
auto channel = grpc::CreateCustomChannel(host, grpc::SslCredentials(options), args);
args.SetMaxReceiveMessageSize(std::numeric_limits<std::int32_t>::max());
auto channel = grpc::CreateCustomChannel(host + ":" + std::to_string(port),
grpc::SslCredentials(options), args);
channel->WaitForConnected(
gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), gpr_time_from_seconds(60, GPR_TIMESPAN)));
return federated::Federated::NewStub(channel);
@ -79,7 +87,7 @@ FederatedComm::FederatedComm(Json const& config) {
rank = OptionalArg<Integer>(config, "federated_rank", static_cast<Integer::Int>(rank));
auto parsed = common::Split(server_address, ':');
CHECK_EQ(parsed.size(), 2) << "invalid server address:" << server_address;
CHECK_EQ(parsed.size(), 2) << "Invalid server address:" << server_address;
CHECK_NE(rank, -1) << "Parameter `federated_rank` is required";
CHECK_NE(world_size, 0) << "Parameter `federated_world_size` is required.";
@ -111,4 +119,11 @@ FederatedComm::FederatedComm(Json const& config) {
this->Init(parsed[0], std::stoi(parsed[1]), world_size, rank, server_cert, client_key,
client_cert);
}
#if !defined(XGBOOST_USE_CUDA)
Comm* FederatedComm::MakeCUDAVar(Context const*, std::shared_ptr<Coll>) const {
common::AssertGPUSupport();
return nullptr;
}
#endif // !defined(XGBOOST_USE_CUDA)
} // namespace xgboost::collective

View File

@ -0,0 +1,20 @@
/**
* Copyright 2023, XGBoost Contributors
*/
#include <memory> // for shared_ptr
#include "../../src/common/cuda_context.cuh"
#include "federated_comm.cuh"
#include "xgboost/context.h" // for Context
namespace xgboost::collective {
CUDAFederatedComm::CUDAFederatedComm(Context const* ctx, std::shared_ptr<FederatedComm const> impl)
: FederatedComm{impl}, stream_{ctx->CUDACtx()->Stream()} {
CHECK(impl);
}
Comm* FederatedComm::MakeCUDAVar(Context const* ctx, std::shared_ptr<Coll>) const {
return new CUDAFederatedComm{
ctx, std::dynamic_pointer_cast<FederatedComm const>(this->shared_from_this())};
}
} // namespace xgboost::collective

View File

@ -0,0 +1,20 @@
/**
* Copyright 2023, XGBoost Contributors
*/
#pragma once
#include <memory> // for shared_ptr
#include "../../src/common/device_helpers.cuh" // for CUDAStreamView
#include "federated_comm.h" // for FederatedComm
#include "xgboost/context.h" // for Context
namespace xgboost::collective {
class CUDAFederatedComm : public FederatedComm {
dh::CUDAStreamView stream_;
public:
explicit CUDAFederatedComm(Context const* ctx, std::shared_ptr<FederatedComm const> impl);
[[nodiscard]] auto Stream() const { return stream_; }
};
} // namespace xgboost::collective

View File

@ -16,12 +16,20 @@
namespace xgboost::collective {
class FederatedComm : public Comm {
std::unique_ptr<federated::Federated::Stub> stub_;
std::shared_ptr<federated::Federated::Stub> stub_;
void Init(std::string const& host, std::int32_t port, std::int32_t world, std::int32_t rank,
std::string const& server_cert, std::string const& client_key,
std::string const& client_cert);
protected:
explicit FederatedComm(std::shared_ptr<FederatedComm const> that) : stub_{that->stub_} {
this->rank_ = that->Rank();
this->world_ = that->World();
this->tracker_ = that->TrackerInfo();
}
public:
/**
* @param config
@ -49,5 +57,8 @@ class FederatedComm : public Comm {
return Success();
}
[[nodiscard]] bool IsFederated() const override { return true; }
[[nodiscard]] federated::Federated::Stub* Handle() const { return stub_.get(); }
Comm* MakeCUDAVar(Context const* ctx, std::shared_ptr<Coll> pimpl) const override;
};
} // namespace xgboost::collective

View File

@ -3,7 +3,7 @@
*/
#pragma once
#include <federated.grpc.pb.h>
#include <federated.old.grpc.pb.h>
#include <cstdint> // for int32_t
#include <future> // for future

View File

@ -16,9 +16,41 @@
#include "../../src/common/io.h" // for ReadAll
#include "../../src/common/json_utils.h" // for RequiredArg
#include "../../src/common/timer.h" // for Timer
#include "federated_server.h" // for FederatedService
namespace xgboost::collective {
namespace federated {
grpc::Status FederatedService::Allgather(grpc::ServerContext*, AllgatherRequest const* request,
AllgatherReply* reply) {
handler_.Allgather(request->send_buffer().data(), request->send_buffer().size(),
reply->mutable_receive_buffer(), request->sequence_number(), request->rank());
return grpc::Status::OK;
}
grpc::Status FederatedService::AllgatherV(grpc::ServerContext*, AllgatherVRequest const* request,
AllgatherVReply* reply) {
handler_.AllgatherV(request->send_buffer().data(), request->send_buffer().size(),
reply->mutable_receive_buffer(), request->sequence_number(), request->rank());
return grpc::Status::OK;
}
grpc::Status FederatedService::Allreduce(grpc::ServerContext*, AllreduceRequest const* request,
AllreduceReply* reply) {
handler_.Allreduce(request->send_buffer().data(), request->send_buffer().size(),
reply->mutable_receive_buffer(), request->sequence_number(), request->rank(),
static_cast<xgboost::collective::DataType>(request->data_type()),
static_cast<xgboost::collective::Operation>(request->reduce_operation()));
return grpc::Status::OK;
}
grpc::Status FederatedService::Broadcast(grpc::ServerContext*, BroadcastRequest const* request,
BroadcastReply* reply) {
handler_.Broadcast(request->send_buffer().data(), request->send_buffer().size(),
reply->mutable_receive_buffer(), request->sequence_number(), request->rank(),
request->root());
return grpc::Status::OK;
}
} // namespace federated
FederatedTracker::FederatedTracker(Json const& config) : Tracker{config} {
auto is_secure = RequiredArg<Boolean const>(config, "federated_secure", __func__);
if (is_secure) {
@ -31,7 +63,8 @@ FederatedTracker::FederatedTracker(Json const& config) : Tracker{config} {
std::future<Result> FederatedTracker::Run() {
return std::async([this]() {
std::string const server_address = "0.0.0.0:" + std::to_string(this->port_);
federated::FederatedService service{static_cast<std::int32_t>(this->n_workers_)};
xgboost::collective::federated::FederatedService service{
static_cast<std::int32_t>(this->n_workers_)};
grpc::ServerBuilder builder;
if (this->server_cert_file_.empty()) {
@ -42,7 +75,6 @@ std::future<Result> FederatedTracker::Run() {
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials());
}
builder.RegisterService(&service);
server_ = builder.BuildAndStart();
LOG(CONSOLE) << "Insecure federated server listening on " << server_address << ", world size "
<< this->n_workers_;
} else {
@ -60,12 +92,12 @@ std::future<Result> FederatedTracker::Run() {
builder.AddListeningPort(server_address, grpc::SslServerCredentials(options));
}
builder.RegisterService(&service);
server_ = builder.BuildAndStart();
LOG(CONSOLE) << "Federated server listening on " << server_address << ", world size "
<< n_workers_;
}
try {
server_ = builder.BuildAndStart();
server_->Wait();
} catch (std::exception const& e) {
return collective::Fail(std::string{e.what()});

View File

@ -8,11 +8,35 @@
#include <memory> // for unique_ptr
#include <string> // for string
#include "../../src/collective/in_memory_handler.h"
#include "../../src/collective/tracker.h" // for Tracker
#include "xgboost/collective/result.h" // for Result
#include "xgboost/json.h" // for Json
namespace xgboost::collective {
namespace federated {
class FederatedService final : public Federated::Service {
public:
explicit FederatedService(std::int32_t world_size)
: handler_{static_cast<std::size_t>(world_size)} {}
grpc::Status Allgather(grpc::ServerContext* context, AllgatherRequest const* request,
AllgatherReply* reply) override;
grpc::Status AllgatherV(grpc::ServerContext* context, AllgatherVRequest const* request,
AllgatherVReply* reply) override;
grpc::Status Allreduce(grpc::ServerContext* context, AllreduceRequest const* request,
AllreduceReply* reply) override;
grpc::Status Broadcast(grpc::ServerContext* context, BroadcastRequest const* request,
BroadcastReply* reply) override;
private:
xgboost::collective::InMemoryHandler handler_;
};
}; // namespace federated
class FederatedTracker : public collective::Tracker {
std::unique_ptr<grpc::Server> server_;
std::string server_key_path_;

View File

@ -24,7 +24,7 @@ class Coll : public std::enable_shared_from_this<Coll> {
Coll() = default;
virtual ~Coll() noexcept(false) {} // NOLINT
Coll* MakeCUDAVar();
virtual Coll* MakeCUDAVar();
/**
* @brief Allreduce

View File

@ -9,7 +9,8 @@
#include <string> // for string
#include <utility> // for move, forward
#include "allgather.h"
#include "../common/common.h" // for AssertGPUSupport
#include "allgather.h" // for RingAllgather
#include "protocol.h" // for kMagic
#include "xgboost/base.h" // for XGBOOST_STRICT_R_MODE
#include "xgboost/collective/socket.h" // for TCPSocket
@ -48,6 +49,14 @@ Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, st
this->Rank(), this->World());
}
#if !defined(XGBOOST_USE_NCCL)
Comm* Comm::MakeCUDAVar(Context const*, std::shared_ptr<Coll>) const {
common::AssertGPUSupport();
common::AssertNCCLSupport();
return nullptr;
}
#endif // !defined(XGBOOST_USE_NCCL)
[[nodiscard]] Result ConnectWorkers(Comm const& comm, TCPSocket* listener, std::int32_t lport,
proto::PeerInfo ninfo, std::chrono::seconds timeout,
std::int32_t retry,

View File

@ -20,14 +20,14 @@
namespace xgboost::collective {
namespace {
Result GetUniqueId(Comm const& comm, ncclUniqueId* pid) {
Result GetUniqueId(Comm const& comm, std::shared_ptr<Coll> coll, ncclUniqueId* pid) {
static const int kRootRank = 0;
ncclUniqueId id;
if (comm.Rank() == kRootRank) {
dh::safe_nccl(ncclGetUniqueId(&id));
}
auto rc = Broadcast(comm, common::Span{reinterpret_cast<std::int8_t*>(&id), sizeof(ncclUniqueId)},
kRootRank);
auto rc = coll->Broadcast(
comm, common::Span{reinterpret_cast<std::int8_t*>(&id), sizeof(ncclUniqueId)}, kRootRank);
if (!rc.OK()) {
return rc;
}
@ -63,7 +63,7 @@ static std::string PrintUUID(xgboost::common::Span<std::uint64_t, kUuidLength> c
}
} // namespace
Comm* Comm::MakeCUDAVar(Context const* ctx, std::shared_ptr<Coll> pimpl) {
Comm* Comm::MakeCUDAVar(Context const* ctx, std::shared_ptr<Coll> pimpl) const {
return new NCCLComm{ctx, *this, pimpl};
}
@ -86,6 +86,7 @@ NCCLComm::NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr<Coll> p
GetCudaUUID(s_this_uuid, ctx->Device());
auto rc = pimpl->Allgather(root, common::EraseType(s_uuid), s_this_uuid.size_bytes());
CHECK(rc.OK()) << rc.Report();
std::vector<xgboost::common::Span<std::uint64_t, kUuidLength>> converted(root.World());
@ -103,7 +104,7 @@ NCCLComm::NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr<Coll> p
<< "Multiple processes within communication group running on same CUDA "
<< "device is not supported. " << PrintUUID(s_this_uuid) << "\n";
rc = GetUniqueId(root, &nccl_unique_id_);
rc = GetUniqueId(root, pimpl, &nccl_unique_id_);
CHECK(rc.OK()) << rc.Report();
dh::safe_nccl(ncclCommInitRank(&nccl_comm_, root.World(), nccl_unique_id_, root.Rank()));

View File

@ -40,7 +40,7 @@ class Coll;
/**
* @brief Base communicator storing info about the tracker and other communicators.
*/
class Comm {
class Comm : public std::enable_shared_from_this<Comm> {
protected:
std::int32_t world_{-1};
std::int32_t rank_{0};
@ -87,7 +87,7 @@ class Comm {
[[nodiscard]] virtual Result SignalError(Result const&) { return Success(); }
Comm* MakeCUDAVar(Context const* ctx, std::shared_ptr<Coll> pimpl);
virtual Comm* MakeCUDAVar(Context const* ctx, std::shared_ptr<Coll> pimpl) const;
};
class RabitComm : public Comm {

View File

@ -170,6 +170,12 @@ inline void AssertGPUSupport() {
#endif // XGBOOST_USE_CUDA && XGBOOST_USE_HIP
}
inline void AssertNCCLSupport() {
#if !defined(XGBOOST_USE_NCCL)
LOG(FATAL) << "XGBoost version not compiled with NCCL support.";
#endif // !defined(XGBOOST_USE_NCCL)
}
inline void AssertOneAPISupport() {
#ifndef XGBOOST_USE_ONEAPI
LOG(FATAL) << "XGBoost version not compiled with OneAPI support.";

View File

@ -0,0 +1,94 @@
/**
* Copyright 2022-2023, XGBoost contributors
*/
#include <gtest/gtest.h>
#include <xgboost/span.h> // for Span
#include <array> // for array
#include "../../../../src/common/type.h" // for EraseType
#include "../../collective/test_worker.h" // for SocketTest
#include "federated_coll.h" // for FederatedColl
#include "federated_comm.h" // for FederatedComm
#include "test_worker.h" // for TestFederated
namespace xgboost::collective {
namespace {
class FederatedCollTest : public SocketTest {};
} // namespace
TEST_F(FederatedCollTest, Allreduce) {
std::int32_t n_workers = std::min(std::thread::hardware_concurrency(), 3u);
TestFederated(n_workers, [=](std::shared_ptr<FederatedComm> comm, std::int32_t) {
std::array<std::int32_t, 5> buffer = {1, 2, 3, 4, 5};
std::array<std::int32_t, 5> expected;
std::transform(buffer.cbegin(), buffer.cend(), expected.begin(),
[=](auto i) { return i * n_workers; });
auto coll = std::make_shared<FederatedColl>();
auto rc = coll->Allreduce(*comm, common::EraseType(common::Span{buffer.data(), buffer.size()}),
ArrayInterfaceHandler::kI4, Op::kSum);
ASSERT_TRUE(rc.OK());
for (auto i = 0; i < 5; i++) {
ASSERT_EQ(buffer[i], expected[i]);
}
});
}
TEST_F(FederatedCollTest, Broadcast) {
std::int32_t n_workers = std::min(std::thread::hardware_concurrency(), 3u);
TestFederated(n_workers, [=](std::shared_ptr<FederatedComm> comm, std::int32_t) {
FederatedColl coll{};
auto rc = Success();
if (comm->Rank() == 0) {
std::string buffer{"hello"};
rc = coll.Broadcast(*comm, common::EraseType(common::Span{buffer.data(), buffer.size()}), 0);
ASSERT_EQ(buffer, "hello");
} else {
std::string buffer{" "};
rc = coll.Broadcast(*comm, common::EraseType(common::Span{buffer.data(), buffer.size()}), 0);
ASSERT_EQ(buffer, "hello");
}
ASSERT_TRUE(rc.OK());
});
}
TEST_F(FederatedCollTest, Allgather) {
std::int32_t n_workers = std::min(std::thread::hardware_concurrency(), 3u);
TestFederated(n_workers, [=](std::shared_ptr<FederatedComm> comm, std::int32_t) {
FederatedColl coll{};
std::vector<std::int32_t> buffer(n_workers, 0);
buffer[comm->Rank()] = comm->Rank();
auto rc = coll.Allgather(*comm, common::EraseType(common::Span{buffer.data(), buffer.size()}),
sizeof(int));
ASSERT_TRUE(rc.OK());
for (auto i = 0; i < n_workers; i++) {
ASSERT_EQ(buffer[i], i);
}
});
}
TEST_F(FederatedCollTest, AllgatherV) {
std::int32_t n_workers = 2;
TestFederated(n_workers, [=](std::shared_ptr<FederatedComm> comm, std::int32_t) {
FederatedColl coll{};
std::vector<std::string_view> inputs{"Federated", " Learning!!!"};
std::vector<std::int64_t> recv_segments(inputs.size() + 1, 0);
std::string r;
std::vector<std::int64_t> sizes{static_cast<std::int64_t>(inputs[0].size()),
static_cast<std::int64_t>(inputs[1].size())};
r.resize(sizes[0] + sizes[1]);
auto rc = coll.AllgatherV(
*comm,
common::EraseType(common::Span{inputs[comm->Rank()].data(), inputs[comm->Rank()].size()}),
common::Span{sizes.data(), sizes.size()}, recv_segments,
common::EraseType(common::Span{r.data(), r.size()}), AllgatherVAlgo::kRing);
EXPECT_EQ(r, "Federated Learning!!!");
ASSERT_TRUE(rc.OK());
});
}
} // namespace xgboost::collective

View File

@ -0,0 +1,131 @@
/**
* Copyright 2022-2023, XGBoost contributors
*/
#include <gtest/gtest.h>
#include <xgboost/collective/result.h> // for Result
#include "../../../../src/common/common.h" // for AllVisibleGPUs
#include "../../../../src/common/device_helpers.cuh" // for device_vector
#include "../../../../src/common/type.h" // for EraseType
#include "../../collective/test_worker.h" // for SocketTest
#include "../../helpers.h" // for MakeCUDACtx
#include "federated_coll.cuh"
#include "federated_comm.cuh"
#include "test_worker.h" // for TestFederated
namespace xgboost::collective {
namespace {
class FederatedCollTestGPU : public SocketTest {};
struct Worker {
std::shared_ptr<FederatedColl> impl;
std::shared_ptr<Comm> nccl_comm;
std::shared_ptr<CUDAFederatedColl> coll;
Worker(std::shared_ptr<FederatedComm> comm, std::int32_t rank) {
auto ctx = MakeCUDACtx(rank);
impl = std::make_shared<FederatedColl>();
nccl_comm.reset(comm->MakeCUDAVar(&ctx, impl));
coll = std::make_shared<CUDAFederatedColl>(impl);
}
};
void TestAllreduce(std::shared_ptr<FederatedComm> comm, std::int32_t rank, std::int32_t n_workers) {
Worker w{comm, rank};
dh::device_vector<std::int32_t> buffer{std::vector<std::int32_t>{1, 2, 3, 4, 5}};
dh::device_vector<std::int32_t> expected(buffer.size());
thrust::transform(buffer.cbegin(), buffer.cend(), expected.begin(),
[=] XGBOOST_DEVICE(std::int32_t i) { return i * n_workers; });
auto rc = w.coll->Allreduce(*w.nccl_comm, common::EraseType(dh::ToSpan(buffer)),
ArrayInterfaceHandler::kI4, Op::kSum);
ASSERT_TRUE(rc.OK());
for (auto i = 0; i < 5; i++) {
ASSERT_EQ(buffer[i], expected[i]);
}
}
void TestBroadcast(std::shared_ptr<FederatedComm> comm, std::int32_t rank) {
Worker w{comm, rank};
auto rc = Success();
std::vector<std::int32_t> expect{0, 1, 2, 3};
if (comm->Rank() == 0) {
dh::device_vector<std::int32_t> buffer{expect};
rc = w.coll->Broadcast(*w.nccl_comm, common::EraseType(dh::ToSpan(buffer)), 0);
std::vector<std::int32_t> expect{0, 1, 2, 3};
ASSERT_EQ(buffer, expect);
} else {
dh::device_vector<std::int32_t> buffer(std::vector<std::int32_t>{4, 5, 6, 7});
rc = w.coll->Broadcast(*w.nccl_comm, common::EraseType(dh::ToSpan(buffer)), 0);
ASSERT_EQ(buffer, expect);
}
ASSERT_TRUE(rc.OK());
}
void TestAllgather(std::shared_ptr<FederatedComm> comm, std::int32_t rank, std::int32_t n_workers) {
Worker w{comm, rank};
dh::device_vector<std::int32_t> buffer(n_workers, 0);
buffer[comm->Rank()] = comm->Rank();
auto rc = w.coll->Allgather(*w.nccl_comm, common::EraseType(dh::ToSpan(buffer)), sizeof(int));
ASSERT_TRUE(rc.OK());
for (auto i = 0; i < n_workers; i++) {
ASSERT_EQ(buffer[i], i);
}
}
void TestAllgatherV(std::shared_ptr<FederatedComm> comm, std::int32_t rank) {
Worker w{comm, rank};
std::vector<dh::device_vector<std::int32_t>> inputs{std::vector<std::int32_t>{1, 2, 3},
std::vector<std::int32_t>{4, 5}};
std::vector<std::int64_t> recv_segments(inputs.size() + 1, 0);
dh::device_vector<std::int32_t> r;
std::vector<std::int64_t> sizes{static_cast<std::int64_t>(inputs[0].size()),
static_cast<std::int64_t>(inputs[1].size())};
r.resize(sizes[0] + sizes[1]);
auto rc = w.coll->AllgatherV(*w.nccl_comm, common::EraseType(dh::ToSpan(inputs[comm->Rank()])),
common::Span{sizes.data(), sizes.size()}, recv_segments,
common::EraseType(dh::ToSpan(r)), AllgatherVAlgo::kRing);
ASSERT_TRUE(rc.OK());
ASSERT_EQ(r[0], 1);
for (std::size_t i = 1; i < r.size(); ++i) {
ASSERT_EQ(r[i], r[i - 1] + 1);
}
}
} // namespace
TEST_F(FederatedCollTestGPU, Allreduce) {
std::int32_t n_workers = common::AllVisibleGPUs();
TestFederated(n_workers, [=](std::shared_ptr<FederatedComm> comm, std::int32_t rank) {
TestAllreduce(comm, rank, n_workers);
});
}
TEST_F(FederatedCollTestGPU, Broadcast) {
std::int32_t n_workers = common::AllVisibleGPUs();
TestFederated(n_workers, [=](std::shared_ptr<FederatedComm> comm, std::int32_t rank) {
TestBroadcast(comm, rank);
});
}
TEST_F(FederatedCollTestGPU, Allgather) {
std::int32_t n_workers = common::AllVisibleGPUs();
TestFederated(n_workers, [=](std::shared_ptr<FederatedComm> comm, std::int32_t rank) {
TestAllgather(comm, rank, n_workers);
});
}
TEST_F(FederatedCollTestGPU, AllgatherV) {
std::int32_t n_workers = 2;
TestFederated(n_workers, [=](std::shared_ptr<FederatedComm> comm, std::int32_t rank) {
TestAllgatherV(comm, rank);
});
}
} // namespace xgboost::collective

View File

@ -7,10 +7,10 @@
#include <thread> // for thread
#include "../../../../plugin/federated/federated_comm.h"
#include "../../collective/net_test.h" // for SocketTest
#include "../../helpers.h" // for ExpectThrow
#include "test_worker.h" // for TestFederated
#include "xgboost/json.h" // for Json
#include "../../collective/test_worker.h" // for SocketTest
#include "../../helpers.h" // for ExpectThrow
#include "test_worker.h" // for TestFederated
#include "xgboost/json.h" // for Json
namespace xgboost::collective {
namespace {
@ -71,14 +71,9 @@ TEST_F(FederatedCommTest, IsDistributed) {
TEST_F(FederatedCommTest, InsecureTracker) {
std::int32_t n_workers = std::min(std::thread::hardware_concurrency(), 3u);
TestFederated(n_workers, [=](std::int32_t port, std::int32_t rank) {
Json config{Object{}};
config["federated_world_size"] = n_workers;
config["federated_rank"] = rank;
config["federated_server_address"] = "0.0.0.0:" + std::to_string(port);
FederatedComm comm{config};
ASSERT_EQ(comm.Rank(), rank);
ASSERT_EQ(comm.World(), n_workers);
TestFederated(n_workers, [=](std::shared_ptr<FederatedComm> comm, std::int32_t rank) {
ASSERT_EQ(comm->Rank(), rank);
ASSERT_EQ(comm->World(), n_workers);
});
}
} // namespace xgboost::collective

View File

@ -9,7 +9,8 @@
#include <thread> // for thread
#include "../../../../plugin/federated/federated_tracker.h"
#include "xgboost/json.h" // for Json
#include "federated_comm.h" // for FederatedComm
#include "xgboost/json.h" // for Json
namespace xgboost::collective {
template <typename WorkerFn>
@ -28,7 +29,15 @@ void TestFederated(std::int32_t n_workers, WorkerFn&& fn) {
std::int32_t port = tracker.Port();
for (std::int32_t i = 0; i < n_workers; ++i) {
workers.emplace_back([=] { fn(port, i); });
workers.emplace_back([=] {
Json config{Object{}};
config["federated_world_size"] = n_workers;
config["federated_rank"] = i;
config["federated_server_address"] = "0.0.0.0:" + std::to_string(port);
auto comm = std::make_shared<FederatedComm>(config);
fn(comm, i);
});
}
for (auto& t : workers) {