enable rocm, fix nccl_device_communicator.cuh

This commit is contained in:
amdsc21 2023-03-08 06:18:13 +01:00
parent 762fd9028d
commit 0fc1f640a9

View File

@ -52,7 +52,12 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
nccl_unique_id_ = GetUniqueId(); nccl_unique_id_ = GetUniqueId();
dh::safe_nccl(ncclCommInitRank(&nccl_comm_, world, nccl_unique_id_, rank)); dh::safe_nccl(ncclCommInitRank(&nccl_comm_, world, nccl_unique_id_, rank));
#if defined(XGBOOST_USE_HIP)
dh::safe_cuda(hipStreamCreate(&cuda_stream_));
#else
dh::safe_cuda(cudaStreamCreate(&cuda_stream_)); dh::safe_cuda(cudaStreamCreate(&cuda_stream_));
#endif
} }
~NcclDeviceCommunicator() override { ~NcclDeviceCommunicator() override {
@ -60,7 +65,11 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
return; return;
} }
if (cuda_stream_) { if (cuda_stream_) {
#if defined(XGBOOST_USE_HIP)
dh::safe_cuda(hipStreamDestroy(cuda_stream_));
#else
dh::safe_cuda(cudaStreamDestroy(cuda_stream_)); dh::safe_cuda(cudaStreamDestroy(cuda_stream_));
#endif
} }
if (nccl_comm_) { if (nccl_comm_) {
dh::safe_nccl(ncclCommDestroy(nccl_comm_)); dh::safe_nccl(ncclCommDestroy(nccl_comm_));
@ -94,7 +103,12 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
return; return;
} }
#if defined(XGBOOST_USE_HIP)
dh::safe_cuda(hipSetDevice(device_ordinal_));
#else
dh::safe_cuda(cudaSetDevice(device_ordinal_)); dh::safe_cuda(cudaSetDevice(device_ordinal_));
#endif
int const world_size = communicator_->GetWorldSize(); int const world_size = communicator_->GetWorldSize();
int const rank = communicator_->GetRank(); int const rank = communicator_->GetRank();
@ -121,17 +135,33 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
if (communicator_->GetWorldSize() == 1) { if (communicator_->GetWorldSize() == 1) {
return; return;
} }
#if defined(XGBOOST_USE_HIP)
dh::safe_cuda(hipSetDevice(device_ordinal_));
dh::safe_cuda(hipStreamSynchronize(cuda_stream_));
#else
dh::safe_cuda(cudaSetDevice(device_ordinal_)); dh::safe_cuda(cudaSetDevice(device_ordinal_));
dh::safe_cuda(cudaStreamSynchronize(cuda_stream_)); dh::safe_cuda(cudaStreamSynchronize(cuda_stream_));
#endif
} }
private: private:
static constexpr std::size_t kUuidLength = static constexpr std::size_t kUuidLength =
#if defined(XGBOOST_USE_HIP)
sizeof(std::declval<hipDeviceProp>().uuid) / sizeof(uint64_t);
#else
sizeof(std::declval<cudaDeviceProp>().uuid) / sizeof(uint64_t); sizeof(std::declval<cudaDeviceProp>().uuid) / sizeof(uint64_t);
#endif
void GetCudaUUID(xgboost::common::Span<uint64_t, kUuidLength> const &uuid) const { void GetCudaUUID(xgboost::common::Span<uint64_t, kUuidLength> const &uuid) const {
#if defined(XGBOOST_USE_HIP)
hipDeviceProp prob{};
dh::safe_cuda(hipGetDeviceProperties(&prob, device_ordinal_));
#else
cudaDeviceProp prob{}; cudaDeviceProp prob{};
dh::safe_cuda(cudaGetDeviceProperties(&prob, device_ordinal_)); dh::safe_cuda(cudaGetDeviceProperties(&prob, device_ordinal_));
#endif
std::memcpy(uuid.data(), static_cast<void *>(&(prob.uuid)), sizeof(prob.uuid)); std::memcpy(uuid.data(), static_cast<void *>(&(prob.uuid)), sizeof(prob.uuid));
} }
@ -168,7 +198,12 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
return; return;
} }
#if defined(XGBOOST_USE_HIP)
dh::safe_cuda(hipSetDevice(device_ordinal_));
#else
dh::safe_cuda(cudaSetDevice(device_ordinal_)); dh::safe_cuda(cudaSetDevice(device_ordinal_));
#endif
dh::safe_nccl(ncclAllReduce(send_receive_buffer, send_receive_buffer, count, data_type, ncclSum, dh::safe_nccl(ncclAllReduce(send_receive_buffer, send_receive_buffer, count, data_type, ncclSum,
nccl_comm_, cuda_stream_)); nccl_comm_, cuda_stream_));
allreduce_bytes_ += count * sizeof(T); allreduce_bytes_ += count * sizeof(T);
@ -178,7 +213,13 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
int const device_ordinal_; int const device_ordinal_;
Communicator *communicator_; Communicator *communicator_;
ncclComm_t nccl_comm_{}; ncclComm_t nccl_comm_{};
#if defined(XGBOOST_USE_HIP)
hipStream_t cuda_stream_{};
#else
cudaStream_t cuda_stream_{}; cudaStream_t cuda_stream_{};
#endif
ncclUniqueId nccl_unique_id_{}; ncclUniqueId nccl_unique_id_{};
size_t allreduce_bytes_{0}; // Keep statistics of the number of bytes communicated. size_t allreduce_bytes_{0}; // Keep statistics of the number of bytes communicated.
size_t allreduce_calls_{0}; // Keep statistics of the number of reduce calls. size_t allreduce_calls_{0}; // Keep statistics of the number of reduce calls.