temp merge, disable 1 line, SetValid

This commit is contained in:
Your Name
2023-10-12 16:16:44 -07:00
492 changed files with 15533 additions and 9376 deletions

View File

@@ -0,0 +1,40 @@
/**
* Copyright 2023 by XGBoost contributors
*
* Higher level functions built on top the Communicator API, taking care of behavioral differences
* between row-split vs column-split distributed training, and horizontal vs vertical federated
* learning.
*/
#pragma once
#include <xgboost/data.h>
#include <limits>
#include <string>
#include <utility>
#include <vector>
#include "communicator-inl.cuh"
namespace xgboost {
namespace collective {
/**
* @brief Find the global sum of the given values across all workers.
*
* This only applies when the data is split row-wise (horizontally). When data is split
* column-wise (vertically), the original values are returned.
*
* @tparam T The type of the values.
* @param info MetaInfo about the DMatrix.
* @param device The device id.
* @param values Pointer to the inputs to sum.
* @param size Number of values to sum.
*/
template <typename T>
void GlobalSum(MetaInfo const& info, int device, T* values, size_t size) {
if (info.IsRowSplit()) {
collective::AllReduce<collective::Operation::kSum>(device, values, size);
}
}
} // namespace collective
} // namespace xgboost

View File

@@ -26,7 +26,6 @@ namespace collective {
* applied there, with the results broadcast to other workers.
*
* @tparam Function The function used to calculate the results.
* @tparam Args Arguments to the function.
* @param info MetaInfo about the DMatrix.
* @param buffer The buffer storing the results.
* @param size The size of the buffer.
@@ -57,6 +56,52 @@ void ApplyWithLabels(MetaInfo const& info, void* buffer, size_t size, Function&&
}
}
/**
* @brief Apply the given function where the labels are.
*
* Normally all the workers have access to the labels, so the function is just applied locally. In
* vertical federated learning, we assume labels are only available on worker 0, so the function is
* applied there, with the results broadcast to other workers.
*
* @tparam T Type of the HostDeviceVector storing the results.
* @tparam Function The function used to calculate the results.
* @param info MetaInfo about the DMatrix.
* @param result The HostDeviceVector storing the results.
* @param function The function used to calculate the results.
*/
template <typename T, typename Function>
void ApplyWithLabels(MetaInfo const& info, HostDeviceVector<T>* result, Function&& function) {
if (info.IsVerticalFederated()) {
// We assume labels are only available on worker 0, so the calculation is done there and result
// broadcast to other workers.
std::string message;
if (collective::GetRank() == 0) {
try {
std::forward<Function>(function)();
} catch (dmlc::Error& e) {
message = e.what();
}
}
collective::Broadcast(&message, 0);
if (!message.empty()) {
LOG(FATAL) << &message[0];
return;
}
std::size_t size{};
if (collective::GetRank() == 0) {
size = result->Size();
}
collective::Broadcast(&size, sizeof(std::size_t), 0);
result->Resize(size);
collective::Broadcast(result->HostPointer(), size * sizeof(T), 0);
} else {
std::forward<Function>(function)();
}
}
/**
* @brief Find the global max of the given value across all workers.
*

View File

@@ -57,6 +57,20 @@ inline void AllReduce(int device, double *send_receive_buffer, size_t count) {
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kDouble, op);
}
/**
* @brief Gather values from all all processes.
*
* This assumes all ranks have the same size.
*
* @param send_buffer Buffer storing the data to be sent.
* @param receive_buffer Buffer storing the gathered data.
* @param send_size Size of the sent data in bytes.
*/
inline void AllGather(int device, void const *send_buffer, void *receive_buffer,
std::size_t send_size) {
Communicator::GetDevice(device)->AllGather(send_buffer, receive_buffer, send_size);
}
/**
* @brief Gather variable-length values from all processes.
* @param device ID of the device.

View File

@@ -41,7 +41,8 @@ void Communicator::Init(Json const& config) {
#endif
break;
}
case CommunicatorType::kInMemory: {
case CommunicatorType::kInMemory:
case CommunicatorType::kInMemoryNccl: {
communicator_.reset(InMemoryCommunicator::Create(config));
break;
}

View File

@@ -29,13 +29,22 @@ DeviceCommunicator* Communicator::GetDevice(int device_ordinal) {
old_device_ordinal = device_ordinal;
old_world_size = communicator_->GetWorldSize();
#if defined(XGBOOST_USE_NCCL) || defined(XGBOOST_USE_RCCL)
if (type_ != CommunicatorType::kFederated) {
device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal, Get()));
} else {
device_communicator_.reset(new DeviceCommunicatorAdapter(device_ordinal, Get()));
switch (type_) {
case CommunicatorType::kRabit:
device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal, false));
break;
case CommunicatorType::kFederated:
case CommunicatorType::kInMemory:
device_communicator_.reset(new DeviceCommunicatorAdapter(device_ordinal));
break;
case CommunicatorType::kInMemoryNccl:
device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal, true));
break;
default:
device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal, false));
}
#else
device_communicator_.reset(new DeviceCommunicatorAdapter(device_ordinal, Get()));
device_communicator_.reset(new DeviceCommunicatorAdapter(device_ordinal));
#endif
}
return device_communicator_.get();

View File

@@ -69,7 +69,7 @@ enum class Operation {
class DeviceCommunicator;
enum class CommunicatorType { kUnknown, kRabit, kFederated, kInMemory };
enum class CommunicatorType { kUnknown, kRabit, kFederated, kInMemory, kInMemoryNccl };
/** \brief Case-insensitive string comparison. */
inline int CompareStringsCaseInsensitive(const char *s1, const char *s2) {
@@ -220,6 +220,8 @@ class Communicator {
result = CommunicatorType::kFederated;
} else if (!CompareStringsCaseInsensitive("in-memory", str)) {
result = CommunicatorType::kInMemory;
} else if (!CompareStringsCaseInsensitive("in-memory-nccl", str)) {
result = CommunicatorType::kInMemoryNccl;
} else {
LOG(FATAL) << "Unknown communicator type " << str;
}

View File

@@ -27,6 +27,17 @@ class DeviceCommunicator {
virtual void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
Operation op) = 0;
/**
* @brief Gather values from all all processes.
*
* This assumes all ranks have the same size.
*
* @param send_buffer Buffer storing the data to be sent.
* @param receive_buffer Buffer storing the gathered data.
* @param send_size Size of the sent data in bytes.
*/
virtual void AllGather(void const *send_buffer, void *receive_buffer, std::size_t send_size) = 0;
/**
* @brief Gather variable-length values from all processes.
* @param send_buffer Buffer storing the input data.

View File

@@ -11,21 +11,18 @@ namespace collective {
class DeviceCommunicatorAdapter : public DeviceCommunicator {
public:
DeviceCommunicatorAdapter(int device_ordinal, Communicator *communicator)
: device_ordinal_{device_ordinal}, communicator_{communicator} {
explicit DeviceCommunicatorAdapter(int device_ordinal)
: device_ordinal_{device_ordinal}, world_size_{GetWorldSize()}, rank_{GetRank()} {
if (device_ordinal_ < 0) {
LOG(FATAL) << "Invalid device ordinal: " << device_ordinal_;
}
if (communicator_ == nullptr) {
LOG(FATAL) << "Communicator cannot be null.";
}
}
~DeviceCommunicatorAdapter() override = default;
void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
Operation op) override {
if (communicator_->GetWorldSize() == 1) {
if (world_size_ == 1) {
return;
}
@@ -35,62 +32,82 @@ class DeviceCommunicatorAdapter : public DeviceCommunicator {
dh::safe_cuda(hipSetDevice(device_ordinal_));
#endif
auto size = count * GetTypeSize(data_type);
host_buffer_.reserve(size);
host_buffer_.resize(size);
#if defined(XGBOOST_USE_CUDA)
dh::safe_cuda(cudaMemcpy(host_buffer_.data(), send_receive_buffer, size, cudaMemcpyDefault));
communicator_->AllReduce(host_buffer_.data(), count, data_type, op);
Allreduce(host_buffer_.data(), count, data_type, op);
dh::safe_cuda(cudaMemcpy(send_receive_buffer, host_buffer_.data(), size, cudaMemcpyDefault));
#elif defined(XGBOOST_USE_HIP)
dh::safe_cuda(hipMemcpy(host_buffer_.data(), send_receive_buffer, size, hipMemcpyDefault));
communicator_->AllReduce(host_buffer_.data(), count, data_type, op);
AllReduce(host_buffer_.data(), count, data_type, op);
dh::safe_cuda(hipMemcpy(send_receive_buffer, host_buffer_.data(), size, hipMemcpyDefault));
#endif
}
void AllGather(void const *send_buffer, void *receive_buffer, std::size_t send_size) override {
if (world_size_ == 1) {
return;
}
#if defined(XGBOOST_USE_CUDA)
dh::safe_cuda(cudaSetDevice(device_ordinal_));
host_buffer_.resize(send_size * world_size_);
dh::safe_cuda(cudaMemcpy(host_buffer_.data() + rank_ * send_size, send_buffer, send_size,
cudaMemcpyDefault));
Allgather(host_buffer_.data(), host_buffer_.size());
dh::safe_cuda(
cudaMemcpy(receive_buffer, host_buffer_.data(), host_buffer_.size(), cudaMemcpyDefault));
#elif defined(XGBOOST_USE_HIP)
dh::safe_cuda(hipSetDevice(device_ordinal_));
host_buffer_.resize(send_size * world_size_);
dh::safe_cuda(hipMemcpy(host_buffer_.data() + rank_ * send_size, send_buffer, send_size,
hipMemcpyDefault));
Allgather(host_buffer_.data(), host_buffer_.size());
dh::safe_cuda(
hipMemcpy(receive_buffer, host_buffer_.data(), host_buffer_.size(), hipMemcpyDefault));
#endif
}
void AllGatherV(void const *send_buffer, size_t length_bytes, std::vector<std::size_t> *segments,
dh::caching_device_vector<char> *receive_buffer) override {
if (communicator_->GetWorldSize() == 1) {
if (world_size_ == 1) {
return;
}
#if defined(XGBOOST_USE_HIP)
dh::safe_cuda(hipSetDevice(device_ordinal_));
#else
#elif defined(XGBOOST_USE_CUDA)
dh::safe_cuda(cudaSetDevice(device_ordinal_));
#endif
int const world_size = communicator_->GetWorldSize();
int const rank = communicator_->GetRank();
segments->clear();
segments->resize(world_size, 0);
segments->at(rank) = length_bytes;
communicator_->AllReduce(segments->data(), segments->size(), DataType::kUInt64,
Operation::kMax);
segments->resize(world_size_, 0);
segments->at(rank_) = length_bytes;
Allreduce(segments->data(), segments->size(), DataType::kUInt64, Operation::kMax);
auto total_bytes = std::accumulate(segments->cbegin(), segments->cend(), 0UL);
receive_buffer->resize(total_bytes);
host_buffer_.reserve(total_bytes);
host_buffer_.resize(total_bytes);
size_t offset = 0;
for (int32_t i = 0; i < world_size; ++i) {
for (int32_t i = 0; i < world_size_; ++i) {
size_t as_bytes = segments->at(i);
if (i == rank) {
#if defined(XGBOOST_USE_HIP)
dh::safe_cuda(hipMemcpy(host_buffer_.data() + offset, send_buffer, segments->at(rank),
hipMemcpyDefault));
#else
dh::safe_cuda(cudaMemcpy(host_buffer_.data() + offset, send_buffer, segments->at(rank),
if (i == rank_) {
#if defined(XGBOOST_USE_CUDA)
dh::safe_cuda(cudaMemcpy(host_buffer_.data() + offset, send_buffer, segments->at(rank_),
cudaMemcpyDefault));
#elif defined(XGBOOST_USE_HIP)
dh::safe_cuda(hipMemcpy(host_buffer_.data() + offset, send_buffer, segments->at(rank_),
hipMemcpyDefault));
#endif
}
communicator_->Broadcast(host_buffer_.data() + offset, as_bytes, i);
Broadcast(host_buffer_.data() + offset, as_bytes, i);
offset += as_bytes;
}
#if defined(XGBOOST_USE_HIP)
dh::safe_cuda(hipMemcpy(receive_buffer->data().get(), host_buffer_.data(), total_bytes,
hipMemcpyDefault));
#else
#elif defined(XGBOOST_USE_CUDA)
dh::safe_cuda(cudaMemcpy(receive_buffer->data().get(), host_buffer_.data(), total_bytes,
cudaMemcpyDefault));
#endif
@@ -102,7 +119,8 @@ class DeviceCommunicatorAdapter : public DeviceCommunicator {
private:
int const device_ordinal_;
Communicator *communicator_;
int const world_size_;
int const rank_;
/// Host buffer used to call communicator functions.
std::vector<char> host_buffer_{};
};

View File

@@ -0,0 +1,229 @@
/*!
* Copyright 2023 XGBoost contributors
*/
#if defined(XGBOOST_USE_NCCL) || defined(XGBOOST_USE_RCCL)
#include "nccl_device_communicator.cuh"
namespace xgboost {
namespace collective {
NcclDeviceCommunicator::NcclDeviceCommunicator(int device_ordinal, bool needs_sync)
: device_ordinal_{device_ordinal},
needs_sync_{needs_sync},
world_size_{GetWorldSize()},
rank_{GetRank()} {
if (device_ordinal_ < 0) {
LOG(FATAL) << "Invalid device ordinal: " << device_ordinal_;
}
if (world_size_ == 1) {
return;
}
std::vector<uint64_t> uuids(world_size_ * kUuidLength, 0);
auto s_uuid = xgboost::common::Span<uint64_t>{uuids.data(), uuids.size()};
auto s_this_uuid = s_uuid.subspan(rank_ * kUuidLength, kUuidLength);
GetCudaUUID(s_this_uuid);
// TODO(rongou): replace this with allgather.
Allreduce(uuids.data(), uuids.size(), DataType::kUInt64, Operation::kSum);
std::vector<xgboost::common::Span<uint64_t, kUuidLength>> converted(world_size_);
size_t j = 0;
for (size_t i = 0; i < uuids.size(); i += kUuidLength) {
converted[j] = xgboost::common::Span<uint64_t, kUuidLength>{uuids.data() + i, kUuidLength};
j++;
}
auto iter = std::unique(converted.begin(), converted.end());
auto n_uniques = std::distance(converted.begin(), iter);
CHECK_EQ(n_uniques, world_size_)
<< "Multiple processes within communication group running on same CUDA "
<< "device is not supported. " << PrintUUID(s_this_uuid) << "\n";
nccl_unique_id_ = GetUniqueId();
dh::safe_cuda(cudaSetDevice(device_ordinal_));
dh::safe_nccl(ncclCommInitRank(&nccl_comm_, world_size_, nccl_unique_id_, rank_));
}
NcclDeviceCommunicator::~NcclDeviceCommunicator() {
if (world_size_ == 1) {
return;
}
if (nccl_comm_) {
dh::safe_nccl(ncclCommDestroy(nccl_comm_));
}
if (xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug)) {
LOG(CONSOLE) << "======== NCCL Statistics========";
LOG(CONSOLE) << "AllReduce calls: " << allreduce_calls_;
LOG(CONSOLE) << "AllReduce total MiB communicated: " << allreduce_bytes_ / 1048576;
}
}
namespace {
ncclDataType_t GetNcclDataType(DataType const &data_type) {
ncclDataType_t result{ncclInt8};
switch (data_type) {
case DataType::kInt8:
result = ncclInt8;
break;
case DataType::kUInt8:
result = ncclUint8;
break;
case DataType::kInt32:
result = ncclInt32;
break;
case DataType::kUInt32:
result = ncclUint32;
break;
case DataType::kInt64:
result = ncclInt64;
break;
case DataType::kUInt64:
result = ncclUint64;
break;
case DataType::kFloat:
result = ncclFloat;
break;
case DataType::kDouble:
result = ncclDouble;
break;
default:
LOG(FATAL) << "Unknown data type.";
}
return result;
}
bool IsBitwiseOp(Operation const &op) {
return op == Operation::kBitwiseAND || op == Operation::kBitwiseOR ||
op == Operation::kBitwiseXOR;
}
ncclRedOp_t GetNcclRedOp(Operation const &op) {
ncclRedOp_t result{ncclMax};
switch (op) {
case Operation::kMax:
result = ncclMax;
break;
case Operation::kMin:
result = ncclMin;
break;
case Operation::kSum:
result = ncclSum;
break;
default:
LOG(FATAL) << "Unsupported reduce operation.";
}
return result;
}
template <typename Func>
void RunBitwiseAllreduce(char *out_buffer, char const *device_buffer, Func func, int world_size,
std::size_t size) {
dh::LaunchN(size, [=] __device__(std::size_t idx) {
auto result = device_buffer[idx];
for (auto rank = 1; rank < world_size; rank++) {
result = func(result, device_buffer[rank * size + idx]);
}
out_buffer[idx] = result;
});
}
} // anonymous namespace
void NcclDeviceCommunicator::BitwiseAllReduce(void *send_receive_buffer, std::size_t count,
DataType data_type, Operation op) {
auto const size = count * GetTypeSize(data_type);
dh::caching_device_vector<char> buffer(size * world_size_);
auto *device_buffer = buffer.data().get();
// First gather data from all the workers.
dh::safe_nccl(ncclAllGather(send_receive_buffer, device_buffer, count, GetNcclDataType(data_type),
nccl_comm_, dh::DefaultStream()));
if (needs_sync_) {
dh::DefaultStream().Sync();
}
// Then reduce locally.
auto *out_buffer = static_cast<char *>(send_receive_buffer);
switch (op) {
case Operation::kBitwiseAND:
RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_and<char>(), world_size_, size);
break;
case Operation::kBitwiseOR:
RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_or<char>(), world_size_, size);
break;
case Operation::kBitwiseXOR:
RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_xor<char>(), world_size_, size);
break;
default:
LOG(FATAL) << "Not a bitwise reduce operation.";
}
}
void NcclDeviceCommunicator::AllReduce(void *send_receive_buffer, std::size_t count,
DataType data_type, Operation op) {
if (world_size_ == 1) {
return;
}
dh::safe_cuda(cudaSetDevice(device_ordinal_));
if (IsBitwiseOp(op)) {
BitwiseAllReduce(send_receive_buffer, count, data_type, op);
} else {
dh::safe_nccl(ncclAllReduce(send_receive_buffer, send_receive_buffer, count,
GetNcclDataType(data_type), GetNcclRedOp(op), nccl_comm_,
dh::DefaultStream()));
}
allreduce_bytes_ += count * GetTypeSize(data_type);
allreduce_calls_ += 1;
}
void NcclDeviceCommunicator::AllGather(void const *send_buffer, void *receive_buffer,
std::size_t send_size) {
if (world_size_ == 1) {
return;
}
dh::safe_cuda(cudaSetDevice(device_ordinal_));
dh::safe_nccl(ncclAllGather(send_buffer, receive_buffer, send_size, ncclInt8, nccl_comm_,
dh::DefaultStream()));
}
void NcclDeviceCommunicator::AllGatherV(void const *send_buffer, size_t length_bytes,
std::vector<std::size_t> *segments,
dh::caching_device_vector<char> *receive_buffer) {
if (world_size_ == 1) {
return;
}
dh::safe_cuda(cudaSetDevice(device_ordinal_));
segments->clear();
segments->resize(world_size_, 0);
segments->at(rank_) = length_bytes;
Allreduce(segments->data(), segments->size(), DataType::kUInt64, Operation::kMax);
auto total_bytes = std::accumulate(segments->cbegin(), segments->cend(), 0UL);
receive_buffer->resize(total_bytes);
size_t offset = 0;
dh::safe_nccl(ncclGroupStart());
for (int32_t i = 0; i < world_size_; ++i) {
size_t as_bytes = segments->at(i);
dh::safe_nccl(ncclBroadcast(send_buffer, receive_buffer->data().get() + offset, as_bytes,
ncclChar, i, nccl_comm_, dh::DefaultStream()));
offset += as_bytes;
}
dh::safe_nccl(ncclGroupEnd());
}
void NcclDeviceCommunicator::Synchronize() {
if (world_size_ == 1) {
return;
}
dh::safe_cuda(cudaSetDevice(device_ordinal_));
dh::DefaultStream().Sync();
}
} // namespace collective
} // namespace xgboost
#endif

