enable RCCL

This commit is contained in:
Hui Liu 2023-10-30 14:05:04 -07:00
parent d7f1235b7d
commit b6b5218245
5 changed files with 11 additions and 5 deletions

View File

@ -87,7 +87,7 @@ namespace xgboost::collective {
} }
} }
#if !defined(XGBOOST_USE_NCCL) #if !defined(XGBOOST_USE_NCCL) && !defined(XGBOOST_USE_RCCL)
Coll* Coll::MakeCUDAVar() { Coll* Coll::MakeCUDAVar() {
LOG(FATAL) << "NCCL is required for device communication."; LOG(FATAL) << "NCCL is required for device communication.";
return nullptr; return nullptr;

View File

@ -1,7 +1,7 @@
/** /**
* Copyright 2023, XGBoost Contributors * Copyright 2023, XGBoost Contributors
*/ */
#if defined(XGBOOST_USE_NCCL) #if defined(XGBOOST_USE_NCCL) || defined(XGBOOST_USE_RCCL)
#include <cstdint> // for int8_t, int64_t #include <cstdint> // for int8_t, int64_t
#include "../common/cuda_context.cuh" #include "../common/cuda_context.cuh"

View File

@ -1,7 +1,7 @@
/** /**
* Copyright 2023, XGBoost Contributors * Copyright 2023, XGBoost Contributors
*/ */
#if defined(XGBOOST_USE_NCCL) #if defined(XGBOOST_USE_NCCL) || defined(XGBOOST_USE_RCCL)
#include <algorithm> // for sort #include <algorithm> // for sort
#include <cstddef> // for size_t #include <cstddef> // for size_t
#include <cstdint> // for uint64_t, int8_t #include <cstdint> // for uint64_t, int8_t

View File

@ -3,7 +3,7 @@
*/ */
#pragma once #pragma once
#ifdef XGBOOST_USE_NCCL #if defined(XGBOOST_USE_NCCL) || defined(XGBOOST_USE_RCCL)
#include "nccl.h" #include "nccl.h"
#endif // XGBOOST_USE_NCCL #endif // XGBOOST_USE_NCCL
#include "../common/device_helpers.cuh" #include "../common/device_helpers.cuh"

View File

@ -1092,7 +1092,13 @@ class CUDAStreamView {
operator hipStream_t() const { // NOLINT operator hipStream_t() const { // NOLINT
return stream_; return stream_;
} }
void Sync() { dh::safe_cuda(hipStreamSynchronize(stream_)); } hipError_t Sync(bool error = true) {
if (error) {
dh::safe_cuda(hipStreamSynchronize(stream_));
return hipSuccess;
}
return hipStreamSynchronize(stream_);
}
}; };
inline void CUDAEvent::Record(CUDAStreamView stream) { // NOLINT inline void CUDAEvent::Record(CUDAStreamView stream) { // NOLINT