enable rocm, fix nccl_device_communicator.cuh
This commit is contained in:
parent
762fd9028d
commit
0fc1f640a9
@ -52,7 +52,12 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
|
|||||||
|
|
||||||
nccl_unique_id_ = GetUniqueId();
|
nccl_unique_id_ = GetUniqueId();
|
||||||
dh::safe_nccl(ncclCommInitRank(&nccl_comm_, world, nccl_unique_id_, rank));
|
dh::safe_nccl(ncclCommInitRank(&nccl_comm_, world, nccl_unique_id_, rank));
|
||||||
|
|
||||||
|
#if defined(XGBOOST_USE_HIP)
|
||||||
|
dh::safe_cuda(hipStreamCreate(&cuda_stream_));
|
||||||
|
#else
|
||||||
dh::safe_cuda(cudaStreamCreate(&cuda_stream_));
|
dh::safe_cuda(cudaStreamCreate(&cuda_stream_));
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
~NcclDeviceCommunicator() override {
|
~NcclDeviceCommunicator() override {
|
||||||
@ -60,7 +65,11 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (cuda_stream_) {
|
if (cuda_stream_) {
|
||||||
|
#if defined(XGBOOST_USE_HIP)
|
||||||
|
dh::safe_cuda(hipStreamDestroy(cuda_stream_));
|
||||||
|
#else
|
||||||
dh::safe_cuda(cudaStreamDestroy(cuda_stream_));
|
dh::safe_cuda(cudaStreamDestroy(cuda_stream_));
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
if (nccl_comm_) {
|
if (nccl_comm_) {
|
||||||
dh::safe_nccl(ncclCommDestroy(nccl_comm_));
|
dh::safe_nccl(ncclCommDestroy(nccl_comm_));
|
||||||
@ -94,7 +103,12 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if defined(XGBOOST_USE_HIP)
|
||||||
|
dh::safe_cuda(hipSetDevice(device_ordinal_));
|
||||||
|
#else
|
||||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||||
|
#endif
|
||||||
|
|
||||||
int const world_size = communicator_->GetWorldSize();
|
int const world_size = communicator_->GetWorldSize();
|
||||||
int const rank = communicator_->GetRank();
|
int const rank = communicator_->GetRank();
|
||||||
|
|
||||||
@ -121,17 +135,33 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
|
|||||||
if (communicator_->GetWorldSize() == 1) {
|
if (communicator_->GetWorldSize() == 1) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if defined(XGBOOST_USE_HIP)
|
||||||
|
dh::safe_cuda(hipSetDevice(device_ordinal_));
|
||||||
|
dh::safe_cuda(hipStreamSynchronize(cuda_stream_));
|
||||||
|
#else
|
||||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||||
dh::safe_cuda(cudaStreamSynchronize(cuda_stream_));
|
dh::safe_cuda(cudaStreamSynchronize(cuda_stream_));
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
static constexpr std::size_t kUuidLength =
|
static constexpr std::size_t kUuidLength =
|
||||||
|
#if defined(XGBOOST_USE_HIP)
|
||||||
|
sizeof(std::declval<hipDeviceProp>().uuid) / sizeof(uint64_t);
|
||||||
|
#else
|
||||||
sizeof(std::declval<cudaDeviceProp>().uuid) / sizeof(uint64_t);
|
sizeof(std::declval<cudaDeviceProp>().uuid) / sizeof(uint64_t);
|
||||||
|
#endif
|
||||||
|
|
||||||
void GetCudaUUID(xgboost::common::Span<uint64_t, kUuidLength> const &uuid) const {
|
void GetCudaUUID(xgboost::common::Span<uint64_t, kUuidLength> const &uuid) const {
|
||||||
|
#if defined(XGBOOST_USE_HIP)
|
||||||
|
hipDeviceProp prob{};
|
||||||
|
dh::safe_cuda(hipGetDeviceProperties(&prob, device_ordinal_));
|
||||||
|
#else
|
||||||
cudaDeviceProp prob{};
|
cudaDeviceProp prob{};
|
||||||
dh::safe_cuda(cudaGetDeviceProperties(&prob, device_ordinal_));
|
dh::safe_cuda(cudaGetDeviceProperties(&prob, device_ordinal_));
|
||||||
|
#endif
|
||||||
|
|
||||||
std::memcpy(uuid.data(), static_cast<void *>(&(prob.uuid)), sizeof(prob.uuid));
|
std::memcpy(uuid.data(), static_cast<void *>(&(prob.uuid)), sizeof(prob.uuid));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -168,7 +198,12 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if defined(XGBOOST_USE_HIP)
|
||||||
|
dh::safe_cuda(hipSetDevice(device_ordinal_));
|
||||||
|
#else
|
||||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||||
|
#endif
|
||||||
|
|
||||||
dh::safe_nccl(ncclAllReduce(send_receive_buffer, send_receive_buffer, count, data_type, ncclSum,
|
dh::safe_nccl(ncclAllReduce(send_receive_buffer, send_receive_buffer, count, data_type, ncclSum,
|
||||||
nccl_comm_, cuda_stream_));
|
nccl_comm_, cuda_stream_));
|
||||||
allreduce_bytes_ += count * sizeof(T);
|
allreduce_bytes_ += count * sizeof(T);
|
||||||
@ -178,7 +213,13 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
|
|||||||
int const device_ordinal_;
|
int const device_ordinal_;
|
||||||
Communicator *communicator_;
|
Communicator *communicator_;
|
||||||
ncclComm_t nccl_comm_{};
|
ncclComm_t nccl_comm_{};
|
||||||
|
|
||||||
|
#if defined(XGBOOST_USE_HIP)
|
||||||
|
hipStream_t cuda_stream_{};
|
||||||
|
#else
|
||||||
cudaStream_t cuda_stream_{};
|
cudaStream_t cuda_stream_{};
|
||||||
|
#endif
|
||||||
|
|
||||||
ncclUniqueId nccl_unique_id_{};
|
ncclUniqueId nccl_unique_id_{};
|
||||||
size_t allreduce_bytes_{0}; // Keep statistics of the number of bytes communicated.
|
size_t allreduce_bytes_{0}; // Keep statistics of the number of bytes communicated.
|
||||||
size_t allreduce_calls_{0}; // Keep statistics of the number of reduce calls.
|
size_t allreduce_calls_{0}; // Keep statistics of the number of reduce calls.
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user