Improve allgather functions (#9649)
This commit is contained in:
parent
d1dee4ad99
commit
e164d51c43
@ -7,6 +7,7 @@ package xgboost.federated;
|
||||
|
||||
service Federated {
|
||||
rpc Allgather(AllgatherRequest) returns (AllgatherReply) {}
|
||||
rpc AllgatherV(AllgatherVRequest) returns (AllgatherVReply) {}
|
||||
rpc Allreduce(AllreduceRequest) returns (AllreduceReply) {}
|
||||
rpc Broadcast(BroadcastRequest) returns (BroadcastReply) {}
|
||||
}
|
||||
@ -42,6 +43,17 @@ message AllgatherReply {
|
||||
bytes receive_buffer = 1;
|
||||
}
|
||||
|
||||
message AllgatherVRequest {
|
||||
// An incrementing counter that is unique to each round to operations.
|
||||
uint64 sequence_number = 1;
|
||||
int32 rank = 2;
|
||||
bytes send_buffer = 3;
|
||||
}
|
||||
|
||||
message AllgatherVReply {
|
||||
bytes receive_buffer = 1;
|
||||
}
|
||||
|
||||
message AllreduceRequest {
|
||||
// An incrementing counter that is unique to each round to operations.
|
||||
uint64 sequence_number = 1;
|
||||
|
||||
@ -44,11 +44,11 @@ class FederatedClient {
|
||||
}()},
|
||||
rank_{rank} {}
|
||||
|
||||
std::string Allgather(std::string const &send_buffer) {
|
||||
std::string Allgather(std::string_view send_buffer) {
|
||||
AllgatherRequest request;
|
||||
request.set_sequence_number(sequence_number_++);
|
||||
request.set_rank(rank_);
|
||||
request.set_send_buffer(send_buffer);
|
||||
request.set_send_buffer(send_buffer.data(), send_buffer.size());
|
||||
|
||||
AllgatherReply reply;
|
||||
grpc::ClientContext context;
|
||||
@ -63,6 +63,25 @@ class FederatedClient {
|
||||
}
|
||||
}
|
||||
|
||||
std::string AllgatherV(std::string_view send_buffer) {
|
||||
AllgatherVRequest request;
|
||||
request.set_sequence_number(sequence_number_++);
|
||||
request.set_rank(rank_);
|
||||
request.set_send_buffer(send_buffer.data(), send_buffer.size());
|
||||
|
||||
AllgatherVReply reply;
|
||||
grpc::ClientContext context;
|
||||
context.set_wait_for_ready(true);
|
||||
grpc::Status status = stub_->AllgatherV(&context, request, &reply);
|
||||
|
||||
if (status.ok()) {
|
||||
return reply.receive_buffer();
|
||||
} else {
|
||||
std::cout << status.error_code() << ": " << status.error_message() << '\n';
|
||||
throw std::runtime_error("AllgatherV RPC failed");
|
||||
}
|
||||
}
|
||||
|
||||
std::string Allreduce(std::string const &send_buffer, DataType data_type,
|
||||
ReduceOperation reduce_operation) {
|
||||
AllreduceRequest request;
|
||||
|
||||
@ -125,14 +125,19 @@ class FederatedCommunicator : public Communicator {
|
||||
[[nodiscard]] bool IsFederated() const override { return true; }
|
||||
|
||||
/**
|
||||
* \brief Perform in-place allgather.
|
||||
* \param send_receive_buffer Buffer for both sending and receiving data.
|
||||
* \param size Number of bytes to be gathered.
|
||||
* \brief Perform allgather.
|
||||
* \param input Buffer for sending data.
|
||||
*/
|
||||
void AllGather(void *send_receive_buffer, std::size_t size) override {
|
||||
std::string const send_buffer(reinterpret_cast<char const *>(send_receive_buffer), size);
|
||||
auto const received = client_->Allgather(send_buffer);
|
||||
received.copy(reinterpret_cast<char *>(send_receive_buffer), size);
|
||||
std::string AllGather(std::string_view input) override {
|
||||
return client_->Allgather(input);
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Perform variable-length allgather.
|
||||
* \param input Buffer for sending data.
|
||||
*/
|
||||
std::string AllGatherV(std::string_view input) override {
|
||||
return client_->AllgatherV(input);
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@ -19,6 +19,13 @@ grpc::Status FederatedService::Allgather(grpc::ServerContext*, AllgatherRequest
|
||||
return grpc::Status::OK;
|
||||
}
|
||||
|
||||
grpc::Status FederatedService::AllgatherV(grpc::ServerContext*, AllgatherVRequest const* request,
|
||||
AllgatherVReply* reply) {
|
||||
handler_.AllgatherV(request->send_buffer().data(), request->send_buffer().size(),
|
||||
reply->mutable_receive_buffer(), request->sequence_number(), request->rank());
|
||||
return grpc::Status::OK;
|
||||
}
|
||||
|
||||
grpc::Status FederatedService::Allreduce(grpc::ServerContext*, AllreduceRequest const* request,
|
||||
AllreduceReply* reply) {
|
||||
handler_.Allreduce(request->send_buffer().data(), request->send_buffer().size(),
|
||||
@ -36,8 +43,8 @@ grpc::Status FederatedService::Broadcast(grpc::ServerContext*, BroadcastRequest
|
||||
return grpc::Status::OK;
|
||||
}
|
||||
|
||||
void RunServer(int port, int world_size, char const* server_key_file, char const* server_cert_file,
|
||||
char const* client_cert_file) {
|
||||
void RunServer(int port, std::size_t world_size, char const* server_key_file,
|
||||
char const* server_cert_file, char const* client_cert_file) {
|
||||
std::string const server_address = "0.0.0.0:" + std::to_string(port);
|
||||
FederatedService service{world_size};
|
||||
|
||||
@ -59,7 +66,7 @@ void RunServer(int port, int world_size, char const* server_key_file, char const
|
||||
server->Wait();
|
||||
}
|
||||
|
||||
void RunInsecureServer(int port, int world_size) {
|
||||
void RunInsecureServer(int port, std::size_t world_size) {
|
||||
std::string const server_address = "0.0.0.0:" + std::to_string(port);
|
||||
FederatedService service{world_size};
|
||||
|
||||
|
||||
@ -12,11 +12,14 @@ namespace federated {
|
||||
|
||||
class FederatedService final : public Federated::Service {
|
||||
public:
|
||||
explicit FederatedService(int const world_size) : handler_{world_size} {}
|
||||
explicit FederatedService(std::size_t const world_size) : handler_{world_size} {}
|
||||
|
||||
grpc::Status Allgather(grpc::ServerContext* context, AllgatherRequest const* request,
|
||||
AllgatherReply* reply) override;
|
||||
|
||||
grpc::Status AllgatherV(grpc::ServerContext* context, AllgatherVRequest const* request,
|
||||
AllgatherVReply* reply) override;
|
||||
|
||||
grpc::Status Allreduce(grpc::ServerContext* context, AllreduceRequest const* request,
|
||||
AllreduceReply* reply) override;
|
||||
|
||||
@ -27,10 +30,10 @@ class FederatedService final : public Federated::Service {
|
||||
xgboost::collective::InMemoryHandler handler_;
|
||||
};
|
||||
|
||||
void RunServer(int port, int world_size, char const* server_key_file, char const* server_cert_file,
|
||||
char const* client_cert_file);
|
||||
void RunServer(int port, std::size_t world_size, char const* server_key_file,
|
||||
char const* server_cert_file, char const* client_cert_file);
|
||||
|
||||
void RunInsecureServer(int port, int world_size);
|
||||
void RunInsecureServer(int port, std::size_t world_size);
|
||||
|
||||
} // namespace federated
|
||||
} // namespace xgboost
|
||||
|
||||
@ -1724,7 +1724,7 @@ XGB_DLL int XGCommunicatorAllreduce(void *send_receive_buffer, size_t count, int
|
||||
}
|
||||
|
||||
#if defined(XGBOOST_USE_FEDERATED)
|
||||
XGB_DLL int XGBRunFederatedServer(int port, int world_size, char const *server_key_path,
|
||||
XGB_DLL int XGBRunFederatedServer(int port, std::size_t world_size, char const *server_key_path,
|
||||
char const *server_cert_path, char const *client_cert_path) {
|
||||
API_BEGIN();
|
||||
federated::RunServer(port, world_size, server_key_path, server_cert_path, client_cert_path);
|
||||
@ -1732,7 +1732,7 @@ XGB_DLL int XGBRunFederatedServer(int port, int world_size, char const *server_k
|
||||
}
|
||||
|
||||
// Run a server without SSL for local testing.
|
||||
XGB_DLL int XGBRunInsecureFederatedServer(int port, int world_size) {
|
||||
XGB_DLL int XGBRunInsecureFederatedServer(int port, std::size_t world_size) {
|
||||
API_BEGIN();
|
||||
federated::RunInsecureServer(port, world_size);
|
||||
API_END();
|
||||
|
||||
@ -57,9 +57,7 @@ namespace collective {
|
||||
* - federated_client_key: Client key file path. Only needed for the SSL mode.
|
||||
* - federated_client_cert: Client certificate file path. Only needed for the SSL mode.
|
||||
*/
|
||||
inline void Init(Json const& config) {
|
||||
Communicator::Init(config);
|
||||
}
|
||||
inline void Init(Json const &config) { Communicator::Init(config); }
|
||||
|
||||
/*!
|
||||
* \brief Finalize the collective communicator.
|
||||
@ -141,17 +139,89 @@ inline void Broadcast(std::string *sendrecv_data, int root) {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Gathers a single value all processes and distributes the result to all processes.
|
||||
*
|
||||
* @param input The single value.
|
||||
*/
|
||||
template <typename T>
|
||||
inline std::vector<T> Allgather(T const &input) {
|
||||
std::string_view str_input{reinterpret_cast<char const *>(&input), sizeof(T)};
|
||||
auto const output = Communicator::Get()->AllGather(str_input);
|
||||
CHECK_EQ(output.size() % sizeof(T), 0);
|
||||
std::vector<T> result(output.size() / sizeof(T));
|
||||
std::memcpy(reinterpret_cast<void *>(result.data()), output.data(), output.size());
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Gathers data from all processes and distributes it to all processes.
|
||||
*
|
||||
* This assumes all ranks have the same size, and input data has been sliced into the
|
||||
* corresponding position.
|
||||
* This assumes all ranks have the same size.
|
||||
*
|
||||
* @param send_receive_buffer Buffer storing the data.
|
||||
* @param size Size of the data in bytes.
|
||||
* @param input Buffer storing the data.
|
||||
*/
|
||||
inline void Allgather(void *send_receive_buffer, std::size_t size) {
|
||||
Communicator::Get()->AllGather(send_receive_buffer, size);
|
||||
template <typename T>
|
||||
inline std::vector<T> Allgather(std::vector<T> const &input) {
|
||||
if (input.empty()) {
|
||||
return input;
|
||||
}
|
||||
std::string_view str_input{reinterpret_cast<char const *>(input.data()),
|
||||
input.size() * sizeof(T)};
|
||||
auto const output = Communicator::Get()->AllGather(str_input);
|
||||
CHECK_EQ(output.size() % sizeof(T), 0);
|
||||
std::vector<T> result(output.size() / sizeof(T));
|
||||
std::memcpy(reinterpret_cast<void *>(result.data()), output.data(), output.size());
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Gathers variable-length data from all processes and distributes it to all processes.
|
||||
* @param input Buffer storing the data.
|
||||
*/
|
||||
template <typename T>
|
||||
inline std::vector<T> AllgatherV(std::vector<T> const &input) {
|
||||
std::string_view str_input{reinterpret_cast<char const *>(input.data()),
|
||||
input.size() * sizeof(T)};
|
||||
auto const output = Communicator::Get()->AllGatherV(str_input);
|
||||
CHECK_EQ(output.size() % sizeof(T), 0);
|
||||
std::vector<T> result(output.size() / sizeof(T));
|
||||
if (!output.empty()) {
|
||||
std::memcpy(reinterpret_cast<void *>(result.data()), output.data(), output.size());
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Gathers variable-length strings from all processes and distributes them to all processes.
|
||||
* @param input Variable-length list of variable-length strings.
|
||||
*/
|
||||
inline std::vector<std::string> AllgatherStrings(std::vector<std::string> const &input) {
|
||||
std::size_t total_size{0};
|
||||
for (auto const &s : input) {
|
||||
total_size += s.length() + 1; // +1 for null-terminators
|
||||
}
|
||||
std::string flat_string;
|
||||
flat_string.reserve(total_size);
|
||||
for (auto const &s : input) {
|
||||
flat_string.append(s);
|
||||
flat_string.push_back('\0'); // Append a null-terminator after each string
|
||||
}
|
||||
|
||||
auto const output = Communicator::Get()->AllGatherV(flat_string);
|
||||
|
||||
std::vector<std::string> result;
|
||||
std::size_t start_index = 0;
|
||||
// Iterate through the output, find each null-terminated substring.
|
||||
for (std::size_t i = 0; i < output.size(); i++) {
|
||||
if (output[i] == '\0') {
|
||||
// Construct a std::string from the char* substring
|
||||
result.emplace_back(&output[start_index]);
|
||||
// Move to the next substring
|
||||
start_index = i + 1;
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/*!
|
||||
@ -226,7 +296,7 @@ inline void Allreduce(double *send_receive_buffer, size_t count) {
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct AllgatherVResult {
|
||||
struct SpecialAllgatherVResult {
|
||||
std::vector<std::size_t> offsets;
|
||||
std::vector<std::size_t> sizes;
|
||||
std::vector<T> result;
|
||||
@ -241,14 +311,10 @@ struct AllgatherVResult {
|
||||
* @param sizes Sizes of each input.
|
||||
*/
|
||||
template <typename T>
|
||||
inline AllgatherVResult<T> AllgatherV(std::vector<T> const &inputs,
|
||||
std::vector<std::size_t> const &sizes) {
|
||||
auto num_inputs = sizes.size();
|
||||
|
||||
inline SpecialAllgatherVResult<T> SpecialAllgatherV(std::vector<T> const &inputs,
|
||||
std::vector<std::size_t> const &sizes) {
|
||||
// Gather the sizes across all workers.
|
||||
std::vector<std::size_t> all_sizes(num_inputs * GetWorldSize());
|
||||
std::copy_n(sizes.cbegin(), sizes.size(), all_sizes.begin() + num_inputs * GetRank());
|
||||
collective::Allgather(all_sizes.data(), all_sizes.size() * sizeof(std::size_t));
|
||||
auto const all_sizes = Allgather(sizes);
|
||||
|
||||
// Calculate input offsets (std::exclusive_scan).
|
||||
std::vector<std::size_t> offsets(all_sizes.size());
|
||||
@ -257,11 +323,7 @@ inline AllgatherVResult<T> AllgatherV(std::vector<T> const &inputs,
|
||||
}
|
||||
|
||||
// Gather all the inputs.
|
||||
auto total_input_size = offsets.back() + all_sizes.back();
|
||||
std::vector<T> all_inputs(total_input_size);
|
||||
std::copy_n(inputs.cbegin(), inputs.size(), all_inputs.begin() + offsets[num_inputs * GetRank()]);
|
||||
// We cannot use allgather here, since each worker might have a different size.
|
||||
Allreduce<Operation::kMax>(all_inputs.data(), all_inputs.size());
|
||||
auto const all_inputs = AllgatherV(inputs);
|
||||
|
||||
return {offsets, all_sizes, all_inputs};
|
||||
}
|
||||
|
||||
@ -125,13 +125,17 @@ class Communicator {
|
||||
/**
|
||||
* @brief Gathers data from all processes and distributes it to all processes.
|
||||
*
|
||||
* This assumes all ranks have the same size, and input data has been sliced into the
|
||||
* corresponding position.
|
||||
* This assumes all ranks have the same size.
|
||||
*
|
||||
* @param send_receive_buffer Buffer storing the data.
|
||||
* @param size Size of the data in bytes.
|
||||
* @param input Buffer storing the data.
|
||||
*/
|
||||
virtual void AllGather(void *send_receive_buffer, std::size_t size) = 0;
|
||||
virtual std::string AllGather(std::string_view input) = 0;
|
||||
|
||||
/**
|
||||
* @brief Gathers variable-length data from all processes and distributes it to all processes.
|
||||
* @param input Buffer storing the data.
|
||||
*/
|
||||
virtual std::string AllGatherV(std::string_view input) = 0;
|
||||
|
||||
/**
|
||||
* @brief Combines values from all processes and distributes the result back to all processes.
|
||||
|
||||
@ -40,12 +40,10 @@ class DeviceCommunicatorAdapter : public DeviceCommunicator {
|
||||
}
|
||||
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
host_buffer_.resize(send_size * world_size_);
|
||||
dh::safe_cuda(cudaMemcpy(host_buffer_.data() + rank_ * send_size, send_buffer, send_size,
|
||||
cudaMemcpyDefault));
|
||||
Allgather(host_buffer_.data(), host_buffer_.size());
|
||||
dh::safe_cuda(
|
||||
cudaMemcpy(receive_buffer, host_buffer_.data(), host_buffer_.size(), cudaMemcpyDefault));
|
||||
host_buffer_.resize(send_size);
|
||||
dh::safe_cuda(cudaMemcpy(host_buffer_.data(), send_buffer, send_size, cudaMemcpyDefault));
|
||||
auto const output = Allgather(host_buffer_);
|
||||
dh::safe_cuda(cudaMemcpy(receive_buffer, output.data(), output.size(), cudaMemcpyDefault));
|
||||
}
|
||||
|
||||
void AllGatherV(void const *send_buffer, size_t length_bytes, std::vector<std::size_t> *segments,
|
||||
|
||||
@ -60,11 +60,16 @@ class InMemoryCommunicator : public Communicator {
|
||||
bool IsDistributed() const override { return true; }
|
||||
bool IsFederated() const override { return false; }
|
||||
|
||||
void AllGather(void* in_out, std::size_t size) override {
|
||||
std::string AllGather(std::string_view input) override {
|
||||
std::string output;
|
||||
handler_.Allgather(static_cast<const char*>(in_out), size, &output, sequence_number_++,
|
||||
GetRank());
|
||||
output.copy(static_cast<char*>(in_out), size);
|
||||
handler_.Allgather(input.data(), input.size(), &output, sequence_number_++, GetRank());
|
||||
return output;
|
||||
}
|
||||
|
||||
std::string AllGatherV(std::string_view input) override {
|
||||
std::string output;
|
||||
handler_.AllgatherV(input.data(), input.size(), &output, sequence_number_++, GetRank());
|
||||
return output;
|
||||
}
|
||||
|
||||
void AllReduce(void* in_out, std::size_t size, DataType data_type, Operation operation) override {
|
||||
|
||||
@ -16,23 +16,49 @@ class AllgatherFunctor {
|
||||
public:
|
||||
std::string const name{"Allgather"};
|
||||
|
||||
AllgatherFunctor(int world_size, int rank) : world_size_{world_size}, rank_{rank} {}
|
||||
AllgatherFunctor(std::size_t world_size, std::size_t rank)
|
||||
: world_size_{world_size}, rank_{rank} {}
|
||||
|
||||
void operator()(char const* input, std::size_t bytes, std::string* buffer) const {
|
||||
if (buffer->empty()) {
|
||||
// Copy the input if this is the first request.
|
||||
buffer->assign(input, bytes);
|
||||
} else {
|
||||
// Splice the input into the common buffer.
|
||||
auto const per_rank = bytes / world_size_;
|
||||
auto const index = rank_ * per_rank;
|
||||
buffer->replace(index, per_rank, input + index, per_rank);
|
||||
// Resize the buffer if this is the first request.
|
||||
buffer->resize(bytes * world_size_);
|
||||
}
|
||||
|
||||
// Splice the input into the common buffer.
|
||||
buffer->replace(rank_ * bytes, bytes, input, bytes);
|
||||
}
|
||||
|
||||
private:
|
||||
std::size_t world_size_;
|
||||
std::size_t rank_;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Functor for variable-length allgather.
|
||||
*/
|
||||
class AllgatherVFunctor {
|
||||
public:
|
||||
std::string const name{"AllgatherV"};
|
||||
|
||||
AllgatherVFunctor(std::size_t world_size, std::size_t rank,
|
||||
std::map<std::size_t, std::string_view>* data)
|
||||
: world_size_{world_size}, rank_{rank}, data_{data} {}
|
||||
|
||||
void operator()(char const* input, std::size_t bytes, std::string* buffer) const {
|
||||
data_->emplace(rank_, std::string_view{input, bytes});
|
||||
if (data_->size() == world_size_) {
|
||||
for (auto const& kv : *data_) {
|
||||
buffer->append(kv.second);
|
||||
}
|
||||
data_->clear();
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
int world_size_;
|
||||
int rank_;
|
||||
std::size_t world_size_;
|
||||
std::size_t rank_;
|
||||
std::map<std::size_t, std::string_view>* data_;
|
||||
};
|
||||
|
||||
/**
|
||||
@ -154,7 +180,7 @@ class BroadcastFunctor {
|
||||
public:
|
||||
std::string const name{"Broadcast"};
|
||||
|
||||
BroadcastFunctor(int rank, int root) : rank_{rank}, root_{root} {}
|
||||
BroadcastFunctor(std::size_t rank, std::size_t root) : rank_{rank}, root_{root} {}
|
||||
|
||||
void operator()(char const* input, std::size_t bytes, std::string* buffer) const {
|
||||
if (rank_ == root_) {
|
||||
@ -164,11 +190,11 @@ class BroadcastFunctor {
|
||||
}
|
||||
|
||||
private:
|
||||
int rank_;
|
||||
int root_;
|
||||
std::size_t rank_;
|
||||
std::size_t root_;
|
||||
};
|
||||
|
||||
void InMemoryHandler::Init(int world_size, int) {
|
||||
void InMemoryHandler::Init(std::size_t world_size, std::size_t) {
|
||||
CHECK(world_size_ < world_size) << "In memory handler already initialized.";
|
||||
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
@ -178,7 +204,7 @@ void InMemoryHandler::Init(int world_size, int) {
|
||||
cv_.notify_all();
|
||||
}
|
||||
|
||||
void InMemoryHandler::Shutdown(uint64_t sequence_number, int) {
|
||||
void InMemoryHandler::Shutdown(uint64_t sequence_number, std::size_t) {
|
||||
CHECK(world_size_ > 0) << "In memory handler already shutdown.";
|
||||
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
@ -194,24 +220,30 @@ void InMemoryHandler::Shutdown(uint64_t sequence_number, int) {
|
||||
}
|
||||
|
||||
void InMemoryHandler::Allgather(char const* input, std::size_t bytes, std::string* output,
|
||||
std::size_t sequence_number, int rank) {
|
||||
std::size_t sequence_number, std::size_t rank) {
|
||||
Handle(input, bytes, output, sequence_number, rank, AllgatherFunctor{world_size_, rank});
|
||||
}
|
||||
|
||||
void InMemoryHandler::AllgatherV(char const* input, std::size_t bytes, std::string* output,
|
||||
std::size_t sequence_number, std::size_t rank) {
|
||||
Handle(input, bytes, output, sequence_number, rank, AllgatherVFunctor{world_size_, rank, &aux_});
|
||||
}
|
||||
|
||||
void InMemoryHandler::Allreduce(char const* input, std::size_t bytes, std::string* output,
|
||||
std::size_t sequence_number, int rank, DataType data_type,
|
||||
std::size_t sequence_number, std::size_t rank, DataType data_type,
|
||||
Operation op) {
|
||||
Handle(input, bytes, output, sequence_number, rank, AllreduceFunctor{data_type, op});
|
||||
}
|
||||
|
||||
void InMemoryHandler::Broadcast(char const* input, std::size_t bytes, std::string* output,
|
||||
std::size_t sequence_number, int rank, int root) {
|
||||
std::size_t sequence_number, std::size_t rank, std::size_t root) {
|
||||
Handle(input, bytes, output, sequence_number, rank, BroadcastFunctor{rank, root});
|
||||
}
|
||||
|
||||
template <class HandlerFunctor>
|
||||
void InMemoryHandler::Handle(char const* input, std::size_t bytes, std::string* output,
|
||||
std::size_t sequence_number, int rank, HandlerFunctor const& functor) {
|
||||
std::size_t sequence_number, std::size_t rank,
|
||||
HandlerFunctor const& functor) {
|
||||
// Pass through if there is only 1 client.
|
||||
if (world_size_ == 1) {
|
||||
if (input != output->data()) {
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
*/
|
||||
#pragma once
|
||||
#include <condition_variable>
|
||||
#include <map>
|
||||
#include <string>
|
||||
|
||||
#include "communicator.h"
|
||||
@ -31,7 +32,7 @@ class InMemoryHandler {
|
||||
*
|
||||
* This is used when the handler only needs to be initialized once with a known world size.
|
||||
*/
|
||||
explicit InMemoryHandler(int worldSize) : world_size_{worldSize} {}
|
||||
explicit InMemoryHandler(std::size_t worldSize) : world_size_{worldSize} {}
|
||||
|
||||
/**
|
||||
* @brief Initialize the handler with the world size and rank.
|
||||
@ -41,7 +42,7 @@ class InMemoryHandler {
|
||||
* This is used when multiple objects/threads are accessing the same handler and need to
|
||||
* initialize it collectively.
|
||||
*/
|
||||
void Init(int world_size, int rank);
|
||||
void Init(std::size_t world_size, std::size_t rank);
|
||||
|
||||
/**
|
||||
* @brief Shut down the handler.
|
||||
@ -51,7 +52,7 @@ class InMemoryHandler {
|
||||
* This is used when multiple objects/threads are accessing the same handler and need to
|
||||
* shut it down collectively.
|
||||
*/
|
||||
void Shutdown(uint64_t sequence_number, int rank);
|
||||
void Shutdown(uint64_t sequence_number, std::size_t rank);
|
||||
|
||||
/**
|
||||
* @brief Perform allgather.
|
||||
@ -62,7 +63,18 @@ class InMemoryHandler {
|
||||
* @param rank Index of the worker.
|
||||
*/
|
||||
void Allgather(char const* input, std::size_t bytes, std::string* output,
|
||||
std::size_t sequence_number, int rank);
|
||||
std::size_t sequence_number, std::size_t rank);
|
||||
|
||||
/**
|
||||
* @brief Perform variable-length allgather.
|
||||
* @param input The input buffer.
|
||||
* @param bytes Number of bytes in the input buffer.
|
||||
* @param output The output buffer.
|
||||
* @param sequence_number Call sequence number.
|
||||
* @param rank Index of the worker.
|
||||
*/
|
||||
void AllgatherV(char const* input, std::size_t bytes, std::string* output,
|
||||
std::size_t sequence_number, std::size_t rank);
|
||||
|
||||
/**
|
||||
* @brief Perform allreduce.
|
||||
@ -75,7 +87,7 @@ class InMemoryHandler {
|
||||
* @param op The reduce operation.
|
||||
*/
|
||||
void Allreduce(char const* input, std::size_t bytes, std::string* output,
|
||||
std::size_t sequence_number, int rank, DataType data_type, Operation op);
|
||||
std::size_t sequence_number, std::size_t rank, DataType data_type, Operation op);
|
||||
|
||||
/**
|
||||
* @brief Perform broadcast.
|
||||
@ -87,7 +99,7 @@ class InMemoryHandler {
|
||||
* @param root Index of the worker to broadcast from.
|
||||
*/
|
||||
void Broadcast(char const* input, std::size_t bytes, std::string* output,
|
||||
std::size_t sequence_number, int rank, int root);
|
||||
std::size_t sequence_number, std::size_t rank, std::size_t root);
|
||||
|
||||
private:
|
||||
/**
|
||||
@ -102,15 +114,16 @@ class InMemoryHandler {
|
||||
*/
|
||||
template <class HandlerFunctor>
|
||||
void Handle(char const* input, std::size_t size, std::string* output, std::size_t sequence_number,
|
||||
int rank, HandlerFunctor const& functor);
|
||||
std::size_t rank, HandlerFunctor const& functor);
|
||||
|
||||
int world_size_{}; /// Number of workers.
|
||||
int received_{}; /// Number of calls received with the current sequence.
|
||||
int sent_{}; /// Number of calls completed with the current sequence.
|
||||
std::string buffer_{}; /// A shared common buffer.
|
||||
uint64_t sequence_number_{}; /// Call sequence number.
|
||||
mutable std::mutex mutex_; /// Lock.
|
||||
mutable std::condition_variable cv_; /// Conditional variable to wait on.
|
||||
std::size_t world_size_{}; /// Number of workers.
|
||||
std::size_t received_{}; /// Number of calls received with the current sequence.
|
||||
std::size_t sent_{}; /// Number of calls completed with the current sequence.
|
||||
std::string buffer_{}; /// A shared common buffer.
|
||||
std::map<std::size_t, std::string_view> aux_{}; /// A shared auxiliary map.
|
||||
uint64_t sequence_number_{}; /// Call sequence number.
|
||||
mutable std::mutex mutex_; /// Lock.
|
||||
mutable std::condition_variable cv_; /// Conditional variable to wait on.
|
||||
};
|
||||
|
||||
} // namespace collective
|
||||
|
||||
@ -17,10 +17,11 @@ class NoOpCommunicator : public Communicator {
|
||||
NoOpCommunicator() : Communicator(1, 0) {}
|
||||
bool IsDistributed() const override { return false; }
|
||||
bool IsFederated() const override { return false; }
|
||||
void AllGather(void *, std::size_t) override {}
|
||||
std::string AllGather(std::string_view) override { return {}; }
|
||||
std::string AllGatherV(std::string_view) override { return {}; }
|
||||
void AllReduce(void *, std::size_t, DataType, Operation) override {}
|
||||
void Broadcast(void *, std::size_t, int) override {}
|
||||
std::string GetProcessorName() override { return ""; }
|
||||
std::string GetProcessorName() override { return {}; }
|
||||
void Print(const std::string &message) override { LOG(CONSOLE) << message; }
|
||||
|
||||
protected:
|
||||
|
||||
@ -7,6 +7,7 @@
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "communicator-inl.h"
|
||||
#include "communicator.h"
|
||||
#include "xgboost/json.h"
|
||||
|
||||
@ -55,10 +56,27 @@ class RabitCommunicator : public Communicator {
|
||||
|
||||
bool IsFederated() const override { return false; }
|
||||
|
||||
void AllGather(void *send_receive_buffer, std::size_t size) override {
|
||||
auto const per_rank = size / GetWorldSize();
|
||||
std::string AllGather(std::string_view input) override {
|
||||
auto const per_rank = input.size();
|
||||
auto const total_size = per_rank * GetWorldSize();
|
||||
auto const index = per_rank * GetRank();
|
||||
rabit::Allgather(static_cast<char *>(send_receive_buffer), size, index, per_rank, per_rank);
|
||||
std::string result(total_size, '\0');
|
||||
rabit::Allgather(result.data(), total_size, index, per_rank, per_rank);
|
||||
return result;
|
||||
}
|
||||
|
||||
std::string AllGatherV(std::string_view input) override {
|
||||
auto const size_node_slice = input.size();
|
||||
auto const all_sizes = collective::Allgather(size_node_slice);
|
||||
auto const total_size = std::accumulate(all_sizes.cbegin(), all_sizes.cend(), 0ul);
|
||||
auto const begin_index =
|
||||
std::accumulate(all_sizes.cbegin(), all_sizes.cbegin() + GetRank(), 0ul);
|
||||
auto const size_prev_slice = GetRank() == 0 ? 0 : all_sizes[GetRank() - 1];
|
||||
|
||||
std::string result(total_size, '\0');
|
||||
result.replace(begin_index, size_node_slice, input);
|
||||
rabit::Allgather(result.data(), total_size, begin_index, size_node_slice, size_prev_slice);
|
||||
return result;
|
||||
}
|
||||
|
||||
void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
|
||||
|
||||
@ -76,10 +76,8 @@ DMatrix* SimpleDMatrix::SliceCol(int num_slices, int slice_id) {
|
||||
|
||||
void SimpleDMatrix::ReindexFeatures(Context const* ctx) {
|
||||
if (info_.IsColumnSplit()) {
|
||||
std::vector<uint64_t> buffer(collective::GetWorldSize());
|
||||
buffer[collective::GetRank()] = info_.num_col_;
|
||||
collective::Allgather(buffer.data(), buffer.size() * sizeof(uint64_t));
|
||||
auto offset = std::accumulate(buffer.cbegin(), buffer.cbegin() + collective::GetRank(), 0ul);
|
||||
auto const cols = collective::Allgather(info_.num_col_);
|
||||
auto const offset = std::accumulate(cols.cbegin(), cols.cbegin() + collective::GetRank(), 0ul);
|
||||
if (offset == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
@ -292,20 +292,19 @@ class HistEvaluator {
|
||||
*/
|
||||
std::vector<CPUExpandEntry> Allgather(std::vector<CPUExpandEntry> const &entries) {
|
||||
auto const world = collective::GetWorldSize();
|
||||
auto const rank = collective::GetRank();
|
||||
auto const num_entries = entries.size();
|
||||
|
||||
// First, gather all the primitive fields.
|
||||
std::vector<CPUExpandEntry> all_entries(num_entries * world);
|
||||
std::vector<CPUExpandEntry> local_entries(num_entries);
|
||||
std::vector<uint32_t> cat_bits;
|
||||
std::vector<std::size_t> cat_bits_sizes;
|
||||
for (std::size_t i = 0; i < num_entries; i++) {
|
||||
all_entries[num_entries * rank + i].CopyAndCollect(entries[i], &cat_bits, &cat_bits_sizes);
|
||||
local_entries[i].CopyAndCollect(entries[i], &cat_bits, &cat_bits_sizes);
|
||||
}
|
||||
collective::Allgather(all_entries.data(), all_entries.size() * sizeof(CPUExpandEntry));
|
||||
auto all_entries = collective::Allgather(local_entries);
|
||||
|
||||
// Gather all the cat_bits.
|
||||
auto gathered = collective::AllgatherV(cat_bits, cat_bits_sizes);
|
||||
auto gathered = collective::SpecialAllgatherV(cat_bits, cat_bits_sizes);
|
||||
|
||||
common::ParallelFor(num_entries * world, ctx_->Threads(), [&] (auto i) {
|
||||
// Copy the cat_bits back into all expand entries.
|
||||
@ -579,28 +578,24 @@ class HistMultiEvaluator {
|
||||
*/
|
||||
std::vector<MultiExpandEntry> Allgather(std::vector<MultiExpandEntry> const &entries) {
|
||||
auto const world = collective::GetWorldSize();
|
||||
auto const rank = collective::GetRank();
|
||||
auto const num_entries = entries.size();
|
||||
|
||||
// First, gather all the primitive fields.
|
||||
std::vector<MultiExpandEntry> all_entries(num_entries * world);
|
||||
std::vector<MultiExpandEntry> local_entries(num_entries);
|
||||
std::vector<uint32_t> cat_bits;
|
||||
std::vector<std::size_t> cat_bits_sizes;
|
||||
std::vector<GradientPairPrecise> gradients;
|
||||
for (std::size_t i = 0; i < num_entries; i++) {
|
||||
all_entries[num_entries * rank + i].CopyAndCollect(entries[i], &cat_bits, &cat_bits_sizes,
|
||||
&gradients);
|
||||
local_entries[i].CopyAndCollect(entries[i], &cat_bits, &cat_bits_sizes, &gradients);
|
||||
}
|
||||
collective::Allgather(all_entries.data(), all_entries.size() * sizeof(MultiExpandEntry));
|
||||
auto all_entries = collective::Allgather(local_entries);
|
||||
|
||||
// Gather all the cat_bits.
|
||||
auto gathered_cat_bits = collective::AllgatherV(cat_bits, cat_bits_sizes);
|
||||
auto gathered_cat_bits = collective::SpecialAllgatherV(cat_bits, cat_bits_sizes);
|
||||
|
||||
// Gather all the gradients.
|
||||
auto const num_gradients = gradients.size();
|
||||
std::vector<GradientPairPrecise> all_gradients(num_gradients * world);
|
||||
std::copy_n(gradients.cbegin(), num_gradients, all_gradients.begin() + num_gradients * rank);
|
||||
collective::Allgather(all_gradients.data(), all_gradients.size() * sizeof(GradientPairPrecise));
|
||||
auto const all_gradients = collective::Allgather(gradients);
|
||||
|
||||
auto const total_entries = num_entries * world;
|
||||
auto const gradients_per_entry = num_gradients / num_entries;
|
||||
|
||||
@ -29,6 +29,11 @@ class InMemoryCommunicatorTest : public ::testing::Test {
|
||||
VerifyAllgather(comm, rank);
|
||||
}
|
||||
|
||||
static void AllgatherV(int rank) {
|
||||
InMemoryCommunicator comm{kWorldSize, rank};
|
||||
VerifyAllgatherV(comm, rank);
|
||||
}
|
||||
|
||||
static void AllreduceMax(int rank) {
|
||||
InMemoryCommunicator comm{kWorldSize, rank};
|
||||
VerifyAllreduceMax(comm, rank);
|
||||
@ -80,14 +85,19 @@ class InMemoryCommunicatorTest : public ::testing::Test {
|
||||
|
||||
protected:
|
||||
static void VerifyAllgather(InMemoryCommunicator &comm, int rank) {
|
||||
char buffer[kWorldSize] = {'a', 'b', 'c'};
|
||||
buffer[rank] = '0' + rank;
|
||||
comm.AllGather(buffer, kWorldSize);
|
||||
std::string input{static_cast<char>('0' + rank)};
|
||||
auto output = comm.AllGather(input);
|
||||
for (auto i = 0; i < kWorldSize; i++) {
|
||||
EXPECT_EQ(buffer[i], '0' + i);
|
||||
EXPECT_EQ(output[i], static_cast<char>('0' + i));
|
||||
}
|
||||
}
|
||||
|
||||
static void VerifyAllgatherV(InMemoryCommunicator &comm, int rank) {
|
||||
std::vector<std::string_view> inputs{"a", "bb", "ccc"};
|
||||
auto output = comm.AllGatherV(inputs[rank]);
|
||||
EXPECT_EQ(output, "abbccc");
|
||||
}
|
||||
|
||||
static void VerifyAllreduceMax(InMemoryCommunicator &comm, int rank) {
|
||||
int buffer[] = {1 + rank, 2 + rank, 3 + rank, 4 + rank, 5 + rank};
|
||||
comm.AllReduce(buffer, sizeof(buffer) / sizeof(buffer[0]), DataType::kInt32, Operation::kMax);
|
||||
@ -205,6 +215,8 @@ TEST(InMemoryCommunicatorSimpleTest, IsDistributed) {
|
||||
|
||||
TEST_F(InMemoryCommunicatorTest, Allgather) { Verify(&Allgather); }
|
||||
|
||||
TEST_F(InMemoryCommunicatorTest, AllgatherV) { Verify(&AllgatherV); }
|
||||
|
||||
TEST_F(InMemoryCommunicatorTest, AllreduceMax) { Verify(&AllreduceMax); }
|
||||
|
||||
TEST_F(InMemoryCommunicatorTest, AllreduceMin) { Verify(&AllreduceMin); }
|
||||
|
||||
@ -23,7 +23,7 @@ class ServerForTest {
|
||||
std::unique_ptr<grpc::Server> server_;
|
||||
|
||||
public:
|
||||
explicit ServerForTest(std::int32_t world_size) {
|
||||
explicit ServerForTest(std::size_t world_size) {
|
||||
server_thread_.reset(new std::thread([this, world_size] {
|
||||
grpc::ServerBuilder builder;
|
||||
xgboost::federated::FederatedService service{world_size};
|
||||
|
||||
@ -19,6 +19,11 @@ class FederatedCommunicatorTest : public BaseFederatedTest {
|
||||
CheckAllgather(comm, rank);
|
||||
}
|
||||
|
||||
static void VerifyAllgatherV(int rank, const std::string &server_address) {
|
||||
FederatedCommunicator comm{kWorldSize, rank, server_address};
|
||||
CheckAllgatherV(comm, rank);
|
||||
}
|
||||
|
||||
static void VerifyAllreduce(int rank, const std::string &server_address) {
|
||||
FederatedCommunicator comm{kWorldSize, rank, server_address};
|
||||
CheckAllreduce(comm);
|
||||
@ -31,14 +36,19 @@ class FederatedCommunicatorTest : public BaseFederatedTest {
|
||||
|
||||
protected:
|
||||
static void CheckAllgather(FederatedCommunicator &comm, int rank) {
|
||||
int buffer[kWorldSize] = {0, 0};
|
||||
buffer[rank] = rank;
|
||||
comm.AllGather(buffer, sizeof(buffer));
|
||||
std::string input{static_cast<char>('0' + rank)};
|
||||
auto output = comm.AllGather(input);
|
||||
for (auto i = 0; i < kWorldSize; i++) {
|
||||
EXPECT_EQ(buffer[i], i);
|
||||
EXPECT_EQ(output[i], static_cast<char>('0' + i));
|
||||
}
|
||||
}
|
||||
|
||||
static void CheckAllgatherV(FederatedCommunicator &comm, int rank) {
|
||||
std::vector<std::string_view> inputs{"Federated", " Learning!!!"};
|
||||
auto output = comm.AllGatherV(inputs[rank]);
|
||||
EXPECT_EQ(output, "Federated Learning!!!");
|
||||
}
|
||||
|
||||
static void CheckAllreduce(FederatedCommunicator &comm) {
|
||||
int buffer[] = {1, 2, 3, 4, 5};
|
||||
comm.AllReduce(buffer, sizeof(buffer) / sizeof(buffer[0]), DataType::kInt32, Operation::kSum);
|
||||
@ -119,6 +129,16 @@ TEST_F(FederatedCommunicatorTest, Allgather) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(FederatedCommunicatorTest, AllgatherV) {
|
||||
std::vector<std::thread> threads;
|
||||
for (auto rank = 0; rank < kWorldSize; rank++) {
|
||||
threads.emplace_back(&FederatedCommunicatorTest::VerifyAllgatherV, rank, server_->Address());
|
||||
}
|
||||
for (auto &thread : threads) {
|
||||
thread.join();
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(FederatedCommunicatorTest, Allreduce) {
|
||||
std::vector<std::thread> threads;
|
||||
for (auto rank = 0; rank < kWorldSize; rank++) {
|
||||
|
||||
@ -18,6 +18,11 @@ class FederatedServerTest : public BaseFederatedTest {
|
||||
CheckAllgather(client, rank);
|
||||
}
|
||||
|
||||
static void VerifyAllgatherV(int rank, const std::string& server_address) {
|
||||
federated::FederatedClient client{server_address, rank};
|
||||
CheckAllgatherV(client, rank);
|
||||
}
|
||||
|
||||
static void VerifyAllreduce(int rank, const std::string& server_address) {
|
||||
federated::FederatedClient client{server_address, rank};
|
||||
CheckAllreduce(client);
|
||||
@ -39,8 +44,7 @@ class FederatedServerTest : public BaseFederatedTest {
|
||||
|
||||
protected:
|
||||
static void CheckAllgather(federated::FederatedClient& client, int rank) {
|
||||
int data[kWorldSize] = {0, 0};
|
||||
data[rank] = rank;
|
||||
int data[] = {rank};
|
||||
std::string send_buffer(reinterpret_cast<char const*>(data), sizeof(data));
|
||||
auto reply = client.Allgather(send_buffer);
|
||||
auto const* result = reinterpret_cast<int const*>(reply.data());
|
||||
@ -49,6 +53,12 @@ class FederatedServerTest : public BaseFederatedTest {
|
||||
}
|
||||
}
|
||||
|
||||
static void CheckAllgatherV(federated::FederatedClient& client, int rank) {
|
||||
std::vector<std::string_view> inputs{"Hello,", " World!"};
|
||||
auto reply = client.AllgatherV(inputs[rank]);
|
||||
EXPECT_EQ(reply, "Hello, World!");
|
||||
}
|
||||
|
||||
static void CheckAllreduce(federated::FederatedClient& client) {
|
||||
int data[] = {1, 2, 3, 4, 5};
|
||||
std::string send_buffer(reinterpret_cast<char const*>(data), sizeof(data));
|
||||
@ -80,6 +90,16 @@ TEST_F(FederatedServerTest, Allgather) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(FederatedServerTest, AllgatherV) {
|
||||
std::vector<std::thread> threads;
|
||||
for (auto rank = 0; rank < kWorldSize; rank++) {
|
||||
threads.emplace_back(&FederatedServerTest::VerifyAllgatherV, rank, server_->Address());
|
||||
}
|
||||
for (auto& thread : threads) {
|
||||
thread.join();
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(FederatedServerTest, Allreduce) {
|
||||
std::vector<std::thread> threads;
|
||||
for (auto rank = 0; rank < kWorldSize; rank++) {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user