Common interface for collective communication (#8057)

* implement broadcast for federated communicator

* implement allreduce

* add communicator factory

* add device adapter

* add device communicator to factory

* add rabit communicator

* add rabit communicator to the factory

* add nccl device communicator

* add synchronize to device communicator

* add back print and getprocessorname

* add python wrapper and c api

* clean up types

* fix non-gpu build

* try to fix ci

* fix std::size_t

* portable string compare ignore case

* c style size_t

* fix lint errors

* cross platform setenv

* fix memory leak

* fix lint errors

* address review feedback

* add python test for rabit communicator

* fix failing gtest

* use json to configure communicators

* fix lint error

* get rid of factories

* fix cpu build

* fix include

* fix python import

* don't export collective.py yet

* skip collective communicator pytest on windows

* add review feedback

* update documentation

* remove mpi communicator type

* fix tests

* shutdown the communicator separately

Co-authored-by: Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
Rong Ou
2022-09-12 15:21:12 -07:00
committed by GitHub
parent bc818316f2
commit a2686543a9
25 changed files with 1771 additions and 95 deletions

View File

@@ -22,6 +22,7 @@
#include "c_api_error.h"
#include "c_api_utils.h"
#include "../collective/communicator.h"
#include "../common/io.h"
#include "../common/charconv.h"
#include "../data/adapter.h"
@@ -1370,6 +1371,62 @@ XGB_DLL int XGBoosterFeatureScore(BoosterHandle handle, char const *json_config,
API_END();
}
using xgboost::collective::Communicator;
XGB_DLL int XGCommunicatorInit(char const* json_config) {
API_BEGIN();
Json config { Json::Load(StringView{json_config}) };
Communicator::Init(config);
API_END();
}
XGB_DLL int XGCommunicatorFinalize(void) {
API_BEGIN();
Communicator::Finalize();
API_END();
}
XGB_DLL int XGCommunicatorGetRank(void) {
return Communicator::Get()->GetRank();
}
XGB_DLL int XGCommunicatorGetWorldSize(void) {
return Communicator::Get()->GetWorldSize();
}
XGB_DLL int XGCommunicatorIsDistributed(void) {
return Communicator::Get()->IsDistributed();
}
XGB_DLL int XGCommunicatorPrint(char const *message) {
API_BEGIN();
Communicator::Get()->Print(message);
API_END();
}
XGB_DLL int XGCommunicatorGetProcessorName(char const **name_str) {
API_BEGIN();
auto& local = *GlobalConfigAPIThreadLocalStore::Get();
local.ret_str = Communicator::Get()->GetProcessorName();
*name_str = local.ret_str.c_str();
API_END();
}
XGB_DLL int XGCommunicatorBroadcast(void *send_receive_buffer, size_t size, int root) {
API_BEGIN();
Communicator::Get()->Broadcast(send_receive_buffer, size, root);
API_END();
}
XGB_DLL int XGCommunicatorAllreduce(void *send_receive_buffer, size_t count, int enum_dtype,
int enum_op) {
API_BEGIN();
Communicator::Get()->AllReduce(
send_receive_buffer, count, static_cast<xgboost::collective::DataType>(enum_dtype),
static_cast<xgboost::collective::Operation>(enum_op));
API_END();
}
#if defined(XGBOOST_USE_FEDERATED)
XGB_DLL int XGBRunFederatedServer(int port, int world_size, char const *server_key_path,
char const *server_cert_path, char const *client_cert_path) {

View File

@@ -0,0 +1,59 @@
/*!
* Copyright 2022 XGBoost contributors
*/
#include "communicator.h"
#include "rabit_communicator.h"
#if defined(XGBOOST_USE_FEDERATED)
#include "../../plugin/federated/federated_communicator.h"
#endif
namespace xgboost {
namespace collective {
thread_local std::unique_ptr<Communicator> Communicator::communicator_{};
thread_local CommunicatorType Communicator::type_{};
void Communicator::Init(Json const& config) {
if (communicator_) {
LOG(FATAL) << "Communicator can only be initialized once.";
}
auto type = GetTypeFromEnv();
auto const arg = GetTypeFromConfig(config);
if (arg != CommunicatorType::kUnknown) {
type = arg;
}
if (type == CommunicatorType::kUnknown) {
// Default to Rabit if unspecified.
type = CommunicatorType::kRabit;
}
type_ = type;
switch (type) {
case CommunicatorType::kRabit: {
communicator_.reset(RabitCommunicator::Create(config));
break;
}
case CommunicatorType::kFederated: {
#if defined(XGBOOST_USE_FEDERATED)
communicator_.reset(FederatedCommunicator::Create(config));
#else
LOG(FATAL) << "XGBoost is not compiled with Federated Learning support.";
#endif
break;
}
case CommunicatorType::kUnknown:
LOG(FATAL) << "Unknown communicator type.";
}
}
#ifndef XGBOOST_USE_CUDA
void Communicator::Finalize() {
communicator_->Shutdown();
communicator_.reset(nullptr);
}
#endif
} // namespace collective
} // namespace xgboost

View File

@@ -0,0 +1,41 @@
/*!
* Copyright 2022 XGBoost contributors
*/
#include "communicator.h"
#include "device_communicator.cuh"
#include "device_communicator_adapter.cuh"
#ifdef XGBOOST_USE_NCCL
#include "nccl_device_communicator.cuh"
#endif
namespace xgboost {
namespace collective {
thread_local int Communicator::device_ordinal_{-1};
thread_local std::unique_ptr<DeviceCommunicator> Communicator::device_communicator_{};
void Communicator::Finalize() {
communicator_->Shutdown();
communicator_.reset(nullptr);
device_ordinal_ = -1;
device_communicator_.reset(nullptr);
}
DeviceCommunicator* Communicator::GetDevice(int device_ordinal) {
if (!device_communicator_ || device_ordinal_ != device_ordinal) {
device_ordinal_ = device_ordinal;
#ifdef XGBOOST_USE_NCCL
if (type_ != CommunicatorType::kFederated) {
device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal, Get()));
} else {
device_communicator_.reset(new DeviceCommunicatorAdapter(device_ordinal, Get()));
}
#else
device_communicator_.reset(new DeviceCommunicatorAdapter(device_ordinal, Get()));
#endif
}
return device_communicator_.get();
}
} // namespace collective
} // namespace xgboost

