diff --git a/R-package/R/xgb.plot.importance.R b/R-package/R/xgb.plot.importance.R index 2c02d5a42..07220375d 100644 --- a/R-package/R/xgb.plot.importance.R +++ b/R-package/R/xgb.plot.importance.R @@ -87,7 +87,13 @@ xgb.plot.importance <- function(importance_matrix = NULL, top_n = NULL, measure } # also aggregate, just in case when the values were not yet summed up by feature - importance_matrix <- importance_matrix[, Importance := sum(get(measure)), by = Feature] + importance_matrix <- importance_matrix[ + , lapply(.SD, sum) + , .SDcols = setdiff(names(importance_matrix), "Feature") + , by = Feature + ][ + , Importance := get(measure) + ] # make sure it's ordered importance_matrix <- importance_matrix[order(-abs(Importance))] diff --git a/R-package/tests/testthat/test_helpers.R b/R-package/tests/testthat/test_helpers.R index 04e034ce1..de6a099fc 100644 --- a/R-package/tests/testthat/test_helpers.R +++ b/R-package/tests/testthat/test_helpers.R @@ -382,6 +382,9 @@ test_that("xgb.importance works with GLM model", { expect_equal(colnames(imp2plot), c("Feature", "Weight", "Importance")) xgb.ggplot.importance(importance.GLM) + # check that the input is not modified in-place + expect_false("Importance" %in% names(importance.GLM)) + # for multiclass imp.GLM <- xgb.importance(model = mbst.GLM) expect_equal(dim(imp.GLM), c(12, 3)) @@ -400,6 +403,16 @@ test_that("xgb.model.dt.tree and xgb.importance work with a single split model", expect_equal(imp$Gain, 1) }) +test_that("xgb.plot.importance de-duplicates features", { + importances <- data.table( + Feature = c("col1", "col2", "col2"), + Gain = c(0.4, 0.3, 0.3) + ) + imp2plot <- xgb.plot.importance(importances) + expect_equal(nrow(imp2plot), 2L) + expect_equal(imp2plot$Feature, c("col2", "col1")) +}) + test_that("xgb.plot.tree works with and without feature names", { .skip_if_vcd_not_available() expect_silent(xgb.plot.tree(feature_names = feature.names, model = bst.Tree)) diff --git a/plugin/federated/CMakeLists.txt b/plugin/federated/CMakeLists.txt index 7c2cfa6fb..c4d5ea378 100644 --- a/plugin/federated/CMakeLists.txt +++ b/plugin/federated/CMakeLists.txt @@ -22,12 +22,35 @@ protobuf_generate( PLUGIN "protoc-gen-grpc=\$" PROTOC_OUT_DIR "${PROTO_BINARY_DIR}") +add_library(federated_old_proto STATIC federated.old.proto) +target_link_libraries(federated_old_proto PUBLIC protobuf::libprotobuf gRPC::grpc gRPC::grpc++) +target_include_directories(federated_old_proto PUBLIC ${CMAKE_CURRENT_BINARY_DIR}) +xgboost_target_properties(federated_old_proto) + +protobuf_generate( + TARGET federated_old_proto + LANGUAGE cpp + PROTOC_OUT_DIR "${PROTO_BINARY_DIR}") +protobuf_generate( + TARGET federated_old_proto + LANGUAGE grpc + GENERATE_EXTENSIONS .grpc.pb.h .grpc.pb.cc + PLUGIN "protoc-gen-grpc=\$" + PROTOC_OUT_DIR "${PROTO_BINARY_DIR}") + # Wrapper for the gRPC client. add_library(federated_client INTERFACE) target_sources(federated_client INTERFACE federated_client.h) target_link_libraries(federated_client INTERFACE federated_proto) +target_link_libraries(federated_client INTERFACE federated_old_proto) # Rabit engine for Federated Learning. -target_sources(objxgboost PRIVATE federated_tracker.cc federated_server.cc federated_comm.cc) +target_sources( + objxgboost PRIVATE federated_tracker.cc federated_server.cc federated_comm.cc federated_coll.cc +) +if(USE_CUDA) + target_sources(objxgboost PRIVATE federated_comm.cu federated_coll.cu) +endif() + target_link_libraries(objxgboost PRIVATE federated_client "-Wl,--exclude-libs,ALL") target_compile_definitions(objxgboost PUBLIC -DXGBOOST_USE_FEDERATED=1) diff --git a/plugin/federated/federated.old.proto b/plugin/federated/federated.old.proto new file mode 100644 index 000000000..8450659fd --- /dev/null +++ b/plugin/federated/federated.old.proto @@ -0,0 +1,81 @@ +/*! + * Copyright 2022 XGBoost contributors + */ +syntax = "proto3"; + +package xgboost.federated; + +service Federated { + rpc Allgather(AllgatherRequest) returns (AllgatherReply) {} + rpc AllgatherV(AllgatherVRequest) returns (AllgatherVReply) {} + rpc Allreduce(AllreduceRequest) returns (AllreduceReply) {} + rpc Broadcast(BroadcastRequest) returns (BroadcastReply) {} +} + +enum DataType { + INT8 = 0; + UINT8 = 1; + INT32 = 2; + UINT32 = 3; + INT64 = 4; + UINT64 = 5; + FLOAT = 6; + DOUBLE = 7; +} + +enum ReduceOperation { + MAX = 0; + MIN = 1; + SUM = 2; + BITWISE_AND = 3; + BITWISE_OR = 4; + BITWISE_XOR = 5; +} + +message AllgatherRequest { + // An incrementing counter that is unique to each round to operations. + uint64 sequence_number = 1; + int32 rank = 2; + bytes send_buffer = 3; +} + +message AllgatherReply { + bytes receive_buffer = 1; +} + +message AllgatherVRequest { + // An incrementing counter that is unique to each round to operations. + uint64 sequence_number = 1; + int32 rank = 2; + bytes send_buffer = 3; +} + +message AllgatherVReply { + bytes receive_buffer = 1; +} + +message AllreduceRequest { + // An incrementing counter that is unique to each round to operations. + uint64 sequence_number = 1; + int32 rank = 2; + bytes send_buffer = 3; + DataType data_type = 4; + ReduceOperation reduce_operation = 5; +} + +message AllreduceReply { + bytes receive_buffer = 1; +} + +message BroadcastRequest { + // An incrementing counter that is unique to each round to operations. + uint64 sequence_number = 1; + int32 rank = 2; + bytes send_buffer = 3; + // The root rank to broadcast from. + int32 root = 4; +} + +message BroadcastReply { + bytes receive_buffer = 1; +} diff --git a/plugin/federated/federated.proto b/plugin/federated/federated.proto index 8450659fd..fbc2adf50 100644 --- a/plugin/federated/federated.proto +++ b/plugin/federated/federated.proto @@ -1,9 +1,9 @@ /*! - * Copyright 2022 XGBoost contributors + * Copyright 2022-2023 XGBoost contributors */ syntax = "proto3"; -package xgboost.federated; +package xgboost.collective.federated; service Federated { rpc Allgather(AllgatherRequest) returns (AllgatherReply) {} @@ -13,14 +13,18 @@ service Federated { } enum DataType { - INT8 = 0; - UINT8 = 1; - INT32 = 2; - UINT32 = 3; - INT64 = 4; - UINT64 = 5; - FLOAT = 6; - DOUBLE = 7; + HALF = 0; + FLOAT = 1; + DOUBLE = 2; + LONG_DOUBLE = 3; + INT8 = 4; + INT16 = 5; + INT32 = 6; + INT64 = 7; + UINT8 = 8; + UINT16 = 9; + UINT32 = 10; + UINT64 = 11; } enum ReduceOperation { diff --git a/plugin/federated/federated_client.h b/plugin/federated/federated_client.h index ac1fbd57d..0122a5cfe 100644 --- a/plugin/federated/federated_client.h +++ b/plugin/federated/federated_client.h @@ -2,8 +2,8 @@ * Copyright 2022 XGBoost contributors */ #pragma once -#include -#include +#include +#include #include #include diff --git a/plugin/federated/federated_coll.cc b/plugin/federated/federated_coll.cc new file mode 100644 index 000000000..7c25eeba5 --- /dev/null +++ b/plugin/federated/federated_coll.cc @@ -0,0 +1,155 @@ +/** + * Copyright 2023, XGBoost contributors + */ +#include "federated_coll.h" + +#include +#include + +#include // for copy_n + +#include "../../src/collective/allgather.h" +#include "../../src/common/common.h" // for AssertGPUSupport +#include "federated_comm.h" // for FederatedComm +#include "xgboost/collective/result.h" // for Result + +namespace xgboost::collective { +namespace { +[[nodiscard]] Result GetGRPCResult(std::string const &name, grpc::Status const &status) { + return Fail(name + " RPC failed. " + std::to_string(status.error_code()) + ": " + + status.error_message()); +} + +[[nodiscard]] Result BroadcastImpl(Comm const &comm, std::uint64_t *sequence_number, + common::Span data, std::int32_t root) { + using namespace federated; // NOLINT + + auto fed = dynamic_cast(&comm); + CHECK(fed); + auto stub = fed->Handle(); + + BroadcastRequest request; + request.set_sequence_number(*sequence_number++); + request.set_rank(comm.Rank()); + if (comm.Rank() != root) { + request.set_send_buffer(nullptr, 0); + } else { + request.set_send_buffer(data.data(), data.size()); + } + request.set_root(root); + + BroadcastReply reply; + grpc::ClientContext context; + context.set_wait_for_ready(true); + grpc::Status status = stub->Broadcast(&context, request, &reply); + if (!status.ok()) { + return GetGRPCResult("Broadcast", status); + } + if (comm.Rank() != root) { + auto const &r = reply.receive_buffer(); + std::copy_n(r.cbegin(), r.size(), data.data()); + } + + return Success(); +} +} // namespace + +#if !defined(XGBOOST_USE_CUDA) +Coll *FederatedColl::MakeCUDAVar() { + common::AssertGPUSupport(); + return nullptr; +} +#endif + +[[nodiscard]] Result FederatedColl::Allreduce(Comm const &comm, common::Span data, + ArrayInterfaceHandler::Type type, Op op) { + using namespace federated; // NOLINT + auto fed = dynamic_cast(&comm); + CHECK(fed); + auto stub = fed->Handle(); + + AllreduceRequest request; + request.set_sequence_number(sequence_number_++); + request.set_rank(comm.Rank()); + request.set_send_buffer(data.data(), data.size()); + request.set_data_type(static_cast<::xgboost::collective::federated::DataType>(type)); + request.set_reduce_operation(static_cast<::xgboost::collective::federated::ReduceOperation>(op)); + + AllreduceReply reply; + grpc::ClientContext context; + context.set_wait_for_ready(true); + grpc::Status status = stub->Allreduce(&context, request, &reply); + if (!status.ok()) { + return GetGRPCResult("Allreduce", status); + } + auto const &r = reply.receive_buffer(); + std::copy_n(r.cbegin(), r.size(), data.data()); + return Success(); +} + +[[nodiscard]] Result FederatedColl::Broadcast(Comm const &comm, common::Span data, + std::int32_t root) { + if (comm.Rank() == root) { + return BroadcastImpl(comm, &sequence_number_, data, root); + } else { + return BroadcastImpl(comm, &sequence_number_, data, root); + } +} + +[[nodiscard]] Result FederatedColl::Allgather(Comm const &comm, common::Span data, + std::int64_t size) { + using namespace federated; // NOLINT + auto fed = dynamic_cast(&comm); + CHECK(fed); + auto stub = fed->Handle(); + + auto offset = comm.Rank() * size; + auto segment = data.subspan(offset, size); + + AllgatherRequest request; + request.set_sequence_number(sequence_number_++); + request.set_rank(comm.Rank()); + request.set_send_buffer(segment.data(), segment.size()); + + AllgatherReply reply; + grpc::ClientContext context; + context.set_wait_for_ready(true); + grpc::Status status = stub->Allgather(&context, request, &reply); + + if (!status.ok()) { + return GetGRPCResult("Allgather", status); + } + auto const &r = reply.receive_buffer(); + std::copy_n(r.cbegin(), r.size(), data.begin()); + return Success(); +} + +[[nodiscard]] Result FederatedColl::AllgatherV(Comm const &comm, + common::Span data, + common::Span, + common::Span, + common::Span recv, AllgatherVAlgo) { + using namespace federated; // NOLINT + + auto fed = dynamic_cast(&comm); + CHECK(fed); + auto stub = fed->Handle(); + + AllgatherVRequest request; + request.set_sequence_number(sequence_number_++); + request.set_rank(comm.Rank()); + request.set_send_buffer(data.data(), data.size()); + + AllgatherVReply reply; + grpc::ClientContext context; + context.set_wait_for_ready(true); + grpc::Status status = stub->AllgatherV(&context, request, &reply); + if (!status.ok()) { + return GetGRPCResult("AllgatherV", status); + } + std::string const &r = reply.receive_buffer(); + CHECK_EQ(r.size(), recv.size()); + std::copy_n(r.cbegin(), r.size(), recv.begin()); + return Success(); +} +} // namespace xgboost::collective diff --git a/plugin/federated/federated_coll.cu b/plugin/federated/federated_coll.cu new file mode 100644 index 000000000..a922e1c11 --- /dev/null +++ b/plugin/federated/federated_coll.cu @@ -0,0 +1,92 @@ +/** + * Copyright 2023, XGBoost Contributors + */ +#include // for int8_t, int32_t +#include // for dynamic_pointer_cast +#include // for vector + +#include "../../src/collective/comm.cuh" +#include "../../src/common/cuda_context.cuh" // for CUDAContext +#include "../../src/data/array_interface.h" // for ArrayInterfaceHandler::Type +#include "federated_coll.cuh" +#include "federated_comm.cuh" +#include "xgboost/collective/result.h" // for Result +#include "xgboost/span.h" // for Span + +namespace xgboost::collective { +Coll *FederatedColl::MakeCUDAVar() { + return new CUDAFederatedColl{std::dynamic_pointer_cast(this->shared_from_this())}; +} + +[[nodiscard]] Result CUDAFederatedColl::Allreduce(Comm const &comm, common::Span data, + ArrayInterfaceHandler::Type type, Op op) { + auto cufed = dynamic_cast(&comm); + CHECK(cufed); + + std::vector h_data(data.size()); + + return Success() << [&] { + return GetCUDAResult( + cudaMemcpy(h_data.data(), data.data(), data.size(), cudaMemcpyDeviceToHost)); + } << [&] { + return p_impl_->Allreduce(comm, common::Span{h_data.data(), h_data.size()}, type, op); + } << [&] { + return GetCUDAResult(cudaMemcpyAsync(data.data(), h_data.data(), data.size(), + cudaMemcpyHostToDevice, cufed->Stream())); + }; +} + +[[nodiscard]] Result CUDAFederatedColl::Broadcast(Comm const &comm, common::Span data, + std::int32_t root) { + auto cufed = dynamic_cast(&comm); + CHECK(cufed); + std::vector h_data(data.size()); + + return Success() << [&] { + return GetCUDAResult( + cudaMemcpy(h_data.data(), data.data(), data.size(), cudaMemcpyDeviceToHost)); + } << [&] { + return p_impl_->Broadcast(comm, common::Span{h_data.data(), h_data.size()}, root); + } << [&] { + return GetCUDAResult(cudaMemcpyAsync(data.data(), h_data.data(), data.size(), + cudaMemcpyHostToDevice, cufed->Stream())); + }; +} + +[[nodiscard]] Result CUDAFederatedColl::Allgather(Comm const &comm, common::Span data, + std::int64_t size) { + auto cufed = dynamic_cast(&comm); + CHECK(cufed); + std::vector h_data(data.size()); + + return Success() << [&] { + return GetCUDAResult( + cudaMemcpy(h_data.data(), data.data(), data.size(), cudaMemcpyDeviceToHost)); + } << [&] { + return p_impl_->Allgather(comm, common::Span{h_data.data(), h_data.size()}, size); + } << [&] { + return GetCUDAResult(cudaMemcpyAsync(data.data(), h_data.data(), data.size(), + cudaMemcpyHostToDevice, cufed->Stream())); + }; +} + +[[nodiscard]] Result CUDAFederatedColl::AllgatherV( + Comm const &comm, common::Span data, common::Span sizes, + common::Span recv_segments, common::Span recv, AllgatherVAlgo algo) { + auto cufed = dynamic_cast(&comm); + CHECK(cufed); + + std::vector h_data(data.size()); + std::vector h_recv(recv.size()); + + return Success() << [&] { + return GetCUDAResult( + cudaMemcpy(h_data.data(), data.data(), data.size(), cudaMemcpyDeviceToHost)); + } << [&] { + return this->p_impl_->AllgatherV(comm, h_data, sizes, recv_segments, h_recv, algo); + } << [&] { + return GetCUDAResult(cudaMemcpyAsync(recv.data(), h_recv.data(), h_recv.size(), + cudaMemcpyHostToDevice, cufed->Stream())); + }; +} +} // namespace xgboost::collective diff --git a/plugin/federated/federated_coll.cuh b/plugin/federated/federated_coll.cuh new file mode 100644 index 000000000..a1121d88f --- /dev/null +++ b/plugin/federated/federated_coll.cuh @@ -0,0 +1,26 @@ +/** + * Copyright 2023, XGBoost contributors + */ +#include "../../src/collective/comm.h" // for Comm, Coll +#include "federated_coll.h" // for FederatedColl +#include "xgboost/collective/result.h" // for Result +#include "xgboost/span.h" // for Span + +namespace xgboost::collective { +class CUDAFederatedColl : public Coll { + std::shared_ptr p_impl_; + + public: + explicit CUDAFederatedColl(std::shared_ptr pimpl) : p_impl_{std::move(pimpl)} {} + [[nodiscard]] Result Allreduce(Comm const &comm, common::Span data, + ArrayInterfaceHandler::Type type, Op op) override; + [[nodiscard]] Result Broadcast(Comm const &comm, common::Span data, + std::int32_t root) override; + [[nodiscard]] Result Allgather(Comm const &, common::Span data, + std::int64_t size) override; + [[nodiscard]] Result AllgatherV(Comm const &comm, common::Span data, + common::Span sizes, + common::Span recv_segments, + common::Span recv, AllgatherVAlgo algo) override; +}; +} // namespace xgboost::collective diff --git a/plugin/federated/federated_coll.h b/plugin/federated/federated_coll.h new file mode 100644 index 000000000..c261b01e1 --- /dev/null +++ b/plugin/federated/federated_coll.h @@ -0,0 +1,30 @@ +/** + * Copyright 2023, XGBoost contributors + */ +#pragma once +#include "../../src/collective/coll.h" // for Coll +#include "../../src/collective/comm.h" // for Comm +#include "../../src/common/io.h" // for ReadAll +#include "../../src/common/json_utils.h" // for OptionalArg +#include "xgboost/json.h" // for Json + +namespace xgboost::collective { +class FederatedColl : public Coll { + private: + std::uint64_t sequence_number_{0}; + + public: + Coll *MakeCUDAVar() override; + + [[nodiscard]] Result Allreduce(Comm const &, common::Span data, + ArrayInterfaceHandler::Type type, Op op) override; + [[nodiscard]] Result Broadcast(Comm const &comm, common::Span data, + std::int32_t root) override; + [[nodiscard]] Result Allgather(Comm const &, common::Span data, + std::int64_t) override; + [[nodiscard]] Result AllgatherV(Comm const &comm, common::Span data, + common::Span sizes, + common::Span recv_segments, + common::Span recv, AllgatherVAlgo algo) override; +}; +} // namespace xgboost::collective diff --git a/plugin/federated/federated_comm.cc b/plugin/federated/federated_comm.cc index 4b51fd52d..8a649340f 100644 --- a/plugin/federated/federated_comm.cc +++ b/plugin/federated/federated_comm.cc @@ -7,6 +7,7 @@ #include // for int32_t #include // for getenv +#include // for numeric_limits #include // for string, stoi #include "../../src/common/common.h" // for Split @@ -29,12 +30,18 @@ void FederatedComm::Init(std::string const& host, std::int32_t port, std::int32_ CHECK_GE(rank, 0) << "Invalid worker rank."; CHECK_LT(rank, world) << "Invalid worker rank."; + auto certs = {server_cert, client_cert, client_cert}; + auto is_empty = [](auto const& s) { return s.empty(); }; + bool valid = std::all_of(certs.begin(), certs.end(), is_empty) || + std::none_of(certs.begin(), certs.end(), is_empty); + CHECK(valid) << "Invalid arguments for certificates."; + if (server_cert.empty()) { stub_ = [&] { grpc::ChannelArguments args; - args.SetMaxReceiveMessageSize(std::numeric_limits::max()); - return federated::Federated::NewStub( - grpc::CreateCustomChannel(host, grpc::InsecureChannelCredentials(), args)); + args.SetMaxReceiveMessageSize(std::numeric_limits::max()); + return federated::Federated::NewStub(grpc::CreateCustomChannel( + host + ":" + std::to_string(port), grpc::InsecureChannelCredentials(), args)); }(); } else { stub_ = [&] { @@ -43,8 +50,9 @@ void FederatedComm::Init(std::string const& host, std::int32_t port, std::int32_ options.pem_private_key = client_key; options.pem_cert_chain = client_cert; grpc::ChannelArguments args; - args.SetMaxReceiveMessageSize(std::numeric_limits::max()); - auto channel = grpc::CreateCustomChannel(host, grpc::SslCredentials(options), args); + args.SetMaxReceiveMessageSize(std::numeric_limits::max()); + auto channel = grpc::CreateCustomChannel(host + ":" + std::to_string(port), + grpc::SslCredentials(options), args); channel->WaitForConnected( gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), gpr_time_from_seconds(60, GPR_TIMESPAN))); return federated::Federated::NewStub(channel); @@ -79,7 +87,7 @@ FederatedComm::FederatedComm(Json const& config) { rank = OptionalArg(config, "federated_rank", static_cast(rank)); auto parsed = common::Split(server_address, ':'); - CHECK_EQ(parsed.size(), 2) << "invalid server address:" << server_address; + CHECK_EQ(parsed.size(), 2) << "Invalid server address:" << server_address; CHECK_NE(rank, -1) << "Parameter `federated_rank` is required"; CHECK_NE(world_size, 0) << "Parameter `federated_world_size` is required."; @@ -111,4 +119,11 @@ FederatedComm::FederatedComm(Json const& config) { this->Init(parsed[0], std::stoi(parsed[1]), world_size, rank, server_cert, client_key, client_cert); } + +#if !defined(XGBOOST_USE_CUDA) +Comm* FederatedComm::MakeCUDAVar(Context const*, std::shared_ptr) const { + common::AssertGPUSupport(); + return nullptr; +} +#endif // !defined(XGBOOST_USE_CUDA) } // namespace xgboost::collective diff --git a/plugin/federated/federated_comm.cu b/plugin/federated/federated_comm.cu new file mode 100644 index 000000000..b05d38b1b --- /dev/null +++ b/plugin/federated/federated_comm.cu @@ -0,0 +1,20 @@ +/** + * Copyright 2023, XGBoost Contributors + */ +#include // for shared_ptr + +#include "../../src/common/cuda_context.cuh" +#include "federated_comm.cuh" +#include "xgboost/context.h" // for Context + +namespace xgboost::collective { +CUDAFederatedComm::CUDAFederatedComm(Context const* ctx, std::shared_ptr impl) + : FederatedComm{impl}, stream_{ctx->CUDACtx()->Stream()} { + CHECK(impl); +} + +Comm* FederatedComm::MakeCUDAVar(Context const* ctx, std::shared_ptr) const { + return new CUDAFederatedComm{ + ctx, std::dynamic_pointer_cast(this->shared_from_this())}; +} +} // namespace xgboost::collective diff --git a/plugin/federated/federated_comm.cuh b/plugin/federated/federated_comm.cuh new file mode 100644 index 000000000..df9127644 --- /dev/null +++ b/plugin/federated/federated_comm.cuh @@ -0,0 +1,20 @@ +/** + * Copyright 2023, XGBoost Contributors + */ +#pragma once + +#include // for shared_ptr + +#include "../../src/common/device_helpers.cuh" // for CUDAStreamView +#include "federated_comm.h" // for FederatedComm +#include "xgboost/context.h" // for Context + +namespace xgboost::collective { +class CUDAFederatedComm : public FederatedComm { + dh::CUDAStreamView stream_; + + public: + explicit CUDAFederatedComm(Context const* ctx, std::shared_ptr impl); + [[nodiscard]] auto Stream() const { return stream_; } +}; +} // namespace xgboost::collective diff --git a/plugin/federated/federated_comm.h b/plugin/federated/federated_comm.h index 8e6fe7d67..fb97a78b0 100644 --- a/plugin/federated/federated_comm.h +++ b/plugin/federated/federated_comm.h @@ -16,12 +16,20 @@ namespace xgboost::collective { class FederatedComm : public Comm { - std::unique_ptr stub_; + std::shared_ptr stub_; void Init(std::string const& host, std::int32_t port, std::int32_t world, std::int32_t rank, std::string const& server_cert, std::string const& client_key, std::string const& client_cert); + protected: + explicit FederatedComm(std::shared_ptr that) : stub_{that->stub_} { + this->rank_ = that->Rank(); + this->world_ = that->World(); + + this->tracker_ = that->TrackerInfo(); + } + public: /** * @param config @@ -49,5 +57,8 @@ class FederatedComm : public Comm { return Success(); } [[nodiscard]] bool IsFederated() const override { return true; } + [[nodiscard]] federated::Federated::Stub* Handle() const { return stub_.get(); } + + Comm* MakeCUDAVar(Context const* ctx, std::shared_ptr pimpl) const override; }; } // namespace xgboost::collective diff --git a/plugin/federated/federated_server.h b/plugin/federated/federated_server.h index 20f3149f9..de760d9d8 100644 --- a/plugin/federated/federated_server.h +++ b/plugin/federated/federated_server.h @@ -3,7 +3,7 @@ */ #pragma once -#include +#include #include // for int32_t #include // for future diff --git a/plugin/federated/federated_tracker.cc b/plugin/federated/federated_tracker.cc index 3dad9d7ce..aca468d32 100644 --- a/plugin/federated/federated_tracker.cc +++ b/plugin/federated/federated_tracker.cc @@ -16,9 +16,41 @@ #include "../../src/common/io.h" // for ReadAll #include "../../src/common/json_utils.h" // for RequiredArg #include "../../src/common/timer.h" // for Timer -#include "federated_server.h" // for FederatedService namespace xgboost::collective { +namespace federated { +grpc::Status FederatedService::Allgather(grpc::ServerContext*, AllgatherRequest const* request, + AllgatherReply* reply) { + handler_.Allgather(request->send_buffer().data(), request->send_buffer().size(), + reply->mutable_receive_buffer(), request->sequence_number(), request->rank()); + return grpc::Status::OK; +} + +grpc::Status FederatedService::AllgatherV(grpc::ServerContext*, AllgatherVRequest const* request, + AllgatherVReply* reply) { + handler_.AllgatherV(request->send_buffer().data(), request->send_buffer().size(), + reply->mutable_receive_buffer(), request->sequence_number(), request->rank()); + return grpc::Status::OK; +} + +grpc::Status FederatedService::Allreduce(grpc::ServerContext*, AllreduceRequest const* request, + AllreduceReply* reply) { + handler_.Allreduce(request->send_buffer().data(), request->send_buffer().size(), + reply->mutable_receive_buffer(), request->sequence_number(), request->rank(), + static_cast(request->data_type()), + static_cast(request->reduce_operation())); + return grpc::Status::OK; +} + +grpc::Status FederatedService::Broadcast(grpc::ServerContext*, BroadcastRequest const* request, + BroadcastReply* reply) { + handler_.Broadcast(request->send_buffer().data(), request->send_buffer().size(), + reply->mutable_receive_buffer(), request->sequence_number(), request->rank(), + request->root()); + return grpc::Status::OK; +} +} // namespace federated + FederatedTracker::FederatedTracker(Json const& config) : Tracker{config} { auto is_secure = RequiredArg(config, "federated_secure", __func__); if (is_secure) { @@ -31,7 +63,8 @@ FederatedTracker::FederatedTracker(Json const& config) : Tracker{config} { std::future FederatedTracker::Run() { return std::async([this]() { std::string const server_address = "0.0.0.0:" + std::to_string(this->port_); - federated::FederatedService service{static_cast(this->n_workers_)}; + xgboost::collective::federated::FederatedService service{ + static_cast(this->n_workers_)}; grpc::ServerBuilder builder; if (this->server_cert_file_.empty()) { @@ -42,7 +75,6 @@ std::future FederatedTracker::Run() { builder.AddListeningPort(server_address, grpc::InsecureServerCredentials()); } builder.RegisterService(&service); - server_ = builder.BuildAndStart(); LOG(CONSOLE) << "Insecure federated server listening on " << server_address << ", world size " << this->n_workers_; } else { @@ -60,12 +92,12 @@ std::future FederatedTracker::Run() { builder.AddListeningPort(server_address, grpc::SslServerCredentials(options)); } builder.RegisterService(&service); - server_ = builder.BuildAndStart(); LOG(CONSOLE) << "Federated server listening on " << server_address << ", world size " << n_workers_; } try { + server_ = builder.BuildAndStart(); server_->Wait(); } catch (std::exception const& e) { return collective::Fail(std::string{e.what()}); diff --git a/plugin/federated/federated_tracker.h b/plugin/federated/federated_tracker.h index 9043adb38..9ad48bee1 100644 --- a/plugin/federated/federated_tracker.h +++ b/plugin/federated/federated_tracker.h @@ -8,11 +8,35 @@ #include // for unique_ptr #include // for string +#include "../../src/collective/in_memory_handler.h" #include "../../src/collective/tracker.h" // for Tracker #include "xgboost/collective/result.h" // for Result #include "xgboost/json.h" // for Json namespace xgboost::collective { +namespace federated { +class FederatedService final : public Federated::Service { + public: + explicit FederatedService(std::int32_t world_size) + : handler_{static_cast(world_size)} {} + + grpc::Status Allgather(grpc::ServerContext* context, AllgatherRequest const* request, + AllgatherReply* reply) override; + + grpc::Status AllgatherV(grpc::ServerContext* context, AllgatherVRequest const* request, + AllgatherVReply* reply) override; + + grpc::Status Allreduce(grpc::ServerContext* context, AllreduceRequest const* request, + AllreduceReply* reply) override; + + grpc::Status Broadcast(grpc::ServerContext* context, BroadcastRequest const* request, + BroadcastReply* reply) override; + + private: + xgboost::collective::InMemoryHandler handler_; +}; +}; // namespace federated + class FederatedTracker : public collective::Tracker { std::unique_ptr server_; std::string server_key_path_; diff --git a/src/collective/coll.h b/src/collective/coll.h index 0189ffd5e..1afc8ed59 100644 --- a/src/collective/coll.h +++ b/src/collective/coll.h @@ -24,7 +24,7 @@ class Coll : public std::enable_shared_from_this { Coll() = default; virtual ~Coll() noexcept(false) {} // NOLINT - Coll* MakeCUDAVar(); + virtual Coll* MakeCUDAVar(); /** * @brief Allreduce diff --git a/src/collective/comm.cc b/src/collective/comm.cc index dbd45cbb2..241dca2ce 100644 --- a/src/collective/comm.cc +++ b/src/collective/comm.cc @@ -9,7 +9,8 @@ #include // for string #include // for move, forward -#include "allgather.h" +#include "../common/common.h" // for AssertGPUSupport +#include "allgather.h" // for RingAllgather #include "protocol.h" // for kMagic #include "xgboost/base.h" // for XGBOOST_STRICT_R_MODE #include "xgboost/collective/socket.h" // for TCPSocket @@ -48,6 +49,14 @@ Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, st this->Rank(), this->World()); } +#if !defined(XGBOOST_USE_NCCL) +Comm* Comm::MakeCUDAVar(Context const*, std::shared_ptr) const { + common::AssertGPUSupport(); + common::AssertNCCLSupport(); + return nullptr; +} +#endif // !defined(XGBOOST_USE_NCCL) + [[nodiscard]] Result ConnectWorkers(Comm const& comm, TCPSocket* listener, std::int32_t lport, proto::PeerInfo ninfo, std::chrono::seconds timeout, std::int32_t retry, diff --git a/src/collective/comm.cu b/src/collective/comm.cu index 07dfafbef..d8fe77067 100644 --- a/src/collective/comm.cu +++ b/src/collective/comm.cu @@ -20,14 +20,14 @@ namespace xgboost::collective { namespace { -Result GetUniqueId(Comm const& comm, ncclUniqueId* pid) { +Result GetUniqueId(Comm const& comm, std::shared_ptr coll, ncclUniqueId* pid) { static const int kRootRank = 0; ncclUniqueId id; if (comm.Rank() == kRootRank) { dh::safe_nccl(ncclGetUniqueId(&id)); } - auto rc = Broadcast(comm, common::Span{reinterpret_cast(&id), sizeof(ncclUniqueId)}, - kRootRank); + auto rc = coll->Broadcast( + comm, common::Span{reinterpret_cast(&id), sizeof(ncclUniqueId)}, kRootRank); if (!rc.OK()) { return rc; } @@ -63,7 +63,7 @@ static std::string PrintUUID(xgboost::common::Span c } } // namespace -Comm* Comm::MakeCUDAVar(Context const* ctx, std::shared_ptr pimpl) { +Comm* Comm::MakeCUDAVar(Context const* ctx, std::shared_ptr pimpl) const { return new NCCLComm{ctx, *this, pimpl}; } @@ -86,6 +86,7 @@ NCCLComm::NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr p GetCudaUUID(s_this_uuid, ctx->Device()); auto rc = pimpl->Allgather(root, common::EraseType(s_uuid), s_this_uuid.size_bytes()); + CHECK(rc.OK()) << rc.Report(); std::vector> converted(root.World()); @@ -103,7 +104,7 @@ NCCLComm::NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr p << "Multiple processes within communication group running on same CUDA " << "device is not supported. " << PrintUUID(s_this_uuid) << "\n"; - rc = GetUniqueId(root, &nccl_unique_id_); + rc = GetUniqueId(root, pimpl, &nccl_unique_id_); CHECK(rc.OK()) << rc.Report(); dh::safe_nccl(ncclCommInitRank(&nccl_comm_, root.World(), nccl_unique_id_, root.Rank())); diff --git a/src/collective/comm.h b/src/collective/comm.h index afb543c46..76ab479d7 100644 --- a/src/collective/comm.h +++ b/src/collective/comm.h @@ -40,7 +40,7 @@ class Coll; /** * @brief Base communicator storing info about the tracker and other communicators. */ -class Comm { +class Comm : public std::enable_shared_from_this { protected: std::int32_t world_{-1}; std::int32_t rank_{0}; @@ -87,7 +87,7 @@ class Comm { [[nodiscard]] virtual Result SignalError(Result const&) { return Success(); } - Comm* MakeCUDAVar(Context const* ctx, std::shared_ptr pimpl); + virtual Comm* MakeCUDAVar(Context const* ctx, std::shared_ptr pimpl) const; }; class RabitComm : public Comm { diff --git a/src/common/common.h b/src/common/common.h index 7cea0591f..8263283f3 100644 --- a/src/common/common.h +++ b/src/common/common.h @@ -170,6 +170,12 @@ inline void AssertGPUSupport() { #endif // XGBOOST_USE_CUDA && XGBOOST_USE_HIP } +inline void AssertNCCLSupport() { +#if !defined(XGBOOST_USE_NCCL) + LOG(FATAL) << "XGBoost version not compiled with NCCL support."; +#endif // !defined(XGBOOST_USE_NCCL) +} + inline void AssertOneAPISupport() { #ifndef XGBOOST_USE_ONEAPI LOG(FATAL) << "XGBoost version not compiled with OneAPI support."; diff --git a/tests/cpp/plugin/federated/test_federated_coll.cc b/tests/cpp/plugin/federated/test_federated_coll.cc new file mode 100644 index 000000000..ad053f286 --- /dev/null +++ b/tests/cpp/plugin/federated/test_federated_coll.cc @@ -0,0 +1,94 @@ +/** + * Copyright 2022-2023, XGBoost contributors + */ +#include +#include // for Span + +#include // for array + +#include "../../../../src/common/type.h" // for EraseType +#include "../../collective/test_worker.h" // for SocketTest +#include "federated_coll.h" // for FederatedColl +#include "federated_comm.h" // for FederatedComm +#include "test_worker.h" // for TestFederated + +namespace xgboost::collective { +namespace { +class FederatedCollTest : public SocketTest {}; +} // namespace + +TEST_F(FederatedCollTest, Allreduce) { + std::int32_t n_workers = std::min(std::thread::hardware_concurrency(), 3u); + TestFederated(n_workers, [=](std::shared_ptr comm, std::int32_t) { + std::array buffer = {1, 2, 3, 4, 5}; + std::array expected; + std::transform(buffer.cbegin(), buffer.cend(), expected.begin(), + [=](auto i) { return i * n_workers; }); + + auto coll = std::make_shared(); + auto rc = coll->Allreduce(*comm, common::EraseType(common::Span{buffer.data(), buffer.size()}), + ArrayInterfaceHandler::kI4, Op::kSum); + ASSERT_TRUE(rc.OK()); + for (auto i = 0; i < 5; i++) { + ASSERT_EQ(buffer[i], expected[i]); + } + }); +} + +TEST_F(FederatedCollTest, Broadcast) { + std::int32_t n_workers = std::min(std::thread::hardware_concurrency(), 3u); + TestFederated(n_workers, [=](std::shared_ptr comm, std::int32_t) { + FederatedColl coll{}; + auto rc = Success(); + if (comm->Rank() == 0) { + std::string buffer{"hello"}; + rc = coll.Broadcast(*comm, common::EraseType(common::Span{buffer.data(), buffer.size()}), 0); + ASSERT_EQ(buffer, "hello"); + } else { + std::string buffer{" "}; + rc = coll.Broadcast(*comm, common::EraseType(common::Span{buffer.data(), buffer.size()}), 0); + ASSERT_EQ(buffer, "hello"); + } + ASSERT_TRUE(rc.OK()); + }); +} + +TEST_F(FederatedCollTest, Allgather) { + std::int32_t n_workers = std::min(std::thread::hardware_concurrency(), 3u); + TestFederated(n_workers, [=](std::shared_ptr comm, std::int32_t) { + FederatedColl coll{}; + + std::vector buffer(n_workers, 0); + buffer[comm->Rank()] = comm->Rank(); + auto rc = coll.Allgather(*comm, common::EraseType(common::Span{buffer.data(), buffer.size()}), + sizeof(int)); + ASSERT_TRUE(rc.OK()); + for (auto i = 0; i < n_workers; i++) { + ASSERT_EQ(buffer[i], i); + } + }); +} + +TEST_F(FederatedCollTest, AllgatherV) { + std::int32_t n_workers = 2; + TestFederated(n_workers, [=](std::shared_ptr comm, std::int32_t) { + FederatedColl coll{}; + + std::vector inputs{"Federated", " Learning!!!"}; + std::vector recv_segments(inputs.size() + 1, 0); + std::string r; + std::vector sizes{static_cast(inputs[0].size()), + static_cast(inputs[1].size())}; + r.resize(sizes[0] + sizes[1]); + + auto rc = coll.AllgatherV( + *comm, + common::EraseType(common::Span{inputs[comm->Rank()].data(), inputs[comm->Rank()].size()}), + common::Span{sizes.data(), sizes.size()}, recv_segments, + common::EraseType(common::Span{r.data(), r.size()}), AllgatherVAlgo::kRing); + + EXPECT_EQ(r, "Federated Learning!!!"); + ASSERT_TRUE(rc.OK()); + }); +} +} // namespace xgboost::collective diff --git a/tests/cpp/plugin/federated/test_federated_coll.cu b/tests/cpp/plugin/federated/test_federated_coll.cu new file mode 100644 index 000000000..44211f8d7 --- /dev/null +++ b/tests/cpp/plugin/federated/test_federated_coll.cu @@ -0,0 +1,131 @@ +/** + * Copyright 2022-2023, XGBoost contributors + */ + +#include +#include // for Result + +#include "../../../../src/common/common.h" // for AllVisibleGPUs +#include "../../../../src/common/device_helpers.cuh" // for device_vector +#include "../../../../src/common/type.h" // for EraseType +#include "../../collective/test_worker.h" // for SocketTest +#include "../../helpers.h" // for MakeCUDACtx +#include "federated_coll.cuh" +#include "federated_comm.cuh" +#include "test_worker.h" // for TestFederated + +namespace xgboost::collective { +namespace { +class FederatedCollTestGPU : public SocketTest {}; + +struct Worker { + std::shared_ptr impl; + std::shared_ptr nccl_comm; + std::shared_ptr coll; + + Worker(std::shared_ptr comm, std::int32_t rank) { + auto ctx = MakeCUDACtx(rank); + impl = std::make_shared(); + nccl_comm.reset(comm->MakeCUDAVar(&ctx, impl)); + coll = std::make_shared(impl); + } +}; + +void TestAllreduce(std::shared_ptr comm, std::int32_t rank, std::int32_t n_workers) { + Worker w{comm, rank}; + + dh::device_vector buffer{std::vector{1, 2, 3, 4, 5}}; + dh::device_vector expected(buffer.size()); + thrust::transform(buffer.cbegin(), buffer.cend(), expected.begin(), + [=] XGBOOST_DEVICE(std::int32_t i) { return i * n_workers; }); + + auto rc = w.coll->Allreduce(*w.nccl_comm, common::EraseType(dh::ToSpan(buffer)), + ArrayInterfaceHandler::kI4, Op::kSum); + ASSERT_TRUE(rc.OK()); + for (auto i = 0; i < 5; i++) { + ASSERT_EQ(buffer[i], expected[i]); + } +} + +void TestBroadcast(std::shared_ptr comm, std::int32_t rank) { + Worker w{comm, rank}; + + auto rc = Success(); + std::vector expect{0, 1, 2, 3}; + + if (comm->Rank() == 0) { + dh::device_vector buffer{expect}; + rc = w.coll->Broadcast(*w.nccl_comm, common::EraseType(dh::ToSpan(buffer)), 0); + std::vector expect{0, 1, 2, 3}; + ASSERT_EQ(buffer, expect); + } else { + dh::device_vector buffer(std::vector{4, 5, 6, 7}); + rc = w.coll->Broadcast(*w.nccl_comm, common::EraseType(dh::ToSpan(buffer)), 0); + ASSERT_EQ(buffer, expect); + } + ASSERT_TRUE(rc.OK()); +} + +void TestAllgather(std::shared_ptr comm, std::int32_t rank, std::int32_t n_workers) { + Worker w{comm, rank}; + + dh::device_vector buffer(n_workers, 0); + buffer[comm->Rank()] = comm->Rank(); + auto rc = w.coll->Allgather(*w.nccl_comm, common::EraseType(dh::ToSpan(buffer)), sizeof(int)); + ASSERT_TRUE(rc.OK()); + for (auto i = 0; i < n_workers; i++) { + ASSERT_EQ(buffer[i], i); + } +} + +void TestAllgatherV(std::shared_ptr comm, std::int32_t rank) { + Worker w{comm, rank}; + + std::vector> inputs{std::vector{1, 2, 3}, + std::vector{4, 5}}; + std::vector recv_segments(inputs.size() + 1, 0); + dh::device_vector r; + std::vector sizes{static_cast(inputs[0].size()), + static_cast(inputs[1].size())}; + r.resize(sizes[0] + sizes[1]); + + auto rc = w.coll->AllgatherV(*w.nccl_comm, common::EraseType(dh::ToSpan(inputs[comm->Rank()])), + common::Span{sizes.data(), sizes.size()}, recv_segments, + common::EraseType(dh::ToSpan(r)), AllgatherVAlgo::kRing); + ASSERT_TRUE(rc.OK()); + + ASSERT_EQ(r[0], 1); + for (std::size_t i = 1; i < r.size(); ++i) { + ASSERT_EQ(r[i], r[i - 1] + 1); + } +} +} // namespace + +TEST_F(FederatedCollTestGPU, Allreduce) { + std::int32_t n_workers = common::AllVisibleGPUs(); + TestFederated(n_workers, [=](std::shared_ptr comm, std::int32_t rank) { + TestAllreduce(comm, rank, n_workers); + }); +} + +TEST_F(FederatedCollTestGPU, Broadcast) { + std::int32_t n_workers = common::AllVisibleGPUs(); + TestFederated(n_workers, [=](std::shared_ptr comm, std::int32_t rank) { + TestBroadcast(comm, rank); + }); +} + +TEST_F(FederatedCollTestGPU, Allgather) { + std::int32_t n_workers = common::AllVisibleGPUs(); + TestFederated(n_workers, [=](std::shared_ptr comm, std::int32_t rank) { + TestAllgather(comm, rank, n_workers); + }); +} + +TEST_F(FederatedCollTestGPU, AllgatherV) { + std::int32_t n_workers = 2; + TestFederated(n_workers, [=](std::shared_ptr comm, std::int32_t rank) { + TestAllgatherV(comm, rank); + }); +} +} // namespace xgboost::collective diff --git a/tests/cpp/plugin/federated/test_federated_comm.cc b/tests/cpp/plugin/federated/test_federated_comm.cc index 5bbde1bbb..b45b00910 100644 --- a/tests/cpp/plugin/federated/test_federated_comm.cc +++ b/tests/cpp/plugin/federated/test_federated_comm.cc @@ -7,10 +7,10 @@ #include // for thread #include "../../../../plugin/federated/federated_comm.h" -#include "../../collective/net_test.h" // for SocketTest -#include "../../helpers.h" // for ExpectThrow -#include "test_worker.h" // for TestFederated -#include "xgboost/json.h" // for Json +#include "../../collective/test_worker.h" // for SocketTest +#include "../../helpers.h" // for ExpectThrow +#include "test_worker.h" // for TestFederated +#include "xgboost/json.h" // for Json namespace xgboost::collective { namespace { @@ -71,14 +71,9 @@ TEST_F(FederatedCommTest, IsDistributed) { TEST_F(FederatedCommTest, InsecureTracker) { std::int32_t n_workers = std::min(std::thread::hardware_concurrency(), 3u); - TestFederated(n_workers, [=](std::int32_t port, std::int32_t rank) { - Json config{Object{}}; - config["federated_world_size"] = n_workers; - config["federated_rank"] = rank; - config["federated_server_address"] = "0.0.0.0:" + std::to_string(port); - FederatedComm comm{config}; - ASSERT_EQ(comm.Rank(), rank); - ASSERT_EQ(comm.World(), n_workers); + TestFederated(n_workers, [=](std::shared_ptr comm, std::int32_t rank) { + ASSERT_EQ(comm->Rank(), rank); + ASSERT_EQ(comm->World(), n_workers); }); } } // namespace xgboost::collective diff --git a/tests/cpp/plugin/federated/test_worker.h b/tests/cpp/plugin/federated/test_worker.h index 719b4c343..38bc32c60 100644 --- a/tests/cpp/plugin/federated/test_worker.h +++ b/tests/cpp/plugin/federated/test_worker.h @@ -9,7 +9,8 @@ #include // for thread #include "../../../../plugin/federated/federated_tracker.h" -#include "xgboost/json.h" // for Json +#include "federated_comm.h" // for FederatedComm +#include "xgboost/json.h" // for Json namespace xgboost::collective { template @@ -28,7 +29,15 @@ void TestFederated(std::int32_t n_workers, WorkerFn&& fn) { std::int32_t port = tracker.Port(); for (std::int32_t i = 0; i < n_workers; ++i) { - workers.emplace_back([=] { fn(port, i); }); + workers.emplace_back([=] { + Json config{Object{}}; + config["federated_world_size"] = n_workers; + config["federated_rank"] = i; + config["federated_server_address"] = "0.0.0.0:" + std::to_string(port); + auto comm = std::make_shared(config); + + fn(comm, i); + }); } for (auto& t : workers) {