[Breaking] Switch from rabit to the collective communicator (#8257)

* Switch from rabit to the collective communicator

* fix size_t specialization

* really fix size_t

* try again

* add include

* more include

* fix lint errors

* remove rabit includes

* fix pylint error

* return dict from communicator context

* fix communicator shutdown

* fix dask test

* reset communicator mocklist

* fix distributed tests

* do not save device communicator

* fix jvm gpu tests

* add python test for federated communicator

* Update gputreeshap submodule

Co-authored-by: Hyunsu Philip Cho <chohyu01@cs.washington.edu>
This commit is contained in:
Rong Ou
2022-10-05 15:39:01 -07:00
committed by GitHub
parent e47b3a3da3
commit 668b8a0ea4
79 changed files with 805 additions and 2212 deletions

View File

@@ -0,0 +1,208 @@
/*!
* Copyright 2022 XGBoost contributors
*/
#pragma once
#include <string>
#include "communicator.h"
namespace xgboost {
namespace collective {
/*!
* \brief Initialize the collective communicator.
*
* Currently the communicator API is experimental, function signatures may change in the future
* without notice.
*
* Call this once before using anything.
*
* The additional configuration is not required. Usually the communicator will detect settings
* from environment variables.
*
* \param json_config JSON encoded configuration. Accepted JSON keys are:
* - xgboost_communicator: The type of the communicator. Can be set as an environment variable.
* * rabit: Use Rabit. This is the default if the type is unspecified.
* * mpi: Use MPI.
* * federated: Use the gRPC interface for Federated Learning.
* Only applicable to the Rabit communicator (these are case-sensitive):
* - rabit_tracker_uri: Hostname of the tracker.
* - rabit_tracker_port: Port number of the tracker.
* - rabit_task_id: ID of the current task, can be used to obtain deterministic rank assignment.
* - rabit_world_size: Total number of workers.
* - rabit_hadoop_mode: Enable Hadoop support.
* - rabit_tree_reduce_minsize: Minimal size for tree reduce.
* - rabit_reduce_ring_mincount: Minimal count to perform ring reduce.
* - rabit_reduce_buffer: Size of the reduce buffer.
* - rabit_bootstrap_cache: Size of the bootstrap cache.
* - rabit_debug: Enable debugging.
* - rabit_timeout: Enable timeout.
* - rabit_timeout_sec: Timeout in seconds.
* - rabit_enable_tcp_no_delay: Enable TCP no delay on Unix platforms.
* Only applicable to the Rabit communicator (these are case-sensitive, and can be set as
* environment variables):
* - DMLC_TRACKER_URI: Hostname of the tracker.
* - DMLC_TRACKER_PORT: Port number of the tracker.
* - DMLC_TASK_ID: ID of the current task, can be used to obtain deterministic rank assignment.
* - DMLC_ROLE: Role of the current task, "worker" or "server".
* - DMLC_NUM_ATTEMPT: Number of attempts after task failure.
* - DMLC_WORKER_CONNECT_RETRY: Number of retries to connect to the tracker.
* Only applicable to the Federated communicator (use upper case for environment variables, use
* lower case for runtime configuration):
* - federated_server_address: Address of the federated server.
* - federated_world_size: Number of federated workers.
* - federated_rank: Rank of the current worker.
* - federated_server_cert: Server certificate file path. Only needed for the SSL mode.
* - federated_client_key: Client key file path. Only needed for the SSL mode.
* - federated_client_cert: Client certificate file path. Only needed for the SSL mode.
*/
inline void Init(Json const& config) {
Communicator::Init(config);
}
/*!
* \brief Finalize the collective communicator.
*
* Call this function after you finished all jobs.
*/
inline void Finalize() { Communicator::Finalize(); }
/*!
* \brief Get rank of current process.
*
* \return Rank of the worker.
*/
inline int GetRank() { return Communicator::Get()->GetRank(); }
/*!
* \brief Get total number of processes.
*
* \return Total world size.
*/
inline int GetWorldSize() { return Communicator::Get()->GetWorldSize(); }
/*!
* \brief Get if the communicator is distributed.
*
* \return True if the communicator is distributed.
*/
inline bool IsDistributed() { return Communicator::Get()->IsDistributed(); }
/*!
* \brief Print the message to the communicator.
*
* This function can be used to communicate the information of the progress to the user who monitors
* the communicator.
*
* \param message The message to be printed.
*/
inline void Print(char const *message) { Communicator::Get()->Print(message); }
inline void Print(std::string const &message) { Communicator::Get()->Print(message); }
/*!
* \brief Get the name of the processor.
*
* \return Name of the processor.
*/
inline std::string GetProcessorName() { return Communicator::Get()->GetProcessorName(); }
/*!
* \brief Broadcast a memory region to all others from root. This function is NOT thread-safe.
*
* Example:
* int a = 1;
* Broadcast(&a, sizeof(a), root);
*
* \param send_receive_buffer Pointer to the send or receive buffer.
* \param size Size of the data.
* \param root The process rank to broadcast from.
*/
inline void Broadcast(void *send_receive_buffer, size_t size, int root) {
Communicator::Get()->Broadcast(send_receive_buffer, size, root);
}
inline void Broadcast(std::string *sendrecv_data, int root) {
size_t size = sendrecv_data->length();
Broadcast(&size, sizeof(size), root);
if (sendrecv_data->length() != size) {
sendrecv_data->resize(size);
}
if (size != 0) {
Broadcast(&(*sendrecv_data)[0], size * sizeof(char), root);
}
}
/*!
* \brief Perform in-place allreduce. This function is NOT thread-safe.
*
* Example Usage: the following code gives sum of the result
* vector<int> data(10);
* ...
* Allreduce(&data[0], data.size(), DataType:kInt32, Op::kSum);
* ...
* \param send_receive_buffer Buffer for both sending and receiving data.
* \param count Number of elements to be reduced.
* \param data_type Enumeration of data type, see xgboost::collective::DataType in communicator.h.
* \param op Enumeration of operation type, see xgboost::collective::Operation in communicator.h.
*/
inline void Allreduce(void *send_receive_buffer, size_t count, int data_type, int op) {
Communicator::Get()->AllReduce(send_receive_buffer, count, static_cast<DataType>(data_type),
static_cast<Operation>(op));
}
inline void Allreduce(void *send_receive_buffer, size_t count, DataType data_type, Operation op) {
Communicator::Get()->AllReduce(send_receive_buffer, count, data_type, op);
}
template <Operation op>
inline void Allreduce(int8_t *send_receive_buffer, size_t count) {
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kInt8, op);
}
template <Operation op>
inline void Allreduce(uint8_t *send_receive_buffer, size_t count) {
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kUInt8, op);
}
template <Operation op>
inline void Allreduce(int32_t *send_receive_buffer, size_t count) {
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kInt32, op);
}
template <Operation op>
inline void Allreduce(uint32_t *send_receive_buffer, size_t count) {
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kUInt32, op);
}
template <Operation op>
inline void Allreduce(int64_t *send_receive_buffer, size_t count) {
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kInt64, op);
}
template <Operation op>
inline void Allreduce(uint64_t *send_receive_buffer, size_t count) {
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kUInt64, op);
}
// Specialization for size_t, which is implementation defined, so it might or might not
// be one of uint64_t/uint32_t/unsigned long long/unsigned long.
template <Operation op, typename T,
typename = std::enable_if_t<std::is_same<size_t, T>{} && !std::is_same<uint64_t, T>{}> >
inline void Allreduce(T *send_receive_buffer, size_t count) {
static_assert(sizeof(T) == sizeof(uint64_t), "");
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kUInt64, op);
}
template <Operation op>
inline void Allreduce(float *send_receive_buffer, size_t count) {
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kFloat, op);
}
template <Operation op>
inline void Allreduce(double *send_receive_buffer, size_t count) {
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kDouble, op);
}
} // namespace collective
} // namespace xgboost

