enable rocm, fix simple_dmatrix.cuh
This commit is contained in:
parent
270c7b4802
commit
427f6c2a1a
@ -41,7 +41,13 @@ void CopyDataToDMatrix(AdapterBatchT batch, common::Span<Entry> data,
|
|||||||
template <typename AdapterBatchT>
|
template <typename AdapterBatchT>
|
||||||
void CountRowOffsets(const AdapterBatchT& batch, common::Span<bst_row_t> offset,
|
void CountRowOffsets(const AdapterBatchT& batch, common::Span<bst_row_t> offset,
|
||||||
int device_idx, float missing) {
|
int device_idx, float missing) {
|
||||||
|
|
||||||
|
#if defined(XGBOOST_USE_HIP)
|
||||||
|
dh::safe_cuda(hipSetDevice(device_idx));
|
||||||
|
#else
|
||||||
dh::safe_cuda(cudaSetDevice(device_idx));
|
dh::safe_cuda(cudaSetDevice(device_idx));
|
||||||
|
#endif
|
||||||
|
|
||||||
IsValidFunctor is_valid(missing);
|
IsValidFunctor is_valid(missing);
|
||||||
// Count elements per row
|
// Count elements per row
|
||||||
dh::LaunchN(batch.Size(), [=] __device__(size_t idx) {
|
dh::LaunchN(batch.Size(), [=] __device__(size_t idx) {
|
||||||
@ -54,10 +60,18 @@ void CountRowOffsets(const AdapterBatchT& batch, common::Span<bst_row_t> offset,
|
|||||||
});
|
});
|
||||||
|
|
||||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||||
|
|
||||||
|
#if defined(XGBOOST_USE_HIP)
|
||||||
|
thrust::exclusive_scan(thrust::hip::par(alloc),
|
||||||
|
thrust::device_pointer_cast(offset.data()),
|
||||||
|
thrust::device_pointer_cast(offset.data() + offset.size()),
|
||||||
|
thrust::device_pointer_cast(offset.data()));
|
||||||
|
#else
|
||||||
thrust::exclusive_scan(thrust::cuda::par(alloc),
|
thrust::exclusive_scan(thrust::cuda::par(alloc),
|
||||||
thrust::device_pointer_cast(offset.data()),
|
thrust::device_pointer_cast(offset.data()),
|
||||||
thrust::device_pointer_cast(offset.data() + offset.size()),
|
thrust::device_pointer_cast(offset.data() + offset.size()),
|
||||||
thrust::device_pointer_cast(offset.data()));
|
thrust::device_pointer_cast(offset.data()));
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename AdapterBatchT>
|
template <typename AdapterBatchT>
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user