From 0fc1f640a95faa2a28cd323ca449f3d45afd58b7 Mon Sep 17 00:00:00 2001 From: amdsc21 <96135754+amdsc21@users.noreply.github.com> Date: Wed, 8 Mar 2023 06:18:13 +0100 Subject: [PATCH] enable rocm, fix nccl_device_communicator.cuh --- src/collective/nccl_device_communicator.cuh | 41 +++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/src/collective/nccl_device_communicator.cuh b/src/collective/nccl_device_communicator.cuh index e14a2e446..05e2155f5 100644 --- a/src/collective/nccl_device_communicator.cuh +++ b/src/collective/nccl_device_communicator.cuh @@ -52,7 +52,12 @@ class NcclDeviceCommunicator : public DeviceCommunicator { nccl_unique_id_ = GetUniqueId(); dh::safe_nccl(ncclCommInitRank(&nccl_comm_, world, nccl_unique_id_, rank)); + +#if defined(XGBOOST_USE_HIP) + dh::safe_cuda(hipStreamCreate(&cuda_stream_)); +#else dh::safe_cuda(cudaStreamCreate(&cuda_stream_)); +#endif } ~NcclDeviceCommunicator() override { @@ -60,7 +65,11 @@ class NcclDeviceCommunicator : public DeviceCommunicator { return; } if (cuda_stream_) { +#if defined(XGBOOST_USE_HIP) + dh::safe_cuda(hipStreamDestroy(cuda_stream_)); +#else dh::safe_cuda(cudaStreamDestroy(cuda_stream_)); +#endif } if (nccl_comm_) { dh::safe_nccl(ncclCommDestroy(nccl_comm_)); @@ -94,7 +103,12 @@ class NcclDeviceCommunicator : public DeviceCommunicator { return; } +#if defined(XGBOOST_USE_HIP) + dh::safe_cuda(hipSetDevice(device_ordinal_)); +#else dh::safe_cuda(cudaSetDevice(device_ordinal_)); +#endif + int const world_size = communicator_->GetWorldSize(); int const rank = communicator_->GetRank(); @@ -121,17 +135,33 @@ class NcclDeviceCommunicator : public DeviceCommunicator { if (communicator_->GetWorldSize() == 1) { return; } + +#if defined(XGBOOST_USE_HIP) + dh::safe_cuda(hipSetDevice(device_ordinal_)); + dh::safe_cuda(hipStreamSynchronize(cuda_stream_)); +#else dh::safe_cuda(cudaSetDevice(device_ordinal_)); dh::safe_cuda(cudaStreamSynchronize(cuda_stream_)); +#endif } private: static constexpr std::size_t kUuidLength = +#if defined(XGBOOST_USE_HIP) + sizeof(std::declval().uuid) / sizeof(uint64_t); +#else sizeof(std::declval().uuid) / sizeof(uint64_t); +#endif void GetCudaUUID(xgboost::common::Span const &uuid) const { +#if defined(XGBOOST_USE_HIP) + hipDeviceProp prob{}; + dh::safe_cuda(hipGetDeviceProperties(&prob, device_ordinal_)); +#else cudaDeviceProp prob{}; dh::safe_cuda(cudaGetDeviceProperties(&prob, device_ordinal_)); +#endif + std::memcpy(uuid.data(), static_cast(&(prob.uuid)), sizeof(prob.uuid)); } @@ -168,7 +198,12 @@ class NcclDeviceCommunicator : public DeviceCommunicator { return; } +#if defined(XGBOOST_USE_HIP) + dh::safe_cuda(hipSetDevice(device_ordinal_)); +#else dh::safe_cuda(cudaSetDevice(device_ordinal_)); +#endif + dh::safe_nccl(ncclAllReduce(send_receive_buffer, send_receive_buffer, count, data_type, ncclSum, nccl_comm_, cuda_stream_)); allreduce_bytes_ += count * sizeof(T); @@ -178,7 +213,13 @@ class NcclDeviceCommunicator : public DeviceCommunicator { int const device_ordinal_; Communicator *communicator_; ncclComm_t nccl_comm_{}; + +#if defined(XGBOOST_USE_HIP) + hipStream_t cuda_stream_{}; +#else cudaStream_t cuda_stream_{}; +#endif + ncclUniqueId nccl_unique_id_{}; size_t allreduce_bytes_{0}; // Keep statistics of the number of bytes communicated. size_t allreduce_calls_{0}; // Keep statistics of the number of reduce calls.