Support bitwise allreduce in NCCL communicator (#9300)

This commit is contained in:
Rong Ou 2023-06-16 10:56:50 -07:00 committed by GitHub
parent 2718ff530c
commit d8beb517ed
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 302 additions and 164 deletions

View File

@ -0,0 +1,228 @@
/*!
* Copyright 2023 XGBoost contributors
*/
#if defined(XGBOOST_USE_NCCL)
#include "nccl_device_communicator.cuh"
namespace xgboost {
namespace collective {
NcclDeviceCommunicator::NcclDeviceCommunicator(int device_ordinal, Communicator *communicator)
: device_ordinal_{device_ordinal}, communicator_{communicator} {
if (device_ordinal_ < 0) {
LOG(FATAL) << "Invalid device ordinal: " << device_ordinal_;
}
if (communicator_ == nullptr) {
LOG(FATAL) << "Communicator cannot be null.";
}
int32_t const rank = communicator_->GetRank();
int32_t const world = communicator_->GetWorldSize();
if (world == 1) {
return;
}
std::vector<uint64_t> uuids(world * kUuidLength, 0);
auto s_uuid = xgboost::common::Span<uint64_t>{uuids.data(), uuids.size()};
auto s_this_uuid = s_uuid.subspan(rank * kUuidLength, kUuidLength);
GetCudaUUID(s_this_uuid);
// TODO(rongou): replace this with allgather.
communicator_->AllReduce(uuids.data(), uuids.size(), DataType::kUInt64, Operation::kSum);
std::vector<xgboost::common::Span<uint64_t, kUuidLength>> converted(world);
size_t j = 0;
for (size_t i = 0; i < uuids.size(); i += kUuidLength) {
converted[j] = xgboost::common::Span<uint64_t, kUuidLength>{uuids.data() + i, kUuidLength};
j++;
}
auto iter = std::unique(converted.begin(), converted.end());
auto n_uniques = std::distance(converted.begin(), iter);
CHECK_EQ(n_uniques, world)
<< "Multiple processes within communication group running on same CUDA "
<< "device is not supported. " << PrintUUID(s_this_uuid) << "\n";
nccl_unique_id_ = GetUniqueId();
dh::safe_cuda(cudaSetDevice(device_ordinal_));
dh::safe_nccl(ncclCommInitRank(&nccl_comm_, world, nccl_unique_id_, rank));
dh::safe_cuda(cudaStreamCreate(&cuda_stream_));
}
NcclDeviceCommunicator::~NcclDeviceCommunicator() {
if (communicator_->GetWorldSize() == 1) {
return;
}
if (cuda_stream_) {
dh::safe_cuda(cudaStreamDestroy(cuda_stream_));
}
if (nccl_comm_) {
dh::safe_nccl(ncclCommDestroy(nccl_comm_));
}
if (xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug)) {
LOG(CONSOLE) << "======== NCCL Statistics========";
LOG(CONSOLE) << "AllReduce calls: " << allreduce_calls_;
LOG(CONSOLE) << "AllReduce total MiB communicated: " << allreduce_bytes_ / 1048576;
}
}
namespace {
ncclDataType_t GetNcclDataType(DataType const &data_type) {
ncclDataType_t result;
switch (data_type) {
case DataType::kInt8:
result = ncclInt8;
break;
case DataType::kUInt8:
result = ncclUint8;
break;
case DataType::kInt32:
result = ncclInt32;
break;
case DataType::kUInt32:
result = ncclUint32;
break;
case DataType::kInt64:
result = ncclInt64;
break;
case DataType::kUInt64:
result = ncclUint64;
break;
case DataType::kFloat:
result = ncclFloat;
break;
case DataType::kDouble:
result = ncclDouble;
break;
default:
LOG(FATAL) << "Unknown data type.";
}
return result;
}
bool IsBitwiseOp(Operation const &op) {
return op == Operation::kBitwiseAND || op == Operation::kBitwiseOR ||
op == Operation::kBitwiseXOR;
}
ncclRedOp_t GetNcclRedOp(Operation const &op) {
ncclRedOp_t result;
switch (op) {
case Operation::kMax:
result = ncclMax;
break;
case Operation::kMin:
result = ncclMin;
break;
case Operation::kSum:
result = ncclSum;
break;
default:
LOG(FATAL) << "Unsupported reduce operation.";
}
return result;
}
template <typename Func>
void RunBitwiseAllreduce(char *out_buffer, char const *device_buffer, Func func, int world_size,
std::size_t size, cudaStream_t stream) {
dh::LaunchN(size, stream, [=] __device__(std::size_t idx) {
out_buffer[idx] = device_buffer[idx];
for (auto rank = 1; rank < world_size; rank++) {
out_buffer[idx] = func(out_buffer[idx], device_buffer[rank * size + idx]);
}
});
}
} // anonymous namespace
void NcclDeviceCommunicator::BitwiseAllReduce(void *send_receive_buffer, std::size_t count,
DataType data_type, Operation op) {
auto const world_size = communicator_->GetWorldSize();
auto const size = count * GetTypeSize(data_type);
dh::caching_device_vector<char> buffer(size * world_size);
auto *device_buffer = buffer.data().get();
// First gather data from all the workers.
dh::safe_nccl(ncclAllGather(send_receive_buffer, device_buffer, count, GetNcclDataType(data_type),
nccl_comm_, cuda_stream_));
// Then reduce locally.
auto *out_buffer = static_cast<char *>(send_receive_buffer);
switch (op) {
case Operation::kBitwiseAND:
RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_and<char>(), world_size, size,
cuda_stream_);
break;
case Operation::kBitwiseOR:
RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_or<char>(), world_size, size,
cuda_stream_);
break;
case Operation::kBitwiseXOR:
RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_xor<char>(), world_size, size,
cuda_stream_);
break;
default:
LOG(FATAL) << "Not a bitwise reduce operation.";
}
}
void NcclDeviceCommunicator::AllReduce(void *send_receive_buffer, std::size_t count,
DataType data_type, Operation op) {
if (communicator_->GetWorldSize() == 1) {
return;
}
dh::safe_cuda(cudaSetDevice(device_ordinal_));
if (IsBitwiseOp(op)) {
BitwiseAllReduce(send_receive_buffer, count, data_type, op);
} else {
dh::safe_nccl(ncclAllReduce(send_receive_buffer, send_receive_buffer, count,
GetNcclDataType(data_type), GetNcclRedOp(op), nccl_comm_,
cuda_stream_));
}
allreduce_bytes_ += count * GetTypeSize(data_type);
allreduce_calls_ += 1;
}
void NcclDeviceCommunicator::AllGatherV(void const *send_buffer, size_t length_bytes,
std::vector<std::size_t> *segments,
dh::caching_device_vector<char> *receive_buffer) {
if (communicator_->GetWorldSize() == 1) {
return;
}
dh::safe_cuda(cudaSetDevice(device_ordinal_));
int const world_size = communicator_->GetWorldSize();
int const rank = communicator_->GetRank();
segments->clear();
segments->resize(world_size, 0);
segments->at(rank) = length_bytes;
communicator_->AllReduce(segments->data(), segments->size(), DataType::kUInt64, Operation::kMax);
auto total_bytes = std::accumulate(segments->cbegin(), segments->cend(), 0UL);
receive_buffer->resize(total_bytes);
size_t offset = 0;
dh::safe_nccl(ncclGroupStart());
for (int32_t i = 0; i < world_size; ++i) {
size_t as_bytes = segments->at(i);
dh::safe_nccl(ncclBroadcast(send_buffer, receive_buffer->data().get() + offset, as_bytes,
ncclChar, i, nccl_comm_, cuda_stream_));
offset += as_bytes;
}
dh::safe_nccl(ncclGroupEnd());
}
void NcclDeviceCommunicator::Synchronize() {
if (communicator_->GetWorldSize() == 1) {
return;
}
dh::safe_cuda(cudaSetDevice(device_ordinal_));
dh::safe_cuda(cudaStreamSynchronize(cuda_stream_));
}
} // namespace collective
} // namespace xgboost
#endif

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright 2022 XGBoost contributors * Copyright 2022-2023 XGBoost contributors
*/ */
#pragma once #pragma once
@ -12,116 +12,13 @@ namespace collective {
class NcclDeviceCommunicator : public DeviceCommunicator { class NcclDeviceCommunicator : public DeviceCommunicator {
public: public:
NcclDeviceCommunicator(int device_ordinal, Communicator *communicator) NcclDeviceCommunicator(int device_ordinal, Communicator *communicator);
: device_ordinal_{device_ordinal}, communicator_{communicator} { ~NcclDeviceCommunicator() override;
if (device_ordinal_ < 0) {
LOG(FATAL) << "Invalid device ordinal: " << device_ordinal_;
}
if (communicator_ == nullptr) {
LOG(FATAL) << "Communicator cannot be null.";
}
int32_t const rank = communicator_->GetRank();
int32_t const world = communicator_->GetWorldSize();
if (world == 1) {
return;
}
std::vector<uint64_t> uuids(world * kUuidLength, 0);
auto s_uuid = xgboost::common::Span<uint64_t>{uuids.data(), uuids.size()};
auto s_this_uuid = s_uuid.subspan(rank * kUuidLength, kUuidLength);
GetCudaUUID(s_this_uuid);
// TODO(rongou): replace this with allgather.
communicator_->AllReduce(uuids.data(), uuids.size(), DataType::kUInt64, Operation::kSum);
std::vector<xgboost::common::Span<uint64_t, kUuidLength>> converted(world);
size_t j = 0;
for (size_t i = 0; i < uuids.size(); i += kUuidLength) {
converted[j] = xgboost::common::Span<uint64_t, kUuidLength>{uuids.data() + i, kUuidLength};
j++;
}
auto iter = std::unique(converted.begin(), converted.end());
auto n_uniques = std::distance(converted.begin(), iter);
CHECK_EQ(n_uniques, world)
<< "Multiple processes within communication group running on same CUDA "
<< "device is not supported. " << PrintUUID(s_this_uuid) << "\n";
nccl_unique_id_ = GetUniqueId();
dh::safe_nccl(ncclCommInitRank(&nccl_comm_, world, nccl_unique_id_, rank));
dh::safe_cuda(cudaStreamCreate(&cuda_stream_));
}
~NcclDeviceCommunicator() override {
if (communicator_->GetWorldSize() == 1) {
return;
}
if (cuda_stream_) {
dh::safe_cuda(cudaStreamDestroy(cuda_stream_));
}
if (nccl_comm_) {
dh::safe_nccl(ncclCommDestroy(nccl_comm_));
}
if (xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug)) {
LOG(CONSOLE) << "======== NCCL Statistics========";
LOG(CONSOLE) << "AllReduce calls: " << allreduce_calls_;
LOG(CONSOLE) << "AllReduce total MiB communicated: " << allreduce_bytes_ / 1048576;
}
}
void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type, void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
Operation op) override { Operation op) override;
if (communicator_->GetWorldSize() == 1) {
return;
}
dh::safe_cuda(cudaSetDevice(device_ordinal_));
dh::safe_nccl(ncclAllReduce(send_receive_buffer, send_receive_buffer, count,
GetNcclDataType(data_type), GetNcclRedOp(op), nccl_comm_,
cuda_stream_));
allreduce_bytes_ += count * GetTypeSize(data_type);
allreduce_calls_ += 1;
}
void AllGatherV(void const *send_buffer, size_t length_bytes, std::vector<std::size_t> *segments, void AllGatherV(void const *send_buffer, size_t length_bytes, std::vector<std::size_t> *segments,
dh::caching_device_vector<char> *receive_buffer) override { dh::caching_device_vector<char> *receive_buffer) override;
if (communicator_->GetWorldSize() == 1) { void Synchronize() override;
return;
}
dh::safe_cuda(cudaSetDevice(device_ordinal_));
int const world_size = communicator_->GetWorldSize();
int const rank = communicator_->GetRank();
segments->clear();
segments->resize(world_size, 0);
segments->at(rank) = length_bytes;
communicator_->AllReduce(segments->data(), segments->size(), DataType::kUInt64,
Operation::kMax);
auto total_bytes = std::accumulate(segments->cbegin(), segments->cend(), 0UL);
receive_buffer->resize(total_bytes);
size_t offset = 0;
dh::safe_nccl(ncclGroupStart());
for (int32_t i = 0; i < world_size; ++i) {
size_t as_bytes = segments->at(i);
dh::safe_nccl(ncclBroadcast(send_buffer, receive_buffer->data().get() + offset, as_bytes,
ncclChar, i, nccl_comm_, cuda_stream_));
offset += as_bytes;
}
dh::safe_nccl(ncclGroupEnd());
}
void Synchronize() override {
if (communicator_->GetWorldSize() == 1) {
return;
}
dh::safe_cuda(cudaSetDevice(device_ordinal_));
dh::safe_cuda(cudaStreamSynchronize(cuda_stream_));
}
private: private:
static constexpr std::size_t kUuidLength = static constexpr std::size_t kUuidLength =
@ -160,60 +57,8 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
return id; return id;
} }
static ncclDataType_t GetNcclDataType(DataType const &data_type) { void BitwiseAllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
ncclDataType_t result; Operation op);
switch (data_type) {
case DataType::kInt8:
result = ncclInt8;
break;
case DataType::kUInt8:
result = ncclUint8;
break;
case DataType::kInt32:
result = ncclInt32;
break;
case DataType::kUInt32:
result = ncclUint32;
break;
case DataType::kInt64:
result = ncclInt64;
break;
case DataType::kUInt64:
result = ncclUint64;
break;
case DataType::kFloat:
result = ncclFloat;
break;
case DataType::kDouble:
result = ncclDouble;
break;
default:
LOG(FATAL) << "Unknown data type.";
}
return result;
}
static ncclRedOp_t GetNcclRedOp(Operation const &op) {
ncclRedOp_t result;
switch (op) {
case Operation::kMax:
result = ncclMax;
break;
case Operation::kMin:
result = ncclMin;
break;
case Operation::kSum:
result = ncclSum;
break;
case Operation::kBitwiseAND:
case Operation::kBitwiseOR:
case Operation::kBitwiseXOR:
LOG(FATAL) << "Not implemented yet.";
default:
LOG(FATAL) << "Unknown reduce operation.";
}
return result;
}
int const device_ordinal_; int const device_ordinal_;
Communicator *communicator_; Communicator *communicator_;

