temp merge, disable 1 line, SetValid
This commit is contained in:
40
src/collective/aggregator.cuh
Normal file
40
src/collective/aggregator.cuh
Normal file
@@ -0,0 +1,40 @@
|
||||
/**
|
||||
* Copyright 2023 by XGBoost contributors
|
||||
*
|
||||
* Higher level functions built on top the Communicator API, taking care of behavioral differences
|
||||
* between row-split vs column-split distributed training, and horizontal vs vertical federated
|
||||
* learning.
|
||||
*/
|
||||
#pragma once
|
||||
#include <xgboost/data.h>
|
||||
|
||||
#include <limits>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "communicator-inl.cuh"
|
||||
|
||||
namespace xgboost {
|
||||
namespace collective {
|
||||
|
||||
/**
|
||||
* @brief Find the global sum of the given values across all workers.
|
||||
*
|
||||
* This only applies when the data is split row-wise (horizontally). When data is split
|
||||
* column-wise (vertically), the original values are returned.
|
||||
*
|
||||
* @tparam T The type of the values.
|
||||
* @param info MetaInfo about the DMatrix.
|
||||
* @param device The device id.
|
||||
* @param values Pointer to the inputs to sum.
|
||||
* @param size Number of values to sum.
|
||||
*/
|
||||
template <typename T>
|
||||
void GlobalSum(MetaInfo const& info, int device, T* values, size_t size) {
|
||||
if (info.IsRowSplit()) {
|
||||
collective::AllReduce<collective::Operation::kSum>(device, values, size);
|
||||
}
|
||||
}
|
||||
} // namespace collective
|
||||
} // namespace xgboost
|
||||
@@ -26,7 +26,6 @@ namespace collective {
|
||||
* applied there, with the results broadcast to other workers.
|
||||
*
|
||||
* @tparam Function The function used to calculate the results.
|
||||
* @tparam Args Arguments to the function.
|
||||
* @param info MetaInfo about the DMatrix.
|
||||
* @param buffer The buffer storing the results.
|
||||
* @param size The size of the buffer.
|
||||
@@ -57,6 +56,52 @@ void ApplyWithLabels(MetaInfo const& info, void* buffer, size_t size, Function&&
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Apply the given function where the labels are.
|
||||
*
|
||||
* Normally all the workers have access to the labels, so the function is just applied locally. In
|
||||
* vertical federated learning, we assume labels are only available on worker 0, so the function is
|
||||
* applied there, with the results broadcast to other workers.
|
||||
*
|
||||
* @tparam T Type of the HostDeviceVector storing the results.
|
||||
* @tparam Function The function used to calculate the results.
|
||||
* @param info MetaInfo about the DMatrix.
|
||||
* @param result The HostDeviceVector storing the results.
|
||||
* @param function The function used to calculate the results.
|
||||
*/
|
||||
template <typename T, typename Function>
|
||||
void ApplyWithLabels(MetaInfo const& info, HostDeviceVector<T>* result, Function&& function) {
|
||||
if (info.IsVerticalFederated()) {
|
||||
// We assume labels are only available on worker 0, so the calculation is done there and result
|
||||
// broadcast to other workers.
|
||||
std::string message;
|
||||
if (collective::GetRank() == 0) {
|
||||
try {
|
||||
std::forward<Function>(function)();
|
||||
} catch (dmlc::Error& e) {
|
||||
message = e.what();
|
||||
}
|
||||
}
|
||||
|
||||
collective::Broadcast(&message, 0);
|
||||
if (!message.empty()) {
|
||||
LOG(FATAL) << &message[0];
|
||||
return;
|
||||
}
|
||||
|
||||
std::size_t size{};
|
||||
if (collective::GetRank() == 0) {
|
||||
size = result->Size();
|
||||
}
|
||||
collective::Broadcast(&size, sizeof(std::size_t), 0);
|
||||
|
||||
result->Resize(size);
|
||||
collective::Broadcast(result->HostPointer(), size * sizeof(T), 0);
|
||||
} else {
|
||||
std::forward<Function>(function)();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Find the global max of the given value across all workers.
|
||||
*
|
||||
|
||||
@@ -57,6 +57,20 @@ inline void AllReduce(int device, double *send_receive_buffer, size_t count) {
|
||||
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kDouble, op);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Gather values from all all processes.
|
||||
*
|
||||
* This assumes all ranks have the same size.
|
||||
*
|
||||
* @param send_buffer Buffer storing the data to be sent.
|
||||
* @param receive_buffer Buffer storing the gathered data.
|
||||
* @param send_size Size of the sent data in bytes.
|
||||
*/
|
||||
inline void AllGather(int device, void const *send_buffer, void *receive_buffer,
|
||||
std::size_t send_size) {
|
||||
Communicator::GetDevice(device)->AllGather(send_buffer, receive_buffer, send_size);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Gather variable-length values from all processes.
|
||||
* @param device ID of the device.
|
||||
|
||||
@@ -41,7 +41,8 @@ void Communicator::Init(Json const& config) {
|
||||
#endif
|
||||
break;
|
||||
}
|
||||
case CommunicatorType::kInMemory: {
|
||||
case CommunicatorType::kInMemory:
|
||||
case CommunicatorType::kInMemoryNccl: {
|
||||
communicator_.reset(InMemoryCommunicator::Create(config));
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -29,13 +29,22 @@ DeviceCommunicator* Communicator::GetDevice(int device_ordinal) {
|
||||
old_device_ordinal = device_ordinal;
|
||||
old_world_size = communicator_->GetWorldSize();
|
||||
#if defined(XGBOOST_USE_NCCL) || defined(XGBOOST_USE_RCCL)
|
||||
if (type_ != CommunicatorType::kFederated) {
|
||||
device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal, Get()));
|
||||
} else {
|
||||
device_communicator_.reset(new DeviceCommunicatorAdapter(device_ordinal, Get()));
|
||||
switch (type_) {
|
||||
case CommunicatorType::kRabit:
|
||||
device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal, false));
|
||||
break;
|
||||
case CommunicatorType::kFederated:
|
||||
case CommunicatorType::kInMemory:
|
||||
device_communicator_.reset(new DeviceCommunicatorAdapter(device_ordinal));
|
||||
break;
|
||||
case CommunicatorType::kInMemoryNccl:
|
||||
device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal, true));
|
||||
break;
|
||||
default:
|
||||
device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal, false));
|
||||
}
|
||||
#else
|
||||
device_communicator_.reset(new DeviceCommunicatorAdapter(device_ordinal, Get()));
|
||||
device_communicator_.reset(new DeviceCommunicatorAdapter(device_ordinal));
|
||||
#endif
|
||||
}
|
||||
return device_communicator_.get();
|
||||
|
||||
@@ -69,7 +69,7 @@ enum class Operation {
|
||||
|
||||
class DeviceCommunicator;
|
||||
|
||||
enum class CommunicatorType { kUnknown, kRabit, kFederated, kInMemory };
|
||||
enum class CommunicatorType { kUnknown, kRabit, kFederated, kInMemory, kInMemoryNccl };
|
||||
|
||||
/** \brief Case-insensitive string comparison. */
|
||||
inline int CompareStringsCaseInsensitive(const char *s1, const char *s2) {
|
||||
@@ -220,6 +220,8 @@ class Communicator {
|
||||
result = CommunicatorType::kFederated;
|
||||
} else if (!CompareStringsCaseInsensitive("in-memory", str)) {
|
||||
result = CommunicatorType::kInMemory;
|
||||
} else if (!CompareStringsCaseInsensitive("in-memory-nccl", str)) {
|
||||
result = CommunicatorType::kInMemoryNccl;
|
||||
} else {
|
||||
LOG(FATAL) << "Unknown communicator type " << str;
|
||||
}
|
||||
|
||||
@@ -27,6 +27,17 @@ class DeviceCommunicator {
|
||||
virtual void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
|
||||
Operation op) = 0;
|
||||
|
||||
/**
|
||||
* @brief Gather values from all all processes.
|
||||
*
|
||||
* This assumes all ranks have the same size.
|
||||
*
|
||||
* @param send_buffer Buffer storing the data to be sent.
|
||||
* @param receive_buffer Buffer storing the gathered data.
|
||||
* @param send_size Size of the sent data in bytes.
|
||||
*/
|
||||
virtual void AllGather(void const *send_buffer, void *receive_buffer, std::size_t send_size) = 0;
|
||||
|
||||
/**
|
||||
* @brief Gather variable-length values from all processes.
|
||||
* @param send_buffer Buffer storing the input data.
|
||||
|
||||
@@ -11,21 +11,18 @@ namespace collective {
|
||||
|
||||
class DeviceCommunicatorAdapter : public DeviceCommunicator {
|
||||
public:
|
||||
DeviceCommunicatorAdapter(int device_ordinal, Communicator *communicator)
|
||||
: device_ordinal_{device_ordinal}, communicator_{communicator} {
|
||||
explicit DeviceCommunicatorAdapter(int device_ordinal)
|
||||
: device_ordinal_{device_ordinal}, world_size_{GetWorldSize()}, rank_{GetRank()} {
|
||||
if (device_ordinal_ < 0) {
|
||||
LOG(FATAL) << "Invalid device ordinal: " << device_ordinal_;
|
||||
}
|
||||
if (communicator_ == nullptr) {
|
||||
LOG(FATAL) << "Communicator cannot be null.";
|
||||
}
|
||||
}
|
||||
|
||||
~DeviceCommunicatorAdapter() override = default;
|
||||
|
||||
void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
|
||||
Operation op) override {
|
||||
if (communicator_->GetWorldSize() == 1) {
|
||||
if (world_size_ == 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -35,62 +32,82 @@ class DeviceCommunicatorAdapter : public DeviceCommunicator {
|
||||
dh::safe_cuda(hipSetDevice(device_ordinal_));
|
||||
#endif
|
||||
auto size = count * GetTypeSize(data_type);
|
||||
host_buffer_.reserve(size);
|
||||
host_buffer_.resize(size);
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
dh::safe_cuda(cudaMemcpy(host_buffer_.data(), send_receive_buffer, size, cudaMemcpyDefault));
|
||||
communicator_->AllReduce(host_buffer_.data(), count, data_type, op);
|
||||
Allreduce(host_buffer_.data(), count, data_type, op);
|
||||
dh::safe_cuda(cudaMemcpy(send_receive_buffer, host_buffer_.data(), size, cudaMemcpyDefault));
|
||||
#elif defined(XGBOOST_USE_HIP)
|
||||
dh::safe_cuda(hipMemcpy(host_buffer_.data(), send_receive_buffer, size, hipMemcpyDefault));
|
||||
communicator_->AllReduce(host_buffer_.data(), count, data_type, op);
|
||||
AllReduce(host_buffer_.data(), count, data_type, op);
|
||||
dh::safe_cuda(hipMemcpy(send_receive_buffer, host_buffer_.data(), size, hipMemcpyDefault));
|
||||
#endif
|
||||
}
|
||||
|
||||
void AllGather(void const *send_buffer, void *receive_buffer, std::size_t send_size) override {
|
||||
if (world_size_ == 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
host_buffer_.resize(send_size * world_size_);
|
||||
dh::safe_cuda(cudaMemcpy(host_buffer_.data() + rank_ * send_size, send_buffer, send_size,
|
||||
cudaMemcpyDefault));
|
||||
Allgather(host_buffer_.data(), host_buffer_.size());
|
||||
dh::safe_cuda(
|
||||
cudaMemcpy(receive_buffer, host_buffer_.data(), host_buffer_.size(), cudaMemcpyDefault));
|
||||
#elif defined(XGBOOST_USE_HIP)
|
||||
dh::safe_cuda(hipSetDevice(device_ordinal_));
|
||||
host_buffer_.resize(send_size * world_size_);
|
||||
dh::safe_cuda(hipMemcpy(host_buffer_.data() + rank_ * send_size, send_buffer, send_size,
|
||||
hipMemcpyDefault));
|
||||
Allgather(host_buffer_.data(), host_buffer_.size());
|
||||
dh::safe_cuda(
|
||||
hipMemcpy(receive_buffer, host_buffer_.data(), host_buffer_.size(), hipMemcpyDefault));
|
||||
#endif
|
||||
}
|
||||
|
||||
void AllGatherV(void const *send_buffer, size_t length_bytes, std::vector<std::size_t> *segments,
|
||||
dh::caching_device_vector<char> *receive_buffer) override {
|
||||
if (communicator_->GetWorldSize() == 1) {
|
||||
if (world_size_ == 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
#if defined(XGBOOST_USE_HIP)
|
||||
dh::safe_cuda(hipSetDevice(device_ordinal_));
|
||||
#else
|
||||
#elif defined(XGBOOST_USE_CUDA)
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
#endif
|
||||
|
||||
int const world_size = communicator_->GetWorldSize();
|
||||
int const rank = communicator_->GetRank();
|
||||
|
||||
segments->clear();
|
||||
segments->resize(world_size, 0);
|
||||
segments->at(rank) = length_bytes;
|
||||
communicator_->AllReduce(segments->data(), segments->size(), DataType::kUInt64,
|
||||
Operation::kMax);
|
||||
segments->resize(world_size_, 0);
|
||||
segments->at(rank_) = length_bytes;
|
||||
Allreduce(segments->data(), segments->size(), DataType::kUInt64, Operation::kMax);
|
||||
auto total_bytes = std::accumulate(segments->cbegin(), segments->cend(), 0UL);
|
||||
receive_buffer->resize(total_bytes);
|
||||
|
||||
host_buffer_.reserve(total_bytes);
|
||||
host_buffer_.resize(total_bytes);
|
||||
size_t offset = 0;
|
||||
for (int32_t i = 0; i < world_size; ++i) {
|
||||
for (int32_t i = 0; i < world_size_; ++i) {
|
||||
size_t as_bytes = segments->at(i);
|
||||
if (i == rank) {
|
||||
#if defined(XGBOOST_USE_HIP)
|
||||
dh::safe_cuda(hipMemcpy(host_buffer_.data() + offset, send_buffer, segments->at(rank),
|
||||
hipMemcpyDefault));
|
||||
#else
|
||||
dh::safe_cuda(cudaMemcpy(host_buffer_.data() + offset, send_buffer, segments->at(rank),
|
||||
if (i == rank_) {
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
dh::safe_cuda(cudaMemcpy(host_buffer_.data() + offset, send_buffer, segments->at(rank_),
|
||||
cudaMemcpyDefault));
|
||||
#elif defined(XGBOOST_USE_HIP)
|
||||
dh::safe_cuda(hipMemcpy(host_buffer_.data() + offset, send_buffer, segments->at(rank_),
|
||||
hipMemcpyDefault));
|
||||
#endif
|
||||
}
|
||||
communicator_->Broadcast(host_buffer_.data() + offset, as_bytes, i);
|
||||
Broadcast(host_buffer_.data() + offset, as_bytes, i);
|
||||
offset += as_bytes;
|
||||
}
|
||||
|
||||
#if defined(XGBOOST_USE_HIP)
|
||||
dh::safe_cuda(hipMemcpy(receive_buffer->data().get(), host_buffer_.data(), total_bytes,
|
||||
hipMemcpyDefault));
|
||||
#else
|
||||
#elif defined(XGBOOST_USE_CUDA)
|
||||
dh::safe_cuda(cudaMemcpy(receive_buffer->data().get(), host_buffer_.data(), total_bytes,
|
||||
cudaMemcpyDefault));
|
||||
#endif
|
||||
@@ -102,7 +119,8 @@ class DeviceCommunicatorAdapter : public DeviceCommunicator {
|
||||
|
||||
private:
|
||||
int const device_ordinal_;
|
||||
Communicator *communicator_;
|
||||
int const world_size_;
|
||||
int const rank_;
|
||||
/// Host buffer used to call communicator functions.
|
||||
std::vector<char> host_buffer_{};
|
||||
};
|
||||
|
||||
229
src/collective/nccl_device_communicator.cu
Normal file
229
src/collective/nccl_device_communicator.cu
Normal file
@@ -0,0 +1,229 @@
|
||||
/*!
|
||||
* Copyright 2023 XGBoost contributors
|
||||
*/
|
||||
#if defined(XGBOOST_USE_NCCL) || defined(XGBOOST_USE_RCCL)
|
||||
#include "nccl_device_communicator.cuh"
|
||||
|
||||
namespace xgboost {
|
||||
namespace collective {
|
||||
|
||||
NcclDeviceCommunicator::NcclDeviceCommunicator(int device_ordinal, bool needs_sync)
|
||||
: device_ordinal_{device_ordinal},
|
||||
needs_sync_{needs_sync},
|
||||
world_size_{GetWorldSize()},
|
||||
rank_{GetRank()} {
|
||||
if (device_ordinal_ < 0) {
|
||||
LOG(FATAL) << "Invalid device ordinal: " << device_ordinal_;
|
||||
}
|
||||
if (world_size_ == 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<uint64_t> uuids(world_size_ * kUuidLength, 0);
|
||||
auto s_uuid = xgboost::common::Span<uint64_t>{uuids.data(), uuids.size()};
|
||||
auto s_this_uuid = s_uuid.subspan(rank_ * kUuidLength, kUuidLength);
|
||||
GetCudaUUID(s_this_uuid);
|
||||
|
||||
// TODO(rongou): replace this with allgather.
|
||||
Allreduce(uuids.data(), uuids.size(), DataType::kUInt64, Operation::kSum);
|
||||
|
||||
std::vector<xgboost::common::Span<uint64_t, kUuidLength>> converted(world_size_);
|
||||
size_t j = 0;
|
||||
for (size_t i = 0; i < uuids.size(); i += kUuidLength) {
|
||||
converted[j] = xgboost::common::Span<uint64_t, kUuidLength>{uuids.data() + i, kUuidLength};
|
||||
j++;
|
||||
}
|
||||
|
||||
auto iter = std::unique(converted.begin(), converted.end());
|
||||
auto n_uniques = std::distance(converted.begin(), iter);
|
||||
|
||||
CHECK_EQ(n_uniques, world_size_)
|
||||
<< "Multiple processes within communication group running on same CUDA "
|
||||
<< "device is not supported. " << PrintUUID(s_this_uuid) << "\n";
|
||||
|
||||
nccl_unique_id_ = GetUniqueId();
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
dh::safe_nccl(ncclCommInitRank(&nccl_comm_, world_size_, nccl_unique_id_, rank_));
|
||||
}
|
||||
|
||||
NcclDeviceCommunicator::~NcclDeviceCommunicator() {
|
||||
if (world_size_ == 1) {
|
||||
return;
|
||||
}
|
||||
if (nccl_comm_) {
|
||||
dh::safe_nccl(ncclCommDestroy(nccl_comm_));
|
||||
}
|
||||
if (xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug)) {
|
||||
LOG(CONSOLE) << "======== NCCL Statistics========";
|
||||
LOG(CONSOLE) << "AllReduce calls: " << allreduce_calls_;
|
||||
LOG(CONSOLE) << "AllReduce total MiB communicated: " << allreduce_bytes_ / 1048576;
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
ncclDataType_t GetNcclDataType(DataType const &data_type) {
|
||||
ncclDataType_t result{ncclInt8};
|
||||
switch (data_type) {
|
||||
case DataType::kInt8:
|
||||
result = ncclInt8;
|
||||
break;
|
||||
case DataType::kUInt8:
|
||||
result = ncclUint8;
|
||||
break;
|
||||
case DataType::kInt32:
|
||||
result = ncclInt32;
|
||||
break;
|
||||
case DataType::kUInt32:
|
||||
result = ncclUint32;
|
||||
break;
|
||||
case DataType::kInt64:
|
||||
result = ncclInt64;
|
||||
break;
|
||||
case DataType::kUInt64:
|
||||
result = ncclUint64;
|
||||
break;
|
||||
case DataType::kFloat:
|
||||
result = ncclFloat;
|
||||
break;
|
||||
case DataType::kDouble:
|
||||
result = ncclDouble;
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "Unknown data type.";
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
bool IsBitwiseOp(Operation const &op) {
|
||||
return op == Operation::kBitwiseAND || op == Operation::kBitwiseOR ||
|
||||
op == Operation::kBitwiseXOR;
|
||||
}
|
||||
|
||||
ncclRedOp_t GetNcclRedOp(Operation const &op) {
|
||||
ncclRedOp_t result{ncclMax};
|
||||
switch (op) {
|
||||
case Operation::kMax:
|
||||
result = ncclMax;
|
||||
break;
|
||||
case Operation::kMin:
|
||||
result = ncclMin;
|
||||
break;
|
||||
case Operation::kSum:
|
||||
result = ncclSum;
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "Unsupported reduce operation.";
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename Func>
|
||||
void RunBitwiseAllreduce(char *out_buffer, char const *device_buffer, Func func, int world_size,
|
||||
std::size_t size) {
|
||||
dh::LaunchN(size, [=] __device__(std::size_t idx) {
|
||||
auto result = device_buffer[idx];
|
||||
for (auto rank = 1; rank < world_size; rank++) {
|
||||
result = func(result, device_buffer[rank * size + idx]);
|
||||
}
|
||||
out_buffer[idx] = result;
|
||||
});
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
void NcclDeviceCommunicator::BitwiseAllReduce(void *send_receive_buffer, std::size_t count,
|
||||
DataType data_type, Operation op) {
|
||||
auto const size = count * GetTypeSize(data_type);
|
||||
dh::caching_device_vector<char> buffer(size * world_size_);
|
||||
auto *device_buffer = buffer.data().get();
|
||||
|
||||
// First gather data from all the workers.
|
||||
dh::safe_nccl(ncclAllGather(send_receive_buffer, device_buffer, count, GetNcclDataType(data_type),
|
||||
nccl_comm_, dh::DefaultStream()));
|
||||
if (needs_sync_) {
|
||||
dh::DefaultStream().Sync();
|
||||
}
|
||||
|
||||
// Then reduce locally.
|
||||
auto *out_buffer = static_cast<char *>(send_receive_buffer);
|
||||
switch (op) {
|
||||
case Operation::kBitwiseAND:
|
||||
RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_and<char>(), world_size_, size);
|
||||
break;
|
||||
case Operation::kBitwiseOR:
|
||||
RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_or<char>(), world_size_, size);
|
||||
break;
|
||||
case Operation::kBitwiseXOR:
|
||||
RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_xor<char>(), world_size_, size);
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "Not a bitwise reduce operation.";
|
||||
}
|
||||
}
|
||||
|
||||
void NcclDeviceCommunicator::AllReduce(void *send_receive_buffer, std::size_t count,
|
||||
DataType data_type, Operation op) {
|
||||
if (world_size_ == 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
if (IsBitwiseOp(op)) {
|
||||
BitwiseAllReduce(send_receive_buffer, count, data_type, op);
|
||||
} else {
|
||||
dh::safe_nccl(ncclAllReduce(send_receive_buffer, send_receive_buffer, count,
|
||||
GetNcclDataType(data_type), GetNcclRedOp(op), nccl_comm_,
|
||||
dh::DefaultStream()));
|
||||
}
|
||||
allreduce_bytes_ += count * GetTypeSize(data_type);
|
||||
allreduce_calls_ += 1;
|
||||
}
|
||||
|
||||
void NcclDeviceCommunicator::AllGather(void const *send_buffer, void *receive_buffer,
|
||||
std::size_t send_size) {
|
||||
if (world_size_ == 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
dh::safe_nccl(ncclAllGather(send_buffer, receive_buffer, send_size, ncclInt8, nccl_comm_,
|
||||
dh::DefaultStream()));
|
||||
}
|
||||
|
||||
void NcclDeviceCommunicator::AllGatherV(void const *send_buffer, size_t length_bytes,
|
||||
std::vector<std::size_t> *segments,
|
||||
dh::caching_device_vector<char> *receive_buffer) {
|
||||
if (world_size_ == 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
|
||||
segments->clear();
|
||||
segments->resize(world_size_, 0);
|
||||
segments->at(rank_) = length_bytes;
|
||||
Allreduce(segments->data(), segments->size(), DataType::kUInt64, Operation::kMax);
|
||||
auto total_bytes = std::accumulate(segments->cbegin(), segments->cend(), 0UL);
|
||||
receive_buffer->resize(total_bytes);
|
||||
|
||||
size_t offset = 0;
|
||||
dh::safe_nccl(ncclGroupStart());
|
||||
for (int32_t i = 0; i < world_size_; ++i) {
|
||||
size_t as_bytes = segments->at(i);
|
||||
dh::safe_nccl(ncclBroadcast(send_buffer, receive_buffer->data().get() + offset, as_bytes,
|
||||
ncclChar, i, nccl_comm_, dh::DefaultStream()));
|
||||
offset += as_bytes;
|
||||
}
|
||||
dh::safe_nccl(ncclGroupEnd());
|
||||
}
|
||||
|
||||
void NcclDeviceCommunicator::Synchronize() {
|
||||
if (world_size_ == 1) {
|
||||
return;
|
||||
}
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
dh::DefaultStream().Sync();
|
||||
}
|
||||
|
||||
} // namespace collective
|
||||
} // namespace xgboost
|
||||
#endif
|
||||
@@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2022 XGBoost contributors
|
||||
* Copyright 2022-2023 XGBoost contributors
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
@@ -12,136 +12,27 @@ namespace collective {
|
||||
|
||||
class NcclDeviceCommunicator : public DeviceCommunicator {
|
||||
public:
|
||||
NcclDeviceCommunicator(int device_ordinal, Communicator *communicator)
|
||||
: device_ordinal_{device_ordinal}, communicator_{communicator} {
|
||||
if (device_ordinal_ < 0) {
|
||||
LOG(FATAL) << "Invalid device ordinal: " << device_ordinal_;
|
||||
}
|
||||
if (communicator_ == nullptr) {
|
||||
LOG(FATAL) << "Communicator cannot be null.";
|
||||
}
|
||||
|
||||
int32_t const rank = communicator_->GetRank();
|
||||
int32_t const world = communicator_->GetWorldSize();
|
||||
|
||||
if (world == 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<uint64_t> uuids(world * kUuidLength, 0);
|
||||
auto s_uuid = xgboost::common::Span<uint64_t>{uuids.data(), uuids.size()};
|
||||
auto s_this_uuid = s_uuid.subspan(rank * kUuidLength, kUuidLength);
|
||||
GetCudaUUID(s_this_uuid);
|
||||
|
||||
// TODO(rongou): replace this with allgather.
|
||||
communicator_->AllReduce(uuids.data(), uuids.size(), DataType::kUInt64, Operation::kSum);
|
||||
|
||||
std::vector<xgboost::common::Span<uint64_t, kUuidLength>> converted(world);
|
||||
size_t j = 0;
|
||||
for (size_t i = 0; i < uuids.size(); i += kUuidLength) {
|
||||
converted[j] = xgboost::common::Span<uint64_t, kUuidLength>{uuids.data() + i, kUuidLength};
|
||||
j++;
|
||||
}
|
||||
|
||||
auto iter = std::unique(converted.begin(), converted.end());
|
||||
auto n_uniques = std::distance(converted.begin(), iter);
|
||||
|
||||
CHECK_EQ(n_uniques, world)
|
||||
<< "Multiple processes within communication group running on same CUDA "
|
||||
<< "device is not supported. " << PrintUUID(s_this_uuid) << "\n";
|
||||
|
||||
nccl_unique_id_ = GetUniqueId();
|
||||
dh::safe_nccl(ncclCommInitRank(&nccl_comm_, world, nccl_unique_id_, rank));
|
||||
|
||||
#if defined(XGBOOST_USE_HIP)
|
||||
dh::safe_cuda(hipStreamCreate(&cuda_stream_));
|
||||
#else
|
||||
dh::safe_cuda(cudaStreamCreate(&cuda_stream_));
|
||||
#endif
|
||||
}
|
||||
|
||||
~NcclDeviceCommunicator() override {
|
||||
if (communicator_->GetWorldSize() == 1) {
|
||||
return;
|
||||
}
|
||||
if (cuda_stream_) {
|
||||
#if defined(XGBOOST_USE_HIP)
|
||||
dh::safe_cuda(hipStreamDestroy(cuda_stream_));
|
||||
#else
|
||||
dh::safe_cuda(cudaStreamDestroy(cuda_stream_));
|
||||
#endif
|
||||
}
|
||||
if (nccl_comm_) {
|
||||
dh::safe_nccl(ncclCommDestroy(nccl_comm_));
|
||||
}
|
||||
if (xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug)) {
|
||||
LOG(CONSOLE) << "======== NCCL Statistics========";
|
||||
LOG(CONSOLE) << "AllReduce calls: " << allreduce_calls_;
|
||||
LOG(CONSOLE) << "AllReduce total MiB communicated: " << allreduce_bytes_ / 1048576;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Construct a new NCCL communicator.
|
||||
* @param device_ordinal The GPU device id.
|
||||
* @param needs_sync Whether extra CUDA stream synchronization is needed.
|
||||
*
|
||||
* In multi-GPU tests when multiple NCCL communicators are created in the same process, sometimes
|
||||
* a deadlock happens because NCCL kernels are blocking. The extra CUDA stream synchronization
|
||||
* makes sure that the NCCL kernels are caught up, thus avoiding the deadlock.
|
||||
*
|
||||
* The Rabit communicator runs with one process per GPU, so the additional synchronization is not
|
||||
* needed. The in-memory communicator is used in tests with multiple threads, each thread
|
||||
* representing a rank/worker, so the additional synchronization is needed to avoid deadlocks.
|
||||
*/
|
||||
explicit NcclDeviceCommunicator(int device_ordinal, bool needs_sync);
|
||||
~NcclDeviceCommunicator() override;
|
||||
void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
|
||||
Operation op) override {
|
||||
if (communicator_->GetWorldSize() == 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
dh::safe_nccl(ncclAllReduce(send_receive_buffer, send_receive_buffer, count,
|
||||
GetNcclDataType(data_type), GetNcclRedOp(op), nccl_comm_,
|
||||
cuda_stream_));
|
||||
allreduce_bytes_ += count * GetTypeSize(data_type);
|
||||
allreduce_calls_ += 1;
|
||||
}
|
||||
|
||||
Operation op) override;
|
||||
void AllGather(void const *send_buffer, void *receive_buffer, std::size_t send_size) override;
|
||||
void AllGatherV(void const *send_buffer, size_t length_bytes, std::vector<std::size_t> *segments,
|
||||
dh::caching_device_vector<char> *receive_buffer) override {
|
||||
if (communicator_->GetWorldSize() == 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
#if defined(XGBOOST_USE_HIP)
|
||||
dh::safe_cuda(hipSetDevice(device_ordinal_));
|
||||
#else
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
#endif
|
||||
|
||||
int const world_size = communicator_->GetWorldSize();
|
||||
int const rank = communicator_->GetRank();
|
||||
|
||||
segments->clear();
|
||||
segments->resize(world_size, 0);
|
||||
segments->at(rank) = length_bytes;
|
||||
communicator_->AllReduce(segments->data(), segments->size(), DataType::kUInt64,
|
||||
Operation::kMax);
|
||||
auto total_bytes = std::accumulate(segments->cbegin(), segments->cend(), 0UL);
|
||||
receive_buffer->resize(total_bytes);
|
||||
|
||||
size_t offset = 0;
|
||||
dh::safe_nccl(ncclGroupStart());
|
||||
for (int32_t i = 0; i < world_size; ++i) {
|
||||
size_t as_bytes = segments->at(i);
|
||||
dh::safe_nccl(ncclBroadcast(send_buffer, receive_buffer->data().get() + offset, as_bytes,
|
||||
ncclChar, i, nccl_comm_, cuda_stream_));
|
||||
offset += as_bytes;
|
||||
}
|
||||
dh::safe_nccl(ncclGroupEnd());
|
||||
}
|
||||
|
||||
void Synchronize() override {
|
||||
if (communicator_->GetWorldSize() == 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
#if defined(XGBOOST_USE_HIP)
|
||||
dh::safe_cuda(hipSetDevice(device_ordinal_));
|
||||
dh::safe_cuda(hipStreamSynchronize(cuda_stream_));
|
||||
#else
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
dh::safe_cuda(cudaStreamSynchronize(cuda_stream_));
|
||||
#endif
|
||||
}
|
||||
dh::caching_device_vector<char> *receive_buffer) override;
|
||||
void Synchronize() override;
|
||||
|
||||
private:
|
||||
static constexpr std::size_t kUuidLength =
|
||||
@@ -182,79 +73,21 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
|
||||
ncclUniqueId GetUniqueId() {
|
||||
static const int kRootRank = 0;
|
||||
ncclUniqueId id;
|
||||
if (communicator_->GetRank() == kRootRank) {
|
||||
if (rank_ == kRootRank) {
|
||||
dh::safe_nccl(ncclGetUniqueId(&id));
|
||||
}
|
||||
communicator_->Broadcast(static_cast<void *>(&id), sizeof(ncclUniqueId),
|
||||
static_cast<int>(kRootRank));
|
||||
Broadcast(static_cast<void *>(&id), sizeof(ncclUniqueId), static_cast<int>(kRootRank));
|
||||
return id;
|
||||
}
|
||||
|
||||
static ncclDataType_t GetNcclDataType(DataType const &data_type) {
|
||||
ncclDataType_t result;
|
||||
switch (data_type) {
|
||||
case DataType::kInt8:
|
||||
result = ncclInt8;
|
||||
break;
|
||||
case DataType::kUInt8:
|
||||
result = ncclUint8;
|
||||
break;
|
||||
case DataType::kInt32:
|
||||
result = ncclInt32;
|
||||
break;
|
||||
case DataType::kUInt32:
|
||||
result = ncclUint32;
|
||||
break;
|
||||
case DataType::kInt64:
|
||||
result = ncclInt64;
|
||||
break;
|
||||
case DataType::kUInt64:
|
||||
result = ncclUint64;
|
||||
break;
|
||||
case DataType::kFloat:
|
||||
result = ncclFloat;
|
||||
break;
|
||||
case DataType::kDouble:
|
||||
result = ncclDouble;
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "Unknown data type.";
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
static ncclRedOp_t GetNcclRedOp(Operation const &op) {
|
||||
ncclRedOp_t result;
|
||||
switch (op) {
|
||||
case Operation::kMax:
|
||||
result = ncclMax;
|
||||
break;
|
||||
case Operation::kMin:
|
||||
result = ncclMin;
|
||||
break;
|
||||
case Operation::kSum:
|
||||
result = ncclSum;
|
||||
break;
|
||||
case Operation::kBitwiseAND:
|
||||
case Operation::kBitwiseOR:
|
||||
case Operation::kBitwiseXOR:
|
||||
LOG(FATAL) << "Not implemented yet.";
|
||||
default:
|
||||
LOG(FATAL) << "Unknown reduce operation.";
|
||||
}
|
||||
return result;
|
||||
}
|
||||
void BitwiseAllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
|
||||
Operation op);
|
||||
|
||||
int const device_ordinal_;
|
||||
Communicator *communicator_;
|
||||
bool const needs_sync_;
|
||||
int const world_size_;
|
||||
int const rank_;
|
||||
ncclComm_t nccl_comm_{};
|
||||
|
||||
#if defined(XGBOOST_USE_HIP)
|
||||
hipStream_t cuda_stream_{};
|
||||
#else
|
||||
cudaStream_t cuda_stream_{};
|
||||
#endif
|
||||
|
||||
ncclUniqueId nccl_unique_id_{};
|
||||
size_t allreduce_bytes_{0}; // Keep statistics of the number of bytes communicated.
|
||||
size_t allreduce_calls_{0}; // Keep statistics of the number of reduce calls.
|
||||
|
||||
@@ -1,19 +1,22 @@
|
||||
/*!
|
||||
* Copyright (c) 2022 by XGBoost Contributors
|
||||
/**
|
||||
* Copyright 2022-2023 by XGBoost Contributors
|
||||
*/
|
||||
#include "xgboost/collective/socket.h"
|
||||
|
||||
#include <cstddef> // std::size_t
|
||||
#include <cstdint> // std::int32_t
|
||||
#include <cstring> // std::memcpy, std::memset
|
||||
#include <filesystem> // for path
|
||||
#include <system_error> // std::error_code, std::system_category
|
||||
|
||||
#include "rabit/internal/socket.h" // for PollHelper
|
||||
#include "xgboost/collective/result.h" // for Result
|
||||
|
||||
#if defined(__unix__) || defined(__APPLE__)
|
||||
#include <netdb.h> // getaddrinfo, freeaddrinfo
|
||||
#endif // defined(__unix__) || defined(__APPLE__)
|
||||
|
||||
namespace xgboost {
|
||||
namespace collective {
|
||||
namespace xgboost::collective {
|
||||
SockAddress MakeSockAddress(StringView host, in_port_t port) {
|
||||
struct addrinfo hints;
|
||||
std::memset(&hints, 0, sizeof(hints));
|
||||
@@ -71,7 +74,12 @@ std::size_t TCPSocket::Recv(std::string *p_str) {
|
||||
return bytes;
|
||||
}
|
||||
|
||||
std::error_code Connect(SockAddress const &addr, TCPSocket *out) {
|
||||
[[nodiscard]] Result Connect(xgboost::StringView host, std::int32_t port, std::int32_t retry,
|
||||
std::chrono::seconds timeout,
|
||||
xgboost::collective::TCPSocket *out_conn) {
|
||||
auto addr = MakeSockAddress(xgboost::StringView{host}, port);
|
||||
auto &conn = *out_conn;
|
||||
|
||||
sockaddr const *addr_handle{nullptr};
|
||||
socklen_t addr_len{0};
|
||||
if (addr.IsV4()) {
|
||||
@@ -81,14 +89,67 @@ std::error_code Connect(SockAddress const &addr, TCPSocket *out) {
|
||||
addr_handle = reinterpret_cast<const sockaddr *>(&addr.V6().Handle());
|
||||
addr_len = sizeof(addr.V6().Handle());
|
||||
}
|
||||
auto socket = TCPSocket::Create(addr.Domain());
|
||||
CHECK_EQ(static_cast<std::int32_t>(socket.Domain()), static_cast<std::int32_t>(addr.Domain()));
|
||||
auto rc = connect(socket.Handle(), addr_handle, addr_len);
|
||||
if (rc != 0) {
|
||||
return std::error_code{errno, std::system_category()};
|
||||
|
||||
conn = TCPSocket::Create(addr.Domain());
|
||||
CHECK_EQ(static_cast<std::int32_t>(conn.Domain()), static_cast<std::int32_t>(addr.Domain()));
|
||||
conn.SetNonBlock(true);
|
||||
|
||||
Result last_error;
|
||||
auto log_failure = [&host, &last_error](Result err, char const *file, std::int32_t line) {
|
||||
last_error = std::move(err);
|
||||
LOG(WARNING) << std::filesystem::path{file}.filename().string() << "(" << line
|
||||
<< "): Failed to connect to:" << host << " Error:" << last_error.Report();
|
||||
};
|
||||
|
||||
for (std::int32_t attempt = 0; attempt < std::max(retry, 1); ++attempt) {
|
||||
if (attempt > 0) {
|
||||
LOG(WARNING) << "Retrying connection to " << host << " for the " << attempt << " time.";
|
||||
#if defined(_MSC_VER) || defined(__MINGW32__)
|
||||
Sleep(attempt << 1);
|
||||
#else
|
||||
sleep(attempt << 1);
|
||||
#endif
|
||||
}
|
||||
|
||||
auto rc = connect(conn.Handle(), addr_handle, addr_len);
|
||||
if (rc != 0) {
|
||||
auto errcode = system::LastError();
|
||||
if (!system::ErrorWouldBlock(errcode)) {
|
||||
log_failure(Fail("connect failed.", std::error_code{errcode, std::system_category()}),
|
||||
__FILE__, __LINE__);
|
||||
continue;
|
||||
}
|
||||
|
||||
rabit::utils::PollHelper poll;
|
||||
poll.WatchWrite(conn);
|
||||
auto result = poll.Poll(timeout);
|
||||
if (!result.OK()) {
|
||||
log_failure(std::move(result), __FILE__, __LINE__);
|
||||
continue;
|
||||
}
|
||||
if (!poll.CheckWrite(conn)) {
|
||||
log_failure(Fail("poll failed.", std::error_code{errcode, std::system_category()}),
|
||||
__FILE__, __LINE__);
|
||||
continue;
|
||||
}
|
||||
result = conn.GetSockError();
|
||||
if (!result.OK()) {
|
||||
log_failure(std::move(result), __FILE__, __LINE__);
|
||||
continue;
|
||||
}
|
||||
|
||||
conn.SetNonBlock(false);
|
||||
return Success();
|
||||
|
||||
} else {
|
||||
conn.SetNonBlock(false);
|
||||
return Success();
|
||||
}
|
||||
}
|
||||
*out = std::move(socket);
|
||||
return std::make_error_code(std::errc{});
|
||||
|
||||
std::stringstream ss;
|
||||
ss << "Failed to connect to " << host << ":" << port;
|
||||
conn.Close();
|
||||
return Fail(ss.str(), std::move(last_error));
|
||||
}
|
||||
} // namespace collective
|
||||
} // namespace xgboost
|
||||
} // namespace xgboost::collective
|
||||
|
||||
Reference in New Issue
Block a user