From 62c4efac51c7b821fff9e104abacb6c4ce0d1e92 Mon Sep 17 00:00:00 2001 From: amdsc21 <96135754+amdsc21@users.noreply.github.com> Date: Wed, 8 Mar 2023 06:37:34 +0100 Subject: [PATCH] enable rocm, fix transform.h --- src/common/transform.h | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/common/transform.h b/src/common/transform.h index 5f9c3f1bf..974ee86d6 100644 --- a/src/common/transform.h +++ b/src/common/transform.h @@ -140,7 +140,13 @@ class Transform { // granularity is used in data vector. size_t shard_size = range_size; Range shard_range {0, static_cast(shard_size)}; + +#if defined(XGBOOST_USE_HIP) + dh::safe_cuda(hipSetDevice(device_)); +#else dh::safe_cuda(cudaSetDevice(device_)); +#endif + const int kGrids = static_cast(DivRoundUp(*(range_.end()), kBlockThreads)); if (kGrids == 0) {