enable rocm, fix device_adapter.cuh
This commit is contained in:
parent
427f6c2a1a
commit
fa92aa56ee
@ -111,7 +111,13 @@ class CudfAdapter : public detail::SingleBatchDataIter<CudfAdapterBatch> {
|
||||
|
||||
device_idx_ = dh::CudaGetPointerDevice(first_column.data);
|
||||
CHECK_NE(device_idx_, Context::kCpuId);
|
||||
|
||||
#if defined(XGBOOST_USE_HIP)
|
||||
dh::safe_cuda(hipSetDevice(device_idx_));
|
||||
#else
|
||||
dh::safe_cuda(cudaSetDevice(device_idx_));
|
||||
#endif
|
||||
|
||||
for (auto& json_col : json_columns) {
|
||||
auto column = ArrayInterface<1>(get<Object const>(json_col));
|
||||
columns.push_back(column);
|
||||
@ -195,7 +201,13 @@ class CupyAdapter : public detail::SingleBatchDataIter<CupyAdapterBatch> {
|
||||
template <typename AdapterBatchT>
|
||||
size_t GetRowCounts(const AdapterBatchT batch, common::Span<size_t> offset,
|
||||
int device_idx, float missing) {
|
||||
|
||||
#if defined(XGBOOST_USE_HIP)
|
||||
dh::safe_cuda(hipSetDevice(device_idx));
|
||||
#else
|
||||
dh::safe_cuda(cudaSetDevice(device_idx));
|
||||
#endif
|
||||
|
||||
IsValidFunctor is_valid(missing);
|
||||
// Count elements per row
|
||||
dh::LaunchN(batch.Size(), [=] __device__(size_t idx) {
|
||||
@ -206,11 +218,20 @@ size_t GetRowCounts(const AdapterBatchT batch, common::Span<size_t> offset,
|
||||
static_cast<unsigned long long>(1)); // NOLINT
|
||||
}
|
||||
});
|
||||
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
|
||||
#if defined(XGBOOST_USE_HIP)
|
||||
dh::Reduce(thrust::hip::par(alloc), thrust::device_pointer_cast(offset.data()),
|
||||
thrust::device_pointer_cast(offset.data()) + offset.size(),
|
||||
static_cast<std::size_t>(0), thrust::maximum<size_t>());
|
||||
#else
|
||||
size_t row_stride =
|
||||
dh::Reduce(thrust::cuda::par(alloc), thrust::device_pointer_cast(offset.data()),
|
||||
thrust::device_pointer_cast(offset.data()) + offset.size(),
|
||||
static_cast<std::size_t>(0), thrust::maximum<size_t>());
|
||||
#endif
|
||||
|
||||
return row_stride;
|
||||
}
|
||||
}; // namespace data
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user