Common interface for collective communication (#8057)
* implement broadcast for federated communicator * implement allreduce * add communicator factory * add device adapter * add device communicator to factory * add rabit communicator * add rabit communicator to the factory * add nccl device communicator * add synchronize to device communicator * add back print and getprocessorname * add python wrapper and c api * clean up types * fix non-gpu build * try to fix ci * fix std::size_t * portable string compare ignore case * c style size_t * fix lint errors * cross platform setenv * fix memory leak * fix lint errors * address review feedback * add python test for rabit communicator * fix failing gtest * use json to configure communicators * fix lint error * get rid of factories * fix cpu build * fix include * fix python import * don't export collective.py yet * skip collective communicator pytest on windows * add review feedback * update documentation * remove mpi communicator type * fix tests * shutdown the communicator separately Co-authored-by: Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
59
src/collective/communicator.cc
Normal file
59
src/collective/communicator.cc
Normal file
@@ -0,0 +1,59 @@
|
||||
/*!
|
||||
* Copyright 2022 XGBoost contributors
|
||||
*/
|
||||
#include "communicator.h"
|
||||
|
||||
#include "rabit_communicator.h"
|
||||
|
||||
#if defined(XGBOOST_USE_FEDERATED)
|
||||
#include "../../plugin/federated/federated_communicator.h"
|
||||
#endif
|
||||
|
||||
namespace xgboost {
|
||||
namespace collective {
|
||||
|
||||
thread_local std::unique_ptr<Communicator> Communicator::communicator_{};
|
||||
thread_local CommunicatorType Communicator::type_{};
|
||||
|
||||
void Communicator::Init(Json const& config) {
|
||||
if (communicator_) {
|
||||
LOG(FATAL) << "Communicator can only be initialized once.";
|
||||
}
|
||||
|
||||
auto type = GetTypeFromEnv();
|
||||
auto const arg = GetTypeFromConfig(config);
|
||||
if (arg != CommunicatorType::kUnknown) {
|
||||
type = arg;
|
||||
}
|
||||
if (type == CommunicatorType::kUnknown) {
|
||||
// Default to Rabit if unspecified.
|
||||
type = CommunicatorType::kRabit;
|
||||
}
|
||||
type_ = type;
|
||||
switch (type) {
|
||||
case CommunicatorType::kRabit: {
|
||||
communicator_.reset(RabitCommunicator::Create(config));
|
||||
break;
|
||||
}
|
||||
case CommunicatorType::kFederated: {
|
||||
#if defined(XGBOOST_USE_FEDERATED)
|
||||
communicator_.reset(FederatedCommunicator::Create(config));
|
||||
#else
|
||||
LOG(FATAL) << "XGBoost is not compiled with Federated Learning support.";
|
||||
#endif
|
||||
break;
|
||||
}
|
||||
case CommunicatorType::kUnknown:
|
||||
LOG(FATAL) << "Unknown communicator type.";
|
||||
}
|
||||
}
|
||||
|
||||
#ifndef XGBOOST_USE_CUDA
|
||||
void Communicator::Finalize() {
|
||||
communicator_->Shutdown();
|
||||
communicator_.reset(nullptr);
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace collective
|
||||
} // namespace xgboost
|
||||
41
src/collective/communicator.cu
Normal file
41
src/collective/communicator.cu
Normal file
@@ -0,0 +1,41 @@
|
||||
/*!
|
||||
* Copyright 2022 XGBoost contributors
|
||||
*/
|
||||
#include "communicator.h"
|
||||
#include "device_communicator.cuh"
|
||||
#include "device_communicator_adapter.cuh"
|
||||
#ifdef XGBOOST_USE_NCCL
|
||||
#include "nccl_device_communicator.cuh"
|
||||
#endif
|
||||
|
||||
namespace xgboost {
|
||||
namespace collective {
|
||||
|
||||
thread_local int Communicator::device_ordinal_{-1};
|
||||
thread_local std::unique_ptr<DeviceCommunicator> Communicator::device_communicator_{};
|
||||
|
||||
void Communicator::Finalize() {
|
||||
communicator_->Shutdown();
|
||||
communicator_.reset(nullptr);
|
||||
device_ordinal_ = -1;
|
||||
device_communicator_.reset(nullptr);
|
||||
}
|
||||
|
||||
DeviceCommunicator* Communicator::GetDevice(int device_ordinal) {
|
||||
if (!device_communicator_ || device_ordinal_ != device_ordinal) {
|
||||
device_ordinal_ = device_ordinal;
|
||||
#ifdef XGBOOST_USE_NCCL
|
||||
if (type_ != CommunicatorType::kFederated) {
|
||||
device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal, Get()));
|
||||
} else {
|
||||
device_communicator_.reset(new DeviceCommunicatorAdapter(device_ordinal, Get()));
|
||||
}
|
||||
#else
|
||||
device_communicator_.reset(new DeviceCommunicatorAdapter(device_ordinal, Get()));
|
||||
#endif
|
||||
}
|
||||
return device_communicator_.get();
|
||||
}
|
||||
|
||||
} // namespace collective
|
||||
} // namespace xgboost
|
||||
218
src/collective/communicator.h
Normal file
218
src/collective/communicator.h
Normal file
@@ -0,0 +1,218 @@
|
||||
/*!
|
||||
* Copyright 2022 XGBoost contributors
|
||||
*/
|
||||
#pragma once
|
||||
#include <xgboost/json.h>
|
||||
#include <xgboost/logging.h>
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
namespace xgboost {
|
||||
namespace collective {
|
||||
|
||||
/** @brief Defines the integral and floating data types. */
|
||||
enum class DataType {
|
||||
kInt8 = 0,
|
||||
kUInt8 = 1,
|
||||
kInt32 = 2,
|
||||
kUInt32 = 3,
|
||||
kInt64 = 4,
|
||||
kUInt64 = 5,
|
||||
kFloat = 6,
|
||||
kDouble = 7
|
||||
};
|
||||
|
||||
/** @brief Get the size of the data type. */
|
||||
inline std::size_t GetTypeSize(DataType data_type) {
|
||||
std::size_t size{0};
|
||||
switch (data_type) {
|
||||
case DataType::kInt8:
|
||||
size = sizeof(std::int8_t);
|
||||
break;
|
||||
case DataType::kUInt8:
|
||||
size = sizeof(std::uint8_t);
|
||||
break;
|
||||
case DataType::kInt32:
|
||||
size = sizeof(std::int32_t);
|
||||
break;
|
||||
case DataType::kUInt32:
|
||||
size = sizeof(std::uint32_t);
|
||||
break;
|
||||
case DataType::kInt64:
|
||||
size = sizeof(std::int64_t);
|
||||
break;
|
||||
case DataType::kUInt64:
|
||||
size = sizeof(std::uint64_t);
|
||||
break;
|
||||
case DataType::kFloat:
|
||||
size = sizeof(float);
|
||||
break;
|
||||
case DataType::kDouble:
|
||||
size = sizeof(double);
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "Unknown data type.";
|
||||
}
|
||||
return size;
|
||||
}
|
||||
|
||||
/** @brief Defines the reduction operation. */
|
||||
enum class Operation { kMax = 0, kMin = 1, kSum = 2 };
|
||||
|
||||
class DeviceCommunicator;
|
||||
|
||||
enum class CommunicatorType { kUnknown, kRabit, kFederated };
|
||||
|
||||
/** \brief Case-insensitive string comparison. */
|
||||
inline int CompareStringsCaseInsensitive(const char *s1, const char *s2) {
|
||||
#ifdef _MSC_VER
|
||||
return _stricmp(s1, s2);
|
||||
#else // _MSC_VER
|
||||
return strcasecmp(s1, s2);
|
||||
#endif // _MSC_VER
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief A communicator class that handles collective communication.
|
||||
*/
|
||||
class Communicator {
|
||||
public:
|
||||
/**
|
||||
* @brief Initialize the communicator. This can only be done once.
|
||||
*
|
||||
* @param config JSON configuration for the communicator.
|
||||
*/
|
||||
static void Init(Json const &config);
|
||||
|
||||
/** @brief Finalize the communicator. */
|
||||
static void Finalize();
|
||||
|
||||
/** @brief Get the communicator instance. */
|
||||
static Communicator *Get() { return communicator_.get(); }
|
||||
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
/**
|
||||
* @brief Get the device communicator.
|
||||
*
|
||||
* @param device_ordinal ID of the device.
|
||||
* @return An instance of device communicator.
|
||||
*/
|
||||
static DeviceCommunicator *GetDevice(int device_ordinal);
|
||||
#endif
|
||||
|
||||
virtual ~Communicator() = default;
|
||||
|
||||
/** @brief Get the total number of processes. */
|
||||
int GetWorldSize() const { return world_size_; }
|
||||
|
||||
/** @brief Get the rank of the current processes. */
|
||||
int GetRank() const { return rank_; }
|
||||
|
||||
/** @brief Whether the communicator is running in distributed mode. */
|
||||
virtual bool IsDistributed() const = 0;
|
||||
|
||||
/**
|
||||
* @brief Combines values from all processes and distributes the result back to all processes.
|
||||
*
|
||||
* @param send_receive_buffer Buffer storing the data.
|
||||
* @param count Number of elements in the buffer.
|
||||
* @param data_type Data type stored in the buffer.
|
||||
* @param op The operation to perform.
|
||||
*/
|
||||
virtual void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
|
||||
Operation op) = 0;
|
||||
|
||||
/**
|
||||
* @brief Broadcasts a message from the process with rank `root` to all other processes of the
|
||||
* group.
|
||||
*
|
||||
* @param send_receive_buffer Buffer storing the data.
|
||||
* @param size Size of the data in bytes.
|
||||
* @param root Rank of broadcast root.
|
||||
*/
|
||||
virtual void Broadcast(void *send_receive_buffer, std::size_t size, int root) = 0;
|
||||
|
||||
/**
|
||||
* @brief Gets the name of the processor.
|
||||
*/
|
||||
virtual std::string GetProcessorName() = 0;
|
||||
|
||||
/**
|
||||
* @brief Prints the message.
|
||||
*/
|
||||
virtual void Print(std::string const &message) = 0;
|
||||
|
||||
/** @brief Get the communicator type from environment variables. Visible for testing. */
|
||||
static CommunicatorType GetTypeFromEnv() {
|
||||
auto *env = std::getenv("XGBOOST_COMMUNICATOR");
|
||||
if (env != nullptr) {
|
||||
return StringToType(env);
|
||||
} else {
|
||||
return CommunicatorType::kUnknown;
|
||||
}
|
||||
}
|
||||
|
||||
/** @brief Get the communicator type from runtime configuration. Visible for testing. */
|
||||
static CommunicatorType GetTypeFromConfig(Json const &config) {
|
||||
auto const &j_upper = config["XGBOOST_COMMUNICATOR"];
|
||||
if (IsA<String const>(j_upper)) {
|
||||
return StringToType(get<String const>(j_upper).c_str());
|
||||
}
|
||||
auto const &j_lower = config["xgboost_communicator"];
|
||||
if (IsA<String const>(j_lower)) {
|
||||
return StringToType(get<String const>(j_lower).c_str());
|
||||
}
|
||||
return CommunicatorType::kUnknown;
|
||||
}
|
||||
|
||||
protected:
|
||||
/**
|
||||
* @brief Construct a new communicator.
|
||||
*
|
||||
* @param world_size Total number of processes.
|
||||
* @param rank Rank of the current process.
|
||||
*/
|
||||
Communicator(int world_size, int rank) : world_size_(world_size), rank_(rank) {
|
||||
if (world_size < 1) {
|
||||
LOG(FATAL) << "World size " << world_size << " is less than 1.";
|
||||
}
|
||||
if (rank < 0) {
|
||||
LOG(FATAL) << "Rank " << rank << " is less than 0.";
|
||||
}
|
||||
if (rank >= world_size) {
|
||||
LOG(FATAL) << "Rank " << rank << " is greater than world_size - 1: " << world_size - 1 << ".";
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Shuts down the communicator.
|
||||
*/
|
||||
virtual void Shutdown() = 0;
|
||||
|
||||
private:
|
||||
static CommunicatorType StringToType(char const *str) {
|
||||
CommunicatorType result = CommunicatorType::kUnknown;
|
||||
if (!CompareStringsCaseInsensitive("rabit", str)) {
|
||||
result = CommunicatorType::kRabit;
|
||||
} else if (!CompareStringsCaseInsensitive("federated", str)) {
|
||||
result = CommunicatorType::kFederated;
|
||||
} else {
|
||||
LOG(FATAL) << "Unknown communicator type " << str;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
static thread_local std::unique_ptr<Communicator> communicator_;
|
||||
static thread_local CommunicatorType type_;
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
static thread_local int device_ordinal_;
|
||||
static thread_local std::unique_ptr<DeviceCommunicator> device_communicator_;
|
||||
#endif
|
||||
|
||||
int const world_size_;
|
||||
int const rank_;
|
||||
};
|
||||
|
||||
} // namespace collective
|
||||
} // namespace xgboost
|
||||
42
src/collective/device_communicator.cuh
Normal file
42
src/collective/device_communicator.cuh
Normal file
@@ -0,0 +1,42 @@
|
||||
/*!
|
||||
* Copyright 2022 XGBoost contributors
|
||||
*/
|
||||
#pragma once
|
||||
#include <vector>
|
||||
|
||||
#include "../common/device_helpers.cuh"
|
||||
|
||||
namespace xgboost {
|
||||
namespace collective {
|
||||
|
||||
/**
|
||||
* @brief Collective communicator for device buffers.
|
||||
*/
|
||||
class DeviceCommunicator {
|
||||
public:
|
||||
virtual ~DeviceCommunicator() = default;
|
||||
|
||||
/**
|
||||
* @brief Sum values from all processes and distribute the result back to all processes.
|
||||
* @param send_receive_buffer Buffer storing the data.
|
||||
* @param count Number of elements in the buffer.
|
||||
*/
|
||||
virtual void AllReduceSum(double *send_receive_buffer, int count) = 0;
|
||||
|
||||
/**
|
||||
* @brief Gather variable-length values from all processes.
|
||||
* @param send_buffer Buffer storing the input data.
|
||||
* @param length_bytes Length in bytes of the input data.
|
||||
* @param segments Size of each segment.
|
||||
* @param receive_buffer Buffer storing the output data.
|
||||
*/
|
||||
virtual void AllGatherV(void const *send_buffer, size_t length_bytes,
|
||||
std::vector<size_t> *segments,
|
||||
dh::caching_device_vector<char> *receive_buffer) = 0;
|
||||
|
||||
/** @brief Synchronize device operations. */
|
||||
virtual void Synchronize() = 0;
|
||||
};
|
||||
|
||||
} // namespace collective
|
||||
} // namespace xgboost
|
||||
76
src/collective/device_communicator_adapter.cuh
Normal file
76
src/collective/device_communicator_adapter.cuh
Normal file
@@ -0,0 +1,76 @@
|
||||
/*!
|
||||
* Copyright 2022 XGBoost contributors
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "communicator.h"
|
||||
#include "device_communicator.cuh"
|
||||
|
||||
namespace xgboost {
|
||||
namespace collective {
|
||||
|
||||
class DeviceCommunicatorAdapter : public DeviceCommunicator {
|
||||
public:
|
||||
DeviceCommunicatorAdapter(int device_ordinal, Communicator *communicator)
|
||||
: device_ordinal_{device_ordinal}, communicator_{communicator} {
|
||||
if (device_ordinal_ < 0) {
|
||||
LOG(FATAL) << "Invalid device ordinal: " << device_ordinal_;
|
||||
}
|
||||
if (communicator_ == nullptr) {
|
||||
LOG(FATAL) << "Communicator cannot be null.";
|
||||
}
|
||||
}
|
||||
|
||||
~DeviceCommunicatorAdapter() override = default;
|
||||
|
||||
void AllReduceSum(double *send_receive_buffer, int count) override {
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
auto size = count * sizeof(double);
|
||||
host_buffer_.reserve(size);
|
||||
dh::safe_cuda(cudaMemcpy(host_buffer_.data(), send_receive_buffer, size, cudaMemcpyDefault));
|
||||
communicator_->AllReduce(host_buffer_.data(), count, DataType::kDouble, Operation::kSum);
|
||||
dh::safe_cuda(cudaMemcpy(send_receive_buffer, host_buffer_.data(), size, cudaMemcpyDefault));
|
||||
}
|
||||
|
||||
void AllGatherV(void const *send_buffer, size_t length_bytes, std::vector<std::size_t> *segments,
|
||||
dh::caching_device_vector<char> *receive_buffer) override {
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
int const world_size = communicator_->GetWorldSize();
|
||||
int const rank = communicator_->GetRank();
|
||||
|
||||
segments->clear();
|
||||
segments->resize(world_size, 0);
|
||||
segments->at(rank) = length_bytes;
|
||||
communicator_->AllReduce(segments->data(), segments->size(), DataType::kUInt64,
|
||||
Operation::kMax);
|
||||
auto total_bytes = std::accumulate(segments->cbegin(), segments->cend(), 0UL);
|
||||
receive_buffer->resize(total_bytes);
|
||||
|
||||
host_buffer_.reserve(total_bytes);
|
||||
size_t offset = 0;
|
||||
for (int32_t i = 0; i < world_size; ++i) {
|
||||
size_t as_bytes = segments->at(i);
|
||||
if (i == rank) {
|
||||
dh::safe_cuda(cudaMemcpy(host_buffer_.data() + offset, send_buffer, segments->at(rank),
|
||||
cudaMemcpyDefault));
|
||||
}
|
||||
communicator_->Broadcast(host_buffer_.data() + offset, as_bytes, i);
|
||||
offset += as_bytes;
|
||||
}
|
||||
dh::safe_cuda(cudaMemcpy(receive_buffer->data().get(), host_buffer_.data(), total_bytes,
|
||||
cudaMemcpyDefault));
|
||||
}
|
||||
|
||||
void Synchronize() override {
|
||||
// Noop.
|
||||
}
|
||||
|
||||
private:
|
||||
int const device_ordinal_;
|
||||
Communicator *communicator_;
|
||||
/// Host buffer used to call communicator functions.
|
||||
std::vector<char> host_buffer_{};
|
||||
};
|
||||
|
||||
} // namespace collective
|
||||
} // namespace xgboost
|
||||
149
src/collective/nccl_device_communicator.cuh
Normal file
149
src/collective/nccl_device_communicator.cuh
Normal file
@@ -0,0 +1,149 @@
|
||||
/*!
|
||||
* Copyright 2022 XGBoost contributors
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "../common/device_helpers.cuh"
|
||||
#include "communicator.h"
|
||||
#include "device_communicator.cuh"
|
||||
|
||||
namespace xgboost {
|
||||
namespace collective {
|
||||
|
||||
class NcclDeviceCommunicator : public DeviceCommunicator {
|
||||
public:
|
||||
NcclDeviceCommunicator(int device_ordinal, Communicator *communicator)
|
||||
: device_ordinal_{device_ordinal}, communicator_{communicator} {
|
||||
if (device_ordinal_ < 0) {
|
||||
LOG(FATAL) << "Invalid device ordinal: " << device_ordinal_;
|
||||
}
|
||||
if (communicator_ == nullptr) {
|
||||
LOG(FATAL) << "Communicator cannot be null.";
|
||||
}
|
||||
|
||||
int32_t const rank = communicator_->GetRank();
|
||||
int32_t const world = communicator_->GetWorldSize();
|
||||
|
||||
std::vector<uint64_t> uuids(world * kUuidLength, 0);
|
||||
auto s_uuid = xgboost::common::Span<uint64_t>{uuids.data(), uuids.size()};
|
||||
auto s_this_uuid = s_uuid.subspan(rank * kUuidLength, kUuidLength);
|
||||
GetCudaUUID(s_this_uuid);
|
||||
|
||||
// TODO(rongou): replace this with allgather.
|
||||
communicator_->AllReduce(uuids.data(), uuids.size(), DataType::kUInt64, Operation::kSum);
|
||||
|
||||
std::vector<xgboost::common::Span<uint64_t, kUuidLength>> converted(world);
|
||||
size_t j = 0;
|
||||
for (size_t i = 0; i < uuids.size(); i += kUuidLength) {
|
||||
converted[j] = xgboost::common::Span<uint64_t, kUuidLength>{uuids.data() + i, kUuidLength};
|
||||
j++;
|
||||
}
|
||||
|
||||
auto iter = std::unique(converted.begin(), converted.end());
|
||||
auto n_uniques = std::distance(converted.begin(), iter);
|
||||
|
||||
CHECK_EQ(n_uniques, world)
|
||||
<< "Multiple processes within communication group running on same CUDA "
|
||||
<< "device is not supported. " << PrintUUID(s_this_uuid) << "\n";
|
||||
|
||||
nccl_unique_id_ = GetUniqueId();
|
||||
dh::safe_nccl(ncclCommInitRank(&nccl_comm_, world, nccl_unique_id_, rank));
|
||||
dh::safe_cuda(cudaStreamCreate(&cuda_stream_));
|
||||
}
|
||||
|
||||
~NcclDeviceCommunicator() override {
|
||||
dh::safe_cuda(cudaStreamDestroy(cuda_stream_));
|
||||
ncclCommDestroy(nccl_comm_);
|
||||
if (xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug)) {
|
||||
LOG(CONSOLE) << "======== NCCL Statistics========";
|
||||
LOG(CONSOLE) << "AllReduce calls: " << allreduce_calls_;
|
||||
LOG(CONSOLE) << "AllReduce total MiB communicated: " << allreduce_bytes_ / 1048576;
|
||||
}
|
||||
}
|
||||
|
||||
void AllReduceSum(double *send_receive_buffer, int count) override {
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
dh::safe_nccl(ncclAllReduce(send_receive_buffer, send_receive_buffer, count, ncclDouble,
|
||||
ncclSum, nccl_comm_, cuda_stream_));
|
||||
allreduce_bytes_ += count * sizeof(double);
|
||||
allreduce_calls_ += 1;
|
||||
}
|
||||
|
||||
void AllGatherV(void const *send_buffer, size_t length_bytes, std::vector<std::size_t> *segments,
|
||||
dh::caching_device_vector<char> *receive_buffer) override {
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
int const world_size = communicator_->GetWorldSize();
|
||||
int const rank = communicator_->GetRank();
|
||||
|
||||
segments->clear();
|
||||
segments->resize(world_size, 0);
|
||||
segments->at(rank) = length_bytes;
|
||||
communicator_->AllReduce(segments->data(), segments->size(), DataType::kUInt64,
|
||||
Operation::kMax);
|
||||
auto total_bytes = std::accumulate(segments->cbegin(), segments->cend(), 0UL);
|
||||
receive_buffer->resize(total_bytes);
|
||||
|
||||
size_t offset = 0;
|
||||
dh::safe_nccl(ncclGroupStart());
|
||||
for (int32_t i = 0; i < world_size; ++i) {
|
||||
size_t as_bytes = segments->at(i);
|
||||
dh::safe_nccl(ncclBroadcast(send_buffer, receive_buffer->data().get() + offset, as_bytes,
|
||||
ncclChar, i, nccl_comm_, cuda_stream_));
|
||||
offset += as_bytes;
|
||||
}
|
||||
dh::safe_nccl(ncclGroupEnd());
|
||||
}
|
||||
|
||||
void Synchronize() override {
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
dh::safe_cuda(cudaStreamSynchronize(cuda_stream_));
|
||||
}
|
||||
|
||||
private:
|
||||
static constexpr std::size_t kUuidLength =
|
||||
sizeof(std::declval<cudaDeviceProp>().uuid) / sizeof(uint64_t);
|
||||
|
||||
void GetCudaUUID(xgboost::common::Span<uint64_t, kUuidLength> const &uuid) const {
|
||||
cudaDeviceProp prob{};
|
||||
dh::safe_cuda(cudaGetDeviceProperties(&prob, device_ordinal_));
|
||||
std::memcpy(uuid.data(), static_cast<void *>(&(prob.uuid)), sizeof(prob.uuid));
|
||||
}
|
||||
|
||||
static std::string PrintUUID(xgboost::common::Span<uint64_t, kUuidLength> const &uuid) {
|
||||
std::stringstream ss;
|
||||
for (auto v : uuid) {
|
||||
ss << std::hex << v;
|
||||
}
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
/**
|
||||
* \fn ncclUniqueId GetUniqueId()
|
||||
*
|
||||
* \brief Gets the Unique ID from NCCL to be used in setting up interprocess
|
||||
* communication
|
||||
*
|
||||
* \return the Unique ID
|
||||
*/
|
||||
ncclUniqueId GetUniqueId() {
|
||||
static const int kRootRank = 0;
|
||||
ncclUniqueId id;
|
||||
if (communicator_->GetRank() == kRootRank) {
|
||||
dh::safe_nccl(ncclGetUniqueId(&id));
|
||||
}
|
||||
communicator_->Broadcast(static_cast<void *>(&id), sizeof(ncclUniqueId),
|
||||
static_cast<int>(kRootRank));
|
||||
return id;
|
||||
}
|
||||
|
||||
int const device_ordinal_;
|
||||
Communicator *communicator_;
|
||||
ncclComm_t nccl_comm_{};
|
||||
cudaStream_t cuda_stream_{};
|
||||
ncclUniqueId nccl_unique_id_{};
|
||||
size_t allreduce_bytes_{0}; // Keep statistics of the number of bytes communicated.
|
||||
size_t allreduce_calls_{0}; // Keep statistics of the number of reduce calls.
|
||||
};
|
||||
|
||||
} // namespace collective
|
||||
} // namespace xgboost
|
||||
120
src/collective/rabit_communicator.h
Normal file
120
src/collective/rabit_communicator.h
Normal file
@@ -0,0 +1,120 @@
|
||||
/*!
|
||||
* Copyright 2022 XGBoost contributors
|
||||
*/
|
||||
#pragma once
|
||||
#include <rabit/rabit.h>
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "communicator.h"
|
||||
#include "xgboost/json.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace collective {
|
||||
|
||||
class RabitCommunicator : public Communicator {
|
||||
public:
|
||||
static Communicator *Create(Json const &config) {
|
||||
std::vector<std::string> args_str;
|
||||
for (auto &items : get<Object const>(config)) {
|
||||
switch (items.second.GetValue().Type()) {
|
||||
case xgboost::Value::ValueKind::kString: {
|
||||
args_str.push_back(items.first + "=" + get<String const>(items.second));
|
||||
break;
|
||||
}
|
||||
case xgboost::Value::ValueKind::kInteger: {
|
||||
args_str.push_back(items.first + "=" + std::to_string(get<Integer const>(items.second)));
|
||||
break;
|
||||
}
|
||||
case xgboost::Value::ValueKind::kBoolean: {
|
||||
if (get<Boolean const>(items.second)) {
|
||||
args_str.push_back(items.first + "=1");
|
||||
} else {
|
||||
args_str.push_back(items.first + "=0");
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
std::vector<char *> args;
|
||||
for (auto &key_value : args_str) {
|
||||
args.push_back(&key_value[0]);
|
||||
}
|
||||
if (!rabit::Init(static_cast<int>(args.size()), &args[0])) {
|
||||
LOG(FATAL) << "Failed to initialize Rabit";
|
||||
}
|
||||
return new RabitCommunicator(rabit::GetWorldSize(), rabit::GetRank());
|
||||
}
|
||||
|
||||
RabitCommunicator(int world_size, int rank) : Communicator(world_size, rank) {}
|
||||
|
||||
bool IsDistributed() const override { return rabit::IsDistributed(); }
|
||||
|
||||
void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
|
||||
Operation op) override {
|
||||
switch (data_type) {
|
||||
case DataType::kInt8:
|
||||
DoAllReduce<char>(send_receive_buffer, count, op);
|
||||
break;
|
||||
case DataType::kUInt8:
|
||||
DoAllReduce<unsigned char>(send_receive_buffer, count, op);
|
||||
break;
|
||||
case DataType::kInt32:
|
||||
DoAllReduce<std::int32_t>(send_receive_buffer, count, op);
|
||||
break;
|
||||
case DataType::kUInt32:
|
||||
DoAllReduce<std::uint32_t>(send_receive_buffer, count, op);
|
||||
break;
|
||||
case DataType::kInt64:
|
||||
DoAllReduce<std::int64_t>(send_receive_buffer, count, op);
|
||||
break;
|
||||
case DataType::kUInt64:
|
||||
DoAllReduce<std::uint64_t>(send_receive_buffer, count, op);
|
||||
break;
|
||||
case DataType::kFloat:
|
||||
DoAllReduce<float>(send_receive_buffer, count, op);
|
||||
break;
|
||||
case DataType::kDouble:
|
||||
DoAllReduce<double>(send_receive_buffer, count, op);
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "Unknown data type";
|
||||
}
|
||||
}
|
||||
|
||||
void Broadcast(void *send_receive_buffer, std::size_t size, int root) override {
|
||||
rabit::Broadcast(send_receive_buffer, size, root);
|
||||
}
|
||||
|
||||
std::string GetProcessorName() override { return rabit::GetProcessorName(); }
|
||||
|
||||
void Print(const std::string &message) override { rabit::TrackerPrint(message); }
|
||||
|
||||
protected:
|
||||
void Shutdown() override {
|
||||
rabit::Finalize();
|
||||
}
|
||||
|
||||
private:
|
||||
template <typename DType>
|
||||
void DoAllReduce(void *send_receive_buffer, std::size_t count, Operation op) {
|
||||
switch (op) {
|
||||
case Operation::kMax:
|
||||
rabit::Allreduce<rabit::op::Max, DType>(static_cast<DType *>(send_receive_buffer), count);
|
||||
break;
|
||||
case Operation::kMin:
|
||||
rabit::Allreduce<rabit::op::Min, DType>(static_cast<DType *>(send_receive_buffer), count);
|
||||
break;
|
||||
case Operation::kSum:
|
||||
rabit::Allreduce<rabit::op::Sum, DType>(static_cast<DType *>(send_receive_buffer), count);
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "Unknown allreduce operation";
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace collective
|
||||
} // namespace xgboost
|
||||
Reference in New Issue
Block a user