From b6b5218245a11c6f6c804608cc89318d985a4329 Mon Sep 17 00:00:00 2001 From: Hui Liu <96135754+amdsc21@users.noreply.github.com> Date: Mon, 30 Oct 2023 14:05:04 -0700 Subject: [PATCH] enable RCCL --- src/collective/coll.cc | 2 +- src/collective/coll.cu | 2 +- src/collective/comm.cu | 2 +- src/collective/comm.cuh | 2 +- src/common/device_helpers.hip.h | 8 +++++++- 5 files changed, 11 insertions(+), 5 deletions(-) diff --git a/src/collective/coll.cc b/src/collective/coll.cc index 598e6129d..d977f5e58 100644 --- a/src/collective/coll.cc +++ b/src/collective/coll.cc @@ -87,7 +87,7 @@ namespace xgboost::collective { } } -#if !defined(XGBOOST_USE_NCCL) +#if !defined(XGBOOST_USE_NCCL) && !defined(XGBOOST_USE_RCCL) Coll* Coll::MakeCUDAVar() { LOG(FATAL) << "NCCL is required for device communication."; return nullptr; diff --git a/src/collective/coll.cu b/src/collective/coll.cu index bac9fb094..9802dc096 100644 --- a/src/collective/coll.cu +++ b/src/collective/coll.cu @@ -1,7 +1,7 @@ /** * Copyright 2023, XGBoost Contributors */ -#if defined(XGBOOST_USE_NCCL) +#if defined(XGBOOST_USE_NCCL) || defined(XGBOOST_USE_RCCL) #include // for int8_t, int64_t #include "../common/cuda_context.cuh" diff --git a/src/collective/comm.cu b/src/collective/comm.cu index 31a06e124..2fff9e71b 100644 --- a/src/collective/comm.cu +++ b/src/collective/comm.cu @@ -1,7 +1,7 @@ /** * Copyright 2023, XGBoost Contributors */ -#if defined(XGBOOST_USE_NCCL) +#if defined(XGBOOST_USE_NCCL) || defined(XGBOOST_USE_RCCL) #include // for sort #include // for size_t #include // for uint64_t, int8_t diff --git a/src/collective/comm.cuh b/src/collective/comm.cuh index ea15c50f3..559e4ad01 100644 --- a/src/collective/comm.cuh +++ b/src/collective/comm.cuh @@ -3,7 +3,7 @@ */ #pragma once -#ifdef XGBOOST_USE_NCCL +#if defined(XGBOOST_USE_NCCL) || defined(XGBOOST_USE_RCCL) #include "nccl.h" #endif // XGBOOST_USE_NCCL #include "../common/device_helpers.cuh" diff --git a/src/common/device_helpers.hip.h b/src/common/device_helpers.hip.h index 710e61eeb..9f55d6ef8 100644 --- a/src/common/device_helpers.hip.h +++ b/src/common/device_helpers.hip.h @@ -1092,7 +1092,13 @@ class CUDAStreamView { operator hipStream_t() const { // NOLINT return stream_; } - void Sync() { dh::safe_cuda(hipStreamSynchronize(stream_)); } + hipError_t Sync(bool error = true) { + if (error) { + dh::safe_cuda(hipStreamSynchronize(stream_)); + return hipSuccess; + } + return hipStreamSynchronize(stream_); + } }; inline void CUDAEvent::Record(CUDAStreamView stream) { // NOLINT