Fix device communicator dependency (#9346)
This commit is contained in:
@@ -30,12 +30,12 @@ DeviceCommunicator* Communicator::GetDevice(int device_ordinal) {
|
||||
old_world_size = communicator_->GetWorldSize();
|
||||
#ifdef XGBOOST_USE_NCCL
|
||||
if (type_ != CommunicatorType::kFederated) {
|
||||
device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal, Get()));
|
||||
device_communicator_.reset(new NcclDeviceCommunicator(device_ordinal));
|
||||
} else {
|
||||
device_communicator_.reset(new DeviceCommunicatorAdapter(device_ordinal, Get()));
|
||||
device_communicator_.reset(new DeviceCommunicatorAdapter(device_ordinal));
|
||||
}
|
||||
#else
|
||||
device_communicator_.reset(new DeviceCommunicatorAdapter(device_ordinal, Get()));
|
||||
device_communicator_.reset(new DeviceCommunicatorAdapter(device_ordinal));
|
||||
#endif
|
||||
}
|
||||
return device_communicator_.get();
|
||||
|
||||
@@ -11,21 +11,18 @@ namespace collective {
|
||||
|
||||
class DeviceCommunicatorAdapter : public DeviceCommunicator {
|
||||
public:
|
||||
DeviceCommunicatorAdapter(int device_ordinal, Communicator *communicator)
|
||||
: device_ordinal_{device_ordinal}, communicator_{communicator} {
|
||||
explicit DeviceCommunicatorAdapter(int device_ordinal)
|
||||
: device_ordinal_{device_ordinal}, world_size_{GetWorldSize()}, rank_{GetRank()} {
|
||||
if (device_ordinal_ < 0) {
|
||||
LOG(FATAL) << "Invalid device ordinal: " << device_ordinal_;
|
||||
}
|
||||
if (communicator_ == nullptr) {
|
||||
LOG(FATAL) << "Communicator cannot be null.";
|
||||
}
|
||||
}
|
||||
|
||||
~DeviceCommunicatorAdapter() override = default;
|
||||
|
||||
void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
|
||||
Operation op) override {
|
||||
if (communicator_->GetWorldSize() == 1) {
|
||||
if (world_size_ == 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -33,37 +30,34 @@ class DeviceCommunicatorAdapter : public DeviceCommunicator {
|
||||
auto size = count * GetTypeSize(data_type);
|
||||
host_buffer_.reserve(size);
|
||||
dh::safe_cuda(cudaMemcpy(host_buffer_.data(), send_receive_buffer, size, cudaMemcpyDefault));
|
||||
communicator_->AllReduce(host_buffer_.data(), count, data_type, op);
|
||||
Allreduce(host_buffer_.data(), count, data_type, op);
|
||||
dh::safe_cuda(cudaMemcpy(send_receive_buffer, host_buffer_.data(), size, cudaMemcpyDefault));
|
||||
}
|
||||
|
||||
void AllGatherV(void const *send_buffer, size_t length_bytes, std::vector<std::size_t> *segments,
|
||||
dh::caching_device_vector<char> *receive_buffer) override {
|
||||
if (communicator_->GetWorldSize() == 1) {
|
||||
if (world_size_ == 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
int const world_size = communicator_->GetWorldSize();
|
||||
int const rank = communicator_->GetRank();
|
||||
|
||||
segments->clear();
|
||||
segments->resize(world_size, 0);
|
||||
segments->at(rank) = length_bytes;
|
||||
communicator_->AllReduce(segments->data(), segments->size(), DataType::kUInt64,
|
||||
Operation::kMax);
|
||||
segments->resize(world_size_, 0);
|
||||
segments->at(rank_) = length_bytes;
|
||||
Allreduce(segments->data(), segments->size(), DataType::kUInt64, Operation::kMax);
|
||||
auto total_bytes = std::accumulate(segments->cbegin(), segments->cend(), 0UL);
|
||||
receive_buffer->resize(total_bytes);
|
||||
|
||||
host_buffer_.reserve(total_bytes);
|
||||
size_t offset = 0;
|
||||
for (int32_t i = 0; i < world_size; ++i) {
|
||||
for (int32_t i = 0; i < world_size_; ++i) {
|
||||
size_t as_bytes = segments->at(i);
|
||||
if (i == rank) {
|
||||
dh::safe_cuda(cudaMemcpy(host_buffer_.data() + offset, send_buffer, segments->at(rank),
|
||||
if (i == rank_) {
|
||||
dh::safe_cuda(cudaMemcpy(host_buffer_.data() + offset, send_buffer, segments->at(rank_),
|
||||
cudaMemcpyDefault));
|
||||
}
|
||||
communicator_->Broadcast(host_buffer_.data() + offset, as_bytes, i);
|
||||
Broadcast(host_buffer_.data() + offset, as_bytes, i);
|
||||
offset += as_bytes;
|
||||
}
|
||||
dh::safe_cuda(cudaMemcpy(receive_buffer->data().get(), host_buffer_.data(), total_bytes,
|
||||
@@ -76,7 +70,8 @@ class DeviceCommunicatorAdapter : public DeviceCommunicator {
|
||||
|
||||
private:
|
||||
int const device_ordinal_;
|
||||
Communicator *communicator_;
|
||||
int const world_size_;
|
||||
int const rank_;
|
||||
/// Host buffer used to call communicator functions.
|
||||
std::vector<char> host_buffer_{};
|
||||
};
|
||||
|
||||
@@ -7,31 +7,24 @@
|
||||
namespace xgboost {
|
||||
namespace collective {
|
||||
|
||||
NcclDeviceCommunicator::NcclDeviceCommunicator(int device_ordinal, Communicator *communicator)
|
||||
: device_ordinal_{device_ordinal}, communicator_{communicator} {
|
||||
NcclDeviceCommunicator::NcclDeviceCommunicator(int device_ordinal)
|
||||
: device_ordinal_{device_ordinal}, world_size_{GetWorldSize()}, rank_{GetRank()} {
|
||||
if (device_ordinal_ < 0) {
|
||||
LOG(FATAL) << "Invalid device ordinal: " << device_ordinal_;
|
||||
}
|
||||
if (communicator_ == nullptr) {
|
||||
LOG(FATAL) << "Communicator cannot be null.";
|
||||
}
|
||||
|
||||
int32_t const rank = communicator_->GetRank();
|
||||
int32_t const world = communicator_->GetWorldSize();
|
||||
|
||||
if (world == 1) {
|
||||
if (world_size_ == 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<uint64_t> uuids(world * kUuidLength, 0);
|
||||
std::vector<uint64_t> uuids(world_size_ * kUuidLength, 0);
|
||||
auto s_uuid = xgboost::common::Span<uint64_t>{uuids.data(), uuids.size()};
|
||||
auto s_this_uuid = s_uuid.subspan(rank * kUuidLength, kUuidLength);
|
||||
auto s_this_uuid = s_uuid.subspan(rank_ * kUuidLength, kUuidLength);
|
||||
GetCudaUUID(s_this_uuid);
|
||||
|
||||
// TODO(rongou): replace this with allgather.
|
||||
communicator_->AllReduce(uuids.data(), uuids.size(), DataType::kUInt64, Operation::kSum);
|
||||
Allreduce(uuids.data(), uuids.size(), DataType::kUInt64, Operation::kSum);
|
||||
|
||||
std::vector<xgboost::common::Span<uint64_t, kUuidLength>> converted(world);
|
||||
std::vector<xgboost::common::Span<uint64_t, kUuidLength>> converted(world_size_);
|
||||
size_t j = 0;
|
||||
for (size_t i = 0; i < uuids.size(); i += kUuidLength) {
|
||||
converted[j] = xgboost::common::Span<uint64_t, kUuidLength>{uuids.data() + i, kUuidLength};
|
||||
@@ -41,18 +34,18 @@ NcclDeviceCommunicator::NcclDeviceCommunicator(int device_ordinal, Communicator
|
||||
auto iter = std::unique(converted.begin(), converted.end());
|
||||
auto n_uniques = std::distance(converted.begin(), iter);
|
||||
|
||||
CHECK_EQ(n_uniques, world)
|
||||
CHECK_EQ(n_uniques, world_size_)
|
||||
<< "Multiple processes within communication group running on same CUDA "
|
||||
<< "device is not supported. " << PrintUUID(s_this_uuid) << "\n";
|
||||
|
||||
nccl_unique_id_ = GetUniqueId();
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
dh::safe_nccl(ncclCommInitRank(&nccl_comm_, world, nccl_unique_id_, rank));
|
||||
dh::safe_nccl(ncclCommInitRank(&nccl_comm_, world_size_, nccl_unique_id_, rank_));
|
||||
dh::safe_cuda(cudaStreamCreate(&cuda_stream_));
|
||||
}
|
||||
|
||||
NcclDeviceCommunicator::~NcclDeviceCommunicator() {
|
||||
if (communicator_->GetWorldSize() == 1) {
|
||||
if (world_size_ == 1) {
|
||||
return;
|
||||
}
|
||||
if (cuda_stream_) {
|
||||
@@ -139,9 +132,8 @@ void RunBitwiseAllreduce(char *out_buffer, char const *device_buffer, Func func,
|
||||
|
||||
void NcclDeviceCommunicator::BitwiseAllReduce(void *send_receive_buffer, std::size_t count,
|
||||
DataType data_type, Operation op) {
|
||||
auto const world_size = communicator_->GetWorldSize();
|
||||
auto const size = count * GetTypeSize(data_type);
|
||||
dh::caching_device_vector<char> buffer(size * world_size);
|
||||
dh::caching_device_vector<char> buffer(size * world_size_);
|
||||
auto *device_buffer = buffer.data().get();
|
||||
|
||||
// First gather data from all the workers.
|
||||
@@ -152,15 +144,15 @@ void NcclDeviceCommunicator::BitwiseAllReduce(void *send_receive_buffer, std::si
|
||||
auto *out_buffer = static_cast<char *>(send_receive_buffer);
|
||||
switch (op) {
|
||||
case Operation::kBitwiseAND:
|
||||
RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_and<char>(), world_size, size,
|
||||
RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_and<char>(), world_size_, size,
|
||||
cuda_stream_);
|
||||
break;
|
||||
case Operation::kBitwiseOR:
|
||||
RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_or<char>(), world_size, size,
|
||||
RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_or<char>(), world_size_, size,
|
||||
cuda_stream_);
|
||||
break;
|
||||
case Operation::kBitwiseXOR:
|
||||
RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_xor<char>(), world_size, size,
|
||||
RunBitwiseAllreduce(out_buffer, device_buffer, thrust::bit_xor<char>(), world_size_, size,
|
||||
cuda_stream_);
|
||||
break;
|
||||
default:
|
||||
@@ -170,7 +162,7 @@ void NcclDeviceCommunicator::BitwiseAllReduce(void *send_receive_buffer, std::si
|
||||
|
||||
void NcclDeviceCommunicator::AllReduce(void *send_receive_buffer, std::size_t count,
|
||||
DataType data_type, Operation op) {
|
||||
if (communicator_->GetWorldSize() == 1) {
|
||||
if (world_size_ == 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -189,24 +181,22 @@ void NcclDeviceCommunicator::AllReduce(void *send_receive_buffer, std::size_t co
|
||||
void NcclDeviceCommunicator::AllGatherV(void const *send_buffer, size_t length_bytes,
|
||||
std::vector<std::size_t> *segments,
|
||||
dh::caching_device_vector<char> *receive_buffer) {
|
||||
if (communicator_->GetWorldSize() == 1) {
|
||||
if (world_size_ == 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
int const world_size = communicator_->GetWorldSize();
|
||||
int const rank = communicator_->GetRank();
|
||||
|
||||
segments->clear();
|
||||
segments->resize(world_size, 0);
|
||||
segments->at(rank) = length_bytes;
|
||||
communicator_->AllReduce(segments->data(), segments->size(), DataType::kUInt64, Operation::kMax);
|
||||
segments->resize(world_size_, 0);
|
||||
segments->at(rank_) = length_bytes;
|
||||
Allreduce(segments->data(), segments->size(), DataType::kUInt64, Operation::kMax);
|
||||
auto total_bytes = std::accumulate(segments->cbegin(), segments->cend(), 0UL);
|
||||
receive_buffer->resize(total_bytes);
|
||||
|
||||
size_t offset = 0;
|
||||
dh::safe_nccl(ncclGroupStart());
|
||||
for (int32_t i = 0; i < world_size; ++i) {
|
||||
for (int32_t i = 0; i < world_size_; ++i) {
|
||||
size_t as_bytes = segments->at(i);
|
||||
dh::safe_nccl(ncclBroadcast(send_buffer, receive_buffer->data().get() + offset, as_bytes,
|
||||
ncclChar, i, nccl_comm_, cuda_stream_));
|
||||
@@ -216,7 +206,7 @@ void NcclDeviceCommunicator::AllGatherV(void const *send_buffer, size_t length_b
|
||||
}
|
||||
|
||||
void NcclDeviceCommunicator::Synchronize() {
|
||||
if (communicator_->GetWorldSize() == 1) {
|
||||
if (world_size_ == 1) {
|
||||
return;
|
||||
}
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
|
||||
@@ -12,7 +12,7 @@ namespace collective {
|
||||
|
||||
class NcclDeviceCommunicator : public DeviceCommunicator {
|
||||
public:
|
||||
NcclDeviceCommunicator(int device_ordinal, Communicator *communicator);
|
||||
explicit NcclDeviceCommunicator(int device_ordinal);
|
||||
~NcclDeviceCommunicator() override;
|
||||
void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
|
||||
Operation op) override;
|
||||
@@ -49,11 +49,10 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
|
||||
ncclUniqueId GetUniqueId() {
|
||||
static const int kRootRank = 0;
|
||||
ncclUniqueId id;
|
||||
if (communicator_->GetRank() == kRootRank) {
|
||||
if (rank_ == kRootRank) {
|
||||
dh::safe_nccl(ncclGetUniqueId(&id));
|
||||
}
|
||||
communicator_->Broadcast(static_cast<void *>(&id), sizeof(ncclUniqueId),
|
||||
static_cast<int>(kRootRank));
|
||||
Broadcast(static_cast<void *>(&id), sizeof(ncclUniqueId), static_cast<int>(kRootRank));
|
||||
return id;
|
||||
}
|
||||
|
||||
@@ -61,7 +60,8 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
|
||||
Operation op);
|
||||
|
||||
int const device_ordinal_;
|
||||
Communicator *communicator_;
|
||||
int const world_size_;
|
||||
int const rank_;
|
||||
ncclComm_t nccl_comm_{};
|
||||
cudaStream_t cuda_stream_{};
|
||||
ncclUniqueId nccl_unique_id_{};
|
||||
|
||||
Reference in New Issue
Block a user