View File

@ -5,10 +5,12 @@
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <bitset>
#include <string> // for string #include <string> // for string
#include "../../../src/collective/nccl_device_communicator.cuh"
#include "../../../src/collective/communicator-inl.cuh" #include "../../../src/collective/communicator-inl.cuh"
#include "../../../src/collective/nccl_device_communicator.cuh"
#include "../helpers.h"
namespace xgboost { namespace xgboost {
namespace collective { namespace collective {
@ -31,6 +33,69 @@ TEST(NcclDeviceCommunicatorSimpleTest, SystemError) {
ASSERT_TRUE(str.find("environment variables") != std::string::npos); ASSERT_TRUE(str.find("environment variables") != std::string::npos);
} }
} }
namespace {
void VerifyAllReduceBitwiseAND() {
auto const rank = collective::GetRank();
std::bitset<64> original{};
original[rank] = true;
HostDeviceVector<uint64_t> buffer({original.to_ullong()}, rank);
collective::AllReduce<collective::Operation::kBitwiseAND>(rank, buffer.DevicePointer(), 1);
collective::Synchronize(rank);
EXPECT_EQ(buffer.HostVector()[0], 0ULL);
}
} // anonymous namespace
TEST(NcclDeviceCommunicator, MGPUAllReduceBitwiseAND) {
auto const n_gpus = common::AllVisibleGPUs();
if (n_gpus <= 1) {
GTEST_SKIP() << "Skipping MGPUAllReduceBitwiseAND test with # GPUs = " << n_gpus;
}
RunWithInMemoryCommunicator(n_gpus, VerifyAllReduceBitwiseAND);
}
namespace {
void VerifyAllReduceBitwiseOR() {
auto const world_size = collective::GetWorldSize();
auto const rank = collective::GetRank();
std::bitset<64> original{};
original[rank] = true;
HostDeviceVector<uint64_t> buffer({original.to_ullong()}, rank);
collective::AllReduce<collective::Operation::kBitwiseOR>(rank, buffer.DevicePointer(), 1);
collective::Synchronize(rank);
EXPECT_EQ(buffer.HostVector()[0], (1ULL << world_size) - 1);
}
} // anonymous namespace
TEST(NcclDeviceCommunicator, MGPUAllReduceBitwiseOR) {
auto const n_gpus = common::AllVisibleGPUs();
if (n_gpus <= 1) {
GTEST_SKIP() << "Skipping MGPUAllReduceBitwiseOR test with # GPUs = " << n_gpus;
}
RunWithInMemoryCommunicator(n_gpus, VerifyAllReduceBitwiseOR);
}
namespace {
void VerifyAllReduceBitwiseXOR() {
auto const world_size = collective::GetWorldSize();
auto const rank = collective::GetRank();
std::bitset<64> original{~0ULL};
original[rank] = false;
HostDeviceVector<uint64_t> buffer({original.to_ullong()}, rank);
collective::AllReduce<collective::Operation::kBitwiseXOR>(rank, buffer.DevicePointer(), 1);
collective::Synchronize(rank);
EXPECT_EQ(buffer.HostVector()[0], (1ULL << world_size) - 1);
}
} // anonymous namespace
TEST(NcclDeviceCommunicator, MGPUAllReduceBitwiseXOR) {
auto const n_gpus = common::AllVisibleGPUs();
if (n_gpus <= 1) {
GTEST_SKIP() << "Skipping MGPUAllReduceBitwiseXOR test with # GPUs = " << n_gpus;
}
RunWithInMemoryCommunicator(n_gpus, VerifyAllReduceBitwiseXOR);
}
} // namespace collective } // namespace collective
} // namespace xgboost } // namespace xgboost