[Breaking] Switch from rabit to the collective communicator (#8257)
* Switch from rabit to the collective communicator * fix size_t specialization * really fix size_t * try again * add include * more include * fix lint errors * remove rabit includes * fix pylint error * return dict from communicator context * fix communicator shutdown * fix dask test * reset communicator mocklist * fix distributed tests * do not save device communicator * fix jvm gpu tests * add python test for federated communicator * Update gputreeshap submodule Co-authored-by: Hyunsu Philip Cho <chohyu01@cs.washington.edu>
This commit is contained in:
208
src/collective/communicator-inl.h
Normal file
208
src/collective/communicator-inl.h
Normal file
@@ -0,0 +1,208 @@
|
||||
/*!
|
||||
* Copyright 2022 XGBoost contributors
|
||||
*/
|
||||
#pragma once
|
||||
#include <string>
|
||||
|
||||
#include "communicator.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace collective {
|
||||
|
||||
/*!
|
||||
* \brief Initialize the collective communicator.
|
||||
*
|
||||
* Currently the communicator API is experimental, function signatures may change in the future
|
||||
* without notice.
|
||||
*
|
||||
* Call this once before using anything.
|
||||
*
|
||||
* The additional configuration is not required. Usually the communicator will detect settings
|
||||
* from environment variables.
|
||||
*
|
||||
* \param json_config JSON encoded configuration. Accepted JSON keys are:
|
||||
* - xgboost_communicator: The type of the communicator. Can be set as an environment variable.
|
||||
* * rabit: Use Rabit. This is the default if the type is unspecified.
|
||||
* * mpi: Use MPI.
|
||||
* * federated: Use the gRPC interface for Federated Learning.
|
||||
* Only applicable to the Rabit communicator (these are case-sensitive):
|
||||
* - rabit_tracker_uri: Hostname of the tracker.
|
||||
* - rabit_tracker_port: Port number of the tracker.
|
||||
* - rabit_task_id: ID of the current task, can be used to obtain deterministic rank assignment.
|
||||
* - rabit_world_size: Total number of workers.
|
||||
* - rabit_hadoop_mode: Enable Hadoop support.
|
||||
* - rabit_tree_reduce_minsize: Minimal size for tree reduce.
|
||||
* - rabit_reduce_ring_mincount: Minimal count to perform ring reduce.
|
||||
* - rabit_reduce_buffer: Size of the reduce buffer.
|
||||
* - rabit_bootstrap_cache: Size of the bootstrap cache.
|
||||
* - rabit_debug: Enable debugging.
|
||||
* - rabit_timeout: Enable timeout.
|
||||
* - rabit_timeout_sec: Timeout in seconds.
|
||||
* - rabit_enable_tcp_no_delay: Enable TCP no delay on Unix platforms.
|
||||
* Only applicable to the Rabit communicator (these are case-sensitive, and can be set as
|
||||
* environment variables):
|
||||
* - DMLC_TRACKER_URI: Hostname of the tracker.
|
||||
* - DMLC_TRACKER_PORT: Port number of the tracker.
|
||||
* - DMLC_TASK_ID: ID of the current task, can be used to obtain deterministic rank assignment.
|
||||
* - DMLC_ROLE: Role of the current task, "worker" or "server".
|
||||
* - DMLC_NUM_ATTEMPT: Number of attempts after task failure.
|
||||
* - DMLC_WORKER_CONNECT_RETRY: Number of retries to connect to the tracker.
|
||||
* Only applicable to the Federated communicator (use upper case for environment variables, use
|
||||
* lower case for runtime configuration):
|
||||
* - federated_server_address: Address of the federated server.
|
||||
* - federated_world_size: Number of federated workers.
|
||||
* - federated_rank: Rank of the current worker.
|
||||
* - federated_server_cert: Server certificate file path. Only needed for the SSL mode.
|
||||
* - federated_client_key: Client key file path. Only needed for the SSL mode.
|
||||
* - federated_client_cert: Client certificate file path. Only needed for the SSL mode.
|
||||
*/
|
||||
inline void Init(Json const& config) {
|
||||
Communicator::Init(config);
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief Finalize the collective communicator.
|
||||
*
|
||||
* Call this function after you finished all jobs.
|
||||
*/
|
||||
inline void Finalize() { Communicator::Finalize(); }
|
||||
|
||||
/*!
|
||||
* \brief Get rank of current process.
|
||||
*
|
||||
* \return Rank of the worker.
|
||||
*/
|
||||
inline int GetRank() { return Communicator::Get()->GetRank(); }
|
||||
|
||||
/*!
|
||||
* \brief Get total number of processes.
|
||||
*
|
||||
* \return Total world size.
|
||||
*/
|
||||
inline int GetWorldSize() { return Communicator::Get()->GetWorldSize(); }
|
||||
|
||||
/*!
|
||||
* \brief Get if the communicator is distributed.
|
||||
*
|
||||
* \return True if the communicator is distributed.
|
||||
*/
|
||||
inline bool IsDistributed() { return Communicator::Get()->IsDistributed(); }
|
||||
|
||||
/*!
|
||||
* \brief Print the message to the communicator.
|
||||
*
|
||||
* This function can be used to communicate the information of the progress to the user who monitors
|
||||
* the communicator.
|
||||
*
|
||||
* \param message The message to be printed.
|
||||
*/
|
||||
inline void Print(char const *message) { Communicator::Get()->Print(message); }
|
||||
|
||||
inline void Print(std::string const &message) { Communicator::Get()->Print(message); }
|
||||
|
||||
/*!
|
||||
* \brief Get the name of the processor.
|
||||
*
|
||||
* \return Name of the processor.
|
||||
*/
|
||||
inline std::string GetProcessorName() { return Communicator::Get()->GetProcessorName(); }
|
||||
|
||||
/*!
|
||||
* \brief Broadcast a memory region to all others from root. This function is NOT thread-safe.
|
||||
*
|
||||
* Example:
|
||||
* int a = 1;
|
||||
* Broadcast(&a, sizeof(a), root);
|
||||
*
|
||||
* \param send_receive_buffer Pointer to the send or receive buffer.
|
||||
* \param size Size of the data.
|
||||
* \param root The process rank to broadcast from.
|
||||
*/
|
||||
inline void Broadcast(void *send_receive_buffer, size_t size, int root) {
|
||||
Communicator::Get()->Broadcast(send_receive_buffer, size, root);
|
||||
}
|
||||
|
||||
inline void Broadcast(std::string *sendrecv_data, int root) {
|
||||
size_t size = sendrecv_data->length();
|
||||
Broadcast(&size, sizeof(size), root);
|
||||
if (sendrecv_data->length() != size) {
|
||||
sendrecv_data->resize(size);
|
||||
}
|
||||
if (size != 0) {
|
||||
Broadcast(&(*sendrecv_data)[0], size * sizeof(char), root);
|
||||
}
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief Perform in-place allreduce. This function is NOT thread-safe.
|
||||
*
|
||||
* Example Usage: the following code gives sum of the result
|
||||
* vector<int> data(10);
|
||||
* ...
|
||||
* Allreduce(&data[0], data.size(), DataType:kInt32, Op::kSum);
|
||||
* ...
|
||||
* \param send_receive_buffer Buffer for both sending and receiving data.
|
||||
* \param count Number of elements to be reduced.
|
||||
* \param data_type Enumeration of data type, see xgboost::collective::DataType in communicator.h.
|
||||
* \param op Enumeration of operation type, see xgboost::collective::Operation in communicator.h.
|
||||
*/
|
||||
inline void Allreduce(void *send_receive_buffer, size_t count, int data_type, int op) {
|
||||
Communicator::Get()->AllReduce(send_receive_buffer, count, static_cast<DataType>(data_type),
|
||||
static_cast<Operation>(op));
|
||||
}
|
||||
|
||||
inline void Allreduce(void *send_receive_buffer, size_t count, DataType data_type, Operation op) {
|
||||
Communicator::Get()->AllReduce(send_receive_buffer, count, data_type, op);
|
||||
}
|
||||
|
||||
template <Operation op>
|
||||
inline void Allreduce(int8_t *send_receive_buffer, size_t count) {
|
||||
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kInt8, op);
|
||||
}
|
||||
|
||||
template <Operation op>
|
||||
inline void Allreduce(uint8_t *send_receive_buffer, size_t count) {
|
||||
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kUInt8, op);
|
||||
}
|
||||
|
||||
template <Operation op>
|
||||
inline void Allreduce(int32_t *send_receive_buffer, size_t count) {
|
||||
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kInt32, op);
|
||||
}
|
||||
|
||||
template <Operation op>
|
||||
inline void Allreduce(uint32_t *send_receive_buffer, size_t count) {
|
||||
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kUInt32, op);
|
||||
}
|
||||
|
||||
template <Operation op>
|
||||
inline void Allreduce(int64_t *send_receive_buffer, size_t count) {
|
||||
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kInt64, op);
|
||||
}
|
||||
|
||||
template <Operation op>
|
||||
inline void Allreduce(uint64_t *send_receive_buffer, size_t count) {
|
||||
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kUInt64, op);
|
||||
}
|
||||
|
||||
// Specialization for size_t, which is implementation defined, so it might or might not
|
||||
// be one of uint64_t/uint32_t/unsigned long long/unsigned long.
|
||||
template <Operation op, typename T,
|
||||
typename = std::enable_if_t<std::is_same<size_t, T>{} && !std::is_same<uint64_t, T>{}> >
|
||||
inline void Allreduce(T *send_receive_buffer, size_t count) {
|
||||
static_assert(sizeof(T) == sizeof(uint64_t), "");
|
||||
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kUInt64, op);
|
||||
}
|
||||
|
||||
template <Operation op>
|
||||
inline void Allreduce(float *send_receive_buffer, size_t count) {
|
||||
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kFloat, op);
|
||||
}
|
||||
|
||||
template <Operation op>
|
||||
inline void Allreduce(double *send_receive_buffer, size_t count) {
|
||||
Communicator::Get()->AllReduce(send_receive_buffer, count, DataType::kDouble, op);
|
||||
}
|
||||
|
||||
} // namespace collective
|
||||
} // namespace xgboost
|
||||
@@ -3,6 +3,7 @@
|
||||
*/
|
||||
#include "communicator.h"
|
||||
|
||||
#include "noop_communicator.h"
|
||||
#include "rabit_communicator.h"
|
||||
|
||||
#if defined(XGBOOST_USE_FEDERATED)
|
||||
@@ -12,14 +13,10 @@
|
||||
namespace xgboost {
|
||||
namespace collective {
|
||||
|
||||
thread_local std::unique_ptr<Communicator> Communicator::communicator_{};
|
||||
thread_local std::unique_ptr<Communicator> Communicator::communicator_{new NoOpCommunicator()};
|
||||
thread_local CommunicatorType Communicator::type_{};
|
||||
|
||||
void Communicator::Init(Json const& config) {
|
||||
if (communicator_) {
|
||||
LOG(FATAL) << "Communicator can only be initialized once.";
|
||||
}
|
||||
|
||||
auto type = GetTypeFromEnv();
|
||||
auto const arg = GetTypeFromConfig(config);
|
||||
if (arg != CommunicatorType::kUnknown) {
|
||||
@@ -51,7 +48,7 @@ void Communicator::Init(Json const& config) {
|
||||
#ifndef XGBOOST_USE_CUDA
|
||||
void Communicator::Finalize() {
|
||||
communicator_->Shutdown();
|
||||
communicator_.reset(nullptr);
|
||||
communicator_.reset(new NoOpCommunicator());
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#include "communicator.h"
|
||||
#include "device_communicator.cuh"
|
||||
#include "device_communicator_adapter.cuh"
|
||||
#include "noop_communicator.h"
|
||||
#ifdef XGBOOST_USE_NCCL
|
||||
#include "nccl_device_communicator.cuh"
|
||||
#endif
|
||||
@@ -16,7 +17,7 @@ thread_local std::unique_ptr<DeviceCommunicator> Communicator::device_communicat
|
||||
|
||||
void Communicator::Finalize() {
|
||||
communicator_->Shutdown();
|
||||
communicator_.reset(nullptr);
|
||||
communicator_.reset(new NoOpCommunicator());
|
||||
device_ordinal_ = -1;
|
||||
device_communicator_.reset(nullptr);
|
||||
}
|
||||
|
||||
@@ -23,40 +23,6 @@ enum class DataType {
|
||||
kDouble = 7
|
||||
};
|
||||
|
||||
/** @brief Get the size of the data type. */
|
||||
inline std::size_t GetTypeSize(DataType data_type) {
|
||||
std::size_t size{0};
|
||||
switch (data_type) {
|
||||
case DataType::kInt8:
|
||||
size = sizeof(std::int8_t);
|
||||
break;
|
||||
case DataType::kUInt8:
|
||||
size = sizeof(std::uint8_t);
|
||||
break;
|
||||
case DataType::kInt32:
|
||||
size = sizeof(std::int32_t);
|
||||
break;
|
||||
case DataType::kUInt32:
|
||||
size = sizeof(std::uint32_t);
|
||||
break;
|
||||
case DataType::kInt64:
|
||||
size = sizeof(std::int64_t);
|
||||
break;
|
||||
case DataType::kUInt64:
|
||||
size = sizeof(std::uint64_t);
|
||||
break;
|
||||
case DataType::kFloat:
|
||||
size = sizeof(float);
|
||||
break;
|
||||
case DataType::kDouble:
|
||||
size = sizeof(double);
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "Unknown data type.";
|
||||
}
|
||||
return size;
|
||||
}
|
||||
|
||||
/** @brief Defines the reduction operation. */
|
||||
enum class Operation { kMax = 0, kMin = 1, kSum = 2 };
|
||||
|
||||
|
||||
@@ -21,7 +21,28 @@ class DeviceCommunicator {
|
||||
* @param send_receive_buffer Buffer storing the data.
|
||||
* @param count Number of elements in the buffer.
|
||||
*/
|
||||
virtual void AllReduceSum(double *send_receive_buffer, int count) = 0;
|
||||
virtual void AllReduceSum(float *send_receive_buffer, size_t count) = 0;
|
||||
|
||||
/**
|
||||
* @brief Sum values from all processes and distribute the result back to all processes.
|
||||
* @param send_receive_buffer Buffer storing the data.
|
||||
* @param count Number of elements in the buffer.
|
||||
*/
|
||||
virtual void AllReduceSum(double *send_receive_buffer, size_t count) = 0;
|
||||
|
||||
/**
|
||||
* @brief Sum values from all processes and distribute the result back to all processes.
|
||||
* @param send_receive_buffer Buffer storing the data.
|
||||
* @param count Number of elements in the buffer.
|
||||
*/
|
||||
virtual void AllReduceSum(int64_t *send_receive_buffer, size_t count) = 0;
|
||||
|
||||
/**
|
||||
* @brief Sum values from all processes and distribute the result back to all processes.
|
||||
* @param send_receive_buffer Buffer storing the data.
|
||||
* @param count Number of elements in the buffer.
|
||||
*/
|
||||
virtual void AllReduceSum(uint64_t *send_receive_buffer, size_t count) = 0;
|
||||
|
||||
/**
|
||||
* @brief Gather variable-length values from all processes.
|
||||
|
||||
@@ -23,17 +23,28 @@ class DeviceCommunicatorAdapter : public DeviceCommunicator {
|
||||
|
||||
~DeviceCommunicatorAdapter() override = default;
|
||||
|
||||
void AllReduceSum(double *send_receive_buffer, int count) override {
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
auto size = count * sizeof(double);
|
||||
host_buffer_.reserve(size);
|
||||
dh::safe_cuda(cudaMemcpy(host_buffer_.data(), send_receive_buffer, size, cudaMemcpyDefault));
|
||||
communicator_->AllReduce(host_buffer_.data(), count, DataType::kDouble, Operation::kSum);
|
||||
dh::safe_cuda(cudaMemcpy(send_receive_buffer, host_buffer_.data(), size, cudaMemcpyDefault));
|
||||
void AllReduceSum(float *send_receive_buffer, size_t count) override {
|
||||
DoAllReduceSum<collective::DataType::kFloat>(send_receive_buffer, count);
|
||||
}
|
||||
|
||||
void AllReduceSum(double *send_receive_buffer, size_t count) override {
|
||||
DoAllReduceSum<collective::DataType::kDouble>(send_receive_buffer, count);
|
||||
}
|
||||
|
||||
void AllReduceSum(int64_t *send_receive_buffer, size_t count) override {
|
||||
DoAllReduceSum<collective::DataType::kInt64>(send_receive_buffer, count);
|
||||
}
|
||||
|
||||
void AllReduceSum(uint64_t *send_receive_buffer, size_t count) override {
|
||||
DoAllReduceSum<collective::DataType::kUInt64>(send_receive_buffer, count);
|
||||
}
|
||||
|
||||
void AllGatherV(void const *send_buffer, size_t length_bytes, std::vector<std::size_t> *segments,
|
||||
dh::caching_device_vector<char> *receive_buffer) override {
|
||||
if (communicator_->GetWorldSize() == 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
int const world_size = communicator_->GetWorldSize();
|
||||
int const rank = communicator_->GetRank();
|
||||
@@ -66,6 +77,20 @@ class DeviceCommunicatorAdapter : public DeviceCommunicator {
|
||||
}
|
||||
|
||||
private:
|
||||
template <collective::DataType data_type, typename T>
|
||||
void DoAllReduceSum(T *send_receive_buffer, size_t count) {
|
||||
if (communicator_->GetWorldSize() == 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
auto size = count * sizeof(T);
|
||||
host_buffer_.reserve(size);
|
||||
dh::safe_cuda(cudaMemcpy(host_buffer_.data(), send_receive_buffer, size, cudaMemcpyDefault));
|
||||
communicator_->AllReduce(host_buffer_.data(), count, data_type, collective::Operation::kSum);
|
||||
dh::safe_cuda(cudaMemcpy(send_receive_buffer, host_buffer_.data(), size, cudaMemcpyDefault));
|
||||
}
|
||||
|
||||
int const device_ordinal_;
|
||||
Communicator *communicator_;
|
||||
/// Host buffer used to call communicator functions.
|
||||
|
||||
@@ -24,6 +24,10 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
|
||||
int32_t const rank = communicator_->GetRank();
|
||||
int32_t const world = communicator_->GetWorldSize();
|
||||
|
||||
if (world == 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<uint64_t> uuids(world * kUuidLength, 0);
|
||||
auto s_uuid = xgboost::common::Span<uint64_t>{uuids.data(), uuids.size()};
|
||||
auto s_this_uuid = s_uuid.subspan(rank * kUuidLength, kUuidLength);
|
||||
@@ -52,8 +56,15 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
|
||||
}
|
||||
|
||||
~NcclDeviceCommunicator() override {
|
||||
dh::safe_cuda(cudaStreamDestroy(cuda_stream_));
|
||||
ncclCommDestroy(nccl_comm_);
|
||||
if (communicator_->GetWorldSize() == 1) {
|
||||
return;
|
||||
}
|
||||
if (cuda_stream_) {
|
||||
dh::safe_cuda(cudaStreamDestroy(cuda_stream_));
|
||||
}
|
||||
if (nccl_comm_) {
|
||||
dh::safe_nccl(ncclCommDestroy(nccl_comm_));
|
||||
}
|
||||
if (xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug)) {
|
||||
LOG(CONSOLE) << "======== NCCL Statistics========";
|
||||
LOG(CONSOLE) << "AllReduce calls: " << allreduce_calls_;
|
||||
@@ -61,16 +72,28 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
|
||||
}
|
||||
}
|
||||
|
||||
void AllReduceSum(double *send_receive_buffer, int count) override {
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
dh::safe_nccl(ncclAllReduce(send_receive_buffer, send_receive_buffer, count, ncclDouble,
|
||||
ncclSum, nccl_comm_, cuda_stream_));
|
||||
allreduce_bytes_ += count * sizeof(double);
|
||||
allreduce_calls_ += 1;
|
||||
void AllReduceSum(float *send_receive_buffer, size_t count) override {
|
||||
DoAllReduceSum<ncclFloat>(send_receive_buffer, count);
|
||||
}
|
||||
|
||||
void AllReduceSum(double *send_receive_buffer, size_t count) override {
|
||||
DoAllReduceSum<ncclDouble>(send_receive_buffer, count);
|
||||
}
|
||||
|
||||
void AllReduceSum(int64_t *send_receive_buffer, size_t count) override {
|
||||
DoAllReduceSum<ncclInt64>(send_receive_buffer, count);
|
||||
}
|
||||
|
||||
void AllReduceSum(uint64_t *send_receive_buffer, size_t count) override {
|
||||
DoAllReduceSum<ncclUint64>(send_receive_buffer, count);
|
||||
}
|
||||
|
||||
void AllGatherV(void const *send_buffer, size_t length_bytes, std::vector<std::size_t> *segments,
|
||||
dh::caching_device_vector<char> *receive_buffer) override {
|
||||
if (communicator_->GetWorldSize() == 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
int const world_size = communicator_->GetWorldSize();
|
||||
int const rank = communicator_->GetRank();
|
||||
@@ -95,6 +118,9 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
|
||||
}
|
||||
|
||||
void Synchronize() override {
|
||||
if (communicator_->GetWorldSize() == 1) {
|
||||
return;
|
||||
}
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
dh::safe_cuda(cudaStreamSynchronize(cuda_stream_));
|
||||
}
|
||||
@@ -136,6 +162,19 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
|
||||
return id;
|
||||
}
|
||||
|
||||
template <ncclDataType_t data_type, typename T>
|
||||
void DoAllReduceSum(T *send_receive_buffer, size_t count) {
|
||||
if (communicator_->GetWorldSize() == 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
dh::safe_nccl(ncclAllReduce(send_receive_buffer, send_receive_buffer, count, data_type, ncclSum,
|
||||
nccl_comm_, cuda_stream_));
|
||||
allreduce_bytes_ += count * sizeof(T);
|
||||
allreduce_calls_ += 1;
|
||||
}
|
||||
|
||||
int const device_ordinal_;
|
||||
Communicator *communicator_;
|
||||
ncclComm_t nccl_comm_{};
|
||||
|
||||
30
src/collective/noop_communicator.h
Normal file
30
src/collective/noop_communicator.h
Normal file
@@ -0,0 +1,30 @@
|
||||
/*!
|
||||
* Copyright 2022 XGBoost contributors
|
||||
*/
|
||||
#pragma once
|
||||
#include <string>
|
||||
|
||||
#include "communicator.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace collective {
|
||||
|
||||
/**
|
||||
* A no-op communicator, used for non-distributed training.
|
||||
*/
|
||||
class NoOpCommunicator : public Communicator {
|
||||
public:
|
||||
NoOpCommunicator() : Communicator(1, 0) {}
|
||||
bool IsDistributed() const override { return false; }
|
||||
void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
|
||||
Operation op) override {}
|
||||
void Broadcast(void *send_receive_buffer, std::size_t size, int root) override {}
|
||||
std::string GetProcessorName() override { return ""; }
|
||||
void Print(const std::string &message) override { LOG(CONSOLE) << message; }
|
||||
|
||||
protected:
|
||||
void Shutdown() override {}
|
||||
};
|
||||
|
||||
} // namespace collective
|
||||
} // namespace xgboost
|
||||
Reference in New Issue
Block a user