Add an API guard to prevent global variables being changed. (#6891)
This commit is contained in:
@@ -12,7 +12,6 @@
|
||||
|
||||
namespace dh {
|
||||
|
||||
#if __CUDACC_VER_MAJOR__ > 9
|
||||
constexpr std::size_t kUuidLength =
|
||||
sizeof(std::declval<cudaDeviceProp>().uuid) / sizeof(uint64_t);
|
||||
|
||||
@@ -31,16 +30,13 @@ std::string PrintUUID(xgboost::common::Span<uint64_t, kUuidLength> uuid) {
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
#endif // __CUDACC_VER_MAJOR__ > 9
|
||||
|
||||
void AllReducer::Init(int _device_ordinal) {
|
||||
#ifdef XGBOOST_USE_NCCL
|
||||
LOG(DEBUG) << "Running nccl init on: " << __CUDACC_VER_MAJOR__ << "." << __CUDACC_VER_MINOR__;
|
||||
|
||||
device_ordinal_ = _device_ordinal;
|
||||
int32_t const rank = rabit::GetRank();
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
|
||||
#if __CUDACC_VER_MAJOR__ > 9
|
||||
int32_t const rank = rabit::GetRank();
|
||||
int32_t const world = rabit::GetWorldSize();
|
||||
|
||||
std::vector<uint64_t> uuids(world * kUuidLength, 0);
|
||||
@@ -61,13 +57,13 @@ void AllReducer::Init(int _device_ordinal) {
|
||||
|
||||
auto iter = std::unique(converted.begin(), converted.end());
|
||||
auto n_uniques = std::distance(converted.begin(), iter);
|
||||
|
||||
CHECK_EQ(n_uniques, world)
|
||||
<< "Multiple processes within communication group running on same CUDA "
|
||||
<< "device is not supported";
|
||||
#endif // __CUDACC_VER_MAJOR__ > 9
|
||||
<< "device is not supported. " << PrintUUID(s_this_uuid) << "\n";
|
||||
|
||||
|
||||
id_ = GetUniqueId();
|
||||
dh::safe_cuda(cudaSetDevice(device_ordinal_));
|
||||
dh::safe_nccl(ncclCommInitRank(&comm_, rabit::GetWorldSize(), id_, rank));
|
||||
safe_cuda(cudaStreamCreate(&stream_));
|
||||
initialised_ = true;
|
||||
|
||||
Reference in New Issue
Block a user