View File

@@ -3,6 +3,7 @@
*/
#include "communicator.h"
#include "noop_communicator.h"
#include "rabit_communicator.h"
#if defined(XGBOOST_USE_FEDERATED)
@@ -12,14 +13,10 @@
namespace xgboost {
namespace collective {
thread_local std::unique_ptr<Communicator> Communicator::communicator_{};
thread_local std::unique_ptr<Communicator> Communicator::communicator_{new NoOpCommunicator()};
thread_local CommunicatorType Communicator::type_{};
void Communicator::Init(Json const& config) {
if (communicator_) {
LOG(FATAL) << "Communicator can only be initialized once.";
}
auto type = GetTypeFromEnv();
auto const arg = GetTypeFromConfig(config);
if (arg != CommunicatorType::kUnknown) {
@@ -51,7 +48,7 @@ void Communicator::Init(Json const& config) {
#ifndef XGBOOST_USE_CUDA
void Communicator::Finalize() {
communicator_->Shutdown();
communicator_.reset(nullptr);
communicator_.reset(new NoOpCommunicator());
}
#endif

View File

@@ -4,6 +4,7 @@
#include "communicator.h"
#include "device_communicator.cuh"
#include "device_communicator_adapter.cuh"
#include "noop_communicator.h"
#ifdef XGBOOST_USE_NCCL
#include "nccl_device_communicator.cuh"
#endif
@@ -16,7 +17,7 @@ thread_local std::unique_ptr<DeviceCommunicator> Communicator::device_communicat
void Communicator::Finalize() {
communicator_->Shutdown();
communicator_.reset(nullptr);
communicator_.reset(new NoOpCommunicator());
device_ordinal_ = -1;
device_communicator_.reset(nullptr);
}

View File

@@ -23,40 +23,6 @@ enum class DataType {
kDouble = 7
};
/** @brief Get the size of the data type. */
inline std::size_t GetTypeSize(DataType data_type) {
std::size_t size{0};
switch (data_type) {
case DataType::kInt8:
size = sizeof(std::int8_t);
break;
case DataType::kUInt8:
size = sizeof(std::uint8_t);
break;
case DataType::kInt32:
size = sizeof(std::int32_t);
break;
case DataType::kUInt32:
size = sizeof(std::uint32_t);
break;
case DataType::kInt64:
size = sizeof(std::int64_t);
break;
case DataType::kUInt64:
size = sizeof(std::uint64_t);
break;
case DataType::kFloat:
size = sizeof(float);
break;
case DataType::kDouble:
size = sizeof(double);
break;
default:
LOG(FATAL) << "Unknown data type.";
}
return size;
}
/** @brief Defines the reduction operation. */
enum class Operation { kMax = 0, kMin = 1, kSum = 2 };

View File

@@ -21,7 +21,28 @@ class DeviceCommunicator {
* @param send_receive_buffer Buffer storing the data.
* @param count Number of elements in the buffer.
*/
virtual void AllReduceSum(double *send_receive_buffer, int count) = 0;
virtual void AllReduceSum(float *send_receive_buffer, size_t count) = 0;
/**
* @brief Sum values from all processes and distribute the result back to all processes.
* @param send_receive_buffer Buffer storing the data.
* @param count Number of elements in the buffer.
*/
virtual void AllReduceSum(double *send_receive_buffer, size_t count) = 0;
/**
* @brief Sum values from all processes and distribute the result back to all processes.
* @param send_receive_buffer Buffer storing the data.
* @param count Number of elements in the buffer.
*/
virtual void AllReduceSum(int64_t *send_receive_buffer, size_t count) = 0;
/**
* @brief Sum values from all processes and distribute the result back to all processes.
* @param send_receive_buffer Buffer storing the data.
* @param count Number of elements in the buffer.
*/
virtual void AllReduceSum(uint64_t *send_receive_buffer, size_t count) = 0;
/**
* @brief Gather variable-length values from all processes.

View File

@@ -23,17 +23,28 @@ class DeviceCommunicatorAdapter : public DeviceCommunicator {
~DeviceCommunicatorAdapter() override = default;
void AllReduceSum(double *send_receive_buffer, int count) override {
dh::safe_cuda(cudaSetDevice(device_ordinal_));
auto size = count * sizeof(double);
host_buffer_.reserve(size);
dh::safe_cuda(cudaMemcpy(host_buffer_.data(), send_receive_buffer, size, cudaMemcpyDefault));
communicator_->AllReduce(host_buffer_.data(), count, DataType::kDouble, Operation::kSum);
dh::safe_cuda(cudaMemcpy(send_receive_buffer, host_buffer_.data(), size, cudaMemcpyDefault));
void AllReduceSum(float *send_receive_buffer, size_t count) override {
DoAllReduceSum<collective::DataType::kFloat>(send_receive_buffer, count);
}
void AllReduceSum(double *send_receive_buffer, size_t count) override {
DoAllReduceSum<collective::DataType::kDouble>(send_receive_buffer, count);
}
void AllReduceSum(int64_t *send_receive_buffer, size_t count) override {
DoAllReduceSum<collective::DataType::kInt64>(send_receive_buffer, count);
}
void AllReduceSum(uint64_t *send_receive_buffer, size_t count) override {
DoAllReduceSum<collective::DataType::kUInt64>(send_receive_buffer, count);
}
void AllGatherV(void const *send_buffer, size_t length_bytes, std::vector<std::size_t> *segments,
dh::caching_device_vector<char> *receive_buffer) override {
if (communicator_->GetWorldSize() == 1) {
return;
}
dh::safe_cuda(cudaSetDevice(device_ordinal_));
int const world_size = communicator_->GetWorldSize();
int const rank = communicator_->GetRank();
@@ -66,6 +77,20 @@ class DeviceCommunicatorAdapter : public DeviceCommunicator {
}
private:
template <collective::DataType data_type, typename T>
void DoAllReduceSum(T *send_receive_buffer, size_t count) {
if (communicator_->GetWorldSize() == 1) {
return;
}
dh::safe_cuda(cudaSetDevice(device_ordinal_));
auto size = count * sizeof(T);
host_buffer_.reserve(size);
dh::safe_cuda(cudaMemcpy(host_buffer_.data(), send_receive_buffer, size, cudaMemcpyDefault));
communicator_->AllReduce(host_buffer_.data(), count, data_type, collective::Operation::kSum);
dh::safe_cuda(cudaMemcpy(send_receive_buffer, host_buffer_.data(), size, cudaMemcpyDefault));
}
int const device_ordinal_;
Communicator *communicator_;
/// Host buffer used to call communicator functions.

View File

@@ -24,6 +24,10 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
int32_t const rank = communicator_->GetRank();
int32_t const world = communicator_->GetWorldSize();
if (world == 1) {
return;
}
std::vector<uint64_t> uuids(world * kUuidLength, 0);
auto s_uuid = xgboost::common::Span<uint64_t>{uuids.data(), uuids.size()};
auto s_this_uuid = s_uuid.subspan(rank * kUuidLength, kUuidLength);
@@ -52,8 +56,15 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
}
~NcclDeviceCommunicator() override {
dh::safe_cuda(cudaStreamDestroy(cuda_stream_));
ncclCommDestroy(nccl_comm_);
if (communicator_->GetWorldSize() == 1) {
return;
}
if (cuda_stream_) {
dh::safe_cuda(cudaStreamDestroy(cuda_stream_));
}
if (nccl_comm_) {
dh::safe_nccl(ncclCommDestroy(nccl_comm_));
}
if (xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug)) {
LOG(CONSOLE) << "======== NCCL Statistics========";
LOG(CONSOLE) << "AllReduce calls: " << allreduce_calls_;
@@ -61,16 +72,28 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
}
}
void AllReduceSum(double *send_receive_buffer, int count) override {
dh::safe_cuda(cudaSetDevice(device_ordinal_));
dh::safe_nccl(ncclAllReduce(send_receive_buffer, send_receive_buffer, count, ncclDouble,
ncclSum, nccl_comm_, cuda_stream_));
allreduce_bytes_ += count * sizeof(double);
allreduce_calls_ += 1;
void AllReduceSum(float *send_receive_buffer, size_t count) override {
DoAllReduceSum<ncclFloat>(send_receive_buffer, count);
}
void AllReduceSum(double *send_receive_buffer, size_t count) override {
DoAllReduceSum<ncclDouble>(send_receive_buffer, count);
}
void AllReduceSum(int64_t *send_receive_buffer, size_t count) override {
DoAllReduceSum<ncclInt64>(send_receive_buffer, count);
}
void AllReduceSum(uint64_t *send_receive_buffer, size_t count) override {
DoAllReduceSum<ncclUint64>(send_receive_buffer, count);
}
void AllGatherV(void const *send_buffer, size_t length_bytes, std::vector<std::size_t> *segments,
dh::caching_device_vector<char> *receive_buffer) override {
if (communicator_->GetWorldSize() == 1) {
return;
}
dh::safe_cuda(cudaSetDevice(device_ordinal_));
int const world_size = communicator_->GetWorldSize();
int const rank = communicator_->GetRank();
@@ -95,6 +118,9 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
}
void Synchronize() override {
if (communicator_->GetWorldSize() == 1) {
return;
}
dh::safe_cuda(cudaSetDevice(device_ordinal_));
dh::safe_cuda(cudaStreamSynchronize(cuda_stream_));
}
@@ -136,6 +162,19 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
return id;
}
template <ncclDataType_t data_type, typename T>
void DoAllReduceSum(T *send_receive_buffer, size_t count) {
if (communicator_->GetWorldSize() == 1) {
return;
}
dh::safe_cuda(cudaSetDevice(device_ordinal_));
dh::safe_nccl(ncclAllReduce(send_receive_buffer, send_receive_buffer, count, data_type, ncclSum,
nccl_comm_, cuda_stream_));
allreduce_bytes_ += count * sizeof(T);
allreduce_calls_ += 1;
}
int const device_ordinal_;
Communicator *communicator_;
ncclComm_t nccl_comm_{};

View File

@@ -0,0 +1,30 @@
/*!
* Copyright 2022 XGBoost contributors
*/
#pragma once
#include <string>
#include "communicator.h"
namespace xgboost {
namespace collective {
/**
* A no-op communicator, used for non-distributed training.
*/
class NoOpCommunicator : public Communicator {
public:
NoOpCommunicator() : Communicator(1, 0) {}
bool IsDistributed() const override { return false; }
void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
Operation op) override {}
void Broadcast(void *send_receive_buffer, std::size_t size, int root) override {}
std::string GetProcessorName() override { return ""; }
void Print(const std::string &message) override { LOG(CONSOLE) << message; }
protected:
void Shutdown() override {}
};
} // namespace collective
} // namespace xgboost