From fa92aa56eef8f087cd79a951096f5826274beeae Mon Sep 17 00:00:00 2001 From: amdsc21 <96135754+amdsc21@users.noreply.github.com> Date: Wed, 8 Mar 2023 06:26:31 +0100 Subject: [PATCH] enable rocm, fix device_adapter.cuh --- src/data/device_adapter.cuh | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/src/data/device_adapter.cuh b/src/data/device_adapter.cuh index 56c494dd1..78d5f79b5 100644 --- a/src/data/device_adapter.cuh +++ b/src/data/device_adapter.cuh @@ -111,7 +111,13 @@ class CudfAdapter : public detail::SingleBatchDataIter { device_idx_ = dh::CudaGetPointerDevice(first_column.data); CHECK_NE(device_idx_, Context::kCpuId); + +#if defined(XGBOOST_USE_HIP) + dh::safe_cuda(hipSetDevice(device_idx_)); +#else dh::safe_cuda(cudaSetDevice(device_idx_)); +#endif + for (auto& json_col : json_columns) { auto column = ArrayInterface<1>(get(json_col)); columns.push_back(column); @@ -195,7 +201,13 @@ class CupyAdapter : public detail::SingleBatchDataIter { template size_t GetRowCounts(const AdapterBatchT batch, common::Span offset, int device_idx, float missing) { + +#if defined(XGBOOST_USE_HIP) + dh::safe_cuda(hipSetDevice(device_idx)); +#else dh::safe_cuda(cudaSetDevice(device_idx)); +#endif + IsValidFunctor is_valid(missing); // Count elements per row dh::LaunchN(batch.Size(), [=] __device__(size_t idx) { @@ -206,11 +218,20 @@ size_t GetRowCounts(const AdapterBatchT batch, common::Span offset, static_cast(1)); // NOLINT } }); + dh::XGBCachingDeviceAllocator alloc; + +#if defined(XGBOOST_USE_HIP) + dh::Reduce(thrust::hip::par(alloc), thrust::device_pointer_cast(offset.data()), + thrust::device_pointer_cast(offset.data()) + offset.size(), + static_cast(0), thrust::maximum()); +#else size_t row_stride = dh::Reduce(thrust::cuda::par(alloc), thrust::device_pointer_cast(offset.data()), thrust::device_pointer_cast(offset.data()) + offset.size(), static_cast(0), thrust::maximum()); +#endif + return row_stride; } }; // namespace data