View File

@@ -1,5 +1,5 @@
/*!
* Copyright 2022 XGBoost contributors
* Copyright 2022-2023 XGBoost contributors
*/
#pragma once
@@ -12,136 +12,27 @@ namespace collective {
class NcclDeviceCommunicator : public DeviceCommunicator {
public:
NcclDeviceCommunicator(int device_ordinal, Communicator *communicator)
: device_ordinal_{device_ordinal}, communicator_{communicator} {
if (device_ordinal_ < 0) {
LOG(FATAL) << "Invalid device ordinal: " << device_ordinal_;
}
if (communicator_ == nullptr) {
LOG(FATAL) << "Communicator cannot be null.";
}
int32_t const rank = communicator_->GetRank();
int32_t const world = communicator_->GetWorldSize();
if (world == 1) {
return;
}
std::vector<uint64_t> uuids(world * kUuidLength, 0);
auto s_uuid = xgboost::common::Span<uint64_t>{uuids.data(), uuids.size()};
auto s_this_uuid = s_uuid.subspan(rank * kUuidLength, kUuidLength);
GetCudaUUID(s_this_uuid);
// TODO(rongou): replace this with allgather.
communicator_->AllReduce(uuids.data(), uuids.size(), DataType::kUInt64, Operation::kSum);
std::vector<xgboost::common::Span<uint64_t, kUuidLength>> converted(world);
size_t j = 0;
for (size_t i = 0; i < uuids.size(); i += kUuidLength) {
converted[j] = xgboost::common::Span<uint64_t, kUuidLength>{uuids.data() + i, kUuidLength};
j++;
}
auto iter = std::unique(converted.begin(), converted.end());
auto n_uniques = std::distance(converted.begin(), iter);
CHECK_EQ(n_uniques, world)
<< "Multiple processes within communication group running on same CUDA "
<< "device is not supported. " << PrintUUID(s_this_uuid) << "\n";
nccl_unique_id_ = GetUniqueId();
dh::safe_nccl(ncclCommInitRank(&nccl_comm_, world, nccl_unique_id_, rank));
#if defined(XGBOOST_USE_HIP)
dh::safe_cuda(hipStreamCreate(&cuda_stream_));
#else
dh::safe_cuda(cudaStreamCreate(&cuda_stream_));
#endif
}
~NcclDeviceCommunicator() override {
if (communicator_->GetWorldSize() == 1) {
return;
}
if (cuda_stream_) {
#if defined(XGBOOST_USE_HIP)
dh::safe_cuda(hipStreamDestroy(cuda_stream_));
#else
dh::safe_cuda(cudaStreamDestroy(cuda_stream_));
#endif
}
if (nccl_comm_) {
dh::safe_nccl(ncclCommDestroy(nccl_comm_));
}
if (xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug)) {
LOG(CONSOLE) << "======== NCCL Statistics========";
LOG(CONSOLE) << "AllReduce calls: " << allreduce_calls_;
LOG(CONSOLE) << "AllReduce total MiB communicated: " << allreduce_bytes_ / 1048576;
}
}
/**
* @brief Construct a new NCCL communicator.
* @param device_ordinal The GPU device id.
* @param needs_sync Whether extra CUDA stream synchronization is needed.
*
* In multi-GPU tests when multiple NCCL communicators are created in the same process, sometimes
* a deadlock happens because NCCL kernels are blocking. The extra CUDA stream synchronization
* makes sure that the NCCL kernels are caught up, thus avoiding the deadlock.
*
* The Rabit communicator runs with one process per GPU, so the additional synchronization is not
* needed. The in-memory communicator is used in tests with multiple threads, each thread
* representing a rank/worker, so the additional synchronization is needed to avoid deadlocks.
*/
explicit NcclDeviceCommunicator(int device_ordinal, bool needs_sync);
~NcclDeviceCommunicator() override;
void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
Operation op) override {
if (communicator_->GetWorldSize() == 1) {
return;
}
dh::safe_cuda(cudaSetDevice(device_ordinal_));
dh::safe_nccl(ncclAllReduce(send_receive_buffer, send_receive_buffer, count,
GetNcclDataType(data_type), GetNcclRedOp(op), nccl_comm_,
cuda_stream_));
allreduce_bytes_ += count * GetTypeSize(data_type);
allreduce_calls_ += 1;
}
Operation op) override;
void AllGather(void const *send_buffer, void *receive_buffer, std::size_t send_size) override;
void AllGatherV(void const *send_buffer, size_t length_bytes, std::vector<std::size_t> *segments,
dh::caching_device_vector<char> *receive_buffer) override {
if (communicator_->GetWorldSize() == 1) {
return;
}
#if defined(XGBOOST_USE_HIP)
dh::safe_cuda(hipSetDevice(device_ordinal_));
#else
dh::safe_cuda(cudaSetDevice(device_ordinal_));
#endif
int const world_size = communicator_->GetWorldSize();
int const rank = communicator_->GetRank();
segments->clear();
segments->resize(world_size, 0);
segments->at(rank) = length_bytes;
communicator_->AllReduce(segments->data(), segments->size(), DataType::kUInt64,
Operation::kMax);
auto total_bytes = std::accumulate(segments->cbegin(), segments->cend(), 0UL);
receive_buffer->resize(total_bytes);
size_t offset = 0;
dh::safe_nccl(ncclGroupStart());
for (int32_t i = 0; i < world_size; ++i) {
size_t as_bytes = segments->at(i);
dh::safe_nccl(ncclBroadcast(send_buffer, receive_buffer->data().get() + offset, as_bytes,
ncclChar, i, nccl_comm_, cuda_stream_));
offset += as_bytes;
}
dh::safe_nccl(ncclGroupEnd());
}
void Synchronize() override {
if (communicator_->GetWorldSize() == 1) {
return;
}
#if defined(XGBOOST_USE_HIP)
dh::safe_cuda(hipSetDevice(device_ordinal_));
dh::safe_cuda(hipStreamSynchronize(cuda_stream_));
#else
dh::safe_cuda(cudaSetDevice(device_ordinal_));
dh::safe_cuda(cudaStreamSynchronize(cuda_stream_));
#endif
}
dh::caching_device_vector<char> *receive_buffer) override;
void Synchronize() override;
private:
static constexpr std::size_t kUuidLength =
@@ -182,79 +73,21 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
ncclUniqueId GetUniqueId() {
static const int kRootRank = 0;
ncclUniqueId id;
if (communicator_->GetRank() == kRootRank) {
if (rank_ == kRootRank) {
dh::safe_nccl(ncclGetUniqueId(&id));
}
communicator_->Broadcast(static_cast<void *>(&id), sizeof(ncclUniqueId),
static_cast<int>(kRootRank));
Broadcast(static_cast<void *>(&id), sizeof(ncclUniqueId), static_cast<int>(kRootRank));
return id;
}
static ncclDataType_t GetNcclDataType(DataType const &data_type) {
ncclDataType_t result;
switch (data_type) {
case DataType::kInt8:
result = ncclInt8;
break;
case DataType::kUInt8:
result = ncclUint8;
break;
case DataType::kInt32:
result = ncclInt32;
break;
case DataType::kUInt32:
result = ncclUint32;
break;
case DataType::kInt64:
result = ncclInt64;
break;
case DataType::kUInt64:
result = ncclUint64;
break;
case DataType::kFloat:
result = ncclFloat;
break;
case DataType::kDouble:
result = ncclDouble;
break;
default:
LOG(FATAL) << "Unknown data type.";
}
return result;
}
static ncclRedOp_t GetNcclRedOp(Operation const &op) {
ncclRedOp_t result;
switch (op) {
case Operation::kMax:
result = ncclMax;
break;
case Operation::kMin:
result = ncclMin;
break;
case Operation::kSum:
result = ncclSum;
break;
case Operation::kBitwiseAND:
case Operation::kBitwiseOR:
case Operation::kBitwiseXOR:
LOG(FATAL) << "Not implemented yet.";
default:
LOG(FATAL) << "Unknown reduce operation.";
}
return result;
}
void BitwiseAllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
Operation op);
int const device_ordinal_;
Communicator *communicator_;
bool const needs_sync_;
int const world_size_;
int const rank_;
ncclComm_t nccl_comm_{};
#if defined(XGBOOST_USE_HIP)
hipStream_t cuda_stream_{};
#else
cudaStream_t cuda_stream_{};
#endif
ncclUniqueId nccl_unique_id_{};
size_t allreduce_bytes_{0}; // Keep statistics of the number of bytes communicated.
size_t allreduce_calls_{0}; // Keep statistics of the number of reduce calls.

