Add an in-memory collective communicator (#8494)
This commit is contained in:
parent
157e98edf7
commit
a8255ea678
@ -72,6 +72,8 @@ OBJECTS= \
|
|||||||
$(PKGROOT)/src/logging.o \
|
$(PKGROOT)/src/logging.o \
|
||||||
$(PKGROOT)/src/global_config.o \
|
$(PKGROOT)/src/global_config.o \
|
||||||
$(PKGROOT)/src/collective/communicator.o \
|
$(PKGROOT)/src/collective/communicator.o \
|
||||||
|
$(PKGROOT)/src/collective/in_memory_communicator.o \
|
||||||
|
$(PKGROOT)/src/collective/in_memory_handler.o \
|
||||||
$(PKGROOT)/src/collective/socket.o \
|
$(PKGROOT)/src/collective/socket.o \
|
||||||
$(PKGROOT)/src/common/charconv.o \
|
$(PKGROOT)/src/common/charconv.o \
|
||||||
$(PKGROOT)/src/common/column_matrix.o \
|
$(PKGROOT)/src/common/column_matrix.o \
|
||||||
|
|||||||
@ -72,6 +72,8 @@ OBJECTS= \
|
|||||||
$(PKGROOT)/src/logging.o \
|
$(PKGROOT)/src/logging.o \
|
||||||
$(PKGROOT)/src/global_config.o \
|
$(PKGROOT)/src/global_config.o \
|
||||||
$(PKGROOT)/src/collective/communicator.o \
|
$(PKGROOT)/src/collective/communicator.o \
|
||||||
|
$(PKGROOT)/src/collective/in_memory_communicator.o \
|
||||||
|
$(PKGROOT)/src/collective/in_memory_handler.o \
|
||||||
$(PKGROOT)/src/collective/socket.o \
|
$(PKGROOT)/src/collective/socket.o \
|
||||||
$(PKGROOT)/src/common/charconv.o \
|
$(PKGROOT)/src/common/charconv.o \
|
||||||
$(PKGROOT)/src/common/column_matrix.o \
|
$(PKGROOT)/src/common/column_matrix.o \
|
||||||
|
|||||||
@ -6,7 +6,6 @@ syntax = "proto3";
|
|||||||
package xgboost.federated;
|
package xgboost.federated;
|
||||||
|
|
||||||
service Federated {
|
service Federated {
|
||||||
rpc Allgather(AllgatherRequest) returns (AllgatherReply) {}
|
|
||||||
rpc Allreduce(AllreduceRequest) returns (AllreduceReply) {}
|
rpc Allreduce(AllreduceRequest) returns (AllreduceReply) {}
|
||||||
rpc Broadcast(BroadcastRequest) returns (BroadcastReply) {}
|
rpc Broadcast(BroadcastRequest) returns (BroadcastReply) {}
|
||||||
}
|
}
|
||||||
@ -28,17 +27,6 @@ enum ReduceOperation {
|
|||||||
SUM = 2;
|
SUM = 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
message AllgatherRequest {
|
|
||||||
// An incrementing counter that is unique to each round to operations.
|
|
||||||
uint64 sequence_number = 1;
|
|
||||||
int32 rank = 2;
|
|
||||||
bytes send_buffer = 3;
|
|
||||||
}
|
|
||||||
|
|
||||||
message AllgatherReply {
|
|
||||||
bytes receive_buffer = 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
message AllreduceRequest {
|
message AllreduceRequest {
|
||||||
// An incrementing counter that is unique to each round to operations.
|
// An incrementing counter that is unique to each round to operations.
|
||||||
uint64 sequence_number = 1;
|
uint64 sequence_number = 1;
|
||||||
|
|||||||
@ -46,25 +46,6 @@ class FederatedClient {
|
|||||||
}()},
|
}()},
|
||||||
rank_{rank} {}
|
rank_{rank} {}
|
||||||
|
|
||||||
std::string Allgather(std::string const &send_buffer) {
|
|
||||||
AllgatherRequest request;
|
|
||||||
request.set_sequence_number(sequence_number_++);
|
|
||||||
request.set_rank(rank_);
|
|
||||||
request.set_send_buffer(send_buffer);
|
|
||||||
|
|
||||||
AllgatherReply reply;
|
|
||||||
grpc::ClientContext context;
|
|
||||||
context.set_wait_for_ready(true);
|
|
||||||
grpc::Status status = stub_->Allgather(&context, request, &reply);
|
|
||||||
|
|
||||||
if (status.ok()) {
|
|
||||||
return reply.receive_buffer();
|
|
||||||
} else {
|
|
||||||
std::cout << status.error_code() << ": " << status.error_message() << '\n';
|
|
||||||
throw std::runtime_error("Allgather RPC failed");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string Allreduce(std::string const &send_buffer, DataType data_type,
|
std::string Allreduce(std::string const &send_buffer, DataType data_type,
|
||||||
ReduceOperation reduce_operation) {
|
ReduceOperation reduce_operation) {
|
||||||
AllreduceRequest request;
|
AllreduceRequest request;
|
||||||
|
|||||||
@ -12,40 +12,6 @@
|
|||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace collective {
|
namespace collective {
|
||||||
|
|
||||||
/** @brief Get the size of the data type. */
|
|
||||||
inline std::size_t GetTypeSize(DataType data_type) {
|
|
||||||
std::size_t size{0};
|
|
||||||
switch (data_type) {
|
|
||||||
case DataType::kInt8:
|
|
||||||
size = sizeof(std::int8_t);
|
|
||||||
break;
|
|
||||||
case DataType::kUInt8:
|
|
||||||
size = sizeof(std::uint8_t);
|
|
||||||
break;
|
|
||||||
case DataType::kInt32:
|
|
||||||
size = sizeof(std::int32_t);
|
|
||||||
break;
|
|
||||||
case DataType::kUInt32:
|
|
||||||
size = sizeof(std::uint32_t);
|
|
||||||
break;
|
|
||||||
case DataType::kInt64:
|
|
||||||
size = sizeof(std::int64_t);
|
|
||||||
break;
|
|
||||||
case DataType::kUInt64:
|
|
||||||
size = sizeof(std::uint64_t);
|
|
||||||
break;
|
|
||||||
case DataType::kFloat:
|
|
||||||
size = sizeof(float);
|
|
||||||
break;
|
|
||||||
case DataType::kDouble:
|
|
||||||
size = sizeof(double);
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
LOG(FATAL) << "Unknown data type.";
|
|
||||||
}
|
|
||||||
return size;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief A Federated Learning communicator class that handles collective communication.
|
* @brief A Federated Learning communicator class that handles collective communication.
|
||||||
*/
|
*/
|
||||||
|
|||||||
@ -7,7 +7,6 @@
|
|||||||
#include <grpcpp/server_builder.h>
|
#include <grpcpp/server_builder.h>
|
||||||
#include <xgboost/logging.h>
|
#include <xgboost/logging.h>
|
||||||
|
|
||||||
#include <fstream>
|
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
|
||||||
#include "../../src/common/io.h"
|
#include "../../src/common/io.h"
|
||||||
@ -15,184 +14,20 @@
|
|||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace federated {
|
namespace federated {
|
||||||
|
|
||||||
class AllgatherFunctor {
|
|
||||||
public:
|
|
||||||
std::string const name{"Allgather"};
|
|
||||||
|
|
||||||
explicit AllgatherFunctor(int const world_size) : world_size_{world_size} {}
|
|
||||||
|
|
||||||
void operator()(AllgatherRequest const* request, std::string& buffer) const {
|
|
||||||
auto const rank = request->rank();
|
|
||||||
auto const& send_buffer = request->send_buffer();
|
|
||||||
auto const send_size = send_buffer.size();
|
|
||||||
// Resize the buffer if this is the first request.
|
|
||||||
if (buffer.size() != send_size * world_size_) {
|
|
||||||
buffer.resize(send_size * world_size_);
|
|
||||||
}
|
|
||||||
// Splice the send_buffer into the common buffer.
|
|
||||||
buffer.replace(rank * send_size, send_size, send_buffer);
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
int const world_size_;
|
|
||||||
};
|
|
||||||
|
|
||||||
class AllreduceFunctor {
|
|
||||||
public:
|
|
||||||
std::string const name{"Allreduce"};
|
|
||||||
|
|
||||||
void operator()(AllreduceRequest const* request, std::string& buffer) const {
|
|
||||||
if (buffer.empty()) {
|
|
||||||
// Copy the send_buffer if this is the first request.
|
|
||||||
buffer = request->send_buffer();
|
|
||||||
} else {
|
|
||||||
// Apply the reduce_operation to the send_buffer and the common buffer.
|
|
||||||
Accumulate(buffer, request->send_buffer(), request->data_type(), request->reduce_operation());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
template <class T>
|
|
||||||
void Accumulate(T* buffer, T const* input, std::size_t n,
|
|
||||||
ReduceOperation reduce_operation) const {
|
|
||||||
switch (reduce_operation) {
|
|
||||||
case ReduceOperation::MAX:
|
|
||||||
std::transform(buffer, buffer + n, input, buffer, [](T a, T b) { return std::max(a, b); });
|
|
||||||
break;
|
|
||||||
case ReduceOperation::MIN:
|
|
||||||
std::transform(buffer, buffer + n, input, buffer, [](T a, T b) { return std::min(a, b); });
|
|
||||||
break;
|
|
||||||
case ReduceOperation::SUM:
|
|
||||||
std::transform(buffer, buffer + n, input, buffer, std::plus<T>());
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
throw std::invalid_argument("Invalid reduce operation");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void Accumulate(std::string& buffer, std::string const& input, DataType data_type,
|
|
||||||
ReduceOperation reduce_operation) const {
|
|
||||||
switch (data_type) {
|
|
||||||
case DataType::INT8:
|
|
||||||
Accumulate(reinterpret_cast<std::int8_t*>(&buffer[0]),
|
|
||||||
reinterpret_cast<std::int8_t const*>(input.data()), buffer.size(),
|
|
||||||
reduce_operation);
|
|
||||||
break;
|
|
||||||
case DataType::UINT8:
|
|
||||||
Accumulate(reinterpret_cast<std::uint8_t*>(&buffer[0]),
|
|
||||||
reinterpret_cast<std::uint8_t const*>(input.data()), buffer.size(),
|
|
||||||
reduce_operation);
|
|
||||||
break;
|
|
||||||
case DataType::INT32:
|
|
||||||
Accumulate(reinterpret_cast<std::int32_t*>(&buffer[0]),
|
|
||||||
reinterpret_cast<std::int32_t const*>(input.data()),
|
|
||||||
buffer.size() / sizeof(std::uint32_t), reduce_operation);
|
|
||||||
break;
|
|
||||||
case DataType::UINT32:
|
|
||||||
Accumulate(reinterpret_cast<std::uint32_t*>(&buffer[0]),
|
|
||||||
reinterpret_cast<std::uint32_t const*>(input.data()),
|
|
||||||
buffer.size() / sizeof(std::uint32_t), reduce_operation);
|
|
||||||
break;
|
|
||||||
case DataType::INT64:
|
|
||||||
Accumulate(reinterpret_cast<std::int64_t*>(&buffer[0]),
|
|
||||||
reinterpret_cast<std::int64_t const*>(input.data()),
|
|
||||||
buffer.size() / sizeof(std::int64_t), reduce_operation);
|
|
||||||
break;
|
|
||||||
case DataType::UINT64:
|
|
||||||
Accumulate(reinterpret_cast<std::uint64_t*>(&buffer[0]),
|
|
||||||
reinterpret_cast<std::uint64_t const*>(input.data()),
|
|
||||||
buffer.size() / sizeof(std::uint64_t), reduce_operation);
|
|
||||||
break;
|
|
||||||
case DataType::FLOAT:
|
|
||||||
Accumulate(reinterpret_cast<float*>(&buffer[0]),
|
|
||||||
reinterpret_cast<float const*>(input.data()), buffer.size() / sizeof(float),
|
|
||||||
reduce_operation);
|
|
||||||
break;
|
|
||||||
case DataType::DOUBLE:
|
|
||||||
Accumulate(reinterpret_cast<double*>(&buffer[0]),
|
|
||||||
reinterpret_cast<double const*>(input.data()), buffer.size() / sizeof(double),
|
|
||||||
reduce_operation);
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
throw std::invalid_argument("Invalid data type");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
class BroadcastFunctor {
|
|
||||||
public:
|
|
||||||
std::string const name{"Broadcast"};
|
|
||||||
|
|
||||||
void operator()(BroadcastRequest const* request, std::string& buffer) const {
|
|
||||||
if (request->rank() == request->root()) {
|
|
||||||
// Copy the send_buffer if this is the root.
|
|
||||||
buffer = request->send_buffer();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
grpc::Status FederatedService::Allgather(grpc::ServerContext* context,
|
|
||||||
AllgatherRequest const* request, AllgatherReply* reply) {
|
|
||||||
return Handle(request, reply, AllgatherFunctor{world_size_});
|
|
||||||
}
|
|
||||||
|
|
||||||
grpc::Status FederatedService::Allreduce(grpc::ServerContext* context,
|
grpc::Status FederatedService::Allreduce(grpc::ServerContext* context,
|
||||||
AllreduceRequest const* request, AllreduceReply* reply) {
|
AllreduceRequest const* request, AllreduceReply* reply) {
|
||||||
return Handle(request, reply, AllreduceFunctor{});
|
handler_.Allreduce(request->send_buffer().data(), request->send_buffer().size(),
|
||||||
|
reply->mutable_receive_buffer(), request->sequence_number(), request->rank(),
|
||||||
|
static_cast<xgboost::collective::DataType>(request->data_type()),
|
||||||
|
static_cast<xgboost::collective::Operation>(request->reduce_operation()));
|
||||||
|
return grpc::Status::OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
grpc::Status FederatedService::Broadcast(grpc::ServerContext* context,
|
grpc::Status FederatedService::Broadcast(grpc::ServerContext* context,
|
||||||
BroadcastRequest const* request, BroadcastReply* reply) {
|
BroadcastRequest const* request, BroadcastReply* reply) {
|
||||||
return Handle(request, reply, BroadcastFunctor{});
|
handler_.Broadcast(request->send_buffer().data(), request->send_buffer().size(),
|
||||||
}
|
reply->mutable_receive_buffer(), request->sequence_number(), request->rank(),
|
||||||
|
request->root());
|
||||||
template <class Request, class Reply, class RequestFunctor>
|
|
||||||
grpc::Status FederatedService::Handle(Request const* request, Reply* reply,
|
|
||||||
RequestFunctor const& functor) {
|
|
||||||
// Pass through if there is only 1 client.
|
|
||||||
if (world_size_ == 1) {
|
|
||||||
reply->set_receive_buffer(request->send_buffer());
|
|
||||||
return grpc::Status::OK;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::unique_lock<std::mutex> lock(mutex_);
|
|
||||||
|
|
||||||
auto const sequence_number = request->sequence_number();
|
|
||||||
auto const rank = request->rank();
|
|
||||||
|
|
||||||
LOG(INFO) << functor.name << " rank " << rank << ": waiting for current sequence number";
|
|
||||||
cv_.wait(lock, [this, sequence_number] { return sequence_number_ == sequence_number; });
|
|
||||||
|
|
||||||
LOG(INFO) << functor.name << " rank " << rank << ": handling request";
|
|
||||||
functor(request, buffer_);
|
|
||||||
received_++;
|
|
||||||
|
|
||||||
if (received_ == world_size_) {
|
|
||||||
LOG(INFO) << functor.name << " rank " << rank << ": all requests received";
|
|
||||||
reply->set_receive_buffer(buffer_);
|
|
||||||
sent_++;
|
|
||||||
lock.unlock();
|
|
||||||
cv_.notify_all();
|
|
||||||
return grpc::Status::OK;
|
|
||||||
}
|
|
||||||
|
|
||||||
LOG(INFO) << functor.name << " rank " << rank << ": waiting for all clients";
|
|
||||||
cv_.wait(lock, [this] { return received_ == world_size_; });
|
|
||||||
|
|
||||||
LOG(INFO) << functor.name << " rank " << rank << ": sending reply";
|
|
||||||
reply->set_receive_buffer(buffer_);
|
|
||||||
sent_++;
|
|
||||||
|
|
||||||
if (sent_ == world_size_) {
|
|
||||||
LOG(INFO) << functor.name << " rank " << rank << ": all replies sent";
|
|
||||||
sent_ = 0;
|
|
||||||
received_ = 0;
|
|
||||||
buffer_.clear();
|
|
||||||
sequence_number_++;
|
|
||||||
lock.unlock();
|
|
||||||
cv_.notify_all();
|
|
||||||
}
|
|
||||||
|
|
||||||
return grpc::Status::OK;
|
return grpc::Status::OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -5,18 +5,14 @@
|
|||||||
|
|
||||||
#include <federated.grpc.pb.h>
|
#include <federated.grpc.pb.h>
|
||||||
|
|
||||||
#include <condition_variable>
|
#include "../../src/collective/in_memory_handler.h"
|
||||||
#include <mutex>
|
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace federated {
|
namespace federated {
|
||||||
|
|
||||||
class FederatedService final : public Federated::Service {
|
class FederatedService final : public Federated::Service {
|
||||||
public:
|
public:
|
||||||
explicit FederatedService(int const world_size) : world_size_{world_size} {}
|
explicit FederatedService(int const world_size) : handler_{world_size} {}
|
||||||
|
|
||||||
grpc::Status Allgather(grpc::ServerContext* context, AllgatherRequest const* request,
|
|
||||||
AllgatherReply* reply) override;
|
|
||||||
|
|
||||||
grpc::Status Allreduce(grpc::ServerContext* context, AllreduceRequest const* request,
|
grpc::Status Allreduce(grpc::ServerContext* context, AllreduceRequest const* request,
|
||||||
AllreduceReply* reply) override;
|
AllreduceReply* reply) override;
|
||||||
@ -25,16 +21,7 @@ class FederatedService final : public Federated::Service {
|
|||||||
BroadcastReply* reply) override;
|
BroadcastReply* reply) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
template <class Request, class Reply, class RequestFunctor>
|
xgboost::collective::InMemoryHandler handler_;
|
||||||
grpc::Status Handle(Request const* request, Reply* reply, RequestFunctor const& functor);
|
|
||||||
|
|
||||||
int const world_size_;
|
|
||||||
int received_{};
|
|
||||||
int sent_{};
|
|
||||||
std::string buffer_{};
|
|
||||||
uint64_t sequence_number_{};
|
|
||||||
mutable std::mutex mutex_;
|
|
||||||
mutable std::condition_variable cv_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
void RunServer(int port, int world_size, char const* server_key_file, char const* server_cert_file,
|
void RunServer(int port, int world_size, char const* server_key_file, char const* server_cert_file,
|
||||||
|
|||||||
@ -3,6 +3,7 @@
|
|||||||
*/
|
*/
|
||||||
#include "communicator.h"
|
#include "communicator.h"
|
||||||
|
|
||||||
|
#include "in_memory_communicator.h"
|
||||||
#include "noop_communicator.h"
|
#include "noop_communicator.h"
|
||||||
#include "rabit_communicator.h"
|
#include "rabit_communicator.h"
|
||||||
|
|
||||||
@ -40,6 +41,10 @@ void Communicator::Init(Json const& config) {
|
|||||||
#endif
|
#endif
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
case CommunicatorType::kInMemory: {
|
||||||
|
communicator_.reset(InMemoryCommunicator::Create(config));
|
||||||
|
break;
|
||||||
|
}
|
||||||
case CommunicatorType::kUnknown:
|
case CommunicatorType::kUnknown:
|
||||||
LOG(FATAL) << "Unknown communicator type.";
|
LOG(FATAL) << "Unknown communicator type.";
|
||||||
}
|
}
|
||||||
|
|||||||
@ -23,12 +23,46 @@ enum class DataType {
|
|||||||
kDouble = 7
|
kDouble = 7
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/** @brief Get the size of the data type. */
|
||||||
|
inline std::size_t GetTypeSize(DataType data_type) {
|
||||||
|
std::size_t size{0};
|
||||||
|
switch (data_type) {
|
||||||
|
case DataType::kInt8:
|
||||||
|
size = sizeof(std::int8_t);
|
||||||
|
break;
|
||||||
|
case DataType::kUInt8:
|
||||||
|
size = sizeof(std::uint8_t);
|
||||||
|
break;
|
||||||
|
case DataType::kInt32:
|
||||||
|
size = sizeof(std::int32_t);
|
||||||
|
break;
|
||||||
|
case DataType::kUInt32:
|
||||||
|
size = sizeof(std::uint32_t);
|
||||||
|
break;
|
||||||
|
case DataType::kInt64:
|
||||||
|
size = sizeof(std::int64_t);
|
||||||
|
break;
|
||||||
|
case DataType::kUInt64:
|
||||||
|
size = sizeof(std::uint64_t);
|
||||||
|
break;
|
||||||
|
case DataType::kFloat:
|
||||||
|
size = sizeof(float);
|
||||||
|
break;
|
||||||
|
case DataType::kDouble:
|
||||||
|
size = sizeof(double);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
LOG(FATAL) << "Unknown data type.";
|
||||||
|
}
|
||||||
|
return size;
|
||||||
|
}
|
||||||
|
|
||||||
/** @brief Defines the reduction operation. */
|
/** @brief Defines the reduction operation. */
|
||||||
enum class Operation { kMax = 0, kMin = 1, kSum = 2 };
|
enum class Operation { kMax = 0, kMin = 1, kSum = 2 };
|
||||||
|
|
||||||
class DeviceCommunicator;
|
class DeviceCommunicator;
|
||||||
|
|
||||||
enum class CommunicatorType { kUnknown, kRabit, kFederated };
|
enum class CommunicatorType { kUnknown, kRabit, kFederated, kInMemory };
|
||||||
|
|
||||||
/** \brief Case-insensitive string comparison. */
|
/** \brief Case-insensitive string comparison. */
|
||||||
inline int CompareStringsCaseInsensitive(const char *s1, const char *s2) {
|
inline int CompareStringsCaseInsensitive(const char *s1, const char *s2) {
|
||||||
|
|||||||
12
src/collective/in_memory_communicator.cc
Normal file
12
src/collective/in_memory_communicator.cc
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
/*!
|
||||||
|
* Copyright 2022 XGBoost contributors
|
||||||
|
*/
|
||||||
|
#include "in_memory_communicator.h"
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
namespace collective {
|
||||||
|
|
||||||
|
InMemoryHandler InMemoryCommunicator::handler_{};
|
||||||
|
|
||||||
|
} // namespace collective
|
||||||
|
} // namespace xgboost
|
||||||
91
src/collective/in_memory_communicator.h
Normal file
91
src/collective/in_memory_communicator.h
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
/*!
|
||||||
|
* Copyright 2022 XGBoost contributors
|
||||||
|
*/
|
||||||
|
#pragma once
|
||||||
|
#include <xgboost/json.h>
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "../c_api/c_api_utils.h"
|
||||||
|
#include "in_memory_handler.h"
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
namespace collective {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* An in-memory communicator, useful for testing.
|
||||||
|
*/
|
||||||
|
class InMemoryCommunicator : public Communicator {
|
||||||
|
public:
|
||||||
|
/**
|
||||||
|
* @brief Create a new communicator based on JSON configuration.
|
||||||
|
* @param config JSON configuration.
|
||||||
|
* @return Communicator as specified by the JSON configuration.
|
||||||
|
*/
|
||||||
|
static Communicator* Create(Json const& config) {
|
||||||
|
int world_size{0};
|
||||||
|
int rank{-1};
|
||||||
|
|
||||||
|
// Parse environment variables first.
|
||||||
|
auto* value = getenv("IN_MEMORY_WORLD_SIZE");
|
||||||
|
if (value != nullptr) {
|
||||||
|
world_size = std::stoi(value);
|
||||||
|
}
|
||||||
|
value = getenv("IN_MEMORY_RANK");
|
||||||
|
if (value != nullptr) {
|
||||||
|
rank = std::stoi(value);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Runtime configuration overrides, optional as users can specify them as env vars.
|
||||||
|
world_size = static_cast<int>(OptionalArg<Integer>(config, "in_memory_world_size",
|
||||||
|
static_cast<Integer::Int>(world_size)));
|
||||||
|
rank = static_cast<int>(
|
||||||
|
OptionalArg<Integer>(config, "in_memory_rank", static_cast<Integer::Int>(rank)));
|
||||||
|
|
||||||
|
if (world_size == 0) {
|
||||||
|
LOG(FATAL) << "Federated world size must be set.";
|
||||||
|
}
|
||||||
|
if (rank == -1) {
|
||||||
|
LOG(FATAL) << "Federated rank must be set.";
|
||||||
|
}
|
||||||
|
return new InMemoryCommunicator(world_size, rank);
|
||||||
|
}
|
||||||
|
|
||||||
|
InMemoryCommunicator(int world_size, int rank) : Communicator(world_size, rank) {
|
||||||
|
handler_.Init(world_size, rank);
|
||||||
|
}
|
||||||
|
|
||||||
|
~InMemoryCommunicator() override { handler_.Shutdown(sequence_number_++, GetRank()); }
|
||||||
|
|
||||||
|
bool IsDistributed() const override { return true; }
|
||||||
|
bool IsFederated() const override { return false; }
|
||||||
|
|
||||||
|
void AllReduce(void* in_out, std::size_t size, DataType data_type, Operation operation) override {
|
||||||
|
auto const bytes = size * GetTypeSize(data_type);
|
||||||
|
std::string output;
|
||||||
|
handler_.Allreduce(static_cast<const char*>(in_out), bytes, &output, sequence_number_++,
|
||||||
|
GetRank(), data_type, operation);
|
||||||
|
output.copy(static_cast<char*>(in_out), bytes);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Broadcast(void* in_out, std::size_t size, int root) override {
|
||||||
|
std::string output;
|
||||||
|
handler_.Broadcast(static_cast<const char*>(in_out), size, &output, sequence_number_++,
|
||||||
|
GetRank(), root);
|
||||||
|
output.copy(static_cast<char*>(in_out), size);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string GetProcessorName() override { return "rank" + std::to_string(GetRank()); }
|
||||||
|
|
||||||
|
void Print(const std::string& message) override { LOG(CONSOLE) << message; }
|
||||||
|
|
||||||
|
protected:
|
||||||
|
void Shutdown() override {}
|
||||||
|
|
||||||
|
private:
|
||||||
|
static InMemoryHandler handler_;
|
||||||
|
uint64_t sequence_number_{};
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace collective
|
||||||
|
} // namespace xgboost
|
||||||
200
src/collective/in_memory_handler.cc
Normal file
200
src/collective/in_memory_handler.cc
Normal file
@ -0,0 +1,200 @@
|
|||||||
|
/*!
|
||||||
|
* Copyright 2022 XGBoost contributors
|
||||||
|
*/
|
||||||
|
#include "in_memory_handler.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <functional>
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
namespace collective {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Functor for allreduce.
|
||||||
|
*/
|
||||||
|
class AllreduceFunctor {
|
||||||
|
public:
|
||||||
|
std::string const name{"Allreduce"};
|
||||||
|
|
||||||
|
AllreduceFunctor(DataType dataType, Operation operation)
|
||||||
|
: data_type_(dataType), operation_(operation) {}
|
||||||
|
|
||||||
|
void operator()(char const* input, std::size_t bytes, std::string* buffer) const {
|
||||||
|
if (buffer->empty()) {
|
||||||
|
// Copy the input if this is the first request.
|
||||||
|
buffer->assign(input, bytes);
|
||||||
|
} else {
|
||||||
|
// Apply the reduce_operation to the input and the buffer.
|
||||||
|
Accumulate(input, bytes / GetTypeSize(data_type_), &buffer->front());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
template <class T>
|
||||||
|
void Accumulate(T* buffer, T const* input, std::size_t size, Operation reduce_operation) const {
|
||||||
|
switch (reduce_operation) {
|
||||||
|
case Operation::kMax:
|
||||||
|
std::transform(buffer, buffer + size, input, buffer,
|
||||||
|
[](T a, T b) { return std::max(a, b); });
|
||||||
|
break;
|
||||||
|
case Operation::kMin:
|
||||||
|
std::transform(buffer, buffer + size, input, buffer,
|
||||||
|
[](T a, T b) { return std::min(a, b); });
|
||||||
|
break;
|
||||||
|
case Operation::kSum:
|
||||||
|
std::transform(buffer, buffer + size, input, buffer, std::plus<T>());
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw std::invalid_argument("Invalid reduce operation");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Accumulate(char const* input, std::size_t size, char* buffer) const {
|
||||||
|
switch (data_type_) {
|
||||||
|
case DataType::kInt8:
|
||||||
|
Accumulate(reinterpret_cast<std::int8_t*>(buffer),
|
||||||
|
reinterpret_cast<std::int8_t const*>(input), size, operation_);
|
||||||
|
break;
|
||||||
|
case DataType::kUInt8:
|
||||||
|
Accumulate(reinterpret_cast<std::uint8_t*>(buffer),
|
||||||
|
reinterpret_cast<std::uint8_t const*>(input), size, operation_);
|
||||||
|
break;
|
||||||
|
case DataType::kInt32:
|
||||||
|
Accumulate(reinterpret_cast<std::int32_t*>(buffer),
|
||||||
|
reinterpret_cast<std::int32_t const*>(input), size, operation_);
|
||||||
|
break;
|
||||||
|
case DataType::kUInt32:
|
||||||
|
Accumulate(reinterpret_cast<std::uint32_t*>(buffer),
|
||||||
|
reinterpret_cast<std::uint32_t const*>(input), size, operation_);
|
||||||
|
break;
|
||||||
|
case DataType::kInt64:
|
||||||
|
Accumulate(reinterpret_cast<std::int64_t*>(buffer),
|
||||||
|
reinterpret_cast<std::int64_t const*>(input), size, operation_);
|
||||||
|
break;
|
||||||
|
case DataType::kUInt64:
|
||||||
|
Accumulate(reinterpret_cast<std::uint64_t*>(buffer),
|
||||||
|
reinterpret_cast<std::uint64_t const*>(input), size, operation_);
|
||||||
|
break;
|
||||||
|
case DataType::kFloat:
|
||||||
|
Accumulate(reinterpret_cast<float*>(buffer), reinterpret_cast<float const*>(input), size,
|
||||||
|
operation_);
|
||||||
|
break;
|
||||||
|
case DataType::kDouble:
|
||||||
|
Accumulate(reinterpret_cast<double*>(buffer), reinterpret_cast<double const*>(input), size,
|
||||||
|
operation_);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw std::invalid_argument("Invalid data type");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
DataType data_type_;
|
||||||
|
Operation operation_;
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Functor for broadcast.
|
||||||
|
*/
|
||||||
|
class BroadcastFunctor {
|
||||||
|
public:
|
||||||
|
std::string const name{"Broadcast"};
|
||||||
|
|
||||||
|
BroadcastFunctor(int rank, int root) : rank_(rank), root_(root) {}
|
||||||
|
|
||||||
|
void operator()(char const* input, std::size_t bytes, std::string* buffer) const {
|
||||||
|
if (rank_ == root_) {
|
||||||
|
// Copy the input if this is the root.
|
||||||
|
buffer->assign(input, bytes);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
int rank_;
|
||||||
|
int root_;
|
||||||
|
};
|
||||||
|
|
||||||
|
void InMemoryHandler::Init(int world_size, int rank) {
|
||||||
|
CHECK(world_size_ < world_size) << "In memory handler already initialized.";
|
||||||
|
|
||||||
|
std::unique_lock<std::mutex> lock(mutex_);
|
||||||
|
world_size_++;
|
||||||
|
cv_.wait(lock, [this, world_size] { return world_size_ == world_size; });
|
||||||
|
lock.unlock();
|
||||||
|
cv_.notify_all();
|
||||||
|
}
|
||||||
|
|
||||||
|
void InMemoryHandler::Shutdown(uint64_t sequence_number, int rank) {
|
||||||
|
CHECK(world_size_ > 0) << "In memory handler already shutdown.";
|
||||||
|
|
||||||
|
std::unique_lock<std::mutex> lock(mutex_);
|
||||||
|
cv_.wait(lock, [this, sequence_number] { return sequence_number_ == sequence_number; });
|
||||||
|
received_++;
|
||||||
|
cv_.wait(lock, [this] { return received_ == world_size_; });
|
||||||
|
|
||||||
|
received_ = 0;
|
||||||
|
world_size_ = 0;
|
||||||
|
sequence_number_ = 0;
|
||||||
|
lock.unlock();
|
||||||
|
cv_.notify_all();
|
||||||
|
}
|
||||||
|
|
||||||
|
void InMemoryHandler::Allreduce(char const* input, std::size_t bytes, std::string* output,
|
||||||
|
std::size_t sequence_number, int rank, DataType data_type,
|
||||||
|
Operation op) {
|
||||||
|
Handle(input, bytes, output, sequence_number, rank, AllreduceFunctor{data_type, op});
|
||||||
|
}
|
||||||
|
|
||||||
|
void InMemoryHandler::Broadcast(char const* input, std::size_t bytes, std::string* output,
|
||||||
|
std::size_t sequence_number, int rank, int root) {
|
||||||
|
Handle(input, bytes, output, sequence_number, rank, BroadcastFunctor{rank, root});
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class HandlerFunctor>
|
||||||
|
void InMemoryHandler::Handle(char const* input, std::size_t bytes, std::string* output,
|
||||||
|
std::size_t sequence_number, int rank, HandlerFunctor const& functor) {
|
||||||
|
// Pass through if there is only 1 client.
|
||||||
|
if (world_size_ == 1) {
|
||||||
|
if (input != output->data()) {
|
||||||
|
output->assign(input, bytes);
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_lock<std::mutex> lock(mutex_);
|
||||||
|
|
||||||
|
LOG(INFO) << functor.name << " rank " << rank << ": waiting for current sequence number";
|
||||||
|
cv_.wait(lock, [this, sequence_number] { return sequence_number_ == sequence_number; });
|
||||||
|
|
||||||
|
LOG(INFO) << functor.name << " rank " << rank << ": handling request";
|
||||||
|
functor(input, bytes, &buffer_);
|
||||||
|
received_++;
|
||||||
|
|
||||||
|
if (received_ == world_size_) {
|
||||||
|
LOG(INFO) << functor.name << " rank " << rank << ": all requests received";
|
||||||
|
output->assign(buffer_);
|
||||||
|
sent_++;
|
||||||
|
lock.unlock();
|
||||||
|
cv_.notify_all();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
LOG(INFO) << functor.name << " rank " << rank << ": waiting for all clients";
|
||||||
|
cv_.wait(lock, [this] { return received_ == world_size_; });
|
||||||
|
|
||||||
|
LOG(INFO) << functor.name << " rank " << rank << ": sending reply";
|
||||||
|
output->assign(buffer_);
|
||||||
|
sent_++;
|
||||||
|
|
||||||
|
if (sent_ == world_size_) {
|
||||||
|
LOG(INFO) << functor.name << " rank " << rank << ": all replies sent";
|
||||||
|
sent_ = 0;
|
||||||
|
received_ = 0;
|
||||||
|
buffer_.clear();
|
||||||
|
sequence_number_++;
|
||||||
|
lock.unlock();
|
||||||
|
cv_.notify_all();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace collective
|
||||||
|
} // namespace xgboost
|
||||||
106
src/collective/in_memory_handler.h
Normal file
106
src/collective/in_memory_handler.h
Normal file
@ -0,0 +1,106 @@
|
|||||||
|
/*!
|
||||||
|
* Copyright 2022 XGBoost contributors
|
||||||
|
*/
|
||||||
|
#pragma once
|
||||||
|
#include <condition_variable>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "communicator.h"
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
namespace collective {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Handles collective communication primitives in memory.
|
||||||
|
*
|
||||||
|
* This class is thread safe.
|
||||||
|
*/
|
||||||
|
class InMemoryHandler {
|
||||||
|
public:
|
||||||
|
/**
|
||||||
|
* @brief Default constructor.
|
||||||
|
*
|
||||||
|
* This is used when multiple objects/threads are accessing the same handler and need to
|
||||||
|
* initialize it collectively.
|
||||||
|
*/
|
||||||
|
InMemoryHandler() = default;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Construct a handler with the given world size.
|
||||||
|
* @param world_size Number of workers.
|
||||||
|
*
|
||||||
|
* This is used when the handler only needs to be initialized once with a known world size.
|
||||||
|
*/
|
||||||
|
explicit InMemoryHandler(int worldSize) : world_size_{worldSize} {}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Initialize the handler with the world size and rank.
|
||||||
|
* @param world_size Number of workers.
|
||||||
|
* @param rank Index of the worker.
|
||||||
|
*
|
||||||
|
* This is used when multiple objects/threads are accessing the same handler and need to
|
||||||
|
* initialize it collectively.
|
||||||
|
*/
|
||||||
|
void Init(int world_size, int rank);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Shut down the handler.
|
||||||
|
* @param sequence_number Call sequence number.
|
||||||
|
* @param rank Index of the worker.
|
||||||
|
*
|
||||||
|
* This is used when multiple objects/threads are accessing the same handler and need to
|
||||||
|
* shut it down collectively.
|
||||||
|
*/
|
||||||
|
void Shutdown(uint64_t sequence_number, int rank);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Perform allreduce.
|
||||||
|
* @param input The input buffer.
|
||||||
|
* @param bytes Number of bytes in the input buffer.
|
||||||
|
* @param output The output buffer.
|
||||||
|
* @param sequence_number Call sequence number.
|
||||||
|
* @param rank Index of the worker.
|
||||||
|
* @param data_type Type of the data.
|
||||||
|
* @param op The reduce operation.
|
||||||
|
*/
|
||||||
|
void Allreduce(char const* input, std::size_t bytes, std::string* output,
|
||||||
|
std::size_t sequence_number, int rank, DataType data_type, Operation op);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Perform broadcast.
|
||||||
|
* @param input The input buffer.
|
||||||
|
* @param bytes Number of bytes in the input buffer.
|
||||||
|
* @param output The output buffer.
|
||||||
|
* @param sequence_number Call sequence number.
|
||||||
|
* @param rank Index of the worker.
|
||||||
|
* @param root Index of the worker to broadcast from.
|
||||||
|
*/
|
||||||
|
void Broadcast(char const* input, std::size_t bytes, std::string* output,
|
||||||
|
std::size_t sequence_number, int rank, int root);
|
||||||
|
|
||||||
|
private:
|
||||||
|
/**
|
||||||
|
* @brief Handle a collective communication primitive.
|
||||||
|
* @tparam HandlerFunctor The functor used to perform the specific primitive.
|
||||||
|
* @param input The input buffer.
|
||||||
|
* @param size Size of the input in terms of the data type.
|
||||||
|
* @param output The output buffer.
|
||||||
|
* @param sequence_number Call sequence number.
|
||||||
|
* @param rank Index of the worker.
|
||||||
|
* @param functor The functor instance used to perform the specific primitive.
|
||||||
|
*/
|
||||||
|
template <class HandlerFunctor>
|
||||||
|
void Handle(char const* input, std::size_t size, std::string* output, std::size_t sequence_number,
|
||||||
|
int rank, HandlerFunctor const& functor);
|
||||||
|
|
||||||
|
int world_size_{}; /// Number of workers.
|
||||||
|
int received_{}; /// Number of calls received with the current sequence.
|
||||||
|
int sent_{}; /// Number of calls completed with the current sequence.
|
||||||
|
std::string buffer_{}; /// A shared common buffer.
|
||||||
|
uint64_t sequence_number_{}; /// Call sequence number.
|
||||||
|
mutable std::mutex mutex_; /// Lock.
|
||||||
|
mutable std::condition_variable cv_; /// Conditional variable to wait on.
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace collective
|
||||||
|
} // namespace xgboost
|
||||||
112
tests/cpp/collective/test_in_memory_communicator.cc
Normal file
112
tests/cpp/collective/test_in_memory_communicator.cc
Normal file
@ -0,0 +1,112 @@
|
|||||||
|
/*!
|
||||||
|
* Copyright 2022 XGBoost contributors
|
||||||
|
*/
|
||||||
|
#include <dmlc/parameter.h>
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
#include <thread>
|
||||||
|
|
||||||
|
#include "../../../src/collective/in_memory_communicator.h"
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
namespace collective {
|
||||||
|
|
||||||
|
class InMemoryCommunicatorTest : public ::testing::Test {
|
||||||
|
public:
|
||||||
|
static void VerifyAllreduce(int rank) {
|
||||||
|
InMemoryCommunicator comm{kWorldSize, rank};
|
||||||
|
int buffer[] = {1, 2, 3, 4, 5};
|
||||||
|
comm.AllReduce(buffer, sizeof(buffer) / sizeof(buffer[0]), DataType::kInt32, Operation::kSum);
|
||||||
|
int expected[] = {3, 6, 9, 12, 15};
|
||||||
|
for (auto i = 0; i < 5; i++) {
|
||||||
|
EXPECT_EQ(buffer[i], expected[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void VerifyBroadcast(int rank) {
|
||||||
|
InMemoryCommunicator comm{kWorldSize, rank};
|
||||||
|
if (rank == 0) {
|
||||||
|
std::string buffer{"hello"};
|
||||||
|
comm.Broadcast(&buffer[0], buffer.size(), 0);
|
||||||
|
EXPECT_EQ(buffer, "hello");
|
||||||
|
} else {
|
||||||
|
std::string buffer{" "};
|
||||||
|
comm.Broadcast(&buffer[0], buffer.size(), 0);
|
||||||
|
EXPECT_EQ(buffer, "hello");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
static int const kWorldSize{3};
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST(InMemoryCommunicatorSimpleTest, ThrowOnWorldSizeTooSmall) {
|
||||||
|
auto construct = []() { InMemoryCommunicator comm{0, 0}; };
|
||||||
|
EXPECT_THROW(construct(), dmlc::Error);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(InMemoryCommunicatorSimpleTest, ThrowOnRankTooSmall) {
|
||||||
|
auto construct = []() { InMemoryCommunicator comm{1, -1}; };
|
||||||
|
EXPECT_THROW(construct(), dmlc::Error);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(InMemoryCommunicatorSimpleTest, ThrowOnRankTooBig) {
|
||||||
|
auto construct = []() { InMemoryCommunicator comm{1, 1}; };
|
||||||
|
EXPECT_THROW(construct(), dmlc::Error);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(InMemoryCommunicatorSimpleTest, ThrowOnWorldSizeNotInteger) {
|
||||||
|
auto construct = []() {
|
||||||
|
Json config{JsonObject()};
|
||||||
|
config["in_memory_world_size"] = std::string("1");
|
||||||
|
config["in_memory_rank"] = Integer(0);
|
||||||
|
auto *comm = InMemoryCommunicator::Create(config);
|
||||||
|
delete comm;
|
||||||
|
};
|
||||||
|
EXPECT_THROW(construct(), dmlc::Error);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(InMemoryCommunicatorSimpleTest, ThrowOnRankNotInteger) {
|
||||||
|
auto construct = []() {
|
||||||
|
Json config{JsonObject()};
|
||||||
|
config["in_memory_world_size"] = 1;
|
||||||
|
config["in_memory_rank"] = std::string("0");
|
||||||
|
auto *comm = InMemoryCommunicator::Create(config);
|
||||||
|
delete comm;
|
||||||
|
};
|
||||||
|
EXPECT_THROW(construct(), dmlc::Error);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(InMemoryCommunicatorSimpleTest, GetWorldSizeAndRank) {
|
||||||
|
InMemoryCommunicator comm{1, 0};
|
||||||
|
EXPECT_EQ(comm.GetWorldSize(), 1);
|
||||||
|
EXPECT_EQ(comm.GetRank(), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(InMemoryCommunicatorSimpleTest, IsDistributed) {
|
||||||
|
InMemoryCommunicator comm{1, 0};
|
||||||
|
EXPECT_TRUE(comm.IsDistributed());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(InMemoryCommunicatorTest, Allreduce) {
|
||||||
|
std::vector<std::thread> threads;
|
||||||
|
for (auto rank = 0; rank < kWorldSize; rank++) {
|
||||||
|
threads.emplace_back(std::thread(&InMemoryCommunicatorTest::VerifyAllreduce, rank));
|
||||||
|
}
|
||||||
|
for (auto &thread : threads) {
|
||||||
|
thread.join();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(InMemoryCommunicatorTest, Broadcast) {
|
||||||
|
std::vector<std::thread> threads;
|
||||||
|
for (auto rank = 0; rank < kWorldSize; rank++) {
|
||||||
|
threads.emplace_back(std::thread(&InMemoryCommunicatorTest::VerifyBroadcast, rank));
|
||||||
|
}
|
||||||
|
for (auto &thread : threads) {
|
||||||
|
thread.join();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace collective
|
||||||
|
} // namespace xgboost
|
||||||
@ -26,11 +26,6 @@ namespace xgboost {
|
|||||||
|
|
||||||
class FederatedServerTest : public ::testing::Test {
|
class FederatedServerTest : public ::testing::Test {
|
||||||
public:
|
public:
|
||||||
static void VerifyAllgather(int rank, const std::string& server_address) {
|
|
||||||
federated::FederatedClient client{server_address, rank};
|
|
||||||
CheckAllgather(client, rank);
|
|
||||||
}
|
|
||||||
|
|
||||||
static void VerifyAllreduce(int rank, const std::string& server_address) {
|
static void VerifyAllreduce(int rank, const std::string& server_address) {
|
||||||
federated::FederatedClient client{server_address, rank};
|
federated::FederatedClient client{server_address, rank};
|
||||||
CheckAllreduce(client);
|
CheckAllreduce(client);
|
||||||
@ -44,7 +39,6 @@ class FederatedServerTest : public ::testing::Test {
|
|||||||
static void VerifyMixture(int rank, const std::string& server_address) {
|
static void VerifyMixture(int rank, const std::string& server_address) {
|
||||||
federated::FederatedClient client{server_address, rank};
|
federated::FederatedClient client{server_address, rank};
|
||||||
for (auto i = 0; i < 10; i++) {
|
for (auto i = 0; i < 10; i++) {
|
||||||
CheckAllgather(client, rank);
|
|
||||||
CheckAllreduce(client);
|
CheckAllreduce(client);
|
||||||
CheckBroadcast(client, rank);
|
CheckBroadcast(client, rank);
|
||||||
}
|
}
|
||||||
@ -68,11 +62,6 @@ class FederatedServerTest : public ::testing::Test {
|
|||||||
server_thread_->join();
|
server_thread_->join();
|
||||||
}
|
}
|
||||||
|
|
||||||
static void CheckAllgather(federated::FederatedClient& client, int rank) {
|
|
||||||
auto reply = client.Allgather("hello " + std::to_string(rank) + " ");
|
|
||||||
EXPECT_EQ(reply, "hello 0 hello 1 hello 2 ");
|
|
||||||
}
|
|
||||||
|
|
||||||
static void CheckAllreduce(federated::FederatedClient& client) {
|
static void CheckAllreduce(federated::FederatedClient& client) {
|
||||||
int data[] = {1, 2, 3, 4, 5};
|
int data[] = {1, 2, 3, 4, 5};
|
||||||
std::string send_buffer(reinterpret_cast<char const*>(data), sizeof(data));
|
std::string send_buffer(reinterpret_cast<char const*>(data), sizeof(data));
|
||||||
@ -90,7 +79,7 @@ class FederatedServerTest : public ::testing::Test {
|
|||||||
send_buffer = "hello broadcast";
|
send_buffer = "hello broadcast";
|
||||||
}
|
}
|
||||||
auto reply = client.Broadcast(send_buffer, 0);
|
auto reply = client.Broadcast(send_buffer, 0);
|
||||||
EXPECT_EQ(reply, "hello broadcast");
|
EXPECT_EQ(reply, "hello broadcast") << "rank " << rank;
|
||||||
}
|
}
|
||||||
|
|
||||||
static int const kWorldSize{3};
|
static int const kWorldSize{3};
|
||||||
@ -99,16 +88,6 @@ class FederatedServerTest : public ::testing::Test {
|
|||||||
std::unique_ptr<grpc::Server> server_;
|
std::unique_ptr<grpc::Server> server_;
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(FederatedServerTest, Allgather) {
|
|
||||||
std::vector<std::thread> threads;
|
|
||||||
for (auto rank = 0; rank < kWorldSize; rank++) {
|
|
||||||
threads.emplace_back(std::thread(&FederatedServerTest::VerifyAllgather, rank, server_address_));
|
|
||||||
}
|
|
||||||
for (auto& thread : threads) {
|
|
||||||
thread.join();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(FederatedServerTest, Allreduce) {
|
TEST_F(FederatedServerTest, Allreduce) {
|
||||||
std::vector<std::thread> threads;
|
std::vector<std::thread> threads;
|
||||||
for (auto rank = 0; rank < kWorldSize; rank++) {
|
for (auto rank = 0; rank < kWorldSize; rank++) {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user