Add Allgather to collective communicator (#8765)
* Add Allgather to collective communicator
This commit is contained in:
parent
48cefa012e
commit
cbf98cb9c6
@ -6,6 +6,7 @@ syntax = "proto3";
|
|||||||
package xgboost.federated;
|
package xgboost.federated;
|
||||||
|
|
||||||
service Federated {
|
service Federated {
|
||||||
|
rpc Allgather(AllgatherRequest) returns (AllgatherReply) {}
|
||||||
rpc Allreduce(AllreduceRequest) returns (AllreduceReply) {}
|
rpc Allreduce(AllreduceRequest) returns (AllreduceReply) {}
|
||||||
rpc Broadcast(BroadcastRequest) returns (BroadcastReply) {}
|
rpc Broadcast(BroadcastRequest) returns (BroadcastReply) {}
|
||||||
}
|
}
|
||||||
@ -30,6 +31,17 @@ enum ReduceOperation {
|
|||||||
BITWISE_XOR = 5;
|
BITWISE_XOR = 5;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
message AllgatherRequest {
|
||||||
|
// An incrementing counter that is unique to each round to operations.
|
||||||
|
uint64 sequence_number = 1;
|
||||||
|
int32 rank = 2;
|
||||||
|
bytes send_buffer = 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
message AllgatherReply {
|
||||||
|
bytes receive_buffer = 1;
|
||||||
|
}
|
||||||
|
|
||||||
message AllreduceRequest {
|
message AllreduceRequest {
|
||||||
// An incrementing counter that is unique to each round to operations.
|
// An incrementing counter that is unique to each round to operations.
|
||||||
uint64 sequence_number = 1;
|
uint64 sequence_number = 1;
|
||||||
|
|||||||
@ -46,6 +46,25 @@ class FederatedClient {
|
|||||||
}()},
|
}()},
|
||||||
rank_{rank} {}
|
rank_{rank} {}
|
||||||
|
|
||||||
|
std::string Allgather(std::string const &send_buffer) {
|
||||||
|
AllgatherRequest request;
|
||||||
|
request.set_sequence_number(sequence_number_++);
|
||||||
|
request.set_rank(rank_);
|
||||||
|
request.set_send_buffer(send_buffer);
|
||||||
|
|
||||||
|
AllgatherReply reply;
|
||||||
|
grpc::ClientContext context;
|
||||||
|
context.set_wait_for_ready(true);
|
||||||
|
grpc::Status status = stub_->Allgather(&context, request, &reply);
|
||||||
|
|
||||||
|
if (status.ok()) {
|
||||||
|
return reply.receive_buffer();
|
||||||
|
} else {
|
||||||
|
std::cout << status.error_code() << ": " << status.error_message() << '\n';
|
||||||
|
throw std::runtime_error("Allgather RPC failed");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
std::string Allreduce(std::string const &send_buffer, DataType data_type,
|
std::string Allreduce(std::string const &send_buffer, DataType data_type,
|
||||||
ReduceOperation reduce_operation) {
|
ReduceOperation reduce_operation) {
|
||||||
AllreduceRequest request;
|
AllreduceRequest request;
|
||||||
|
|||||||
@ -126,6 +126,17 @@ class FederatedCommunicator : public Communicator {
|
|||||||
*/
|
*/
|
||||||
bool IsFederated() const override { return true; }
|
bool IsFederated() const override { return true; }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* \brief Perform in-place allgather.
|
||||||
|
* \param send_receive_buffer Buffer for both sending and receiving data.
|
||||||
|
* \param size Number of bytes to be gathered.
|
||||||
|
*/
|
||||||
|
void AllGather(void *send_receive_buffer, std::size_t size) override {
|
||||||
|
std::string const send_buffer(reinterpret_cast<char const *>(send_receive_buffer), size);
|
||||||
|
auto const received = client_->Allgather(send_buffer);
|
||||||
|
received.copy(reinterpret_cast<char *>(send_receive_buffer), size);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* \brief Perform in-place allreduce.
|
* \brief Perform in-place allreduce.
|
||||||
* \param send_receive_buffer Buffer for both sending and receiving data.
|
* \param send_receive_buffer Buffer for both sending and receiving data.
|
||||||
|
|||||||
@ -14,6 +14,13 @@
|
|||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace federated {
|
namespace federated {
|
||||||
|
|
||||||
|
grpc::Status FederatedService::Allgather(grpc::ServerContext* context,
|
||||||
|
AllgatherRequest const* request, AllgatherReply* reply) {
|
||||||
|
handler_.Allgather(request->send_buffer().data(), request->send_buffer().size(),
|
||||||
|
reply->mutable_receive_buffer(), request->sequence_number(), request->rank());
|
||||||
|
return grpc::Status::OK;
|
||||||
|
}
|
||||||
|
|
||||||
grpc::Status FederatedService::Allreduce(grpc::ServerContext* context,
|
grpc::Status FederatedService::Allreduce(grpc::ServerContext* context,
|
||||||
AllreduceRequest const* request, AllreduceReply* reply) {
|
AllreduceRequest const* request, AllreduceReply* reply) {
|
||||||
handler_.Allreduce(request->send_buffer().data(), request->send_buffer().size(),
|
handler_.Allreduce(request->send_buffer().data(), request->send_buffer().size(),
|
||||||
|
|||||||
@ -14,6 +14,9 @@ class FederatedService final : public Federated::Service {
|
|||||||
public:
|
public:
|
||||||
explicit FederatedService(int const world_size) : handler_{world_size} {}
|
explicit FederatedService(int const world_size) : handler_{world_size} {}
|
||||||
|
|
||||||
|
grpc::Status Allgather(grpc::ServerContext* context, AllgatherRequest const* request,
|
||||||
|
AllgatherReply* reply) override;
|
||||||
|
|
||||||
grpc::Status Allreduce(grpc::ServerContext* context, AllreduceRequest const* request,
|
grpc::Status Allreduce(grpc::ServerContext* context, AllreduceRequest const* request,
|
||||||
AllreduceReply* reply) override;
|
AllreduceReply* reply) override;
|
||||||
|
|
||||||
|
|||||||
@ -122,6 +122,17 @@ class Communicator {
|
|||||||
/** @brief Whether the communicator is running in federated mode. */
|
/** @brief Whether the communicator is running in federated mode. */
|
||||||
virtual bool IsFederated() const = 0;
|
virtual bool IsFederated() const = 0;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Gathers data from all processes and distributes it to all processes.
|
||||||
|
*
|
||||||
|
* This assumes all ranks have the same size, and input data has been sliced into the
|
||||||
|
* corresponding position.
|
||||||
|
*
|
||||||
|
* @param send_receive_buffer Buffer storing the data.
|
||||||
|
* @param size Size of the data in bytes.
|
||||||
|
*/
|
||||||
|
virtual void AllGather(void *send_receive_buffer, std::size_t size) = 0;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Combines values from all processes and distributes the result back to all processes.
|
* @brief Combines values from all processes and distributes the result back to all processes.
|
||||||
*
|
*
|
||||||
|
|||||||
@ -60,6 +60,13 @@ class InMemoryCommunicator : public Communicator {
|
|||||||
bool IsDistributed() const override { return true; }
|
bool IsDistributed() const override { return true; }
|
||||||
bool IsFederated() const override { return false; }
|
bool IsFederated() const override { return false; }
|
||||||
|
|
||||||
|
void AllGather(void* in_out, std::size_t size) override {
|
||||||
|
std::string output;
|
||||||
|
handler_.Allgather(static_cast<const char*>(in_out), size, &output, sequence_number_++,
|
||||||
|
GetRank());
|
||||||
|
output.copy(static_cast<char*>(in_out), size);
|
||||||
|
}
|
||||||
|
|
||||||
void AllReduce(void* in_out, std::size_t size, DataType data_type, Operation operation) override {
|
void AllReduce(void* in_out, std::size_t size, DataType data_type, Operation operation) override {
|
||||||
auto const bytes = size * GetTypeSize(data_type);
|
auto const bytes = size * GetTypeSize(data_type);
|
||||||
std::string output;
|
std::string output;
|
||||||
|
|||||||
@ -9,6 +9,32 @@
|
|||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace collective {
|
namespace collective {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Functor for allgather.
|
||||||
|
*/
|
||||||
|
class AllgatherFunctor {
|
||||||
|
public:
|
||||||
|
std::string const name{"Allgather"};
|
||||||
|
|
||||||
|
AllgatherFunctor(int world_size, int rank) : world_size_{world_size}, rank_{rank} {}
|
||||||
|
|
||||||
|
void operator()(char const* input, std::size_t bytes, std::string* buffer) const {
|
||||||
|
if (buffer->empty()) {
|
||||||
|
// Copy the input if this is the first request.
|
||||||
|
buffer->assign(input, bytes);
|
||||||
|
} else {
|
||||||
|
// Splice the input into the common buffer.
|
||||||
|
auto const per_rank = bytes / world_size_;
|
||||||
|
auto const index = rank_ * per_rank;
|
||||||
|
buffer->replace(index, per_rank, input + index, per_rank);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
int world_size_;
|
||||||
|
int rank_;
|
||||||
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Functor for allreduce.
|
* @brief Functor for allreduce.
|
||||||
*/
|
*/
|
||||||
@ -17,7 +43,7 @@ class AllreduceFunctor {
|
|||||||
std::string const name{"Allreduce"};
|
std::string const name{"Allreduce"};
|
||||||
|
|
||||||
AllreduceFunctor(DataType dataType, Operation operation)
|
AllreduceFunctor(DataType dataType, Operation operation)
|
||||||
: data_type_(dataType), operation_(operation) {}
|
: data_type_{dataType}, operation_{operation} {}
|
||||||
|
|
||||||
void operator()(char const* input, std::size_t bytes, std::string* buffer) const {
|
void operator()(char const* input, std::size_t bytes, std::string* buffer) const {
|
||||||
if (buffer->empty()) {
|
if (buffer->empty()) {
|
||||||
@ -128,7 +154,7 @@ class BroadcastFunctor {
|
|||||||
public:
|
public:
|
||||||
std::string const name{"Broadcast"};
|
std::string const name{"Broadcast"};
|
||||||
|
|
||||||
BroadcastFunctor(int rank, int root) : rank_(rank), root_(root) {}
|
BroadcastFunctor(int rank, int root) : rank_{rank}, root_{root} {}
|
||||||
|
|
||||||
void operator()(char const* input, std::size_t bytes, std::string* buffer) const {
|
void operator()(char const* input, std::size_t bytes, std::string* buffer) const {
|
||||||
if (rank_ == root_) {
|
if (rank_ == root_) {
|
||||||
@ -167,6 +193,11 @@ void InMemoryHandler::Shutdown(uint64_t sequence_number, int) {
|
|||||||
cv_.notify_all();
|
cv_.notify_all();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void InMemoryHandler::Allgather(char const* input, std::size_t bytes, std::string* output,
|
||||||
|
std::size_t sequence_number, int rank) {
|
||||||
|
Handle(input, bytes, output, sequence_number, rank, AllgatherFunctor{world_size_, rank});
|
||||||
|
}
|
||||||
|
|
||||||
void InMemoryHandler::Allreduce(char const* input, std::size_t bytes, std::string* output,
|
void InMemoryHandler::Allreduce(char const* input, std::size_t bytes, std::string* output,
|
||||||
std::size_t sequence_number, int rank, DataType data_type,
|
std::size_t sequence_number, int rank, DataType data_type,
|
||||||
Operation op) {
|
Operation op) {
|
||||||
|
|||||||
@ -53,6 +53,17 @@ class InMemoryHandler {
|
|||||||
*/
|
*/
|
||||||
void Shutdown(uint64_t sequence_number, int rank);
|
void Shutdown(uint64_t sequence_number, int rank);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Perform allgather.
|
||||||
|
* @param input The input buffer.
|
||||||
|
* @param bytes Number of bytes in the input buffer.
|
||||||
|
* @param output The output buffer.
|
||||||
|
* @param sequence_number Call sequence number.
|
||||||
|
* @param rank Index of the worker.
|
||||||
|
*/
|
||||||
|
void Allgather(char const* input, std::size_t bytes, std::string* output,
|
||||||
|
std::size_t sequence_number, int rank);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Perform allreduce.
|
* @brief Perform allreduce.
|
||||||
* @param input The input buffer.
|
* @param input The input buffer.
|
||||||
|
|||||||
@ -17,6 +17,7 @@ class NoOpCommunicator : public Communicator {
|
|||||||
NoOpCommunicator() : Communicator(1, 0) {}
|
NoOpCommunicator() : Communicator(1, 0) {}
|
||||||
bool IsDistributed() const override { return false; }
|
bool IsDistributed() const override { return false; }
|
||||||
bool IsFederated() const override { return false; }
|
bool IsFederated() const override { return false; }
|
||||||
|
void AllGather(void *, std::size_t) override {}
|
||||||
void AllReduce(void *, std::size_t, DataType, Operation) override {}
|
void AllReduce(void *, std::size_t, DataType, Operation) override {}
|
||||||
void Broadcast(void *, std::size_t, int) override {}
|
void Broadcast(void *, std::size_t, int) override {}
|
||||||
std::string GetProcessorName() override { return ""; }
|
std::string GetProcessorName() override { return ""; }
|
||||||
|
|||||||
@ -55,6 +55,12 @@ class RabitCommunicator : public Communicator {
|
|||||||
|
|
||||||
bool IsFederated() const override { return false; }
|
bool IsFederated() const override { return false; }
|
||||||
|
|
||||||
|
void AllGather(void *send_receive_buffer, std::size_t size) override {
|
||||||
|
auto const per_rank = size / GetWorldSize();
|
||||||
|
auto const index = per_rank * GetRank();
|
||||||
|
rabit::Allgather(static_cast<char *>(send_receive_buffer), size, index, per_rank, per_rank);
|
||||||
|
}
|
||||||
|
|
||||||
void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
|
void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
|
||||||
Operation op) override {
|
Operation op) override {
|
||||||
switch (data_type) {
|
switch (data_type) {
|
||||||
|
|||||||
@ -24,6 +24,16 @@ class InMemoryCommunicatorTest : public ::testing::Test {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void Allgather(int rank) {
|
||||||
|
InMemoryCommunicator comm{kWorldSize, rank};
|
||||||
|
char buffer[kWorldSize] = {'a', 'b', 'c'};
|
||||||
|
buffer[rank] = '0' + rank;
|
||||||
|
comm.AllGather(buffer, kWorldSize);
|
||||||
|
for (auto i = 0; i < kWorldSize; i++) {
|
||||||
|
EXPECT_EQ(buffer[i], '0' + i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
static void AllreduceMax(int rank) {
|
static void AllreduceMax(int rank) {
|
||||||
InMemoryCommunicator comm{kWorldSize, rank};
|
InMemoryCommunicator comm{kWorldSize, rank};
|
||||||
int buffer[] = {1 + rank, 2 + rank, 3 + rank, 4 + rank, 5 + rank};
|
int buffer[] = {1 + rank, 2 + rank, 3 + rank, 4 + rank, 5 + rank};
|
||||||
@ -147,6 +157,8 @@ TEST(InMemoryCommunicatorSimpleTest, IsDistributed) {
|
|||||||
EXPECT_TRUE(comm.IsDistributed());
|
EXPECT_TRUE(comm.IsDistributed());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(InMemoryCommunicatorTest, Allgather) { Verify(&Allgather); }
|
||||||
|
|
||||||
TEST_F(InMemoryCommunicatorTest, AllreduceMax) { Verify(&AllreduceMax); }
|
TEST_F(InMemoryCommunicatorTest, AllreduceMax) { Verify(&AllreduceMax); }
|
||||||
|
|
||||||
TEST_F(InMemoryCommunicatorTest, AllreduceMin) { Verify(&AllreduceMin); }
|
TEST_F(InMemoryCommunicatorTest, AllreduceMin) { Verify(&AllreduceMin); }
|
||||||
|
|||||||
@ -28,6 +28,11 @@ namespace collective {
|
|||||||
|
|
||||||
class FederatedCommunicatorTest : public ::testing::Test {
|
class FederatedCommunicatorTest : public ::testing::Test {
|
||||||
public:
|
public:
|
||||||
|
static void VerifyAllgather(int rank, const std::string& server_address) {
|
||||||
|
FederatedCommunicator comm{kWorldSize, rank, server_address};
|
||||||
|
CheckAllgather(comm, rank);
|
||||||
|
}
|
||||||
|
|
||||||
static void VerifyAllreduce(int rank, const std::string& server_address) {
|
static void VerifyAllreduce(int rank, const std::string& server_address) {
|
||||||
FederatedCommunicator comm{kWorldSize, rank, server_address};
|
FederatedCommunicator comm{kWorldSize, rank, server_address};
|
||||||
CheckAllreduce(comm);
|
CheckAllreduce(comm);
|
||||||
@ -56,6 +61,15 @@ class FederatedCommunicatorTest : public ::testing::Test {
|
|||||||
server_thread_->join();
|
server_thread_->join();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void CheckAllgather(FederatedCommunicator &comm, int rank) {
|
||||||
|
int buffer[kWorldSize] = {0, 0, 0};
|
||||||
|
buffer[rank] = rank;
|
||||||
|
comm.AllGather(buffer, sizeof(buffer));
|
||||||
|
for (auto i = 0; i < kWorldSize; i++) {
|
||||||
|
EXPECT_EQ(buffer[i], i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
static void CheckAllreduce(FederatedCommunicator &comm) {
|
static void CheckAllreduce(FederatedCommunicator &comm) {
|
||||||
int buffer[] = {1, 2, 3, 4, 5};
|
int buffer[] = {1, 2, 3, 4, 5};
|
||||||
comm.AllReduce(buffer, sizeof(buffer) / sizeof(buffer[0]), DataType::kInt32, Operation::kSum);
|
comm.AllReduce(buffer, sizeof(buffer) / sizeof(buffer[0]), DataType::kInt32, Operation::kSum);
|
||||||
@ -144,6 +158,17 @@ TEST(FederatedCommunicatorSimpleTest, IsDistributed) {
|
|||||||
EXPECT_TRUE(comm.IsDistributed());
|
EXPECT_TRUE(comm.IsDistributed());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(FederatedCommunicatorTest, Allgather) {
|
||||||
|
std::vector<std::thread> threads;
|
||||||
|
for (auto rank = 0; rank < kWorldSize; rank++) {
|
||||||
|
threads.emplace_back(
|
||||||
|
std::thread(&FederatedCommunicatorTest::VerifyAllgather, rank, server_address_));
|
||||||
|
}
|
||||||
|
for (auto &thread : threads) {
|
||||||
|
thread.join();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(FederatedCommunicatorTest, Allreduce) {
|
TEST_F(FederatedCommunicatorTest, Allreduce) {
|
||||||
std::vector<std::thread> threads;
|
std::vector<std::thread> threads;
|
||||||
for (auto rank = 0; rank < kWorldSize; rank++) {
|
for (auto rank = 0; rank < kWorldSize; rank++) {
|
||||||
|
|||||||
@ -4,13 +4,13 @@
|
|||||||
#include <grpcpp/server_builder.h>
|
#include <grpcpp/server_builder.h>
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
#include <ctime>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <thread>
|
#include <thread>
|
||||||
#include <ctime>
|
|
||||||
|
|
||||||
#include "helpers.h"
|
|
||||||
#include "federated_client.h"
|
#include "federated_client.h"
|
||||||
#include "federated_server.h"
|
#include "federated_server.h"
|
||||||
|
#include "helpers.h"
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
@ -26,6 +26,11 @@ namespace xgboost {
|
|||||||
|
|
||||||
class FederatedServerTest : public ::testing::Test {
|
class FederatedServerTest : public ::testing::Test {
|
||||||
public:
|
public:
|
||||||
|
static void VerifyAllgather(int rank, const std::string& server_address) {
|
||||||
|
federated::FederatedClient client{server_address, rank};
|
||||||
|
CheckAllgather(client, rank);
|
||||||
|
}
|
||||||
|
|
||||||
static void VerifyAllreduce(int rank, const std::string& server_address) {
|
static void VerifyAllreduce(int rank, const std::string& server_address) {
|
||||||
federated::FederatedClient client{server_address, rank};
|
federated::FederatedClient client{server_address, rank};
|
||||||
CheckAllreduce(client);
|
CheckAllreduce(client);
|
||||||
@ -39,6 +44,7 @@ class FederatedServerTest : public ::testing::Test {
|
|||||||
static void VerifyMixture(int rank, const std::string& server_address) {
|
static void VerifyMixture(int rank, const std::string& server_address) {
|
||||||
federated::FederatedClient client{server_address, rank};
|
federated::FederatedClient client{server_address, rank};
|
||||||
for (auto i = 0; i < 10; i++) {
|
for (auto i = 0; i < 10; i++) {
|
||||||
|
CheckAllgather(client, rank);
|
||||||
CheckAllreduce(client);
|
CheckAllreduce(client);
|
||||||
CheckBroadcast(client, rank);
|
CheckBroadcast(client, rank);
|
||||||
}
|
}
|
||||||
@ -62,6 +68,17 @@ class FederatedServerTest : public ::testing::Test {
|
|||||||
server_thread_->join();
|
server_thread_->join();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void CheckAllgather(federated::FederatedClient& client, int rank) {
|
||||||
|
int data[kWorldSize] = {0, 0, 0};
|
||||||
|
data[rank] = rank;
|
||||||
|
std::string send_buffer(reinterpret_cast<char const*>(data), sizeof(data));
|
||||||
|
auto reply = client.Allgather(send_buffer);
|
||||||
|
auto const* result = reinterpret_cast<int const*>(reply.data());
|
||||||
|
for (auto i = 0; i < kWorldSize; i++) {
|
||||||
|
EXPECT_EQ(result[i], i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
static void CheckAllreduce(federated::FederatedClient& client) {
|
static void CheckAllreduce(federated::FederatedClient& client) {
|
||||||
int data[] = {1, 2, 3, 4, 5};
|
int data[] = {1, 2, 3, 4, 5};
|
||||||
std::string send_buffer(reinterpret_cast<char const*>(data), sizeof(data));
|
std::string send_buffer(reinterpret_cast<char const*>(data), sizeof(data));
|
||||||
@ -88,6 +105,16 @@ class FederatedServerTest : public ::testing::Test {
|
|||||||
std::unique_ptr<grpc::Server> server_;
|
std::unique_ptr<grpc::Server> server_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
TEST_F(FederatedServerTest, Allgather) {
|
||||||
|
std::vector<std::thread> threads;
|
||||||
|
for (auto rank = 0; rank < kWorldSize; rank++) {
|
||||||
|
threads.emplace_back(std::thread(&FederatedServerTest::VerifyAllgather, rank, server_address_));
|
||||||
|
}
|
||||||
|
for (auto& thread : threads) {
|
||||||
|
thread.join();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(FederatedServerTest, Allreduce) {
|
TEST_F(FederatedServerTest, Allreduce) {
|
||||||
std::vector<std::thread> threads;
|
std::vector<std::thread> threads;
|
||||||
for (auto rank = 0; rank < kWorldSize; rank++) {
|
for (auto rank = 0; rank < kWorldSize; rank++) {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user