View File

@@ -0,0 +1,218 @@
/*!
* Copyright 2022 XGBoost contributors
*/
#pragma once
#include <xgboost/json.h>
#include <xgboost/logging.h>
#include <memory>
#include <string>
namespace xgboost {
namespace collective {
/** @brief Defines the integral and floating data types. */
enum class DataType {
kInt8 = 0,
kUInt8 = 1,
kInt32 = 2,
kUInt32 = 3,
kInt64 = 4,
kUInt64 = 5,
kFloat = 6,
kDouble = 7
};
/** @brief Get the size of the data type. */
inline std::size_t GetTypeSize(DataType data_type) {
std::size_t size{0};
switch (data_type) {
case DataType::kInt8:
size = sizeof(std::int8_t);
break;
case DataType::kUInt8:
size = sizeof(std::uint8_t);
break;
case DataType::kInt32:
size = sizeof(std::int32_t);
break;
case DataType::kUInt32:
size = sizeof(std::uint32_t);
break;
case DataType::kInt64:
size = sizeof(std::int64_t);
break;
case DataType::kUInt64:
size = sizeof(std::uint64_t);
break;
case DataType::kFloat:
size = sizeof(float);
break;
case DataType::kDouble:
size = sizeof(double);
break;
default:
LOG(FATAL) << "Unknown data type.";
}
return size;
}
/** @brief Defines the reduction operation. */
enum class Operation { kMax = 0, kMin = 1, kSum = 2 };
class DeviceCommunicator;
enum class CommunicatorType { kUnknown, kRabit, kFederated };
/** \brief Case-insensitive string comparison. */
inline int CompareStringsCaseInsensitive(const char *s1, const char *s2) {
#ifdef _MSC_VER
return _stricmp(s1, s2);
#else // _MSC_VER
return strcasecmp(s1, s2);
#endif // _MSC_VER
}
/**
* @brief A communicator class that handles collective communication.
*/
class Communicator {
public:
/**
* @brief Initialize the communicator. This can only be done once.
*
* @param config JSON configuration for the communicator.
*/
static void Init(Json const &config);
/** @brief Finalize the communicator. */
static void Finalize();
/** @brief Get the communicator instance. */
static Communicator *Get() { return communicator_.get(); }
#if defined(XGBOOST_USE_CUDA)
/**
* @brief Get the device communicator.
*
* @param device_ordinal ID of the device.
* @return An instance of device communicator.
*/
static DeviceCommunicator *GetDevice(int device_ordinal);
#endif
virtual ~Communicator() = default;
/** @brief Get the total number of processes. */
int GetWorldSize() const { return world_size_; }
/** @brief Get the rank of the current processes. */
int GetRank() const { return rank_; }
/** @brief Whether the communicator is running in distributed mode. */
virtual bool IsDistributed() const = 0;
/**
* @brief Combines values from all processes and distributes the result back to all processes.
*
* @param send_receive_buffer Buffer storing the data.
* @param count Number of elements in the buffer.
* @param data_type Data type stored in the buffer.
* @param op The operation to perform.
*/
virtual void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
Operation op) = 0;
/**
* @brief Broadcasts a message from the process with rank `root` to all other processes of the
* group.
*
* @param send_receive_buffer Buffer storing the data.
* @param size Size of the data in bytes.
* @param root Rank of broadcast root.
*/
virtual void Broadcast(void *send_receive_buffer, std::size_t size, int root) = 0;
/**
* @brief Gets the name of the processor.
*/
virtual std::string GetProcessorName() = 0;
/**
* @brief Prints the message.
*/
virtual void Print(std::string const &message) = 0;
/** @brief Get the communicator type from environment variables. Visible for testing. */
static CommunicatorType GetTypeFromEnv() {
auto *env = std::getenv("XGBOOST_COMMUNICATOR");
if (env != nullptr) {
return StringToType(env);
} else {
return CommunicatorType::kUnknown;
}
}
/** @brief Get the communicator type from runtime configuration. Visible for testing. */
static CommunicatorType GetTypeFromConfig(Json const &config) {
auto const &j_upper = config["XGBOOST_COMMUNICATOR"];
if (IsA<String const>(j_upper)) {
return StringToType(get<String const>(j_upper).c_str());
}
auto const &j_lower = config["xgboost_communicator"];
if (IsA<String const>(j_lower)) {
return StringToType(get<String const>(j_lower).c_str());
}
return CommunicatorType::kUnknown;
}
protected:
/**
* @brief Construct a new communicator.
*
* @param world_size Total number of processes.
* @param rank Rank of the current process.
*/
Communicator(int world_size, int rank) : world_size_(world_size), rank_(rank) {
if (world_size < 1) {
LOG(FATAL) << "World size " << world_size << " is less than 1.";
}
if (rank < 0) {
LOG(FATAL) << "Rank " << rank << " is less than 0.";
}
if (rank >= world_size) {
LOG(FATAL) << "Rank " << rank << " is greater than world_size - 1: " << world_size - 1 << ".";
}
}
/**
* @brief Shuts down the communicator.
*/
virtual void Shutdown() = 0;
private:
static CommunicatorType StringToType(char const *str) {
CommunicatorType result = CommunicatorType::kUnknown;
if (!CompareStringsCaseInsensitive("rabit", str)) {
result = CommunicatorType::kRabit;
} else if (!CompareStringsCaseInsensitive("federated", str)) {
result = CommunicatorType::kFederated;
} else {
LOG(FATAL) << "Unknown communicator type " << str;
}
return result;
}
static thread_local std::unique_ptr<Communicator> communicator_;
static thread_local CommunicatorType type_;
#if defined(XGBOOST_USE_CUDA)
static thread_local int device_ordinal_;
static thread_local std::unique_ptr<DeviceCommunicator> device_communicator_;
#endif
int const world_size_;
int const rank_;
};
} // namespace collective
} // namespace xgboost

