enable rocm, fix simple_dmatrix.cuh
This commit is contained in:
parent
270c7b4802
commit
427f6c2a1a
@ -41,7 +41,13 @@ void CopyDataToDMatrix(AdapterBatchT batch, common::Span<Entry> data,
|
||||
template <typename AdapterBatchT>
|
||||
void CountRowOffsets(const AdapterBatchT& batch, common::Span<bst_row_t> offset,
|
||||
int device_idx, float missing) {
|
||||
|
||||
#if defined(XGBOOST_USE_HIP)
|
||||
dh::safe_cuda(hipSetDevice(device_idx));
|
||||
#else
|
||||
dh::safe_cuda(cudaSetDevice(device_idx));
|
||||
#endif
|
||||
|
||||
IsValidFunctor is_valid(missing);
|
||||
// Count elements per row
|
||||
dh::LaunchN(batch.Size(), [=] __device__(size_t idx) {
|
||||
@ -54,10 +60,18 @@ void CountRowOffsets(const AdapterBatchT& batch, common::Span<bst_row_t> offset,
|
||||
});
|
||||
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
|
||||
#if defined(XGBOOST_USE_HIP)
|
||||
thrust::exclusive_scan(thrust::hip::par(alloc),
|
||||
thrust::device_pointer_cast(offset.data()),
|
||||
thrust::device_pointer_cast(offset.data() + offset.size()),
|
||||
thrust::device_pointer_cast(offset.data()));
|
||||
#else
|
||||
thrust::exclusive_scan(thrust::cuda::par(alloc),
|
||||
thrust::device_pointer_cast(offset.data()),
|
||||
thrust::device_pointer_cast(offset.data() + offset.size()),
|
||||
thrust::device_pointer_cast(offset.data()));
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename AdapterBatchT>
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user