enable rocm, fix quantile.cuh

This commit is contained in:
amdsc21 2023-03-08 06:32:09 +01:00
parent 2eb0b6aae4
commit d3be67ad8e

View File

@ -175,7 +175,13 @@ class SketchContainer {
template <typename KeyComp = thrust::equal_to<size_t>> template <typename KeyComp = thrust::equal_to<size_t>>
size_t Unique(KeyComp key_comp = thrust::equal_to<size_t>{}) { size_t Unique(KeyComp key_comp = thrust::equal_to<size_t>{}) {
timer_.Start(__func__); timer_.Start(__func__);
#if defined(XGBOOST_USE_HIP)
dh::safe_cuda(hipSetDevice(device_));
#else
dh::safe_cuda(cudaSetDevice(device_)); dh::safe_cuda(cudaSetDevice(device_));
#endif
this->columns_ptr_.SetDevice(device_); this->columns_ptr_.SetDevice(device_);
Span<OffsetT> d_column_scan = this->columns_ptr_.DeviceSpan(); Span<OffsetT> d_column_scan = this->columns_ptr_.DeviceSpan();
CHECK_EQ(d_column_scan.size(), num_columns_ + 1); CHECK_EQ(d_column_scan.size(), num_columns_ + 1);
@ -186,11 +192,21 @@ class SketchContainer {
dh::XGBCachingDeviceAllocator<char> alloc; dh::XGBCachingDeviceAllocator<char> alloc;
d_column_scan = this->columns_ptr_.DeviceSpan(); d_column_scan = this->columns_ptr_.DeviceSpan();
#if defined(XGBOOST_USE_HIP)
size_t n_uniques = dh::SegmentedUnique(
thrust::hip::par(alloc), d_column_scan.data(),
d_column_scan.data() + d_column_scan.size(), entries.data(),
entries.data() + entries.size(), scan_out.DevicePointer(),
entries.data(), detail::SketchUnique{}, key_comp);
#else
size_t n_uniques = dh::SegmentedUnique( size_t n_uniques = dh::SegmentedUnique(
thrust::cuda::par(alloc), d_column_scan.data(), thrust::cuda::par(alloc), d_column_scan.data(),
d_column_scan.data() + d_column_scan.size(), entries.data(), d_column_scan.data() + d_column_scan.size(), entries.data(),
entries.data() + entries.size(), scan_out.DevicePointer(), entries.data() + entries.size(), scan_out.DevicePointer(),
entries.data(), detail::SketchUnique{}, key_comp); entries.data(), detail::SketchUnique{}, key_comp);
#endif
this->columns_ptr_.Copy(scan_out); this->columns_ptr_.Copy(scan_out);
CHECK(!this->columns_ptr_.HostCanRead()); CHECK(!this->columns_ptr_.HostCanRead());