Fix device communicator dependency (#9346)

This commit is contained in:
Rong Ou
2023-06-28 19:34:30 -07:00
committed by GitHub
parent f4798718c7
commit f90771eec6
10 changed files with 107 additions and 123 deletions

View File

@@ -30,12 +30,12 @@ DeviceCommunicator* Communicator::GetDevice(int device_ordinal) {
old_world_size = communicator_->GetWorldSize();
#ifdef XGBOOST_USE_NCCL
if (type_ != CommunicatorType::kFederated) {
device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal, Get()));
device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal));
} else {
device_communicator_.reset(new DeviceCommunicatorAdapter(device_ordinal, Get()));
device_communicator_.reset(new DeviceCommunicatorAdapter(device_ordinal));
}
#else
device_communicator_.reset(new DeviceCommunicatorAdapter(device_ordinal, Get()));
device_communicator_.reset(new DeviceCommunicatorAdapter(device_ordinal));
#endif
}
return device_communicator_.get();

View File

@@ -11,21 +11,18 @@ namespace collective {
class DeviceCommunicatorAdapter : public DeviceCommunicator {
public:
DeviceCommunicatorAdapter(int device_ordinal, Communicator *communicator)
: device_ordinal_{device_ordinal}, communicator_{communicator} {
explicit DeviceCommunicatorAdapter(int device_ordinal)
: device_ordinal_{device_ordinal}, world_size_{GetWorldSize()}, rank_{GetRank()} {
if (device_ordinal_ < 0) {
LOG(FATAL) << "Invalid device ordinal: " << device_ordinal_;
}
if (communicator_ == nullptr) {
LOG(FATAL) << "Communicator cannot be null.";
}
}
~DeviceCommunicatorAdapter() override = default;
void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
Operation op) override {
if (communicator_->GetWorldSize() == 1) {
if (world_size_ == 1) {
return;
}
@@ -33,37 +30,34 @@ class DeviceCommunicatorAdapter : public DeviceCommunicator {
auto size = count * GetTypeSize(data_type);
host_buffer_.reserve(size);
dh::safe_cuda(cudaMemcpy(host_buffer_.data(), send_receive_buffer, size, cudaMemcpyDefault));
communicator_->AllReduce(host_buffer_.data(), count, data_type, op);
Allreduce(host_buffer_.data(), count, data_type, op);
dh::safe_cuda(cudaMemcpy(send_receive_buffer, host_buffer_.data(), size, cudaMemcpyDefault));
}
void AllGatherV(void const *send_buffer, size_t length_bytes, std::vector<std::size_t> *segments,
dh::caching_device_vector<char> *receive_buffer) override {
if (communicator_->GetWorldSize() == 1) {
if (world_size_ == 1) {
return;
}
dh::safe_cuda(cudaSetDevice(device_ordinal_));
int const world_size = communicator_->GetWorldSize();
int const rank = communicator_->GetRank();
segments->clear();
segments->resize(world_size, 0);
segments->at(rank) = length_bytes;
communicator_->AllReduce(segments->data(), segments->size(), DataType::kUInt64,
Operation::kMax);
segments->resize(world_size_, 0);
segments->at(rank_) = length_bytes;
Allreduce(segments->data(), segments->size(), DataType::kUInt64, Operation::kMax);
auto total_bytes = std::accumulate(segments->cbegin(), segments->cend(), 0UL);
receive_buffer->resize(total_bytes);
host_buffer_.reserve(total_bytes);
size_t offset = 0;
for (int32_t i = 0; i < world_size; ++i) {
for (int32_t i = 0; i < world_size_; ++i) {
size_t as_bytes = segments->at(i);
if (i == rank) {
dh::safe_cuda(cudaMemcpy(host_buffer_.data() + offset, send_buffer, segments->at(rank),
if (i == rank_) {
dh::safe_cuda(cudaMemcpy(host_buffer_.data() + offset, send_buffer, segments->at(rank_),
cudaMemcpyDefault));
}
communicator_->Broadcast(host_buffer_.data() + offset, as_bytes, i);
Broadcast(host_buffer_.data() + offset, as_bytes, i);
offset += as_bytes;
}
dh::safe_cuda(cudaMemcpy(receive_buffer->data().get(), host_buffer_.data(), total_bytes,
@@ -76,7 +70,8 @@ class DeviceCommunicatorAdapter : public DeviceCommunicator {
private:
int const device_ordinal_;
Communicator *communicator_;
int const world_size_;
int const rank_;
/// Host buffer used to call communicator functions.
std::vector<char> host_buffer_{};
};

View File

@@ -7,31 +7,24 @@
namespace xgboost {
namespace collective {
NcclDeviceCommunicator::NcclDeviceCommunicator(int device_ordinal, Communicator *communicator)
: device_ordinal_{device_ordinal}, communicator_{communicator} {
NcclDeviceCommunicator::NcclDeviceCommunicator(int device_ordinal)
: device_ordinal_{device_ordinal}, world_size_{GetWorldSize()}, rank_{GetRank()} {
if (device_ordinal_ < 0) {
LOG(FATAL) << "Invalid device ordinal: " << device_ordinal_;
}
if (communicator_ == nullptr) {
LOG(FATAL) << "Communicator cannot be null.";
}
int32_t const rank = communicator_->GetRank();
int32_t const world = communicator_->GetWorldSize();
if (world == 1) {
if (world_size_ == 1) {
return;
}
std::vector<uint64_t> uuids(world * kUuidLength, 0);
std::vector<uint64_t> uuids(world_size_ * kUuidLength, 0);
auto s_uuid = xgboost::common::Span<uint64_t>{uuids.data(), uuids.size()};
auto s_this_uuid = s_uuid.subspan(rank * kUuidLength, kUuidLength);
auto s_this_uuid = s_uuid.subspan(rank_ * kUuidLength, kUuidLength);
GetCudaUUID(s_this_uuid);
// TODO(rongou): replace this with allgather.
communicator_->AllReduce(uuids.data(), uuids.size(), DataType::kUInt64, Operation::kSum);
Allreduce(uuids.data(), uuids.size(), DataType::kUInt64, Operation::kSum);
std::vector<xgboost::common::Span<uint64_t, kUuidLength>> converted(world);
std::vector<xgboost::common::Span<uint64_t, kUuidLength>> converted(world_size_);
size_t j = 0;
for (size_t i = 0; i < uuids.size(); i += kUuidLength) {
converted[j] = xgboost::common::Span<uint64_t, kUuidLength>{uuids.data() + i, kUuidLength};
@@ -41,18 +34,18 @@ NcclDeviceCommunicator::NcclDeviceCommunicator(int device_ordinal, Communicator
auto iter = std::unique(converted.begin(), converted.end());
auto n_uniques = std::distance(converted.begin(), iter);
CHECK_EQ(n_uniques, world)
CHECK_EQ(n_uniques, world_size_)
<< "Multiple processes within communication group running on same CUDA "
<< "device is not supported. " << PrintUUID(s_this_uuid) << "\n";
nccl_unique_id_ = GetUniqueId();
dh::safe_cuda(cudaSetDevice(device_ordinal_));
dh::safe_nccl(ncclCommInitRank(&nccl_comm_, world, nccl_unique_id_, rank));
dh::safe_nccl(ncclCommInitRank(&nccl_comm_, world_size_, nccl_unique_id_, rank_));
dh::safe_cuda(cudaStreamCreate(&cuda_stream_));
}
NcclDeviceCommunicator::~NcclDeviceCommunicator() {
if (communicator_->GetWorldSize() == 1) {
if (world_size_ == 1) {
return;
}
if (cuda_stream_) {
@@ -139,9 +132,8 @@ void RunBitwiseAllreduce(char *out_buffer, char const *device_buffer, Func func,
void NcclDeviceCommunicator::BitwiseAllReduce(void *send_receive_buffer, std::size_t count,
DataType data_type, Operation op) {
auto const world_size = communicator_->GetWorldSize();
auto const size = count * GetTypeSize(data_type);
dh::caching_device_vector<char> buffer(size * world_size);
dh::caching_device_vector<char> buffer(size * world_size_);
auto *device_buffer = buffer.data().get();
// First gather data from all the workers.
@@ -152,15 +144,15 @@ void NcclDeviceCommunicator::BitwiseAllReduce(void *send_receive_buffer, std::si
auto *out_buffer = static_cast<char *>(send_receive_buffer);
switch (op) {
case Operation::kBitwiseAND:
RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_and<char>(), world_size, size,
RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_and<char>(), world_size_, size,
cuda_stream_);
break;
case Operation::kBitwiseOR:
RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_or<char>(), world_size, size,
RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_or<char>(), world_size_, size,
cuda_stream_);
break;
case Operation::kBitwiseXOR:
RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_xor<char>(), world_size, size,
RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_xor<char>(), world_size_, size,
cuda_stream_);
break;
default:
@@ -170,7 +162,7 @@ void NcclDeviceCommunicator::BitwiseAllReduce(void *send_receive_buffer, std::si
void NcclDeviceCommunicator::AllReduce(void *send_receive_buffer, std::size_t count,
DataType data_type, Operation op) {
if (communicator_->GetWorldSize() == 1) {
if (world_size_ == 1) {
return;
}
@@ -189,24 +181,22 @@ void NcclDeviceCommunicator::AllReduce(void *send_receive_buffer, std::size_t co
void NcclDeviceCommunicator::AllGatherV(void const *send_buffer, size_t length_bytes,
std::vector<std::size_t> *segments,
dh::caching_device_vector<char> *receive_buffer) {
if (communicator_->GetWorldSize() == 1) {
if (world_size_ == 1) {
return;
}
dh::safe_cuda(cudaSetDevice(device_ordinal_));
int const world_size = communicator_->GetWorldSize();
int const rank = communicator_->GetRank();
segments->clear();
segments->resize(world_size, 0);
segments->at(rank) = length_bytes;
communicator_->AllReduce(segments->data(), segments->size(), DataType::kUInt64, Operation::kMax);
segments->resize(world_size_, 0);
segments->at(rank_) = length_bytes;
Allreduce(segments->data(), segments->size(), DataType::kUInt64, Operation::kMax);
auto total_bytes = std::accumulate(segments->cbegin(), segments->cend(), 0UL);
receive_buffer->resize(total_bytes);
size_t offset = 0;
dh::safe_nccl(ncclGroupStart());
for (int32_t i = 0; i < world_size; ++i) {
for (int32_t i = 0; i < world_size_; ++i) {
size_t as_bytes = segments->at(i);
dh::safe_nccl(ncclBroadcast(send_buffer, receive_buffer->data().get() + offset, as_bytes,
ncclChar, i, nccl_comm_, cuda_stream_));
@@ -216,7 +206,7 @@ void NcclDeviceCommunicator::AllGatherV(void const *send_buffer, size_t length_b
}
void NcclDeviceCommunicator::Synchronize() {
if (communicator_->GetWorldSize() == 1) {
if (world_size_ == 1) {
return;
}
dh::safe_cuda(cudaSetDevice(device_ordinal_));

View File

@@ -12,7 +12,7 @@ namespace collective {
class NcclDeviceCommunicator : public DeviceCommunicator {
public:
NcclDeviceCommunicator(int device_ordinal, Communicator *communicator);
explicit NcclDeviceCommunicator(int device_ordinal);
~NcclDeviceCommunicator() override;
void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
Operation op) override;
@@ -49,11 +49,10 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
ncclUniqueId GetUniqueId() {
static const int kRootRank = 0;
ncclUniqueId id;
if (communicator_->GetRank() == kRootRank) {
if (rank_ == kRootRank) {
dh::safe_nccl(ncclGetUniqueId(&id));
}
communicator_->Broadcast(static_cast<void *>(&id), sizeof(ncclUniqueId),
static_cast<int>(kRootRank));
Broadcast(static_cast<void *>(&id), sizeof(ncclUniqueId), static_cast<int>(kRootRank));
return id;
}
@@ -61,7 +60,8 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
Operation op);
int const device_ordinal_;
Communicator *communicator_;
int const world_size_;
int const rank_;
ncclComm_t nccl_comm_{};
cudaStream_t cuda_stream_{};
ncclUniqueId nccl_unique_id_{};