Support bitwise allreduce in NCCL communicator (#9300)
This commit is contained in:
parent
2718ff530c
commit
d8beb517ed
228
src/collective/nccl_device_communicator.cu
Normal file
228
src/collective/nccl_device_communicator.cu
Normal file
@ -0,0 +1,228 @@
|
|||||||
|
/*!
|
||||||
|
* Copyright 2023 XGBoost contributors
|
||||||
|
*/
|
||||||
|
#if defined(XGBOOST_USE_NCCL)
|
||||||
|
#include "nccl_device_communicator.cuh"
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
namespace collective {
|
||||||
|
|
||||||
|
NcclDeviceCommunicator::NcclDeviceCommunicator(int device_ordinal, Communicator *communicator)
|
||||||
|
: device_ordinal_{device_ordinal}, communicator_{communicator} {
|
||||||
|
if (device_ordinal_ < 0) {
|
||||||
|
LOG(FATAL) << "Invalid device ordinal: " << device_ordinal_;
|
||||||
|
}
|
||||||
|
if (communicator_ == nullptr) {
|
||||||
|
LOG(FATAL) << "Communicator cannot be null.";
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t const rank = communicator_->GetRank();
|
||||||
|
int32_t const world = communicator_->GetWorldSize();
|
||||||
|
|
||||||
|
if (world == 1) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<uint64_t> uuids(world * kUuidLength, 0);
|
||||||
|
auto s_uuid = xgboost::common::Span<uint64_t>{uuids.data(), uuids.size()};
|
||||||
|
auto s_this_uuid = s_uuid.subspan(rank * kUuidLength, kUuidLength);
|
||||||
|
GetCudaUUID(s_this_uuid);
|
||||||
|
|
||||||
|
// TODO(rongou): replace this with allgather.
|
||||||
|
communicator_->AllReduce(uuids.data(), uuids.size(), DataType::kUInt64, Operation::kSum);
|
||||||
|
|
||||||
|
std::vector<xgboost::common::Span<uint64_t, kUuidLength>> converted(world);
|
||||||
|
size_t j = 0;
|
||||||
|
for (size_t i = 0; i < uuids.size(); i += kUuidLength) {
|
||||||
|
converted[j] = xgboost::common::Span<uint64_t, kUuidLength>{uuids.data() + i, kUuidLength};
|
||||||
|
j++;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto iter = std::unique(converted.begin(), converted.end());
|
||||||
|
auto n_uniques = std::distance(converted.begin(), iter);
|
||||||
|
|
||||||
|
CHECK_EQ(n_uniques, world)
|
||||||
|
<< "Multiple processes within communication group running on same CUDA "
|
||||||
|
<< "device is not supported. " << PrintUUID(s_this_uuid) << "\n";
|
||||||
|
|
||||||
|
nccl_unique_id_ = GetUniqueId();
|
||||||
|
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||||
|
dh::safe_nccl(ncclCommInitRank(&nccl_comm_, world, nccl_unique_id_, rank));
|
||||||
|
dh::safe_cuda(cudaStreamCreate(&cuda_stream_));
|
||||||
|
}
|
||||||
|
|
||||||
|
NcclDeviceCommunicator::~NcclDeviceCommunicator() {
|
||||||
|
if (communicator_->GetWorldSize() == 1) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (cuda_stream_) {
|
||||||
|
dh::safe_cuda(cudaStreamDestroy(cuda_stream_));
|
||||||
|
}
|
||||||
|
if (nccl_comm_) {
|
||||||
|
dh::safe_nccl(ncclCommDestroy(nccl_comm_));
|
||||||
|
}
|
||||||
|
if (xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug)) {
|
||||||
|
LOG(CONSOLE) << "======== NCCL Statistics========";
|
||||||
|
LOG(CONSOLE) << "AllReduce calls: " << allreduce_calls_;
|
||||||
|
LOG(CONSOLE) << "AllReduce total MiB communicated: " << allreduce_bytes_ / 1048576;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
ncclDataType_t GetNcclDataType(DataType const &data_type) {
|
||||||
|
ncclDataType_t result;
|
||||||
|
switch (data_type) {
|
||||||
|
case DataType::kInt8:
|
||||||
|
result = ncclInt8;
|
||||||
|
break;
|
||||||
|
case DataType::kUInt8:
|
||||||
|
result = ncclUint8;
|
||||||
|
break;
|
||||||
|
case DataType::kInt32:
|
||||||
|
result = ncclInt32;
|
||||||
|
break;
|
||||||
|
case DataType::kUInt32:
|
||||||
|
result = ncclUint32;
|
||||||
|
break;
|
||||||
|
case DataType::kInt64:
|
||||||
|
result = ncclInt64;
|
||||||
|
break;
|
||||||
|
case DataType::kUInt64:
|
||||||
|
result = ncclUint64;
|
||||||
|
break;
|
||||||
|
case DataType::kFloat:
|
||||||
|
result = ncclFloat;
|
||||||
|
break;
|
||||||
|
case DataType::kDouble:
|
||||||
|
result = ncclDouble;
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
LOG(FATAL) << "Unknown data type.";
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool IsBitwiseOp(Operation const &op) {
|
||||||
|
return op == Operation::kBitwiseAND || op == Operation::kBitwiseOR ||
|
||||||
|
op == Operation::kBitwiseXOR;
|
||||||
|
}
|
||||||
|
|
||||||
|
ncclRedOp_t GetNcclRedOp(Operation const &op) {
|
||||||
|
ncclRedOp_t result;
|
||||||
|
switch (op) {
|
||||||
|
case Operation::kMax:
|
||||||
|
result = ncclMax;
|
||||||
|
break;
|
||||||
|
case Operation::kMin:
|
||||||
|
result = ncclMin;
|
||||||
|
break;
|
||||||
|
case Operation::kSum:
|
||||||
|
result = ncclSum;
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
LOG(FATAL) << "Unsupported reduce operation.";
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Func>
|
||||||
|
void RunBitwiseAllreduce(char *out_buffer, char const *device_buffer, Func func, int world_size,
|
||||||
|
std::size_t size, cudaStream_t stream) {
|
||||||
|
dh::LaunchN(size, stream, [=] __device__(std::size_t idx) {
|
||||||
|
out_buffer[idx] = device_buffer[idx];
|
||||||
|
for (auto rank = 1; rank < world_size; rank++) {
|
||||||
|
out_buffer[idx] = func(out_buffer[idx], device_buffer[rank * size + idx]);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
} // anonymous namespace
|
||||||
|
|
||||||
|
void NcclDeviceCommunicator::BitwiseAllReduce(void *send_receive_buffer, std::size_t count,
|
||||||
|
DataType data_type, Operation op) {
|
||||||
|
auto const world_size = communicator_->GetWorldSize();
|
||||||
|
auto const size = count * GetTypeSize(data_type);
|
||||||
|
dh::caching_device_vector<char> buffer(size * world_size);
|
||||||
|
auto *device_buffer = buffer.data().get();
|
||||||
|
|
||||||
|
// First gather data from all the workers.
|
||||||
|
dh::safe_nccl(ncclAllGather(send_receive_buffer, device_buffer, count, GetNcclDataType(data_type),
|
||||||
|
nccl_comm_, cuda_stream_));
|
||||||
|
|
||||||
|
// Then reduce locally.
|
||||||
|
auto *out_buffer = static_cast<char *>(send_receive_buffer);
|
||||||
|
switch (op) {
|
||||||
|
case Operation::kBitwiseAND:
|
||||||
|
RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_and<char>(), world_size, size,
|
||||||
|
cuda_stream_);
|
||||||
|
break;
|
||||||
|
case Operation::kBitwiseOR:
|
||||||
|
RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_or<char>(), world_size, size,
|
||||||
|
cuda_stream_);
|
||||||
|
break;
|
||||||
|
case Operation::kBitwiseXOR:
|
||||||
|
RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_xor<char>(), world_size, size,
|
||||||
|
cuda_stream_);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
LOG(FATAL) << "Not a bitwise reduce operation.";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void NcclDeviceCommunicator::AllReduce(void *send_receive_buffer, std::size_t count,
|
||||||
|
DataType data_type, Operation op) {
|
||||||
|
if (communicator_->GetWorldSize() == 1) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||||
|
if (IsBitwiseOp(op)) {
|
||||||
|
BitwiseAllReduce(send_receive_buffer, count, data_type, op);
|
||||||
|
} else {
|
||||||
|
dh::safe_nccl(ncclAllReduce(send_receive_buffer, send_receive_buffer, count,
|
||||||
|
GetNcclDataType(data_type), GetNcclRedOp(op), nccl_comm_,
|
||||||
|
cuda_stream_));
|
||||||
|
}
|
||||||
|
allreduce_bytes_ += count * GetTypeSize(data_type);
|
||||||
|
allreduce_calls_ += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
void NcclDeviceCommunicator::AllGatherV(void const *send_buffer, size_t length_bytes,
|
||||||
|
std::vector<std::size_t> *segments,
|
||||||
|
dh::caching_device_vector<char> *receive_buffer) {
|
||||||
|
if (communicator_->GetWorldSize() == 1) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||||
|
int const world_size = communicator_->GetWorldSize();
|
||||||
|
int const rank = communicator_->GetRank();
|
||||||
|
|
||||||
|
segments->clear();
|
||||||
|
segments->resize(world_size, 0);
|
||||||
|
segments->at(rank) = length_bytes;
|
||||||
|
communicator_->AllReduce(segments->data(), segments->size(), DataType::kUInt64, Operation::kMax);
|
||||||
|
auto total_bytes = std::accumulate(segments->cbegin(), segments->cend(), 0UL);
|
||||||
|
receive_buffer->resize(total_bytes);
|
||||||
|
|
||||||
|
size_t offset = 0;
|
||||||
|
dh::safe_nccl(ncclGroupStart());
|
||||||
|
for (int32_t i = 0; i < world_size; ++i) {
|
||||||
|
size_t as_bytes = segments->at(i);
|
||||||
|
dh::safe_nccl(ncclBroadcast(send_buffer, receive_buffer->data().get() + offset, as_bytes,
|
||||||
|
ncclChar, i, nccl_comm_, cuda_stream_));
|
||||||
|
offset += as_bytes;
|
||||||
|
}
|
||||||
|
dh::safe_nccl(ncclGroupEnd());
|
||||||
|
}
|
||||||
|
|
||||||
|
void NcclDeviceCommunicator::Synchronize() {
|
||||||
|
if (communicator_->GetWorldSize() == 1) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||||
|
dh::safe_cuda(cudaStreamSynchronize(cuda_stream_));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace collective
|
||||||
|
} // namespace xgboost
|
||||||
|
#endif
|
||||||
@ -1,5 +1,5 @@
|
|||||||
/*!
|
/*!
|
||||||
* Copyright 2022 XGBoost contributors
|
* Copyright 2022-2023 XGBoost contributors
|
||||||
*/
|
*/
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
@ -12,116 +12,13 @@ namespace collective {
|
|||||||
|
|
||||||
class NcclDeviceCommunicator : public DeviceCommunicator {
|
class NcclDeviceCommunicator : public DeviceCommunicator {
|
||||||
public:
|
public:
|
||||||
NcclDeviceCommunicator(int device_ordinal, Communicator *communicator)
|
NcclDeviceCommunicator(int device_ordinal, Communicator *communicator);
|
||||||
: device_ordinal_{device_ordinal}, communicator_{communicator} {
|
~NcclDeviceCommunicator() override;
|
||||||
if (device_ordinal_ < 0) {
|
|
||||||
LOG(FATAL) << "Invalid device ordinal: " << device_ordinal_;
|
|
||||||
}
|
|
||||||
if (communicator_ == nullptr) {
|
|
||||||
LOG(FATAL) << "Communicator cannot be null.";
|
|
||||||
}
|
|
||||||
|
|
||||||
int32_t const rank = communicator_->GetRank();
|
|
||||||
int32_t const world = communicator_->GetWorldSize();
|
|
||||||
|
|
||||||
if (world == 1) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<uint64_t> uuids(world * kUuidLength, 0);
|
|
||||||
auto s_uuid = xgboost::common::Span<uint64_t>{uuids.data(), uuids.size()};
|
|
||||||
auto s_this_uuid = s_uuid.subspan(rank * kUuidLength, kUuidLength);
|
|
||||||
GetCudaUUID(s_this_uuid);
|
|
||||||
|
|
||||||
// TODO(rongou): replace this with allgather.
|
|
||||||
communicator_->AllReduce(uuids.data(), uuids.size(), DataType::kUInt64, Operation::kSum);
|
|
||||||
|
|
||||||
std::vector<xgboost::common::Span<uint64_t, kUuidLength>> converted(world);
|
|
||||||
size_t j = 0;
|
|
||||||
for (size_t i = 0; i < uuids.size(); i += kUuidLength) {
|
|
||||||
converted[j] = xgboost::common::Span<uint64_t, kUuidLength>{uuids.data() + i, kUuidLength};
|
|
||||||
j++;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto iter = std::unique(converted.begin(), converted.end());
|
|
||||||
auto n_uniques = std::distance(converted.begin(), iter);
|
|
||||||
|
|
||||||
CHECK_EQ(n_uniques, world)
|
|
||||||
<< "Multiple processes within communication group running on same CUDA "
|
|
||||||
<< "device is not supported. " << PrintUUID(s_this_uuid) << "\n";
|
|
||||||
|
|
||||||
nccl_unique_id_ = GetUniqueId();
|
|
||||||
dh::safe_nccl(ncclCommInitRank(&nccl_comm_, world, nccl_unique_id_, rank));
|
|
||||||
dh::safe_cuda(cudaStreamCreate(&cuda_stream_));
|
|
||||||
}
|
|
||||||
|
|
||||||
~NcclDeviceCommunicator() override {
|
|
||||||
if (communicator_->GetWorldSize() == 1) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
if (cuda_stream_) {
|
|
||||||
dh::safe_cuda(cudaStreamDestroy(cuda_stream_));
|
|
||||||
}
|
|
||||||
if (nccl_comm_) {
|
|
||||||
dh::safe_nccl(ncclCommDestroy(nccl_comm_));
|
|
||||||
}
|
|
||||||
if (xgboost::ConsoleLogger::ShouldLog(xgboost::ConsoleLogger::LV::kDebug)) {
|
|
||||||
LOG(CONSOLE) << "======== NCCL Statistics========";
|
|
||||||
LOG(CONSOLE) << "AllReduce calls: " << allreduce_calls_;
|
|
||||||
LOG(CONSOLE) << "AllReduce total MiB communicated: " << allreduce_bytes_ / 1048576;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
|
void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
|
||||||
Operation op) override {
|
Operation op) override;
|
||||||
if (communicator_->GetWorldSize() == 1) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
|
||||||
dh::safe_nccl(ncclAllReduce(send_receive_buffer, send_receive_buffer, count,
|
|
||||||
GetNcclDataType(data_type), GetNcclRedOp(op), nccl_comm_,
|
|
||||||
cuda_stream_));
|
|
||||||
allreduce_bytes_ += count * GetTypeSize(data_type);
|
|
||||||
allreduce_calls_ += 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
void AllGatherV(void const *send_buffer, size_t length_bytes, std::vector<std::size_t> *segments,
|
void AllGatherV(void const *send_buffer, size_t length_bytes, std::vector<std::size_t> *segments,
|
||||||
dh::caching_device_vector<char> *receive_buffer) override {
|
dh::caching_device_vector<char> *receive_buffer) override;
|
||||||
if (communicator_->GetWorldSize() == 1) {
|
void Synchronize() override;
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
|
||||||
int const world_size = communicator_->GetWorldSize();
|
|
||||||
int const rank = communicator_->GetRank();
|
|
||||||
|
|
||||||
segments->clear();
|
|
||||||
segments->resize(world_size, 0);
|
|
||||||
segments->at(rank) = length_bytes;
|
|
||||||
communicator_->AllReduce(segments->data(), segments->size(), DataType::kUInt64,
|
|
||||||
Operation::kMax);
|
|
||||||
auto total_bytes = std::accumulate(segments->cbegin(), segments->cend(), 0UL);
|
|
||||||
receive_buffer->resize(total_bytes);
|
|
||||||
|
|
||||||
size_t offset = 0;
|
|
||||||
dh::safe_nccl(ncclGroupStart());
|
|
||||||
for (int32_t i = 0; i < world_size; ++i) {
|
|
||||||
size_t as_bytes = segments->at(i);
|
|
||||||
dh::safe_nccl(ncclBroadcast(send_buffer, receive_buffer->data().get() + offset, as_bytes,
|
|
||||||
ncclChar, i, nccl_comm_, cuda_stream_));
|
|
||||||
offset += as_bytes;
|
|
||||||
}
|
|
||||||
dh::safe_nccl(ncclGroupEnd());
|
|
||||||
}
|
|
||||||
|
|
||||||
void Synchronize() override {
|
|
||||||
if (communicator_->GetWorldSize() == 1) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
|
||||||
dh::safe_cuda(cudaStreamSynchronize(cuda_stream_));
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
static constexpr std::size_t kUuidLength =
|
static constexpr std::size_t kUuidLength =
|
||||||
@ -160,60 +57,8 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
|
|||||||
return id;
|
return id;
|
||||||
}
|
}
|
||||||
|
|
||||||
static ncclDataType_t GetNcclDataType(DataType const &data_type) {
|
void BitwiseAllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
|
||||||
ncclDataType_t result;
|
Operation op);
|
||||||
switch (data_type) {
|
|
||||||
case DataType::kInt8:
|
|
||||||
result = ncclInt8;
|
|
||||||
break;
|
|
||||||
case DataType::kUInt8:
|
|
||||||
result = ncclUint8;
|
|
||||||
break;
|
|
||||||
case DataType::kInt32:
|
|
||||||
result = ncclInt32;
|
|
||||||
break;
|
|
||||||
case DataType::kUInt32:
|
|
||||||
result = ncclUint32;
|
|
||||||
break;
|
|
||||||
case DataType::kInt64:
|
|
||||||
result = ncclInt64;
|
|
||||||
break;
|
|
||||||
case DataType::kUInt64:
|
|
||||||
result = ncclUint64;
|
|
||||||
break;
|
|
||||||
case DataType::kFloat:
|
|
||||||
result = ncclFloat;
|
|
||||||
break;
|
|
||||||
case DataType::kDouble:
|
|
||||||
result = ncclDouble;
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
LOG(FATAL) << "Unknown data type.";
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
static ncclRedOp_t GetNcclRedOp(Operation const &op) {
|
|
||||||
ncclRedOp_t result;
|
|
||||||
switch (op) {
|
|
||||||
case Operation::kMax:
|
|
||||||
result = ncclMax;
|
|
||||||
break;
|
|
||||||
case Operation::kMin:
|
|
||||||
result = ncclMin;
|
|
||||||
break;
|
|
||||||
case Operation::kSum:
|
|
||||||
result = ncclSum;
|
|
||||||
break;
|
|
||||||
case Operation::kBitwiseAND:
|
|
||||||
case Operation::kBitwiseOR:
|
|
||||||
case Operation::kBitwiseXOR:
|
|
||||||
LOG(FATAL) << "Not implemented yet.";
|
|
||||||
default:
|
|
||||||
LOG(FATAL) << "Unknown reduce operation.";
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
int const device_ordinal_;
|
int const device_ordinal_;
|
||||||
Communicator *communicator_;
|
Communicator *communicator_;
|
||||||
|
|||||||
@ -5,10 +5,12 @@
|
|||||||
|
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
#include <bitset>
|
||||||
#include <string> // for string
|
#include <string> // for string
|
||||||
|
|
||||||
#include "../../../src/collective/nccl_device_communicator.cuh"
|
|
||||||
#include "../../../src/collective/communicator-inl.cuh"
|
#include "../../../src/collective/communicator-inl.cuh"
|
||||||
|
#include "../../../src/collective/nccl_device_communicator.cuh"
|
||||||
|
#include "../helpers.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace collective {
|
namespace collective {
|
||||||
@ -31,6 +33,69 @@ TEST(NcclDeviceCommunicatorSimpleTest, SystemError) {
|
|||||||
ASSERT_TRUE(str.find("environment variables") != std::string::npos);
|
ASSERT_TRUE(str.find("environment variables") != std::string::npos);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
void VerifyAllReduceBitwiseAND() {
|
||||||
|
auto const rank = collective::GetRank();
|
||||||
|
std::bitset<64> original{};
|
||||||
|
original[rank] = true;
|
||||||
|
HostDeviceVector<uint64_t> buffer({original.to_ullong()}, rank);
|
||||||
|
collective::AllReduce<collective::Operation::kBitwiseAND>(rank, buffer.DevicePointer(), 1);
|
||||||
|
collective::Synchronize(rank);
|
||||||
|
EXPECT_EQ(buffer.HostVector()[0], 0ULL);
|
||||||
|
}
|
||||||
|
} // anonymous namespace
|
||||||
|
|
||||||
|
TEST(NcclDeviceCommunicator, MGPUAllReduceBitwiseAND) {
|
||||||
|
auto const n_gpus = common::AllVisibleGPUs();
|
||||||
|
if (n_gpus <= 1) {
|
||||||
|
GTEST_SKIP() << "Skipping MGPUAllReduceBitwiseAND test with # GPUs = " << n_gpus;
|
||||||
|
}
|
||||||
|
RunWithInMemoryCommunicator(n_gpus, VerifyAllReduceBitwiseAND);
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
void VerifyAllReduceBitwiseOR() {
|
||||||
|
auto const world_size = collective::GetWorldSize();
|
||||||
|
auto const rank = collective::GetRank();
|
||||||
|
std::bitset<64> original{};
|
||||||
|
original[rank] = true;
|
||||||
|
HostDeviceVector<uint64_t> buffer({original.to_ullong()}, rank);
|
||||||
|
collective::AllReduce<collective::Operation::kBitwiseOR>(rank, buffer.DevicePointer(), 1);
|
||||||
|
collective::Synchronize(rank);
|
||||||
|
EXPECT_EQ(buffer.HostVector()[0], (1ULL << world_size) - 1);
|
||||||
|
}
|
||||||
|
} // anonymous namespace
|
||||||
|
|
||||||
|
TEST(NcclDeviceCommunicator, MGPUAllReduceBitwiseOR) {
|
||||||
|
auto const n_gpus = common::AllVisibleGPUs();
|
||||||
|
if (n_gpus <= 1) {
|
||||||
|
GTEST_SKIP() << "Skipping MGPUAllReduceBitwiseOR test with # GPUs = " << n_gpus;
|
||||||
|
}
|
||||||
|
RunWithInMemoryCommunicator(n_gpus, VerifyAllReduceBitwiseOR);
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
void VerifyAllReduceBitwiseXOR() {
|
||||||
|
auto const world_size = collective::GetWorldSize();
|
||||||
|
auto const rank = collective::GetRank();
|
||||||
|
std::bitset<64> original{~0ULL};
|
||||||
|
original[rank] = false;
|
||||||
|
HostDeviceVector<uint64_t> buffer({original.to_ullong()}, rank);
|
||||||
|
collective::AllReduce<collective::Operation::kBitwiseXOR>(rank, buffer.DevicePointer(), 1);
|
||||||
|
collective::Synchronize(rank);
|
||||||
|
EXPECT_EQ(buffer.HostVector()[0], (1ULL << world_size) - 1);
|
||||||
|
}
|
||||||
|
} // anonymous namespace
|
||||||
|
|
||||||
|
TEST(NcclDeviceCommunicator, MGPUAllReduceBitwiseXOR) {
|
||||||
|
auto const n_gpus = common::AllVisibleGPUs();
|
||||||
|
if (n_gpus <= 1) {
|
||||||
|
GTEST_SKIP() << "Skipping MGPUAllReduceBitwiseXOR test with # GPUs = " << n_gpus;
|
||||||
|
}
|
||||||
|
RunWithInMemoryCommunicator(n_gpus, VerifyAllReduceBitwiseXOR);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace collective
|
} // namespace collective
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user