enable rocm, fix device_adapter.cuh

This commit is contained in:
amdsc21 2023-03-08 06:26:31 +01:00
parent 427f6c2a1a
commit fa92aa56ee

View File

@ -111,7 +111,13 @@ class CudfAdapter : public detail::SingleBatchDataIter<CudfAdapterBatch> {
device_idx_ = dh::CudaGetPointerDevice(first_column.data);
CHECK_NE(device_idx_, Context::kCpuId);
#if defined(XGBOOST_USE_HIP)
dh::safe_cuda(hipSetDevice(device_idx_));
#else
dh::safe_cuda(cudaSetDevice(device_idx_));
#endif
for (auto& json_col : json_columns) {
auto column = ArrayInterface<1>(get<Object const>(json_col));
columns.push_back(column);
@ -195,7 +201,13 @@ class CupyAdapter : public detail::SingleBatchDataIter<CupyAdapterBatch> {
template <typename AdapterBatchT>
size_t GetRowCounts(const AdapterBatchT batch, common::Span<size_t> offset,
int device_idx, float missing) {
#if defined(XGBOOST_USE_HIP)
dh::safe_cuda(hipSetDevice(device_idx));
#else
dh::safe_cuda(cudaSetDevice(device_idx));
#endif
IsValidFunctor is_valid(missing);
// Count elements per row
dh::LaunchN(batch.Size(), [=] __device__(size_t idx) {
@ -206,11 +218,20 @@ size_t GetRowCounts(const AdapterBatchT batch, common::Span<size_t> offset,
static_cast<unsigned long long>(1)); // NOLINT
}
});
dh::XGBCachingDeviceAllocator<char> alloc;
#if defined(XGBOOST_USE_HIP)
dh::Reduce(thrust::hip::par(alloc), thrust::device_pointer_cast(offset.data()),
thrust::device_pointer_cast(offset.data()) + offset.size(),
static_cast<std::size_t>(0), thrust::maximum<size_t>());
#else
size_t row_stride =
dh::Reduce(thrust::cuda::par(alloc), thrust::device_pointer_cast(offset.data()),
thrust::device_pointer_cast(offset.data()) + offset.size(),
static_cast<std::size_t>(0), thrust::maximum<size_t>());
#endif
return row_stride;
}
}; // namespace data