View File

@@ -0,0 +1,42 @@
/*!
* Copyright 2022 XGBoost contributors
*/
#pragma once
#include <vector>
#include "../common/device_helpers.cuh"
namespace xgboost {
namespace collective {
/**
* @brief Collective communicator for device buffers.
*/
class DeviceCommunicator {
public:
virtual ~DeviceCommunicator() = default;
/**
* @brief Sum values from all processes and distribute the result back to all processes.
* @param send_receive_buffer Buffer storing the data.
* @param count Number of elements in the buffer.
*/
virtual void AllReduceSum(double *send_receive_buffer, int count) = 0;
/**
* @brief Gather variable-length values from all processes.
* @param send_buffer Buffer storing the input data.
* @param length_bytes Length in bytes of the input data.
* @param segments Size of each segment.
* @param receive_buffer Buffer storing the output data.
*/
virtual void AllGatherV(void const *send_buffer, size_t length_bytes,
std::vector<size_t> *segments,
dh::caching_device_vector<char> *receive_buffer) = 0;
/** @brief Synchronize device operations. */
virtual void Synchronize() = 0;
};
} // namespace collective
} // namespace xgboost

View File

@@ -0,0 +1,76 @@
/*!
* Copyright 2022 XGBoost contributors
*/
#pragma once
#include "communicator.h"
#include "device_communicator.cuh"
namespace xgboost {
namespace collective {
class DeviceCommunicatorAdapter : public DeviceCommunicator {
public:
DeviceCommunicatorAdapter(int device_ordinal, Communicator *communicator)
: device_ordinal_{device_ordinal}, communicator_{communicator} {
if (device_ordinal_ < 0) {
LOG(FATAL) << "Invalid device ordinal: " << device_ordinal_;
}
if (communicator_ == nullptr) {
LOG(FATAL) << "Communicator cannot be null.";
}
}
~DeviceCommunicatorAdapter() override = default;
void AllReduceSum(double *send_receive_buffer, int count) override {
dh::safe_cuda(cudaSetDevice(device_ordinal_));
auto size = count * sizeof(double);
host_buffer_.reserve(size);
dh::safe_cuda(cudaMemcpy(host_buffer_.data(), send_receive_buffer, size, cudaMemcpyDefault));
communicator_->AllReduce(host_buffer_.data(), count, DataType::kDouble, Operation::kSum);
dh::safe_cuda(cudaMemcpy(send_receive_buffer, host_buffer_.data(), size, cudaMemcpyDefault));
}
void AllGatherV(void const *send_buffer, size_t length_bytes, std::vector<std::size_t> *segments,
dh::caching_device_vector<char> *receive_buffer) override {
dh::safe_cuda(cudaSetDevice(device_ordinal_));
int const world_size = communicator_->GetWorldSize();
int const rank = communicator_->GetRank();
segments->clear();
segments->resize(world_size, 0);
segments->at(rank) = length_bytes;
communicator_->AllReduce(segments->data(), segments->size(), DataType::kUInt64,
Operation::kMax);
auto total_bytes = std::accumulate(segments->cbegin(), segments->cend(), 0UL);
receive_buffer->resize(total_bytes);
host_buffer_.reserve(total_bytes);
size_t offset = 0;
for (int32_t i = 0; i < world_size; ++i) {
size_t as_bytes = segments->at(i);
if (i == rank) {
dh::safe_cuda(cudaMemcpy(host_buffer_.data() + offset, send_buffer, segments->at(rank),
cudaMemcpyDefault));
}
communicator_->Broadcast(host_buffer_.data() + offset, as_bytes, i);
offset += as_bytes;
}
dh::safe_cuda(cudaMemcpy(receive_buffer->data().get(), host_buffer_.data(), total_bytes,
cudaMemcpyDefault));
}
void Synchronize() override {
// Noop.
}
private:
int const device_ordinal_;
Communicator *communicator_;
/// Host buffer used to call communicator functions.
std::vector<char> host_buffer_{};
};
} // namespace collective
} // namespace xgboost

