Merge branch 'master'
This commit is contained in:
commit
123af45327
@ -87,7 +87,13 @@ xgb.plot.importance <- function(importance_matrix = NULL, top_n = NULL, measure
|
||||
}
|
||||
|
||||
# also aggregate, just in case when the values were not yet summed up by feature
|
||||
importance_matrix <- importance_matrix[, Importance := sum(get(measure)), by = Feature]
|
||||
importance_matrix <- importance_matrix[
|
||||
, lapply(.SD, sum)
|
||||
, .SDcols = setdiff(names(importance_matrix), "Feature")
|
||||
, by = Feature
|
||||
][
|
||||
, Importance := get(measure)
|
||||
]
|
||||
|
||||
# make sure it's ordered
|
||||
importance_matrix <- importance_matrix[order(-abs(Importance))]
|
||||
|
||||
@ -382,6 +382,9 @@ test_that("xgb.importance works with GLM model", {
|
||||
expect_equal(colnames(imp2plot), c("Feature", "Weight", "Importance"))
|
||||
xgb.ggplot.importance(importance.GLM)
|
||||
|
||||
# check that the input is not modified in-place
|
||||
expect_false("Importance" %in% names(importance.GLM))
|
||||
|
||||
# for multiclass
|
||||
imp.GLM <- xgb.importance(model = mbst.GLM)
|
||||
expect_equal(dim(imp.GLM), c(12, 3))
|
||||
@ -400,6 +403,16 @@ test_that("xgb.model.dt.tree and xgb.importance work with a single split model",
|
||||
expect_equal(imp$Gain, 1)
|
||||
})
|
||||
|
||||
test_that("xgb.plot.importance de-duplicates features", {
|
||||
importances <- data.table(
|
||||
Feature = c("col1", "col2", "col2"),
|
||||
Gain = c(0.4, 0.3, 0.3)
|
||||
)
|
||||
imp2plot <- xgb.plot.importance(importances)
|
||||
expect_equal(nrow(imp2plot), 2L)
|
||||
expect_equal(imp2plot$Feature, c("col2", "col1"))
|
||||
})
|
||||
|
||||
test_that("xgb.plot.tree works with and without feature names", {
|
||||
.skip_if_vcd_not_available()
|
||||
expect_silent(xgb.plot.tree(feature_names = feature.names, model = bst.Tree))
|
||||
|
||||
@ -22,12 +22,35 @@ protobuf_generate(
|
||||
PLUGIN "protoc-gen-grpc=\$<TARGET_FILE:gRPC::grpc_cpp_plugin>"
|
||||
PROTOC_OUT_DIR "${PROTO_BINARY_DIR}")
|
||||
|
||||
add_library(federated_old_proto STATIC federated.old.proto)
|
||||
target_link_libraries(federated_old_proto PUBLIC protobuf::libprotobuf gRPC::grpc gRPC::grpc++)
|
||||
target_include_directories(federated_old_proto PUBLIC ${CMAKE_CURRENT_BINARY_DIR})
|
||||
xgboost_target_properties(federated_old_proto)
|
||||
|
||||
protobuf_generate(
|
||||
TARGET federated_old_proto
|
||||
LANGUAGE cpp
|
||||
PROTOC_OUT_DIR "${PROTO_BINARY_DIR}")
|
||||
protobuf_generate(
|
||||
TARGET federated_old_proto
|
||||
LANGUAGE grpc
|
||||
GENERATE_EXTENSIONS .grpc.pb.h .grpc.pb.cc
|
||||
PLUGIN "protoc-gen-grpc=\$<TARGET_FILE:gRPC::grpc_cpp_plugin>"
|
||||
PROTOC_OUT_DIR "${PROTO_BINARY_DIR}")
|
||||
|
||||
# Wrapper for the gRPC client.
|
||||
add_library(federated_client INTERFACE)
|
||||
target_sources(federated_client INTERFACE federated_client.h)
|
||||
target_link_libraries(federated_client INTERFACE federated_proto)
|
||||
target_link_libraries(federated_client INTERFACE federated_old_proto)
|
||||
|
||||
# Rabit engine for Federated Learning.
|
||||
target_sources(objxgboost PRIVATE federated_tracker.cc federated_server.cc federated_comm.cc)
|
||||
target_sources(
|
||||
objxgboost PRIVATE federated_tracker.cc federated_server.cc federated_comm.cc federated_coll.cc
|
||||
)
|
||||
if(USE_CUDA)
|
||||
target_sources(objxgboost PRIVATE federated_comm.cu federated_coll.cu)
|
||||
endif()
|
||||
|
||||
target_link_libraries(objxgboost PRIVATE federated_client "-Wl,--exclude-libs,ALL")
|
||||
target_compile_definitions(objxgboost PUBLIC -DXGBOOST_USE_FEDERATED=1)
|
||||
|
||||
81
plugin/federated/federated.old.proto
Normal file
81
plugin/federated/federated.old.proto
Normal file
@ -0,0 +1,81 @@
|
||||
/*!
|
||||
* Copyright 2022 XGBoost contributors
|
||||
*/
|
||||
syntax = "proto3";
|
||||
|
||||
package xgboost.federated;
|
||||
|
||||
service Federated {
|
||||
rpc Allgather(AllgatherRequest) returns (AllgatherReply) {}
|
||||
rpc AllgatherV(AllgatherVRequest) returns (AllgatherVReply) {}
|
||||
rpc Allreduce(AllreduceRequest) returns (AllreduceReply) {}
|
||||
rpc Broadcast(BroadcastRequest) returns (BroadcastReply) {}
|
||||
}
|
||||
|
||||
enum DataType {
|
||||
INT8 = 0;
|
||||
UINT8 = 1;
|
||||
INT32 = 2;
|
||||
UINT32 = 3;
|
||||
INT64 = 4;
|
||||
UINT64 = 5;
|
||||
FLOAT = 6;
|
||||
DOUBLE = 7;
|
||||
}
|
||||
|
||||
enum ReduceOperation {
|
||||
MAX = 0;
|
||||
MIN = 1;
|
||||
SUM = 2;
|
||||
BITWISE_AND = 3;
|
||||
BITWISE_OR = 4;
|
||||
BITWISE_XOR = 5;
|
||||
}
|
||||
|
||||
message AllgatherRequest {
|
||||
// An incrementing counter that is unique to each round to operations.
|
||||
uint64 sequence_number = 1;
|
||||
int32 rank = 2;
|
||||
bytes send_buffer = 3;
|
||||
}
|
||||
|
||||
message AllgatherReply {
|
||||
bytes receive_buffer = 1;
|
||||
}
|
||||
|
||||
message AllgatherVRequest {
|
||||
// An incrementing counter that is unique to each round to operations.
|
||||
uint64 sequence_number = 1;
|
||||
int32 rank = 2;
|
||||
bytes send_buffer = 3;
|
||||
}
|
||||
|
||||
message AllgatherVReply {
|
||||
bytes receive_buffer = 1;
|
||||
}
|
||||
|
||||
message AllreduceRequest {
|
||||
// An incrementing counter that is unique to each round to operations.
|
||||
uint64 sequence_number = 1;
|
||||
int32 rank = 2;
|
||||
bytes send_buffer = 3;
|
||||
DataType data_type = 4;
|
||||
ReduceOperation reduce_operation = 5;
|
||||
}
|
||||
|
||||
message AllreduceReply {
|
||||
bytes receive_buffer = 1;
|
||||
}
|
||||
|
||||
message BroadcastRequest {
|
||||
// An incrementing counter that is unique to each round to operations.
|
||||
uint64 sequence_number = 1;
|
||||
int32 rank = 2;
|
||||
bytes send_buffer = 3;
|
||||
// The root rank to broadcast from.
|
||||
int32 root = 4;
|
||||
}
|
||||
|
||||
message BroadcastReply {
|
||||
bytes receive_buffer = 1;
|
||||
}
|
||||
@ -1,9 +1,9 @@
|
||||
/*!
|
||||
* Copyright 2022 XGBoost contributors
|
||||
* Copyright 2022-2023 XGBoost contributors
|
||||
*/
|
||||
syntax = "proto3";
|
||||
|
||||
package xgboost.federated;
|
||||
package xgboost.collective.federated;
|
||||
|
||||
service Federated {
|
||||
rpc Allgather(AllgatherRequest) returns (AllgatherReply) {}
|
||||
@ -13,14 +13,18 @@ service Federated {
|
||||
}
|
||||
|
||||
enum DataType {
|
||||
INT8 = 0;
|
||||
UINT8 = 1;
|
||||
INT32 = 2;
|
||||
UINT32 = 3;
|
||||
INT64 = 4;
|
||||
UINT64 = 5;
|
||||
FLOAT = 6;
|
||||
DOUBLE = 7;
|
||||
HALF = 0;
|
||||
FLOAT = 1;
|
||||
DOUBLE = 2;
|
||||
LONG_DOUBLE = 3;
|
||||
INT8 = 4;
|
||||
INT16 = 5;
|
||||
INT32 = 6;
|
||||
INT64 = 7;
|
||||
UINT8 = 8;
|
||||
UINT16 = 9;
|
||||
UINT32 = 10;
|
||||
UINT64 = 11;
|
||||
}
|
||||
|
||||
enum ReduceOperation {
|
||||
|
||||
@ -2,8 +2,8 @@
|
||||
* Copyright 2022 XGBoost contributors
|
||||
*/
|
||||
#pragma once
|
||||
#include <federated.grpc.pb.h>
|
||||
#include <federated.pb.h>
|
||||
#include <federated.old.grpc.pb.h>
|
||||
#include <federated.old.pb.h>
|
||||
#include <grpcpp/grpcpp.h>
|
||||
|
||||
#include <cstdio>
|
||||
|
||||
155
plugin/federated/federated_coll.cc
Normal file
155
plugin/federated/federated_coll.cc
Normal file
@ -0,0 +1,155 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost contributors
|
||||
*/
|
||||
#include "federated_coll.h"
|
||||
|
||||
#include <federated.grpc.pb.h>
|
||||
#include <federated.pb.h>
|
||||
|
||||
#include <algorithm> // for copy_n
|
||||
|
||||
#include "../../src/collective/allgather.h"
|
||||
#include "../../src/common/common.h" // for AssertGPUSupport
|
||||
#include "federated_comm.h" // for FederatedComm
|
||||
#include "xgboost/collective/result.h" // for Result
|
||||
|
||||
namespace xgboost::collective {
|
||||
namespace {
|
||||
[[nodiscard]] Result GetGRPCResult(std::string const &name, grpc::Status const &status) {
|
||||
return Fail(name + " RPC failed. " + std::to_string(status.error_code()) + ": " +
|
||||
status.error_message());
|
||||
}
|
||||
|
||||
[[nodiscard]] Result BroadcastImpl(Comm const &comm, std::uint64_t *sequence_number,
|
||||
common::Span<std::int8_t> data, std::int32_t root) {
|
||||
using namespace federated; // NOLINT
|
||||
|
||||
auto fed = dynamic_cast<FederatedComm const *>(&comm);
|
||||
CHECK(fed);
|
||||
auto stub = fed->Handle();
|
||||
|
||||
BroadcastRequest request;
|
||||
request.set_sequence_number(*sequence_number++);
|
||||
request.set_rank(comm.Rank());
|
||||
if (comm.Rank() != root) {
|
||||
request.set_send_buffer(nullptr, 0);
|
||||
} else {
|
||||
request.set_send_buffer(data.data(), data.size());
|
||||
}
|
||||
request.set_root(root);
|
||||
|
||||
BroadcastReply reply;
|
||||
grpc::ClientContext context;
|
||||
context.set_wait_for_ready(true);
|
||||
grpc::Status status = stub->Broadcast(&context, request, &reply);
|
||||
if (!status.ok()) {
|
||||
return GetGRPCResult("Broadcast", status);
|
||||
}
|
||||
if (comm.Rank() != root) {
|
||||
auto const &r = reply.receive_buffer();
|
||||
std::copy_n(r.cbegin(), r.size(), data.data());
|
||||
}
|
||||
|
||||
return Success();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
#if !defined(XGBOOST_USE_CUDA)
|
||||
Coll *FederatedColl::MakeCUDAVar() {
|
||||
common::AssertGPUSupport();
|
||||
return nullptr;
|
||||
}
|
||||
#endif
|
||||
|
||||
[[nodiscard]] Result FederatedColl::Allreduce(Comm const &comm, common::Span<std::int8_t> data,
|
||||
ArrayInterfaceHandler::Type type, Op op) {
|
||||
using namespace federated; // NOLINT
|
||||
auto fed = dynamic_cast<FederatedComm const *>(&comm);
|
||||
CHECK(fed);
|
||||
auto stub = fed->Handle();
|
||||
|
||||
AllreduceRequest request;
|
||||
request.set_sequence_number(sequence_number_++);
|
||||
request.set_rank(comm.Rank());
|
||||
request.set_send_buffer(data.data(), data.size());
|
||||
request.set_data_type(static_cast<::xgboost::collective::federated::DataType>(type));
|
||||
request.set_reduce_operation(static_cast<::xgboost::collective::federated::ReduceOperation>(op));
|
||||
|
||||
AllreduceReply reply;
|
||||
grpc::ClientContext context;
|
||||
context.set_wait_for_ready(true);
|
||||
grpc::Status status = stub->Allreduce(&context, request, &reply);
|
||||
if (!status.ok()) {
|
||||
return GetGRPCResult("Allreduce", status);
|
||||
}
|
||||
auto const &r = reply.receive_buffer();
|
||||
std::copy_n(r.cbegin(), r.size(), data.data());
|
||||
return Success();
|
||||
}
|
||||
|
||||
[[nodiscard]] Result FederatedColl::Broadcast(Comm const &comm, common::Span<std::int8_t> data,
|
||||
std::int32_t root) {
|
||||
if (comm.Rank() == root) {
|
||||
return BroadcastImpl(comm, &sequence_number_, data, root);
|
||||
} else {
|
||||
return BroadcastImpl(comm, &sequence_number_, data, root);
|
||||
}
|
||||
}
|
||||
|
||||
[[nodiscard]] Result FederatedColl::Allgather(Comm const &comm, common::Span<std::int8_t> data,
|
||||
std::int64_t size) {
|
||||
using namespace federated; // NOLINT
|
||||
auto fed = dynamic_cast<FederatedComm const *>(&comm);
|
||||
CHECK(fed);
|
||||
auto stub = fed->Handle();
|
||||
|
||||
auto offset = comm.Rank() * size;
|
||||
auto segment = data.subspan(offset, size);
|
||||
|
||||
AllgatherRequest request;
|
||||
request.set_sequence_number(sequence_number_++);
|
||||
request.set_rank(comm.Rank());
|
||||
request.set_send_buffer(segment.data(), segment.size());
|
||||
|
||||
AllgatherReply reply;
|
||||
grpc::ClientContext context;
|
||||
context.set_wait_for_ready(true);
|
||||
grpc::Status status = stub->Allgather(&context, request, &reply);
|
||||
|
||||
if (!status.ok()) {
|
||||
return GetGRPCResult("Allgather", status);
|
||||
}
|
||||
auto const &r = reply.receive_buffer();
|
||||
std::copy_n(r.cbegin(), r.size(), data.begin());
|
||||
return Success();
|
||||
}
|
||||
|
||||
[[nodiscard]] Result FederatedColl::AllgatherV(Comm const &comm,
|
||||
common::Span<std::int8_t const> data,
|
||||
common::Span<std::int64_t const>,
|
||||
common::Span<std::int64_t>,
|
||||
common::Span<std::int8_t> recv, AllgatherVAlgo) {
|
||||
using namespace federated; // NOLINT
|
||||
|
||||
auto fed = dynamic_cast<FederatedComm const *>(&comm);
|
||||
CHECK(fed);
|
||||
auto stub = fed->Handle();
|
||||
|
||||
AllgatherVRequest request;
|
||||
request.set_sequence_number(sequence_number_++);
|
||||
request.set_rank(comm.Rank());
|
||||
request.set_send_buffer(data.data(), data.size());
|
||||
|
||||
AllgatherVReply reply;
|
||||
grpc::ClientContext context;
|
||||
context.set_wait_for_ready(true);
|
||||
grpc::Status status = stub->AllgatherV(&context, request, &reply);
|
||||
if (!status.ok()) {
|
||||
return GetGRPCResult("AllgatherV", status);
|
||||
}
|
||||
std::string const &r = reply.receive_buffer();
|
||||
CHECK_EQ(r.size(), recv.size());
|
||||
std::copy_n(r.cbegin(), r.size(), recv.begin());
|
||||
return Success();
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
92
plugin/federated/federated_coll.cu
Normal file
92
plugin/federated/federated_coll.cu
Normal file
@ -0,0 +1,92 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
*/
|
||||
#include <cstdint> // for int8_t, int32_t
|
||||
#include <memory> // for dynamic_pointer_cast
|
||||
#include <vector> // for vector
|
||||
|
||||
#include "../../src/collective/comm.cuh"
|
||||
#include "../../src/common/cuda_context.cuh" // for CUDAContext
|
||||
#include "../../src/data/array_interface.h" // for ArrayInterfaceHandler::Type
|
||||
#include "federated_coll.cuh"
|
||||
#include "federated_comm.cuh"
|
||||
#include "xgboost/collective/result.h" // for Result
|
||||
#include "xgboost/span.h" // for Span
|
||||
|
||||
namespace xgboost::collective {
|
||||
Coll *FederatedColl::MakeCUDAVar() {
|
||||
return new CUDAFederatedColl{std::dynamic_pointer_cast<FederatedColl>(this->shared_from_this())};
|
||||
}
|
||||
|
||||
[[nodiscard]] Result CUDAFederatedColl::Allreduce(Comm const &comm, common::Span<std::int8_t> data,
|
||||
ArrayInterfaceHandler::Type type, Op op) {
|
||||
auto cufed = dynamic_cast<CUDAFederatedComm const *>(&comm);
|
||||
CHECK(cufed);
|
||||
|
||||
std::vector<std::int8_t> h_data(data.size());
|
||||
|
||||
return Success() << [&] {
|
||||
return GetCUDAResult(
|
||||
cudaMemcpy(h_data.data(), data.data(), data.size(), cudaMemcpyDeviceToHost));
|
||||
} << [&] {
|
||||
return p_impl_->Allreduce(comm, common::Span{h_data.data(), h_data.size()}, type, op);
|
||||
} << [&] {
|
||||
return GetCUDAResult(cudaMemcpyAsync(data.data(), h_data.data(), data.size(),
|
||||
cudaMemcpyHostToDevice, cufed->Stream()));
|
||||
};
|
||||
}
|
||||
|
||||
[[nodiscard]] Result CUDAFederatedColl::Broadcast(Comm const &comm, common::Span<std::int8_t> data,
|
||||
std::int32_t root) {
|
||||
auto cufed = dynamic_cast<CUDAFederatedComm const *>(&comm);
|
||||
CHECK(cufed);
|
||||
std::vector<std::int8_t> h_data(data.size());
|
||||
|
||||
return Success() << [&] {
|
||||
return GetCUDAResult(
|
||||
cudaMemcpy(h_data.data(), data.data(), data.size(), cudaMemcpyDeviceToHost));
|
||||
} << [&] {
|
||||
return p_impl_->Broadcast(comm, common::Span{h_data.data(), h_data.size()}, root);
|
||||
} << [&] {
|
||||
return GetCUDAResult(cudaMemcpyAsync(data.data(), h_data.data(), data.size(),
|
||||
cudaMemcpyHostToDevice, cufed->Stream()));
|
||||
};
|
||||
}
|
||||
|
||||
[[nodiscard]] Result CUDAFederatedColl::Allgather(Comm const &comm, common::Span<std::int8_t> data,
|
||||
std::int64_t size) {
|
||||
auto cufed = dynamic_cast<CUDAFederatedComm const *>(&comm);
|
||||
CHECK(cufed);
|
||||
std::vector<std::int8_t> h_data(data.size());
|
||||
|
||||
return Success() << [&] {
|
||||
return GetCUDAResult(
|
||||
cudaMemcpy(h_data.data(), data.data(), data.size(), cudaMemcpyDeviceToHost));
|
||||
} << [&] {
|
||||
return p_impl_->Allgather(comm, common::Span{h_data.data(), h_data.size()}, size);
|
||||
} << [&] {
|
||||
return GetCUDAResult(cudaMemcpyAsync(data.data(), h_data.data(), data.size(),
|
||||
cudaMemcpyHostToDevice, cufed->Stream()));
|
||||
};
|
||||
}
|
||||
|
||||
[[nodiscard]] Result CUDAFederatedColl::AllgatherV(
|
||||
Comm const &comm, common::Span<std::int8_t const> data, common::Span<std::int64_t const> sizes,
|
||||
common::Span<std::int64_t> recv_segments, common::Span<std::int8_t> recv, AllgatherVAlgo algo) {
|
||||
auto cufed = dynamic_cast<CUDAFederatedComm const *>(&comm);
|
||||
CHECK(cufed);
|
||||
|
||||
std::vector<std::int8_t> h_data(data.size());
|
||||
std::vector<std::int8_t> h_recv(recv.size());
|
||||
|
||||
return Success() << [&] {
|
||||
return GetCUDAResult(
|
||||
cudaMemcpy(h_data.data(), data.data(), data.size(), cudaMemcpyDeviceToHost));
|
||||
} << [&] {
|
||||
return this->p_impl_->AllgatherV(comm, h_data, sizes, recv_segments, h_recv, algo);
|
||||
} << [&] {
|
||||
return GetCUDAResult(cudaMemcpyAsync(recv.data(), h_recv.data(), h_recv.size(),
|
||||
cudaMemcpyHostToDevice, cufed->Stream()));
|
||||
};
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
26
plugin/federated/federated_coll.cuh
Normal file
26
plugin/federated/federated_coll.cuh
Normal file
@ -0,0 +1,26 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost contributors
|
||||
*/
|
||||
#include "../../src/collective/comm.h" // for Comm, Coll
|
||||
#include "federated_coll.h" // for FederatedColl
|
||||
#include "xgboost/collective/result.h" // for Result
|
||||
#include "xgboost/span.h" // for Span
|
||||
|
||||
namespace xgboost::collective {
|
||||
class CUDAFederatedColl : public Coll {
|
||||
std::shared_ptr<FederatedColl> p_impl_;
|
||||
|
||||
public:
|
||||
explicit CUDAFederatedColl(std::shared_ptr<FederatedColl> pimpl) : p_impl_{std::move(pimpl)} {}
|
||||
[[nodiscard]] Result Allreduce(Comm const &comm, common::Span<std::int8_t> data,
|
||||
ArrayInterfaceHandler::Type type, Op op) override;
|
||||
[[nodiscard]] Result Broadcast(Comm const &comm, common::Span<std::int8_t> data,
|
||||
std::int32_t root) override;
|
||||
[[nodiscard]] Result Allgather(Comm const &, common::Span<std::int8_t> data,
|
||||
std::int64_t size) override;
|
||||
[[nodiscard]] Result AllgatherV(Comm const &comm, common::Span<std::int8_t const> data,
|
||||
common::Span<std::int64_t const> sizes,
|
||||
common::Span<std::int64_t> recv_segments,
|
||||
common::Span<std::int8_t> recv, AllgatherVAlgo algo) override;
|
||||
};
|
||||
} // namespace xgboost::collective
|
||||
30
plugin/federated/federated_coll.h
Normal file
30
plugin/federated/federated_coll.h
Normal file
@ -0,0 +1,30 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost contributors
|
||||
*/
|
||||
#pragma once
|
||||
#include "../../src/collective/coll.h" // for Coll
|
||||
#include "../../src/collective/comm.h" // for Comm
|
||||
#include "../../src/common/io.h" // for ReadAll
|
||||
#include "../../src/common/json_utils.h" // for OptionalArg
|
||||
#include "xgboost/json.h" // for Json
|
||||
|
||||
namespace xgboost::collective {
|
||||
class FederatedColl : public Coll {
|
||||
private:
|
||||
std::uint64_t sequence_number_{0};
|
||||
|
||||
public:
|
||||
Coll *MakeCUDAVar() override;
|
||||
|
||||
[[nodiscard]] Result Allreduce(Comm const &, common::Span<std::int8_t> data,
|
||||
ArrayInterfaceHandler::Type type, Op op) override;
|
||||
[[nodiscard]] Result Broadcast(Comm const &comm, common::Span<std::int8_t> data,
|
||||
std::int32_t root) override;
|
||||
[[nodiscard]] Result Allgather(Comm const &, common::Span<std::int8_t> data,
|
||||
std::int64_t) override;
|
||||
[[nodiscard]] Result AllgatherV(Comm const &comm, common::Span<std::int8_t const> data,
|
||||
common::Span<std::int64_t const> sizes,
|
||||
common::Span<std::int64_t> recv_segments,
|
||||
common::Span<std::int8_t> recv, AllgatherVAlgo algo) override;
|
||||
};
|
||||
} // namespace xgboost::collective
|
||||
@ -7,6 +7,7 @@
|
||||
|
||||
#include <cstdint> // for int32_t
|
||||
#include <cstdlib> // for getenv
|
||||
#include <limits> // for numeric_limits
|
||||
#include <string> // for string, stoi
|
||||
|
||||
#include "../../src/common/common.h" // for Split
|
||||
@ -29,12 +30,18 @@ void FederatedComm::Init(std::string const& host, std::int32_t port, std::int32_
|
||||
CHECK_GE(rank, 0) << "Invalid worker rank.";
|
||||
CHECK_LT(rank, world) << "Invalid worker rank.";
|
||||
|
||||
auto certs = {server_cert, client_cert, client_cert};
|
||||
auto is_empty = [](auto const& s) { return s.empty(); };
|
||||
bool valid = std::all_of(certs.begin(), certs.end(), is_empty) ||
|
||||
std::none_of(certs.begin(), certs.end(), is_empty);
|
||||
CHECK(valid) << "Invalid arguments for certificates.";
|
||||
|
||||
if (server_cert.empty()) {
|
||||
stub_ = [&] {
|
||||
grpc::ChannelArguments args;
|
||||
args.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
|
||||
return federated::Federated::NewStub(
|
||||
grpc::CreateCustomChannel(host, grpc::InsecureChannelCredentials(), args));
|
||||
args.SetMaxReceiveMessageSize(std::numeric_limits<std::int32_t>::max());
|
||||
return federated::Federated::NewStub(grpc::CreateCustomChannel(
|
||||
host + ":" + std::to_string(port), grpc::InsecureChannelCredentials(), args));
|
||||
}();
|
||||
} else {
|
||||
stub_ = [&] {
|
||||
@ -43,8 +50,9 @@ void FederatedComm::Init(std::string const& host, std::int32_t port, std::int32_
|
||||
options.pem_private_key = client_key;
|
||||
options.pem_cert_chain = client_cert;
|
||||
grpc::ChannelArguments args;
|
||||
args.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
|
||||
auto channel = grpc::CreateCustomChannel(host, grpc::SslCredentials(options), args);
|
||||
args.SetMaxReceiveMessageSize(std::numeric_limits<std::int32_t>::max());
|
||||
auto channel = grpc::CreateCustomChannel(host + ":" + std::to_string(port),
|
||||
grpc::SslCredentials(options), args);
|
||||
channel->WaitForConnected(
|
||||
gpr_time_add(gpr_now(GPR_CLOCK_REALTIME), gpr_time_from_seconds(60, GPR_TIMESPAN)));
|
||||
return federated::Federated::NewStub(channel);
|
||||
@ -79,7 +87,7 @@ FederatedComm::FederatedComm(Json const& config) {
|
||||
rank = OptionalArg<Integer>(config, "federated_rank", static_cast<Integer::Int>(rank));
|
||||
|
||||
auto parsed = common::Split(server_address, ':');
|
||||
CHECK_EQ(parsed.size(), 2) << "invalid server address:" << server_address;
|
||||
CHECK_EQ(parsed.size(), 2) << "Invalid server address:" << server_address;
|
||||
|
||||
CHECK_NE(rank, -1) << "Parameter `federated_rank` is required";
|
||||
CHECK_NE(world_size, 0) << "Parameter `federated_world_size` is required.";
|
||||
@ -111,4 +119,11 @@ FederatedComm::FederatedComm(Json const& config) {
|
||||
this->Init(parsed[0], std::stoi(parsed[1]), world_size, rank, server_cert, client_key,
|
||||
client_cert);
|
||||
}
|
||||
|
||||
#if !defined(XGBOOST_USE_CUDA)
|
||||
Comm* FederatedComm::MakeCUDAVar(Context const*, std::shared_ptr<Coll>) const {
|
||||
common::AssertGPUSupport();
|
||||
return nullptr;
|
||||
}
|
||||
#endif // !defined(XGBOOST_USE_CUDA)
|
||||
} // namespace xgboost::collective
|
||||
|
||||
20
plugin/federated/federated_comm.cu
Normal file
20
plugin/federated/federated_comm.cu
Normal file
@ -0,0 +1,20 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
*/
|
||||
#include <memory> // for shared_ptr
|
||||
|
||||
#include "../../src/common/cuda_context.cuh"
|
||||
#include "federated_comm.cuh"
|
||||
#include "xgboost/context.h" // for Context
|
||||
|
||||
namespace xgboost::collective {
|
||||
CUDAFederatedComm::CUDAFederatedComm(Context const* ctx, std::shared_ptr<FederatedComm const> impl)
|
||||
: FederatedComm{impl}, stream_{ctx->CUDACtx()->Stream()} {
|
||||
CHECK(impl);
|
||||
}
|
||||
|
||||
Comm* FederatedComm::MakeCUDAVar(Context const* ctx, std::shared_ptr<Coll>) const {
|
||||
return new CUDAFederatedComm{
|
||||
ctx, std::dynamic_pointer_cast<FederatedComm const>(this->shared_from_this())};
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
20
plugin/federated/federated_comm.cuh
Normal file
20
plugin/federated/federated_comm.cuh
Normal file
@ -0,0 +1,20 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <memory> // for shared_ptr
|
||||
|
||||
#include "../../src/common/device_helpers.cuh" // for CUDAStreamView
|
||||
#include "federated_comm.h" // for FederatedComm
|
||||
#include "xgboost/context.h" // for Context
|
||||
|
||||
namespace xgboost::collective {
|
||||
class CUDAFederatedComm : public FederatedComm {
|
||||
dh::CUDAStreamView stream_;
|
||||
|
||||
public:
|
||||
explicit CUDAFederatedComm(Context const* ctx, std::shared_ptr<FederatedComm const> impl);
|
||||
[[nodiscard]] auto Stream() const { return stream_; }
|
||||
};
|
||||
} // namespace xgboost::collective
|
||||
@ -16,12 +16,20 @@
|
||||
|
||||
namespace xgboost::collective {
|
||||
class FederatedComm : public Comm {
|
||||
std::unique_ptr<federated::Federated::Stub> stub_;
|
||||
std::shared_ptr<federated::Federated::Stub> stub_;
|
||||
|
||||
void Init(std::string const& host, std::int32_t port, std::int32_t world, std::int32_t rank,
|
||||
std::string const& server_cert, std::string const& client_key,
|
||||
std::string const& client_cert);
|
||||
|
||||
protected:
|
||||
explicit FederatedComm(std::shared_ptr<FederatedComm const> that) : stub_{that->stub_} {
|
||||
this->rank_ = that->Rank();
|
||||
this->world_ = that->World();
|
||||
|
||||
this->tracker_ = that->TrackerInfo();
|
||||
}
|
||||
|
||||
public:
|
||||
/**
|
||||
* @param config
|
||||
@ -49,5 +57,8 @@ class FederatedComm : public Comm {
|
||||
return Success();
|
||||
}
|
||||
[[nodiscard]] bool IsFederated() const override { return true; }
|
||||
[[nodiscard]] federated::Federated::Stub* Handle() const { return stub_.get(); }
|
||||
|
||||
Comm* MakeCUDAVar(Context const* ctx, std::shared_ptr<Coll> pimpl) const override;
|
||||
};
|
||||
} // namespace xgboost::collective
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <federated.grpc.pb.h>
|
||||
#include <federated.old.grpc.pb.h>
|
||||
|
||||
#include <cstdint> // for int32_t
|
||||
#include <future> // for future
|
||||
|
||||
@ -16,9 +16,41 @@
|
||||
#include "../../src/common/io.h" // for ReadAll
|
||||
#include "../../src/common/json_utils.h" // for RequiredArg
|
||||
#include "../../src/common/timer.h" // for Timer
|
||||
#include "federated_server.h" // for FederatedService
|
||||
|
||||
namespace xgboost::collective {
|
||||
namespace federated {
|
||||
grpc::Status FederatedService::Allgather(grpc::ServerContext*, AllgatherRequest const* request,
|
||||
AllgatherReply* reply) {
|
||||
handler_.Allgather(request->send_buffer().data(), request->send_buffer().size(),
|
||||
reply->mutable_receive_buffer(), request->sequence_number(), request->rank());
|
||||
return grpc::Status::OK;
|
||||
}
|
||||
|
||||
grpc::Status FederatedService::AllgatherV(grpc::ServerContext*, AllgatherVRequest const* request,
|
||||
AllgatherVReply* reply) {
|
||||
handler_.AllgatherV(request->send_buffer().data(), request->send_buffer().size(),
|
||||
reply->mutable_receive_buffer(), request->sequence_number(), request->rank());
|
||||
return grpc::Status::OK;
|
||||
}
|
||||
|
||||
grpc::Status FederatedService::Allreduce(grpc::ServerContext*, AllreduceRequest const* request,
|
||||
AllreduceReply* reply) {
|
||||
handler_.Allreduce(request->send_buffer().data(), request->send_buffer().size(),
|
||||
reply->mutable_receive_buffer(), request->sequence_number(), request->rank(),
|
||||
static_cast<xgboost::collective::DataType>(request->data_type()),
|
||||
static_cast<xgboost::collective::Operation>(request->reduce_operation()));
|
||||
return grpc::Status::OK;
|
||||
}
|
||||
|
||||
grpc::Status FederatedService::Broadcast(grpc::ServerContext*, BroadcastRequest const* request,
|
||||
BroadcastReply* reply) {
|
||||
handler_.Broadcast(request->send_buffer().data(), request->send_buffer().size(),
|
||||
reply->mutable_receive_buffer(), request->sequence_number(), request->rank(),
|
||||
request->root());
|
||||
return grpc::Status::OK;
|
||||
}
|
||||
} // namespace federated
|
||||
|
||||
FederatedTracker::FederatedTracker(Json const& config) : Tracker{config} {
|
||||
auto is_secure = RequiredArg<Boolean const>(config, "federated_secure", __func__);
|
||||
if (is_secure) {
|
||||
@ -31,7 +63,8 @@ FederatedTracker::FederatedTracker(Json const& config) : Tracker{config} {
|
||||
std::future<Result> FederatedTracker::Run() {
|
||||
return std::async([this]() {
|
||||
std::string const server_address = "0.0.0.0:" + std::to_string(this->port_);
|
||||
federated::FederatedService service{static_cast<std::int32_t>(this->n_workers_)};
|
||||
xgboost::collective::federated::FederatedService service{
|
||||
static_cast<std::int32_t>(this->n_workers_)};
|
||||
grpc::ServerBuilder builder;
|
||||
|
||||
if (this->server_cert_file_.empty()) {
|
||||
@ -42,7 +75,6 @@ std::future<Result> FederatedTracker::Run() {
|
||||
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials());
|
||||
}
|
||||
builder.RegisterService(&service);
|
||||
server_ = builder.BuildAndStart();
|
||||
LOG(CONSOLE) << "Insecure federated server listening on " << server_address << ", world size "
|
||||
<< this->n_workers_;
|
||||
} else {
|
||||
@ -60,12 +92,12 @@ std::future<Result> FederatedTracker::Run() {
|
||||
builder.AddListeningPort(server_address, grpc::SslServerCredentials(options));
|
||||
}
|
||||
builder.RegisterService(&service);
|
||||
server_ = builder.BuildAndStart();
|
||||
LOG(CONSOLE) << "Federated server listening on " << server_address << ", world size "
|
||||
<< n_workers_;
|
||||
}
|
||||
|
||||
try {
|
||||
server_ = builder.BuildAndStart();
|
||||
server_->Wait();
|
||||
} catch (std::exception const& e) {
|
||||
return collective::Fail(std::string{e.what()});
|
||||
|
||||
@ -8,11 +8,35 @@
|
||||
#include <memory> // for unique_ptr
|
||||
#include <string> // for string
|
||||
|
||||
#include "../../src/collective/in_memory_handler.h"
|
||||
#include "../../src/collective/tracker.h" // for Tracker
|
||||
#include "xgboost/collective/result.h" // for Result
|
||||
#include "xgboost/json.h" // for Json
|
||||
|
||||
namespace xgboost::collective {
|
||||
namespace federated {
|
||||
class FederatedService final : public Federated::Service {
|
||||
public:
|
||||
explicit FederatedService(std::int32_t world_size)
|
||||
: handler_{static_cast<std::size_t>(world_size)} {}
|
||||
|
||||
grpc::Status Allgather(grpc::ServerContext* context, AllgatherRequest const* request,
|
||||
AllgatherReply* reply) override;
|
||||
|
||||
grpc::Status AllgatherV(grpc::ServerContext* context, AllgatherVRequest const* request,
|
||||
AllgatherVReply* reply) override;
|
||||
|
||||
grpc::Status Allreduce(grpc::ServerContext* context, AllreduceRequest const* request,
|
||||
AllreduceReply* reply) override;
|
||||
|
||||
grpc::Status Broadcast(grpc::ServerContext* context, BroadcastRequest const* request,
|
||||
BroadcastReply* reply) override;
|
||||
|
||||
private:
|
||||
xgboost::collective::InMemoryHandler handler_;
|
||||
};
|
||||
}; // namespace federated
|
||||
|
||||
class FederatedTracker : public collective::Tracker {
|
||||
std::unique_ptr<grpc::Server> server_;
|
||||
std::string server_key_path_;
|
||||
|
||||
@ -24,7 +24,7 @@ class Coll : public std::enable_shared_from_this<Coll> {
|
||||
Coll() = default;
|
||||
virtual ~Coll() noexcept(false) {} // NOLINT
|
||||
|
||||
Coll* MakeCUDAVar();
|
||||
virtual Coll* MakeCUDAVar();
|
||||
|
||||
/**
|
||||
* @brief Allreduce
|
||||
|
||||
@ -9,7 +9,8 @@
|
||||
#include <string> // for string
|
||||
#include <utility> // for move, forward
|
||||
|
||||
#include "allgather.h"
|
||||
#include "../common/common.h" // for AssertGPUSupport
|
||||
#include "allgather.h" // for RingAllgather
|
||||
#include "protocol.h" // for kMagic
|
||||
#include "xgboost/base.h" // for XGBOOST_STRICT_R_MODE
|
||||
#include "xgboost/collective/socket.h" // for TCPSocket
|
||||
@ -48,6 +49,14 @@ Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, st
|
||||
this->Rank(), this->World());
|
||||
}
|
||||
|
||||
#if !defined(XGBOOST_USE_NCCL)
|
||||
Comm* Comm::MakeCUDAVar(Context const*, std::shared_ptr<Coll>) const {
|
||||
common::AssertGPUSupport();
|
||||
common::AssertNCCLSupport();
|
||||
return nullptr;
|
||||
}
|
||||
#endif // !defined(XGBOOST_USE_NCCL)
|
||||
|
||||
[[nodiscard]] Result ConnectWorkers(Comm const& comm, TCPSocket* listener, std::int32_t lport,
|
||||
proto::PeerInfo ninfo, std::chrono::seconds timeout,
|
||||
std::int32_t retry,
|
||||
|
||||
@ -20,14 +20,14 @@
|
||||
|
||||
namespace xgboost::collective {
|
||||
namespace {
|
||||
Result GetUniqueId(Comm const& comm, ncclUniqueId* pid) {
|
||||
Result GetUniqueId(Comm const& comm, std::shared_ptr<Coll> coll, ncclUniqueId* pid) {
|
||||
static const int kRootRank = 0;
|
||||
ncclUniqueId id;
|
||||
if (comm.Rank() == kRootRank) {
|
||||
dh::safe_nccl(ncclGetUniqueId(&id));
|
||||
}
|
||||
auto rc = Broadcast(comm, common::Span{reinterpret_cast<std::int8_t*>(&id), sizeof(ncclUniqueId)},
|
||||
kRootRank);
|
||||
auto rc = coll->Broadcast(
|
||||
comm, common::Span{reinterpret_cast<std::int8_t*>(&id), sizeof(ncclUniqueId)}, kRootRank);
|
||||
if (!rc.OK()) {
|
||||
return rc;
|
||||
}
|
||||
@ -63,7 +63,7 @@ static std::string PrintUUID(xgboost::common::Span<std::uint64_t, kUuidLength> c
|
||||
}
|
||||
} // namespace
|
||||
|
||||
Comm* Comm::MakeCUDAVar(Context const* ctx, std::shared_ptr<Coll> pimpl) {
|
||||
Comm* Comm::MakeCUDAVar(Context const* ctx, std::shared_ptr<Coll> pimpl) const {
|
||||
return new NCCLComm{ctx, *this, pimpl};
|
||||
}
|
||||
|
||||
@ -86,6 +86,7 @@ NCCLComm::NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr<Coll> p
|
||||
GetCudaUUID(s_this_uuid, ctx->Device());
|
||||
|
||||
auto rc = pimpl->Allgather(root, common::EraseType(s_uuid), s_this_uuid.size_bytes());
|
||||
|
||||
CHECK(rc.OK()) << rc.Report();
|
||||
|
||||
std::vector<xgboost::common::Span<std::uint64_t, kUuidLength>> converted(root.World());
|
||||
@ -103,7 +104,7 @@ NCCLComm::NCCLComm(Context const* ctx, Comm const& root, std::shared_ptr<Coll> p
|
||||
<< "Multiple processes within communication group running on same CUDA "
|
||||
<< "device is not supported. " << PrintUUID(s_this_uuid) << "\n";
|
||||
|
||||
rc = GetUniqueId(root, &nccl_unique_id_);
|
||||
rc = GetUniqueId(root, pimpl, &nccl_unique_id_);
|
||||
CHECK(rc.OK()) << rc.Report();
|
||||
dh::safe_nccl(ncclCommInitRank(&nccl_comm_, root.World(), nccl_unique_id_, root.Rank()));
|
||||
|
||||
|
||||
@ -40,7 +40,7 @@ class Coll;
|
||||
/**
|
||||
* @brief Base communicator storing info about the tracker and other communicators.
|
||||
*/
|
||||
class Comm {
|
||||
class Comm : public std::enable_shared_from_this<Comm> {
|
||||
protected:
|
||||
std::int32_t world_{-1};
|
||||
std::int32_t rank_{0};
|
||||
@ -87,7 +87,7 @@ class Comm {
|
||||
|
||||
[[nodiscard]] virtual Result SignalError(Result const&) { return Success(); }
|
||||
|
||||
Comm* MakeCUDAVar(Context const* ctx, std::shared_ptr<Coll> pimpl);
|
||||
virtual Comm* MakeCUDAVar(Context const* ctx, std::shared_ptr<Coll> pimpl) const;
|
||||
};
|
||||
|
||||
class RabitComm : public Comm {
|
||||
|
||||
@ -170,6 +170,12 @@ inline void AssertGPUSupport() {
|
||||
#endif // XGBOOST_USE_CUDA && XGBOOST_USE_HIP
|
||||
}
|
||||
|
||||
inline void AssertNCCLSupport() {
|
||||
#if !defined(XGBOOST_USE_NCCL)
|
||||
LOG(FATAL) << "XGBoost version not compiled with NCCL support.";
|
||||
#endif // !defined(XGBOOST_USE_NCCL)
|
||||
}
|
||||
|
||||
inline void AssertOneAPISupport() {
|
||||
#ifndef XGBOOST_USE_ONEAPI
|
||||
LOG(FATAL) << "XGBoost version not compiled with OneAPI support.";
|
||||
|
||||
94
tests/cpp/plugin/federated/test_federated_coll.cc
Normal file
94
tests/cpp/plugin/federated/test_federated_coll.cc
Normal file
@ -0,0 +1,94 @@
|
||||
/**
|
||||
* Copyright 2022-2023, XGBoost contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
#include <xgboost/span.h> // for Span
|
||||
|
||||
#include <array> // for array
|
||||
|
||||
#include "../../../../src/common/type.h" // for EraseType
|
||||
#include "../../collective/test_worker.h" // for SocketTest
|
||||
#include "federated_coll.h" // for FederatedColl
|
||||
#include "federated_comm.h" // for FederatedComm
|
||||
#include "test_worker.h" // for TestFederated
|
||||
|
||||
namespace xgboost::collective {
|
||||
namespace {
|
||||
class FederatedCollTest : public SocketTest {};
|
||||
} // namespace
|
||||
|
||||
TEST_F(FederatedCollTest, Allreduce) {
|
||||
std::int32_t n_workers = std::min(std::thread::hardware_concurrency(), 3u);
|
||||
TestFederated(n_workers, [=](std::shared_ptr<FederatedComm> comm, std::int32_t) {
|
||||
std::array<std::int32_t, 5> buffer = {1, 2, 3, 4, 5};
|
||||
std::array<std::int32_t, 5> expected;
|
||||
std::transform(buffer.cbegin(), buffer.cend(), expected.begin(),
|
||||
[=](auto i) { return i * n_workers; });
|
||||
|
||||
auto coll = std::make_shared<FederatedColl>();
|
||||
auto rc = coll->Allreduce(*comm, common::EraseType(common::Span{buffer.data(), buffer.size()}),
|
||||
ArrayInterfaceHandler::kI4, Op::kSum);
|
||||
ASSERT_TRUE(rc.OK());
|
||||
for (auto i = 0; i < 5; i++) {
|
||||
ASSERT_EQ(buffer[i], expected[i]);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(FederatedCollTest, Broadcast) {
|
||||
std::int32_t n_workers = std::min(std::thread::hardware_concurrency(), 3u);
|
||||
TestFederated(n_workers, [=](std::shared_ptr<FederatedComm> comm, std::int32_t) {
|
||||
FederatedColl coll{};
|
||||
auto rc = Success();
|
||||
if (comm->Rank() == 0) {
|
||||
std::string buffer{"hello"};
|
||||
rc = coll.Broadcast(*comm, common::EraseType(common::Span{buffer.data(), buffer.size()}), 0);
|
||||
ASSERT_EQ(buffer, "hello");
|
||||
} else {
|
||||
std::string buffer{" "};
|
||||
rc = coll.Broadcast(*comm, common::EraseType(common::Span{buffer.data(), buffer.size()}), 0);
|
||||
ASSERT_EQ(buffer, "hello");
|
||||
}
|
||||
ASSERT_TRUE(rc.OK());
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(FederatedCollTest, Allgather) {
|
||||
std::int32_t n_workers = std::min(std::thread::hardware_concurrency(), 3u);
|
||||
TestFederated(n_workers, [=](std::shared_ptr<FederatedComm> comm, std::int32_t) {
|
||||
FederatedColl coll{};
|
||||
|
||||
std::vector<std::int32_t> buffer(n_workers, 0);
|
||||
buffer[comm->Rank()] = comm->Rank();
|
||||
auto rc = coll.Allgather(*comm, common::EraseType(common::Span{buffer.data(), buffer.size()}),
|
||||
sizeof(int));
|
||||
ASSERT_TRUE(rc.OK());
|
||||
for (auto i = 0; i < n_workers; i++) {
|
||||
ASSERT_EQ(buffer[i], i);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(FederatedCollTest, AllgatherV) {
|
||||
std::int32_t n_workers = 2;
|
||||
TestFederated(n_workers, [=](std::shared_ptr<FederatedComm> comm, std::int32_t) {
|
||||
FederatedColl coll{};
|
||||
|
||||
std::vector<std::string_view> inputs{"Federated", " Learning!!!"};
|
||||
std::vector<std::int64_t> recv_segments(inputs.size() + 1, 0);
|
||||
std::string r;
|
||||
std::vector<std::int64_t> sizes{static_cast<std::int64_t>(inputs[0].size()),
|
||||
static_cast<std::int64_t>(inputs[1].size())};
|
||||
r.resize(sizes[0] + sizes[1]);
|
||||
|
||||
auto rc = coll.AllgatherV(
|
||||
*comm,
|
||||
common::EraseType(common::Span{inputs[comm->Rank()].data(), inputs[comm->Rank()].size()}),
|
||||
common::Span{sizes.data(), sizes.size()}, recv_segments,
|
||||
common::EraseType(common::Span{r.data(), r.size()}), AllgatherVAlgo::kRing);
|
||||
|
||||
EXPECT_EQ(r, "Federated Learning!!!");
|
||||
ASSERT_TRUE(rc.OK());
|
||||
});
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
131
tests/cpp/plugin/federated/test_federated_coll.cu
Normal file
131
tests/cpp/plugin/federated/test_federated_coll.cu
Normal file
@ -0,0 +1,131 @@
|
||||
/**
|
||||
* Copyright 2022-2023, XGBoost contributors
|
||||
*/
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <xgboost/collective/result.h> // for Result
|
||||
|
||||
#include "../../../../src/common/common.h" // for AllVisibleGPUs
|
||||
#include "../../../../src/common/device_helpers.cuh" // for device_vector
|
||||
#include "../../../../src/common/type.h" // for EraseType
|
||||
#include "../../collective/test_worker.h" // for SocketTest
|
||||
#include "../../helpers.h" // for MakeCUDACtx
|
||||
#include "federated_coll.cuh"
|
||||
#include "federated_comm.cuh"
|
||||
#include "test_worker.h" // for TestFederated
|
||||
|
||||
namespace xgboost::collective {
|
||||
namespace {
|
||||
class FederatedCollTestGPU : public SocketTest {};
|
||||
|
||||
struct Worker {
|
||||
std::shared_ptr<FederatedColl> impl;
|
||||
std::shared_ptr<Comm> nccl_comm;
|
||||
std::shared_ptr<CUDAFederatedColl> coll;
|
||||
|
||||
Worker(std::shared_ptr<FederatedComm> comm, std::int32_t rank) {
|
||||
auto ctx = MakeCUDACtx(rank);
|
||||
impl = std::make_shared<FederatedColl>();
|
||||
nccl_comm.reset(comm->MakeCUDAVar(&ctx, impl));
|
||||
coll = std::make_shared<CUDAFederatedColl>(impl);
|
||||
}
|
||||
};
|
||||
|
||||
void TestAllreduce(std::shared_ptr<FederatedComm> comm, std::int32_t rank, std::int32_t n_workers) {
|
||||
Worker w{comm, rank};
|
||||
|
||||
dh::device_vector<std::int32_t> buffer{std::vector<std::int32_t>{1, 2, 3, 4, 5}};
|
||||
dh::device_vector<std::int32_t> expected(buffer.size());
|
||||
thrust::transform(buffer.cbegin(), buffer.cend(), expected.begin(),
|
||||
[=] XGBOOST_DEVICE(std::int32_t i) { return i * n_workers; });
|
||||
|
||||
auto rc = w.coll->Allreduce(*w.nccl_comm, common::EraseType(dh::ToSpan(buffer)),
|
||||
ArrayInterfaceHandler::kI4, Op::kSum);
|
||||
ASSERT_TRUE(rc.OK());
|
||||
for (auto i = 0; i < 5; i++) {
|
||||
ASSERT_EQ(buffer[i], expected[i]);
|
||||
}
|
||||
}
|
||||
|
||||
void TestBroadcast(std::shared_ptr<FederatedComm> comm, std::int32_t rank) {
|
||||
Worker w{comm, rank};
|
||||
|
||||
auto rc = Success();
|
||||
std::vector<std::int32_t> expect{0, 1, 2, 3};
|
||||
|
||||
if (comm->Rank() == 0) {
|
||||
dh::device_vector<std::int32_t> buffer{expect};
|
||||
rc = w.coll->Broadcast(*w.nccl_comm, common::EraseType(dh::ToSpan(buffer)), 0);
|
||||
std::vector<std::int32_t> expect{0, 1, 2, 3};
|
||||
ASSERT_EQ(buffer, expect);
|
||||
} else {
|
||||
dh::device_vector<std::int32_t> buffer(std::vector<std::int32_t>{4, 5, 6, 7});
|
||||
rc = w.coll->Broadcast(*w.nccl_comm, common::EraseType(dh::ToSpan(buffer)), 0);
|
||||
ASSERT_EQ(buffer, expect);
|
||||
}
|
||||
ASSERT_TRUE(rc.OK());
|
||||
}
|
||||
|
||||
void TestAllgather(std::shared_ptr<FederatedComm> comm, std::int32_t rank, std::int32_t n_workers) {
|
||||
Worker w{comm, rank};
|
||||
|
||||
dh::device_vector<std::int32_t> buffer(n_workers, 0);
|
||||
buffer[comm->Rank()] = comm->Rank();
|
||||
auto rc = w.coll->Allgather(*w.nccl_comm, common::EraseType(dh::ToSpan(buffer)), sizeof(int));
|
||||
ASSERT_TRUE(rc.OK());
|
||||
for (auto i = 0; i < n_workers; i++) {
|
||||
ASSERT_EQ(buffer[i], i);
|
||||
}
|
||||
}
|
||||
|
||||
void TestAllgatherV(std::shared_ptr<FederatedComm> comm, std::int32_t rank) {
|
||||
Worker w{comm, rank};
|
||||
|
||||
std::vector<dh::device_vector<std::int32_t>> inputs{std::vector<std::int32_t>{1, 2, 3},
|
||||
std::vector<std::int32_t>{4, 5}};
|
||||
std::vector<std::int64_t> recv_segments(inputs.size() + 1, 0);
|
||||
dh::device_vector<std::int32_t> r;
|
||||
std::vector<std::int64_t> sizes{static_cast<std::int64_t>(inputs[0].size()),
|
||||
static_cast<std::int64_t>(inputs[1].size())};
|
||||
r.resize(sizes[0] + sizes[1]);
|
||||
|
||||
auto rc = w.coll->AllgatherV(*w.nccl_comm, common::EraseType(dh::ToSpan(inputs[comm->Rank()])),
|
||||
common::Span{sizes.data(), sizes.size()}, recv_segments,
|
||||
common::EraseType(dh::ToSpan(r)), AllgatherVAlgo::kRing);
|
||||
ASSERT_TRUE(rc.OK());
|
||||
|
||||
ASSERT_EQ(r[0], 1);
|
||||
for (std::size_t i = 1; i < r.size(); ++i) {
|
||||
ASSERT_EQ(r[i], r[i - 1] + 1);
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
TEST_F(FederatedCollTestGPU, Allreduce) {
|
||||
std::int32_t n_workers = common::AllVisibleGPUs();
|
||||
TestFederated(n_workers, [=](std::shared_ptr<FederatedComm> comm, std::int32_t rank) {
|
||||
TestAllreduce(comm, rank, n_workers);
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(FederatedCollTestGPU, Broadcast) {
|
||||
std::int32_t n_workers = common::AllVisibleGPUs();
|
||||
TestFederated(n_workers, [=](std::shared_ptr<FederatedComm> comm, std::int32_t rank) {
|
||||
TestBroadcast(comm, rank);
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(FederatedCollTestGPU, Allgather) {
|
||||
std::int32_t n_workers = common::AllVisibleGPUs();
|
||||
TestFederated(n_workers, [=](std::shared_ptr<FederatedComm> comm, std::int32_t rank) {
|
||||
TestAllgather(comm, rank, n_workers);
|
||||
});
|
||||
}
|
||||
|
||||
TEST_F(FederatedCollTestGPU, AllgatherV) {
|
||||
std::int32_t n_workers = 2;
|
||||
TestFederated(n_workers, [=](std::shared_ptr<FederatedComm> comm, std::int32_t rank) {
|
||||
TestAllgatherV(comm, rank);
|
||||
});
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
@ -7,7 +7,7 @@
|
||||
#include <thread> // for thread
|
||||
|
||||
#include "../../../../plugin/federated/federated_comm.h"
|
||||
#include "../../collective/net_test.h" // for SocketTest
|
||||
#include "../../collective/test_worker.h" // for SocketTest
|
||||
#include "../../helpers.h" // for ExpectThrow
|
||||
#include "test_worker.h" // for TestFederated
|
||||
#include "xgboost/json.h" // for Json
|
||||
@ -71,14 +71,9 @@ TEST_F(FederatedCommTest, IsDistributed) {
|
||||
|
||||
TEST_F(FederatedCommTest, InsecureTracker) {
|
||||
std::int32_t n_workers = std::min(std::thread::hardware_concurrency(), 3u);
|
||||
TestFederated(n_workers, [=](std::int32_t port, std::int32_t rank) {
|
||||
Json config{Object{}};
|
||||
config["federated_world_size"] = n_workers;
|
||||
config["federated_rank"] = rank;
|
||||
config["federated_server_address"] = "0.0.0.0:" + std::to_string(port);
|
||||
FederatedComm comm{config};
|
||||
ASSERT_EQ(comm.Rank(), rank);
|
||||
ASSERT_EQ(comm.World(), n_workers);
|
||||
TestFederated(n_workers, [=](std::shared_ptr<FederatedComm> comm, std::int32_t rank) {
|
||||
ASSERT_EQ(comm->Rank(), rank);
|
||||
ASSERT_EQ(comm->World(), n_workers);
|
||||
});
|
||||
}
|
||||
} // namespace xgboost::collective
|
||||
|
||||
@ -9,6 +9,7 @@
|
||||
#include <thread> // for thread
|
||||
|
||||
#include "../../../../plugin/federated/federated_tracker.h"
|
||||
#include "federated_comm.h" // for FederatedComm
|
||||
#include "xgboost/json.h" // for Json
|
||||
|
||||
namespace xgboost::collective {
|
||||
@ -28,7 +29,15 @@ void TestFederated(std::int32_t n_workers, WorkerFn&& fn) {
|
||||
std::int32_t port = tracker.Port();
|
||||
|
||||
for (std::int32_t i = 0; i < n_workers; ++i) {
|
||||
workers.emplace_back([=] { fn(port, i); });
|
||||
workers.emplace_back([=] {
|
||||
Json config{Object{}};
|
||||
config["federated_world_size"] = n_workers;
|
||||
config["federated_rank"] = i;
|
||||
config["federated_server_address"] = "0.0.0.0:" + std::to_string(port);
|
||||
auto comm = std::make_shared<FederatedComm>(config);
|
||||
|
||||
fn(comm, i);
|
||||
});
|
||||
}
|
||||
|
||||
for (auto& t : workers) {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user