xgboost/src/collective/in_memory_communicator.h
Jiaming Yuan a5a58102e5
Revamp the rabit implementation. (#10112)
This PR replaces the original RABIT implementation with a new one, which has already been partially merged into XGBoost. The new one features:
- Federated learning for both CPU and GPU.
- NCCL.
- More data types.
- A unified interface for all the underlying implementations.
- Improved timeout handling for both tracker and workers.
- Exhausted tests with metrics (fixed a couple of bugs along the way).
- A reusable tracker for Python and JVM packages.
2024-05-20 11:56:23 +08:00

104 lines
3.1 KiB
C++

/*!
* Copyright 2022 XGBoost contributors
*/
#pragma once
#include <xgboost/json.h>
#include <string>
#include "../c_api/c_api_utils.h"
#include "in_memory_handler.h"
namespace xgboost {
namespace collective {
/**
* An in-memory communicator, useful for testing.
*/
class InMemoryCommunicator {
public:
/**
* @brief Create a new communicator based on JSON configuration.
* @param config JSON configuration.
* @return Communicator as specified by the JSON configuration.
*/
static InMemoryCommunicator* Create(Json const& config) {
int world_size{0};
int rank{-1};
// Parse environment variables first.
auto* value = getenv("IN_MEMORY_WORLD_SIZE");
if (value != nullptr) {
world_size = std::stoi(value);
}
value = getenv("IN_MEMORY_RANK");
if (value != nullptr) {
rank = std::stoi(value);
}
// Runtime configuration overrides, optional as users can specify them as env vars.
world_size = static_cast<int>(OptionalArg<Integer>(config, "in_memory_world_size",
static_cast<Integer::Int>(world_size)));
rank = static_cast<int>(
OptionalArg<Integer>(config, "in_memory_rank", static_cast<Integer::Int>(rank)));
if (world_size == 0) {
LOG(FATAL) << "Federated world size must be set.";
}
if (rank == -1) {
LOG(FATAL) << "Federated rank must be set.";
}
return new InMemoryCommunicator(world_size, rank);
}
InMemoryCommunicator(int world_size, int rank) {
handler_.Init(world_size, rank);
}
~InMemoryCommunicator() override { handler_.Shutdown(sequence_number_++, GetRank()); }
bool IsDistributed() const override { return true; }
bool IsFederated() const override { return false; }
std::string AllGather(std::string_view input) override {
std::string output;
handler_.Allgather(input.data(), input.size(), &output, sequence_number_++, GetRank());
return output;
}
std::string AllGatherV(std::string_view input) override {
std::string output;
handler_.AllgatherV(input.data(), input.size(), &output, sequence_number_++, GetRank());
return output;
}
void AllReduce(void* in_out, std::size_t size, DataType data_type, Operation operation) override {
auto const bytes = size * GetTypeSize(data_type);
std::string output;
handler_.Allreduce(static_cast<const char*>(in_out), bytes, &output, sequence_number_++,
GetRank(), data_type, operation);
output.copy(static_cast<char*>(in_out), bytes);
}
void Broadcast(void* in_out, std::size_t size, int root) override {
std::string output;
handler_.Broadcast(static_cast<const char*>(in_out), size, &output, sequence_number_++,
GetRank(), root);
output.copy(static_cast<char*>(in_out), size);
}
std::string GetProcessorName() override { return "rank" + std::to_string(GetRank()); }
void Print(const std::string& message) override { LOG(CONSOLE) << message; }
protected:
void Shutdown() override {}
private:
static InMemoryHandler handler_;
uint64_t sequence_number_{};
};
} // namespace collective
} // namespace xgboost