View File

@@ -0,0 +1,149 @@
/*!
* Copyright 2022 XGBoost contributors
*/
#pragma once
#include "../common/device_helpers.cuh"
#include "communicator.h"
#include "device_communicator.cuh"
namespace xgboost {
namespace collective {
class NcclDeviceCommunicator : public DeviceCommunicator {
public:
NcclDeviceCommunicator(int device_ordinal, Communicator *communicator)
: device_ordinal_{device_ordinal}, communicator_{communicator} {
if (device_ordinal_ < 0) {
LOG(FATAL) << "Invalid device ordinal: " << device_ordinal_;
}
if (communicator_ == nullptr) {
LOG(FATAL) << "Communicator cannot be null.";
}
int32_t const rank = communicator_->GetRank();
int32_t const world = communicator_->GetWorldSize();
std::vector<uint64_t> uuids(world * kUuidLength, 0);
auto s_uuid = xgboost::common::Span<uint64_t>{uuids.data(), uuids.size()};
auto s_this_uuid = s_uuid.subspan(rank * kUuidLength, kUuidLength);
GetCudaUUID(s_this_uuid);
// TODO(rongou): replace this with allgather.
communicator_->AllReduce(uuids.data(), uuids.size(), DataType::kUInt64, Operation::kSum);
std::vector<xgboost::common::Span<uint64_t, kUuidLength>> converted(world);
size_t j = 0;
for (size_t i = 0; i < uuids.size(); i += kUuidLength) {
converted[j] = xgboost::common::Span<uint64_t, kUuidLength>{uuids.data() + i, kUuidLength};
j++;
}
auto iter = std::unique(converted.begin(), converted.end());
auto n_uniques = std::distance(converted.begin(), iter);
CHECK_EQ(n_uniques, world)
<< "Multiple processes within communication group running on same CUDA "
<< "device is not supported. " << PrintUUID(s_this_uuid) << "\n";
nccl_unique_id_ = GetUniqueId();
dh::safe_nccl(ncclCommInitRank(&nccl_comm_, world, nccl_unique_id_, rank));
dh::safe_cuda(cudaStreamCreate(&cuda_stream_));
}
~NcclDeviceCommunicator() override {
dh::safe_cuda(cudaStreamDestroy(cuda_stream_));
ncclCommDestroy(nccl_comm_);
if (xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug)) {
LOG(CONSOLE) << "======== NCCL Statistics========";
LOG(CONSOLE) << "AllReduce calls: " << allreduce_calls_;
LOG(CONSOLE) << "AllReduce total MiB communicated: " << allreduce_bytes_ / 1048576;
}
}
void AllReduceSum(double *send_receive_buffer, int count) override {
dh::safe_cuda(cudaSetDevice(device_ordinal_));
dh::safe_nccl(ncclAllReduce(send_receive_buffer, send_receive_buffer, count, ncclDouble,
ncclSum, nccl_comm_, cuda_stream_));
allreduce_bytes_ += count * sizeof(double);
allreduce_calls_ += 1;
}
void AllGatherV(void const *send_buffer, size_t length_bytes, std::vector<std::size_t> *segments,
dh::caching_device_vector<char> *receive_buffer) override {
dh::safe_cuda(cudaSetDevice(device_ordinal_));
int const world_size = communicator_->GetWorldSize();
int const rank = communicator_->GetRank();
segments->clear();
segments->resize(world_size, 0);
segments->at(rank) = length_bytes;
communicator_->AllReduce(segments->data(), segments->size(), DataType::kUInt64,
Operation::kMax);
auto total_bytes = std::accumulate(segments->cbegin(), segments->cend(), 0UL);
receive_buffer->resize(total_bytes);
size_t offset = 0;
dh::safe_nccl(ncclGroupStart());
for (int32_t i = 0; i < world_size; ++i) {
size_t as_bytes = segments->at(i);
dh::safe_nccl(ncclBroadcast(send_buffer, receive_buffer->data().get() + offset, as_bytes,
ncclChar, i, nccl_comm_, cuda_stream_));
offset += as_bytes;
}
dh::safe_nccl(ncclGroupEnd());
}
void Synchronize() override {
dh::safe_cuda(cudaSetDevice(device_ordinal_));
dh::safe_cuda(cudaStreamSynchronize(cuda_stream_));
}
private:
static constexpr std::size_t kUuidLength =
sizeof(std::declval<cudaDeviceProp>().uuid) / sizeof(uint64_t);
void GetCudaUUID(xgboost::common::Span<uint64_t, kUuidLength> const &uuid) const {
cudaDeviceProp prob{};
dh::safe_cuda(cudaGetDeviceProperties(&prob, device_ordinal_));
std::memcpy(uuid.data(), static_cast<void *>(&(prob.uuid)), sizeof(prob.uuid));
}
static std::string PrintUUID(xgboost::common::Span<uint64_t, kUuidLength> const &uuid) {
std::stringstream ss;
for (auto v : uuid) {
ss << std::hex << v;
}
return ss.str();
}
/**
* \fn ncclUniqueId GetUniqueId()
*
* \brief Gets the Unique ID from NCCL to be used in setting up interprocess
* communication
*
* \return the Unique ID
*/
ncclUniqueId GetUniqueId() {
static const int kRootRank = 0;
ncclUniqueId id;
if (communicator_->GetRank() == kRootRank) {
dh::safe_nccl(ncclGetUniqueId(&id));
}
communicator_->Broadcast(static_cast<void *>(&id), sizeof(ncclUniqueId),
static_cast<int>(kRootRank));
return id;
}
int const device_ordinal_;
Communicator *communicator_;
ncclComm_t nccl_comm_{};
cudaStream_t cuda_stream_{};
ncclUniqueId nccl_unique_id_{};
size_t allreduce_bytes_{0}; // Keep statistics of the number of bytes communicated.
size_t allreduce_calls_{0}; // Keep statistics of the number of reduce calls.
};
} // namespace collective
} // namespace xgboost

