enable rocm, fix cuda_context.cuh
This commit is contained in:
parent
fa92aa56ee
commit
327f1494f1
@ -17,11 +17,21 @@ struct CUDAContext {
|
|||||||
/**
|
/**
|
||||||
* \brief Caching thrust policy.
|
* \brief Caching thrust policy.
|
||||||
*/
|
*/
|
||||||
|
#if defined(XGBOOST_USE_HIP)
|
||||||
|
auto CTP() const { return thrust::hip::par(caching_alloc_).on(dh::DefaultStream()); }
|
||||||
|
#else
|
||||||
auto CTP() const { return thrust::cuda::par(caching_alloc_).on(dh::DefaultStream()); }
|
auto CTP() const { return thrust::cuda::par(caching_alloc_).on(dh::DefaultStream()); }
|
||||||
|
#endif
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* \brief Thrust policy without caching allocator.
|
* \brief Thrust policy without caching allocator.
|
||||||
*/
|
*/
|
||||||
|
#if defined(XGBOOST_USE_HIP)
|
||||||
|
auto TP() const { return thrust::hip::par(alloc_).on(dh::DefaultStream()); }
|
||||||
|
#else
|
||||||
auto TP() const { return thrust::cuda::par(alloc_).on(dh::DefaultStream()); }
|
auto TP() const { return thrust::cuda::par(alloc_).on(dh::DefaultStream()); }
|
||||||
|
#endif
|
||||||
|
|
||||||
auto Stream() const { return dh::DefaultStream(); }
|
auto Stream() const { return dh::DefaultStream(); }
|
||||||
};
|
};
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user