diff --git a/src/collective/nccl_device_communicator.cu b/src/collective/nccl_device_communicator.cu new file mode 100644 index 000000000..6599d4b5a --- /dev/null +++ b/src/collective/nccl_device_communicator.cu @@ -0,0 +1,228 @@ +/*! + * Copyright 2023 XGBoost contributors + */ +#if defined(XGBOOST_USE_NCCL) +#include "nccl_device_communicator.cuh" + +namespace xgboost { +namespace collective { + +NcclDeviceCommunicator::NcclDeviceCommunicator(int device_ordinal, Communicator *communicator) + : device_ordinal_{device_ordinal}, communicator_{communicator} { + if (device_ordinal_ < 0) { + LOG(FATAL) << "Invalid device ordinal: " << device_ordinal_; + } + if (communicator_ == nullptr) { + LOG(FATAL) << "Communicator cannot be null."; + } + + int32_t const rank = communicator_->GetRank(); + int32_t const world = communicator_->GetWorldSize(); + + if (world == 1) { + return; + } + + std::vector uuids(world * kUuidLength, 0); + auto s_uuid = xgboost::common::Span{uuids.data(), uuids.size()}; + auto s_this_uuid = s_uuid.subspan(rank * kUuidLength, kUuidLength); + GetCudaUUID(s_this_uuid); + + // TODO(rongou): replace this with allgather. + communicator_->AllReduce(uuids.data(), uuids.size(), DataType::kUInt64, Operation::kSum); + + std::vector> converted(world); + size_t j = 0; + for (size_t i = 0; i < uuids.size(); i += kUuidLength) { + converted[j] = xgboost::common::Span{uuids.data() + i, kUuidLength}; + j++; + } + + auto iter = std::unique(converted.begin(), converted.end()); + auto n_uniques = std::distance(converted.begin(), iter); + + CHECK_EQ(n_uniques, world) + << "Multiple processes within communication group running on same CUDA " + << "device is not supported. " << PrintUUID(s_this_uuid) << "\n"; + + nccl_unique_id_ = GetUniqueId(); + dh::safe_cuda(cudaSetDevice(device_ordinal_)); + dh::safe_nccl(ncclCommInitRank(&nccl_comm_, world, nccl_unique_id_, rank)); + dh::safe_cuda(cudaStreamCreate(&cuda_stream_)); +} + +NcclDeviceCommunicator::~NcclDeviceCommunicator() { + if (communicator_->GetWorldSize() == 1) { + return; + } + if (cuda_stream_) { + dh::safe_cuda(cudaStreamDestroy(cuda_stream_)); + } + if (nccl_comm_) { + dh::safe_nccl(ncclCommDestroy(nccl_comm_)); + } + if (xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug)) { + LOG(CONSOLE) << "======== NCCL Statistics========"; + LOG(CONSOLE) << "AllReduce calls: " << allreduce_calls_; + LOG(CONSOLE) << "AllReduce total MiB communicated: " << allreduce_bytes_ / 1048576; + } +} + +namespace { +ncclDataType_t GetNcclDataType(DataType const &data_type) { + ncclDataType_t result; + switch (data_type) { + case DataType::kInt8: + result = ncclInt8; + break; + case DataType::kUInt8: + result = ncclUint8; + break; + case DataType::kInt32: + result = ncclInt32; + break; + case DataType::kUInt32: + result = ncclUint32; + break; + case DataType::kInt64: + result = ncclInt64; + break; + case DataType::kUInt64: + result = ncclUint64; + break; + case DataType::kFloat: + result = ncclFloat; + break; + case DataType::kDouble: + result = ncclDouble; + break; + default: + LOG(FATAL) << "Unknown data type."; + } + return result; +} + +bool IsBitwiseOp(Operation const &op) { + return op == Operation::kBitwiseAND || op == Operation::kBitwiseOR || + op == Operation::kBitwiseXOR; +} + +ncclRedOp_t GetNcclRedOp(Operation const &op) { + ncclRedOp_t result; + switch (op) { + case Operation::kMax: + result = ncclMax; + break; + case Operation::kMin: + result = ncclMin; + break; + case Operation::kSum: + result = ncclSum; + break; + default: + LOG(FATAL) << "Unsupported reduce operation."; + } + return result; +} + +template +void RunBitwiseAllreduce(char *out_buffer, char const *device_buffer, Func func, int world_size, + std::size_t size, cudaStream_t stream) { + dh::LaunchN(size, stream, [=] __device__(std::size_t idx) { + out_buffer[idx] = device_buffer[idx]; + for (auto rank = 1; rank < world_size; rank++) { + out_buffer[idx] = func(out_buffer[idx], device_buffer[rank * size + idx]); + } + }); +} +} // anonymous namespace + +void NcclDeviceCommunicator::BitwiseAllReduce(void *send_receive_buffer, std::size_t count, + DataType data_type, Operation op) { + auto const world_size = communicator_->GetWorldSize(); + auto const size = count * GetTypeSize(data_type); + dh::caching_device_vector buffer(size * world_size); + auto *device_buffer = buffer.data().get(); + + // First gather data from all the workers. + dh::safe_nccl(ncclAllGather(send_receive_buffer, device_buffer, count, GetNcclDataType(data_type), + nccl_comm_, cuda_stream_)); + + // Then reduce locally. + auto *out_buffer = static_cast(send_receive_buffer); + switch (op) { + case Operation::kBitwiseAND: + RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_and(), world_size, size, + cuda_stream_); + break; + case Operation::kBitwiseOR: + RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_or(), world_size, size, + cuda_stream_); + break; + case Operation::kBitwiseXOR: + RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_xor(), world_size, size, + cuda_stream_); + break; + default: + LOG(FATAL) << "Not a bitwise reduce operation."; + } +} + +void NcclDeviceCommunicator::AllReduce(void *send_receive_buffer, std::size_t count, + DataType data_type, Operation op) { + if (communicator_->GetWorldSize() == 1) { + return; + } + + dh::safe_cuda(cudaSetDevice(device_ordinal_)); + if (IsBitwiseOp(op)) { + BitwiseAllReduce(send_receive_buffer, count, data_type, op); + } else { + dh::safe_nccl(ncclAllReduce(send_receive_buffer, send_receive_buffer, count, + GetNcclDataType(data_type), GetNcclRedOp(op), nccl_comm_, + cuda_stream_)); + } + allreduce_bytes_ += count * GetTypeSize(data_type); + allreduce_calls_ += 1; +} + +void NcclDeviceCommunicator::AllGatherV(void const *send_buffer, size_t length_bytes, + std::vector *segments, + dh::caching_device_vector *receive_buffer) { + if (communicator_->GetWorldSize() == 1) { + return; + } + + dh::safe_cuda(cudaSetDevice(device_ordinal_)); + int const world_size = communicator_->GetWorldSize(); + int const rank = communicator_->GetRank(); + + segments->clear(); + segments->resize(world_size, 0); + segments->at(rank) = length_bytes; + communicator_->AllReduce(segments->data(), segments->size(), DataType::kUInt64, Operation::kMax); + auto total_bytes = std::accumulate(segments->cbegin(), segments->cend(), 0UL); + receive_buffer->resize(total_bytes); + + size_t offset = 0; + dh::safe_nccl(ncclGroupStart()); + for (int32_t i = 0; i < world_size; ++i) { + size_t as_bytes = segments->at(i); + dh::safe_nccl(ncclBroadcast(send_buffer, receive_buffer->data().get() + offset, as_bytes, + ncclChar, i, nccl_comm_, cuda_stream_)); + offset += as_bytes; + } + dh::safe_nccl(ncclGroupEnd()); +} + +void NcclDeviceCommunicator::Synchronize() { + if (communicator_->GetWorldSize() == 1) { + return; + } + dh::safe_cuda(cudaSetDevice(device_ordinal_)); + dh::safe_cuda(cudaStreamSynchronize(cuda_stream_)); +} + +} // namespace collective +} // namespace xgboost +#endif diff --git a/src/collective/nccl_device_communicator.cuh b/src/collective/nccl_device_communicator.cuh index 4e58fc5ba..e5f76119d 100644 --- a/src/collective/nccl_device_communicator.cuh +++ b/src/collective/nccl_device_communicator.cuh @@ -1,5 +1,5 @@ /*! - * Copyright 2022 XGBoost contributors + * Copyright 2022-2023 XGBoost contributors */ #pragma once @@ -12,116 +12,13 @@ namespace collective { class NcclDeviceCommunicator : public DeviceCommunicator { public: - NcclDeviceCommunicator(int device_ordinal, Communicator *communicator) - : device_ordinal_{device_ordinal}, communicator_{communicator} { - if (device_ordinal_ < 0) { - LOG(FATAL) << "Invalid device ordinal: " << device_ordinal_; - } - if (communicator_ == nullptr) { - LOG(FATAL) << "Communicator cannot be null."; - } - - int32_t const rank = communicator_->GetRank(); - int32_t const world = communicator_->GetWorldSize(); - - if (world == 1) { - return; - } - - std::vector uuids(world * kUuidLength, 0); - auto s_uuid = xgboost::common::Span{uuids.data(), uuids.size()}; - auto s_this_uuid = s_uuid.subspan(rank * kUuidLength, kUuidLength); - GetCudaUUID(s_this_uuid); - - // TODO(rongou): replace this with allgather. - communicator_->AllReduce(uuids.data(), uuids.size(), DataType::kUInt64, Operation::kSum); - - std::vector> converted(world); - size_t j = 0; - for (size_t i = 0; i < uuids.size(); i += kUuidLength) { - converted[j] = xgboost::common::Span{uuids.data() + i, kUuidLength}; - j++; - } - - auto iter = std::unique(converted.begin(), converted.end()); - auto n_uniques = std::distance(converted.begin(), iter); - - CHECK_EQ(n_uniques, world) - << "Multiple processes within communication group running on same CUDA " - << "device is not supported. " << PrintUUID(s_this_uuid) << "\n"; - - nccl_unique_id_ = GetUniqueId(); - dh::safe_nccl(ncclCommInitRank(&nccl_comm_, world, nccl_unique_id_, rank)); - dh::safe_cuda(cudaStreamCreate(&cuda_stream_)); - } - - ~NcclDeviceCommunicator() override { - if (communicator_->GetWorldSize() == 1) { - return; - } - if (cuda_stream_) { - dh::safe_cuda(cudaStreamDestroy(cuda_stream_)); - } - if (nccl_comm_) { - dh::safe_nccl(ncclCommDestroy(nccl_comm_)); - } - if (xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug)) { - LOG(CONSOLE) << "======== NCCL Statistics========"; - LOG(CONSOLE) << "AllReduce calls: " << allreduce_calls_; - LOG(CONSOLE) << "AllReduce total MiB communicated: " << allreduce_bytes_ / 1048576; - } - } - + NcclDeviceCommunicator(int device_ordinal, Communicator *communicator); + ~NcclDeviceCommunicator() override; void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type, - Operation op) override { - if (communicator_->GetWorldSize() == 1) { - return; - } - - dh::safe_cuda(cudaSetDevice(device_ordinal_)); - dh::safe_nccl(ncclAllReduce(send_receive_buffer, send_receive_buffer, count, - GetNcclDataType(data_type), GetNcclRedOp(op), nccl_comm_, - cuda_stream_)); - allreduce_bytes_ += count * GetTypeSize(data_type); - allreduce_calls_ += 1; - } - + Operation op) override; void AllGatherV(void const *send_buffer, size_t length_bytes, std::vector *segments, - dh::caching_device_vector *receive_buffer) override { - if (communicator_->GetWorldSize() == 1) { - return; - } - - dh::safe_cuda(cudaSetDevice(device_ordinal_)); - int const world_size = communicator_->GetWorldSize(); - int const rank = communicator_->GetRank(); - - segments->clear(); - segments->resize(world_size, 0); - segments->at(rank) = length_bytes; - communicator_->AllReduce(segments->data(), segments->size(), DataType::kUInt64, - Operation::kMax); - auto total_bytes = std::accumulate(segments->cbegin(), segments->cend(), 0UL); - receive_buffer->resize(total_bytes); - - size_t offset = 0; - dh::safe_nccl(ncclGroupStart()); - for (int32_t i = 0; i < world_size; ++i) { - size_t as_bytes = segments->at(i); - dh::safe_nccl(ncclBroadcast(send_buffer, receive_buffer->data().get() + offset, as_bytes, - ncclChar, i, nccl_comm_, cuda_stream_)); - offset += as_bytes; - } - dh::safe_nccl(ncclGroupEnd()); - } - - void Synchronize() override { - if (communicator_->GetWorldSize() == 1) { - return; - } - dh::safe_cuda(cudaSetDevice(device_ordinal_)); - dh::safe_cuda(cudaStreamSynchronize(cuda_stream_)); - } + dh::caching_device_vector *receive_buffer) override; + void Synchronize() override; private: static constexpr std::size_t kUuidLength = @@ -160,60 +57,8 @@ class NcclDeviceCommunicator : public DeviceCommunicator { return id; } - static ncclDataType_t GetNcclDataType(DataType const &data_type) { - ncclDataType_t result; - switch (data_type) { - case DataType::kInt8: - result = ncclInt8; - break; - case DataType::kUInt8: - result = ncclUint8; - break; - case DataType::kInt32: - result = ncclInt32; - break; - case DataType::kUInt32: - result = ncclUint32; - break; - case DataType::kInt64: - result = ncclInt64; - break; - case DataType::kUInt64: - result = ncclUint64; - break; - case DataType::kFloat: - result = ncclFloat; - break; - case DataType::kDouble: - result = ncclDouble; - break; - default: - LOG(FATAL) << "Unknown data type."; - } - return result; - } - - static ncclRedOp_t GetNcclRedOp(Operation const &op) { - ncclRedOp_t result; - switch (op) { - case Operation::kMax: - result = ncclMax; - break; - case Operation::kMin: - result = ncclMin; - break; - case Operation::kSum: - result = ncclSum; - break; - case Operation::kBitwiseAND: - case Operation::kBitwiseOR: - case Operation::kBitwiseXOR: - LOG(FATAL) << "Not implemented yet."; - default: - LOG(FATAL) << "Unknown reduce operation."; - } - return result; - } + void BitwiseAllReduce(void *send_receive_buffer, std::size_t count, DataType data_type, + Operation op); int const device_ordinal_; Communicator *communicator_; diff --git a/tests/cpp/collective/test_nccl_device_communicator.cu b/tests/cpp/collective/test_nccl_device_communicator.cu index 6d3203522..6ac861a55 100644 --- a/tests/cpp/collective/test_nccl_device_communicator.cu +++ b/tests/cpp/collective/test_nccl_device_communicator.cu @@ -5,10 +5,12 @@ #include +#include #include // for string -#include "../../../src/collective/nccl_device_communicator.cuh" #include "../../../src/collective/communicator-inl.cuh" +#include "../../../src/collective/nccl_device_communicator.cuh" +#include "../helpers.h" namespace xgboost { namespace collective { @@ -31,6 +33,69 @@ TEST(NcclDeviceCommunicatorSimpleTest, SystemError) { ASSERT_TRUE(str.find("environment variables") != std::string::npos); } } + +namespace { +void VerifyAllReduceBitwiseAND() { + auto const rank = collective::GetRank(); + std::bitset<64> original{}; + original[rank] = true; + HostDeviceVector buffer({original.to_ullong()}, rank); + collective::AllReduce(rank, buffer.DevicePointer(), 1); + collective::Synchronize(rank); + EXPECT_EQ(buffer.HostVector()[0], 0ULL); +} +} // anonymous namespace + +TEST(NcclDeviceCommunicator, MGPUAllReduceBitwiseAND) { + auto const n_gpus = common::AllVisibleGPUs(); + if (n_gpus <= 1) { + GTEST_SKIP() << "Skipping MGPUAllReduceBitwiseAND test with # GPUs = " << n_gpus; + } + RunWithInMemoryCommunicator(n_gpus, VerifyAllReduceBitwiseAND); +} + +namespace { +void VerifyAllReduceBitwiseOR() { + auto const world_size = collective::GetWorldSize(); + auto const rank = collective::GetRank(); + std::bitset<64> original{}; + original[rank] = true; + HostDeviceVector buffer({original.to_ullong()}, rank); + collective::AllReduce(rank, buffer.DevicePointer(), 1); + collective::Synchronize(rank); + EXPECT_EQ(buffer.HostVector()[0], (1ULL << world_size) - 1); +} +} // anonymous namespace + +TEST(NcclDeviceCommunicator, MGPUAllReduceBitwiseOR) { + auto const n_gpus = common::AllVisibleGPUs(); + if (n_gpus <= 1) { + GTEST_SKIP() << "Skipping MGPUAllReduceBitwiseOR test with # GPUs = " << n_gpus; + } + RunWithInMemoryCommunicator(n_gpus, VerifyAllReduceBitwiseOR); +} + +namespace { +void VerifyAllReduceBitwiseXOR() { + auto const world_size = collective::GetWorldSize(); + auto const rank = collective::GetRank(); + std::bitset<64> original{~0ULL}; + original[rank] = false; + HostDeviceVector buffer({original.to_ullong()}, rank); + collective::AllReduce(rank, buffer.DevicePointer(), 1); + collective::Synchronize(rank); + EXPECT_EQ(buffer.HostVector()[0], (1ULL << world_size) - 1); +} +} // anonymous namespace + +TEST(NcclDeviceCommunicator, MGPUAllReduceBitwiseXOR) { + auto const n_gpus = common::AllVisibleGPUs(); + if (n_gpus <= 1) { + GTEST_SKIP() << "Skipping MGPUAllReduceBitwiseXOR test with # GPUs = " << n_gpus; + } + RunWithInMemoryCommunicator(n_gpus, VerifyAllReduceBitwiseXOR); +} + } // namespace collective } // namespace xgboost