View File

@@ -0,0 +1,120 @@
/*!
* Copyright 2022 XGBoost contributors
*/
#pragma once
#include <rabit/rabit.h>
#include <string>
#include <vector>
#include "communicator.h"
#include "xgboost/json.h"
namespace xgboost {
namespace collective {
class RabitCommunicator : public Communicator {
public:
static Communicator *Create(Json const &config) {
std::vector<std::string> args_str;
for (auto &items : get<Object const>(config)) {
switch (items.second.GetValue().Type()) {
case xgboost::Value::ValueKind::kString: {
args_str.push_back(items.first + "=" + get<String const>(items.second));
break;
}
case xgboost::Value::ValueKind::kInteger: {
args_str.push_back(items.first + "=" + std::to_string(get<Integer const>(items.second)));
break;
}
case xgboost::Value::ValueKind::kBoolean: {
if (get<Boolean const>(items.second)) {
args_str.push_back(items.first + "=1");
} else {
args_str.push_back(items.first + "=0");
}
break;
}
default:
break;
}
}
std::vector<char *> args;
for (auto &key_value : args_str) {
args.push_back(&key_value[0]);
}
if (!rabit::Init(static_cast<int>(args.size()), &args[0])) {
LOG(FATAL) << "Failed to initialize Rabit";
}
return new RabitCommunicator(rabit::GetWorldSize(), rabit::GetRank());
}
RabitCommunicator(int world_size, int rank) : Communicator(world_size, rank) {}
bool IsDistributed() const override { return rabit::IsDistributed(); }
void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
Operation op) override {
switch (data_type) {
case DataType::kInt8:
DoAllReduce<char>(send_receive_buffer, count, op);
break;
case DataType::kUInt8:
DoAllReduce<unsigned char>(send_receive_buffer, count, op);
break;
case DataType::kInt32:
DoAllReduce<std::int32_t>(send_receive_buffer, count, op);
break;
case DataType::kUInt32:
DoAllReduce<std::uint32_t>(send_receive_buffer, count, op);
break;
case DataType::kInt64:
DoAllReduce<std::int64_t>(send_receive_buffer, count, op);
break;
case DataType::kUInt64:
DoAllReduce<std::uint64_t>(send_receive_buffer, count, op);
break;
case DataType::kFloat:
DoAllReduce<float>(send_receive_buffer, count, op);
break;
case DataType::kDouble:
DoAllReduce<double>(send_receive_buffer, count, op);
break;
default:
LOG(FATAL) << "Unknown data type";
}
}
void Broadcast(void *send_receive_buffer, std::size_t size, int root) override {
rabit::Broadcast(send_receive_buffer, size, root);
}
std::string GetProcessorName() override { return rabit::GetProcessorName(); }
void Print(const std::string &message) override { rabit::TrackerPrint(message); }
protected:
void Shutdown() override {
rabit::Finalize();
}
private:
template <typename DType>
void DoAllReduce(void *send_receive_buffer, std::size_t count, Operation op) {
switch (op) {
case Operation::kMax:
rabit::Allreduce<rabit::op::Max, DType>(static_cast<DType *>(send_receive_buffer), count);
break;
case Operation::kMin:
rabit::Allreduce<rabit::op::Min, DType>(static_cast<DType *>(send_receive_buffer), count);
break;
case Operation::kSum:
rabit::Allreduce<rabit::op::Sum, DType>(static_cast<DType *>(send_receive_buffer), count);
break;
default:
LOG(FATAL) << "Unknown allreduce operation";
}
}
};
} // namespace collective
} // namespace xgboost

View File

@@ -12,6 +12,7 @@
#include <rabit/rabit.h>
#include <string>
#include <cstring>
#include <fstream>
#include "common.h"
@@ -111,6 +112,22 @@ inline std::string ReadAll(dmlc::Stream* fi, PeekableInStream* fp) {
}
return buffer;
}
/**
* \brief Read the whole file content into a string.
*/
inline std::string ReadAll(std::string const &path) {
std::ifstream stream(path);
if (!stream.is_open()) {
LOG(FATAL) << "Could not open file " << path;
}
std::string content{std::istreambuf_iterator<char>(stream), std::istreambuf_iterator<char>()};
if (content.empty()) {
LOG(FATAL) << "Empty file " << path;
}
return content;
}
} // namespace common
} // namespace xgboost
#endif // XGBOOST_COMMON_IO_H_