View File

@@ -1,19 +1,22 @@
/*!
* Copyright (c) 2022 by XGBoost Contributors
/**
* Copyright 2022-2023 by XGBoost Contributors
*/
#include "xgboost/collective/socket.h"
#include <cstddef> // std::size_t
#include <cstdint> // std::int32_t
#include <cstring> // std::memcpy, std::memset
#include <filesystem> // for path
#include <system_error> // std::error_code, std::system_category
#include "rabit/internal/socket.h" // for PollHelper
#include "xgboost/collective/result.h" // for Result
#if defined(__unix__) || defined(__APPLE__)
#include <netdb.h> // getaddrinfo, freeaddrinfo
#endif // defined(__unix__) || defined(__APPLE__)
namespace xgboost {
namespace collective {
namespace xgboost::collective {
SockAddress MakeSockAddress(StringView host, in_port_t port) {
struct addrinfo hints;
std::memset(&hints, 0, sizeof(hints));
@@ -71,7 +74,12 @@ std::size_t TCPSocket::Recv(std::string *p_str) {
return bytes;
}
std::error_code Connect(SockAddress const &addr, TCPSocket *out) {
[[nodiscard]] Result Connect(xgboost::StringView host, std::int32_t port, std::int32_t retry,
std::chrono::seconds timeout,
xgboost::collective::TCPSocket *out_conn) {
auto addr = MakeSockAddress(xgboost::StringView{host}, port);
auto &conn = *out_conn;
sockaddr const *addr_handle{nullptr};
socklen_t addr_len{0};
if (addr.IsV4()) {
@@ -81,14 +89,67 @@ std::error_code Connect(SockAddress const &addr, TCPSocket *out) {
addr_handle = reinterpret_cast<const sockaddr *>(&addr.V6().Handle());
addr_len = sizeof(addr.V6().Handle());
}
auto socket = TCPSocket::Create(addr.Domain());
CHECK_EQ(static_cast<std::int32_t>(socket.Domain()), static_cast<std::int32_t>(addr.Domain()));
auto rc = connect(socket.Handle(), addr_handle, addr_len);
if (rc != 0) {
return std::error_code{errno, std::system_category()};
conn = TCPSocket::Create(addr.Domain());
CHECK_EQ(static_cast<std::int32_t>(conn.Domain()), static_cast<std::int32_t>(addr.Domain()));
conn.SetNonBlock(true);
Result last_error;
auto log_failure = [&host, &last_error](Result err, char const *file, std::int32_t line) {
last_error = std::move(err);
LOG(WARNING) << std::filesystem::path{file}.filename().string() << "(" << line
<< "): Failed to connect to:" << host << " Error:" << last_error.Report();
};
for (std::int32_t attempt = 0; attempt < std::max(retry, 1); ++attempt) {
if (attempt > 0) {
LOG(WARNING) << "Retrying connection to " << host << " for the " << attempt << " time.";
#if defined(_MSC_VER) || defined(__MINGW32__)
Sleep(attempt << 1);
#else
sleep(attempt << 1);
#endif
}
auto rc = connect(conn.Handle(), addr_handle, addr_len);
if (rc != 0) {
auto errcode = system::LastError();
if (!system::ErrorWouldBlock(errcode)) {
log_failure(Fail("connect failed.", std::error_code{errcode, std::system_category()}),
__FILE__, __LINE__);
continue;
}
rabit::utils::PollHelper poll;
poll.WatchWrite(conn);
auto result = poll.Poll(timeout);
if (!result.OK()) {
log_failure(std::move(result), __FILE__, __LINE__);
continue;
}
if (!poll.CheckWrite(conn)) {
log_failure(Fail("poll failed.", std::error_code{errcode, std::system_category()}),
__FILE__, __LINE__);
continue;
}
result = conn.GetSockError();
if (!result.OK()) {
log_failure(std::move(result), __FILE__, __LINE__);
continue;
}
conn.SetNonBlock(false);
return Success();
} else {
conn.SetNonBlock(false);
return Success();
}
}
*out = std::move(socket);
return std::make_error_code(std::errc{});
std::stringstream ss;
ss << "Failed to connect to " << host << ":" << port;
conn.Close();
return Fail(ss.str(), std::move(last_error));
}
} // namespace collective
} // namespace xgboost
} // namespace xgboost::collective