This commit is contained in:
Hui Liu
2023-10-23 22:29:48 -07:00
parent 558352afc9
commit 79319dfd4d
8 changed files with 115 additions and 100 deletions

View File

@@ -36,21 +36,21 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
private:
static constexpr std::size_t kUuidLength =
#if defined(XGBOOST_USE_HIP)
sizeof(hipUUID) / sizeof(uint64_t);
#elif defined(XGBOOST_USE_CUDA)
#if defined(XGBOOST_USE_CUDA)
sizeof(std::declval<cudaDeviceProp>().uuid) / sizeof(uint64_t);
#elif defined(XGBOOST_USE_HIP)
sizeof(hipUUID) / sizeof(uint64_t);
#endif
void GetCudaUUID(xgboost::common::Span<uint64_t, kUuidLength> const &uuid) const {
#if defined(XGBOOST_USE_HIP)
hipUUID id;
hipDeviceGetUuid(&id, device_ordinal_);
std::memcpy(uuid.data(), static_cast<void *>(&id), sizeof(id));
#elif defined(XGBOOST_USE_CUDA)
#if defined(XGBOOST_USE_CUDA)
cudaDeviceProp prob{};
dh::safe_cuda(cudaGetDeviceProperties(&prob, device_ordinal_));
std::memcpy(uuid.data(), static_cast<void *>(&(prob.uuid)), sizeof(prob.uuid));
#elif defined(XGBOOST_USE_HIP)
hipUUID id;
hipDeviceGetUuid(&id, device_ordinal_);
std::memcpy(uuid.data(), static_cast<void *>(&id), sizeof(id));
#endif
}