enable rocm, fix simple_dmatrix.cuh

This commit is contained in:
amdsc21 2023-03-08 06:24:34 +01:00
parent 270c7b4802
commit 427f6c2a1a

View File

@ -41,7 +41,13 @@ void CopyDataToDMatrix(AdapterBatchT batch, common::Span<Entry> data,
template <typename AdapterBatchT>
void CountRowOffsets(const AdapterBatchT& batch, common::Span<bst_row_t> offset,
int device_idx, float missing) {
#if defined(XGBOOST_USE_HIP)
dh::safe_cuda(hipSetDevice(device_idx));
#else
dh::safe_cuda(cudaSetDevice(device_idx));
#endif
IsValidFunctor is_valid(missing);
// Count elements per row
dh::LaunchN(batch.Size(), [=] __device__(size_t idx) {
@ -54,10 +60,18 @@ void CountRowOffsets(const AdapterBatchT& batch, common::Span<bst_row_t> offset,
});
dh::XGBCachingDeviceAllocator<char> alloc;
#if defined(XGBOOST_USE_HIP)
thrust::exclusive_scan(thrust::hip::par(alloc),
thrust::device_pointer_cast(offset.data()),
thrust::device_pointer_cast(offset.data() + offset.size()),
thrust::device_pointer_cast(offset.data()));
#else
thrust::exclusive_scan(thrust::cuda::par(alloc),
thrust::device_pointer_cast(offset.data()),
thrust::device_pointer_cast(offset.data() + offset.size()),
thrust::device_pointer_cast(offset.data()));
#endif
}
template <typename AdapterBatchT>