Skip optional synchronization in thrust. (#9212)

This commit is contained in:
Jiaming Yuan 2023-05-30 17:23:09 +08:00 committed by GitHub
parent ddec0f378c
commit ae7450ce54
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2022 by XGBoost Contributors * Copyright 2022-2023, XGBoost Contributors
*/ */
#ifndef XGBOOST_COMMON_CUDA_CONTEXT_CUH_ #ifndef XGBOOST_COMMON_CUDA_CONTEXT_CUH_
#define XGBOOST_COMMON_CUDA_CONTEXT_CUH_ #define XGBOOST_COMMON_CUDA_CONTEXT_CUH_
@ -17,11 +17,23 @@ struct CUDAContext {
/** /**
* \brief Caching thrust policy. * \brief Caching thrust policy.
*/ */
auto CTP() const { return thrust::cuda::par(caching_alloc_).on(dh::DefaultStream()); } auto CTP() const {
#if THRUST_MAJOR_VERSION >= 2
return thrust::cuda::par_nosync(caching_alloc_).on(dh::DefaultStream());
#else
return thrust::cuda::par(caching_alloc_).on(dh::DefaultStream());
#endif // THRUST_MAJOR_VERSION >= 2
}
/** /**
* \brief Thrust policy without caching allocator. * \brief Thrust policy without caching allocator.
*/ */
auto TP() const { return thrust::cuda::par(alloc_).on(dh::DefaultStream()); } auto TP() const {
#if THRUST_MAJOR_VERSION >= 2
return thrust::cuda::par_nosync(alloc_).on(dh::DefaultStream());
#else
return thrust::cuda::par(alloc_).on(dh::DefaultStream());
#endif // THRUST_MAJOR_VERSION >= 2
}
auto Stream() const { return dh::DefaultStream(); } auto Stream() const { return dh::DefaultStream(); }
}; };
} // namespace xgboost } // namespace xgboost