enable rocm, fix nccl_device_communicator.cuh
This commit is contained in:
parent
762fd9028d
commit
0fc1f640a9
@ -52,7 +52,12 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
|
||||
|
||||
nccl_unique_id_ = GetUniqueId();
|
||||
dh::safe_nccl(ncclCommInitRank(&nccl_comm_, world, nccl_unique_id_, rank));
|
||||
|
||||
#if defined(XGBOOST_USE_HIP)
|
||||
dh::safe_cuda(hipStreamCreate(&cuda_stream_));
|
||||
#else
|
||||
dh::safe_cuda(cudaStreamCreate(&cuda_stream_));
|
||||
#endif
|
||||
}
|
||||
|
||||
~NcclDeviceCommunicator() override {
|
||||
@ -60,7 +65,11 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
|
||||
return;
|
||||
}
|
||||
if (cuda_stream_) {
|
||||
#if defined(XGBOOST_USE_HIP)
|
||||
dh::safe_cuda(hipStreamDestroy(cuda_stream_));
|
||||
#else
|
||||
dh::safe_cuda(cudaStreamDestroy(cuda_stream_));
|
||||
#endif
|
||||
}
|
||||
if (nccl_comm_) {
|
||||
dh::safe_nccl(ncclCommDestroy(nccl_comm_));
|
||||
@ -94,7 +103,12 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
|
||||
return;
|
||||
}
|
||||
|
||||
#if defined(XGBOOST_USE_HIP)
|
||||
dh::safe_cuda(hipSetDevice(device_ordinal_));
|
||||
#else
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
#endif
|
||||
|
||||
int const world_size = communicator_->GetWorldSize();
|
||||
int const rank = communicator_->GetRank();
|
||||
|
||||
@ -121,17 +135,33 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
|
||||
if (communicator_->GetWorldSize() == 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
#if defined(XGBOOST_USE_HIP)
|
||||
dh::safe_cuda(hipSetDevice(device_ordinal_));
|
||||
dh::safe_cuda(hipStreamSynchronize(cuda_stream_));
|
||||
#else
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
dh::safe_cuda(cudaStreamSynchronize(cuda_stream_));
|
||||
#endif
|
||||
}
|
||||
|
||||
private:
|
||||
static constexpr std::size_t kUuidLength =
|
||||
#if defined(XGBOOST_USE_HIP)
|
||||
sizeof(std::declval<hipDeviceProp>().uuid) / sizeof(uint64_t);
|
||||
#else
|
||||
sizeof(std::declval<cudaDeviceProp>().uuid) / sizeof(uint64_t);
|
||||
#endif
|
||||
|
||||
void GetCudaUUID(xgboost::common::Span<uint64_t, kUuidLength> const &uuid) const {
|
||||
#if defined(XGBOOST_USE_HIP)
|
||||
hipDeviceProp prob{};
|
||||
dh::safe_cuda(hipGetDeviceProperties(&prob, device_ordinal_));
|
||||
#else
|
||||
cudaDeviceProp prob{};
|
||||
dh::safe_cuda(cudaGetDeviceProperties(&prob, device_ordinal_));
|
||||
#endif
|
||||
|
||||
std::memcpy(uuid.data(), static_cast<void *>(&(prob.uuid)), sizeof(prob.uuid));
|
||||
}
|
||||
|
||||
@ -168,7 +198,12 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
|
||||
return;
|
||||
}
|
||||
|
||||
#if defined(XGBOOST_USE_HIP)
|
||||
dh::safe_cuda(hipSetDevice(device_ordinal_));
|
||||
#else
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
#endif
|
||||
|
||||
dh::safe_nccl(ncclAllReduce(send_receive_buffer, send_receive_buffer, count, data_type, ncclSum,
|
||||
nccl_comm_, cuda_stream_));
|
||||
allreduce_bytes_ += count * sizeof(T);
|
||||
@ -178,7 +213,13 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
|
||||
int const device_ordinal_;
|
||||
Communicator *communicator_;
|
||||
ncclComm_t nccl_comm_{};
|
||||
|
||||
#if defined(XGBOOST_USE_HIP)
|
||||
hipStream_t cuda_stream_{};
|
||||
#else
|
||||
cudaStream_t cuda_stream_{};
|
||||
#endif
|
||||
|
||||
ncclUniqueId nccl_unique_id_{};
|
||||
size_t allreduce_bytes_{0}; // Keep statistics of the number of bytes communicated.
|
||||
size_t allreduce_calls_{0}; // Keep statistics of the number of reduce calls.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user