Fix device communicator dependency (#9346)
This commit is contained in:
@@ -12,7 +12,7 @@ namespace collective {
|
||||
|
||||
class NcclDeviceCommunicator : public DeviceCommunicator {
|
||||
public:
|
||||
NcclDeviceCommunicator(int device_ordinal, Communicator *communicator);
|
||||
explicit NcclDeviceCommunicator(int device_ordinal);
|
||||
~NcclDeviceCommunicator() override;
|
||||
void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
|
||||
Operation op) override;
|
||||
@@ -49,11 +49,10 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
|
||||
ncclUniqueId GetUniqueId() {
|
||||
static const int kRootRank = 0;
|
||||
ncclUniqueId id;
|
||||
if (communicator_->GetRank() == kRootRank) {
|
||||
if (rank_ == kRootRank) {
|
||||
dh::safe_nccl(ncclGetUniqueId(&id));
|
||||
}
|
||||
communicator_->Broadcast(static_cast<void *>(&id), sizeof(ncclUniqueId),
|
||||
static_cast<int>(kRootRank));
|
||||
Broadcast(static_cast<void *>(&id), sizeof(ncclUniqueId), static_cast<int>(kRootRank));
|
||||
return id;
|
||||
}
|
||||
|
||||
@@ -61,7 +60,8 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
|
||||
Operation op);
|
||||
|
||||
int const device_ordinal_;
|
||||
Communicator *communicator_;
|
||||
int const world_size_;
|
||||
int const rank_;
|
||||
ncclComm_t nccl_comm_{};
|
||||
cudaStream_t cuda_stream_{};
|
||||
ncclUniqueId nccl_unique_id_{};
|
||||
|
||||
Reference in New Issue
Block a user