enable rocm, fix transform.h
This commit is contained in:
parent
ba9e00d911
commit
62c4efac51
@ -140,7 +140,13 @@ class Transform {
|
|||||||
// granularity is used in data vector.
|
// granularity is used in data vector.
|
||||||
size_t shard_size = range_size;
|
size_t shard_size = range_size;
|
||||||
Range shard_range {0, static_cast<Range::DifferenceType>(shard_size)};
|
Range shard_range {0, static_cast<Range::DifferenceType>(shard_size)};
|
||||||
|
|
||||||
|
#if defined(XGBOOST_USE_HIP)
|
||||||
|
dh::safe_cuda(hipSetDevice(device_));
|
||||||
|
#else
|
||||||
dh::safe_cuda(cudaSetDevice(device_));
|
dh::safe_cuda(cudaSetDevice(device_));
|
||||||
|
#endif
|
||||||
|
|
||||||
const int kGrids =
|
const int kGrids =
|
||||||
static_cast<int>(DivRoundUp(*(range_.end()), kBlockThreads));
|
static_cast<int>(DivRoundUp(*(range_.end()), kBlockThreads));
|
||||||
if (kGrids == 0) {
|
if (kGrids == 0) {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user