enable rocm, fix cuda_context.cuh
This commit is contained in:
parent
fa92aa56ee
commit
327f1494f1
@ -17,11 +17,21 @@ struct CUDAContext {
|
||||
/**
|
||||
* \brief Caching thrust policy.
|
||||
*/
|
||||
#if defined(XGBOOST_USE_HIP)
|
||||
auto CTP() const { return thrust::hip::par(caching_alloc_).on(dh::DefaultStream()); }
|
||||
#else
|
||||
auto CTP() const { return thrust::cuda::par(caching_alloc_).on(dh::DefaultStream()); }
|
||||
#endif
|
||||
|
||||
/**
|
||||
* \brief Thrust policy without caching allocator.
|
||||
*/
|
||||
#if defined(XGBOOST_USE_HIP)
|
||||
auto TP() const { return thrust::hip::par(alloc_).on(dh::DefaultStream()); }
|
||||
#else
|
||||
auto TP() const { return thrust::cuda::par(alloc_).on(dh::DefaultStream()); }
|
||||
#endif
|
||||
|
||||
auto Stream() const { return dh::DefaultStream(); }
|
||||
};
|
||||
} // namespace xgboost
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user