diff --git a/src/collective/coll.cc b/src/collective/coll.cc index 598e6129d..d977f5e58 100644 --- a/src/collective/coll.cc +++ b/src/collective/coll.cc @@ -87,7 +87,7 @@ namespace xgboost::collective { } } -#if !defined(XGBOOST_USE_NCCL) +#if !defined(XGBOOST_USE_NCCL) && !defined(XGBOOST_USE_RCCL) Coll* Coll::MakeCUDAVar() { LOG(FATAL) << "NCCL is required for device communication."; return nullptr; diff --git a/src/collective/coll.cu b/src/collective/coll.cu index bac9fb094..9802dc096 100644 --- a/src/collective/coll.cu +++ b/src/collective/coll.cu @@ -1,7 +1,7 @@ /** * Copyright 2023, XGBoost Contributors */ -#if defined(XGBOOST_USE_NCCL) +#if defined(XGBOOST_USE_NCCL) || defined(XGBOOST_USE_RCCL) #include // for int8_t, int64_t #include "../common/cuda_context.cuh" diff --git a/src/collective/comm.cu b/src/collective/comm.cu index 31a06e124..2fff9e71b 100644 --- a/src/collective/comm.cu +++ b/src/collective/comm.cu @@ -1,7 +1,7 @@ /** * Copyright 2023, XGBoost Contributors */ -#if defined(XGBOOST_USE_NCCL) +#if defined(XGBOOST_USE_NCCL) || defined(XGBOOST_USE_RCCL) #include // for sort #include // for size_t #include // for uint64_t, int8_t diff --git a/src/collective/comm.cuh b/src/collective/comm.cuh index ea15c50f3..559e4ad01 100644 --- a/src/collective/comm.cuh +++ b/src/collective/comm.cuh @@ -3,7 +3,7 @@ */ #pragma once -#ifdef XGBOOST_USE_NCCL +#if defined(XGBOOST_USE_NCCL) || defined(XGBOOST_USE_RCCL) #include "nccl.h" #endif // XGBOOST_USE_NCCL #include "../common/device_helpers.cuh" diff --git a/src/common/device_helpers.hip.h b/src/common/device_helpers.hip.h index 710e61eeb..9f55d6ef8 100644 --- a/src/common/device_helpers.hip.h +++ b/src/common/device_helpers.hip.h @@ -1092,7 +1092,13 @@ class CUDAStreamView { operator hipStream_t() const { // NOLINT return stream_; } - void Sync() { dh::safe_cuda(hipStreamSynchronize(stream_)); } + hipError_t Sync(bool error = true) { + if (error) { + dh::safe_cuda(hipStreamSynchronize(stream_)); + return hipSuccess; + } + return hipStreamSynchronize(stream_); + } }; inline void CUDAEvent::Record(CUDAStreamView stream) { // NOLINT