enable rocm, fix transform.h

This commit is contained in:
amdsc21 2023-03-08 06:37:34 +01:00
parent ba9e00d911
commit 62c4efac51

View File

@ -140,7 +140,13 @@ class Transform {
// granularity is used in data vector. // granularity is used in data vector.
size_t shard_size = range_size; size_t shard_size = range_size;
Range shard_range {0, static_cast<Range::DifferenceType>(shard_size)}; Range shard_range {0, static_cast<Range::DifferenceType>(shard_size)};
#if defined(XGBOOST_USE_HIP)
dh::safe_cuda(hipSetDevice(device_));
#else
dh::safe_cuda(cudaSetDevice(device_)); dh::safe_cuda(cudaSetDevice(device_));
#endif
const int kGrids = const int kGrids =
static_cast<int>(DivRoundUp(*(range_.end()), kBlockThreads)); static_cast<int>(DivRoundUp(*(range_.end()), kBlockThreads));
if (kGrids == 0) { if (kGrids == 0) {