From 427f6c2a1a357816b52019b6ab410351e30f3827 Mon Sep 17 00:00:00 2001 From: amdsc21 <96135754+amdsc21@users.noreply.github.com> Date: Wed, 8 Mar 2023 06:24:34 +0100 Subject: [PATCH] enable rocm, fix simple_dmatrix.cuh --- src/data/simple_dmatrix.cuh | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/data/simple_dmatrix.cuh b/src/data/simple_dmatrix.cuh index c71a52b67..f3d4d953f 100644 --- a/src/data/simple_dmatrix.cuh +++ b/src/data/simple_dmatrix.cuh @@ -41,7 +41,13 @@ void CopyDataToDMatrix(AdapterBatchT batch, common::Span data, template void CountRowOffsets(const AdapterBatchT& batch, common::Span offset, int device_idx, float missing) { + +#if defined(XGBOOST_USE_HIP) + dh::safe_cuda(hipSetDevice(device_idx)); +#else dh::safe_cuda(cudaSetDevice(device_idx)); +#endif + IsValidFunctor is_valid(missing); // Count elements per row dh::LaunchN(batch.Size(), [=] __device__(size_t idx) { @@ -54,10 +60,18 @@ void CountRowOffsets(const AdapterBatchT& batch, common::Span offset, }); dh::XGBCachingDeviceAllocator alloc; + +#if defined(XGBOOST_USE_HIP) + thrust::exclusive_scan(thrust::hip::par(alloc), + thrust::device_pointer_cast(offset.data()), + thrust::device_pointer_cast(offset.data() + offset.size()), + thrust::device_pointer_cast(offset.data())); +#else thrust::exclusive_scan(thrust::cuda::par(alloc), thrust::device_pointer_cast(offset.data()), thrust::device_pointer_cast(offset.data() + offset.size()), thrust::device_pointer_cast(offset.data())); +#endif } template