enable rocm, fix device_adapter.cuh
This commit is contained in:
parent
427f6c2a1a
commit
fa92aa56ee
@ -111,7 +111,13 @@ class CudfAdapter : public detail::SingleBatchDataIter<CudfAdapterBatch> {
|
|||||||
|
|
||||||
device_idx_ = dh::CudaGetPointerDevice(first_column.data);
|
device_idx_ = dh::CudaGetPointerDevice(first_column.data);
|
||||||
CHECK_NE(device_idx_, Context::kCpuId);
|
CHECK_NE(device_idx_, Context::kCpuId);
|
||||||
|
|
||||||
|
#if defined(XGBOOST_USE_HIP)
|
||||||
|
dh::safe_cuda(hipSetDevice(device_idx_));
|
||||||
|
#else
|
||||||
dh::safe_cuda(cudaSetDevice(device_idx_));
|
dh::safe_cuda(cudaSetDevice(device_idx_));
|
||||||
|
#endif
|
||||||
|
|
||||||
for (auto& json_col : json_columns) {
|
for (auto& json_col : json_columns) {
|
||||||
auto column = ArrayInterface<1>(get<Object const>(json_col));
|
auto column = ArrayInterface<1>(get<Object const>(json_col));
|
||||||
columns.push_back(column);
|
columns.push_back(column);
|
||||||
@ -195,7 +201,13 @@ class CupyAdapter : public detail::SingleBatchDataIter<CupyAdapterBatch> {
|
|||||||
template <typename AdapterBatchT>
|
template <typename AdapterBatchT>
|
||||||
size_t GetRowCounts(const AdapterBatchT batch, common::Span<size_t> offset,
|
size_t GetRowCounts(const AdapterBatchT batch, common::Span<size_t> offset,
|
||||||
int device_idx, float missing) {
|
int device_idx, float missing) {
|
||||||
|
|
||||||
|
#if defined(XGBOOST_USE_HIP)
|
||||||
|
dh::safe_cuda(hipSetDevice(device_idx));
|
||||||
|
#else
|
||||||
dh::safe_cuda(cudaSetDevice(device_idx));
|
dh::safe_cuda(cudaSetDevice(device_idx));
|
||||||
|
#endif
|
||||||
|
|
||||||
IsValidFunctor is_valid(missing);
|
IsValidFunctor is_valid(missing);
|
||||||
// Count elements per row
|
// Count elements per row
|
||||||
dh::LaunchN(batch.Size(), [=] __device__(size_t idx) {
|
dh::LaunchN(batch.Size(), [=] __device__(size_t idx) {
|
||||||
@ -206,11 +218,20 @@ size_t GetRowCounts(const AdapterBatchT batch, common::Span<size_t> offset,
|
|||||||
static_cast<unsigned long long>(1)); // NOLINT
|
static_cast<unsigned long long>(1)); // NOLINT
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||||
|
|
||||||
|
#if defined(XGBOOST_USE_HIP)
|
||||||
|
dh::Reduce(thrust::hip::par(alloc), thrust::device_pointer_cast(offset.data()),
|
||||||
|
thrust::device_pointer_cast(offset.data()) + offset.size(),
|
||||||
|
static_cast<std::size_t>(0), thrust::maximum<size_t>());
|
||||||
|
#else
|
||||||
size_t row_stride =
|
size_t row_stride =
|
||||||
dh::Reduce(thrust::cuda::par(alloc), thrust::device_pointer_cast(offset.data()),
|
dh::Reduce(thrust::cuda::par(alloc), thrust::device_pointer_cast(offset.data()),
|
||||||
thrust::device_pointer_cast(offset.data()) + offset.size(),
|
thrust::device_pointer_cast(offset.data()) + offset.size(),
|
||||||
static_cast<std::size_t>(0), thrust::maximum<size_t>());
|
static_cast<std::size_t>(0), thrust::maximum<size_t>());
|
||||||
|
#endif
|
||||||
|
|
||||||
return row_stride;
|
return row_stride;
|
||||||
}
|
}
|
||||||
}; // namespace data
|
}; // namespace data
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user