diff --git a/src/common/quantile.cuh b/src/common/quantile.cuh index 7ebd4ff51..de7f84dc4 100644 --- a/src/common/quantile.cuh +++ b/src/common/quantile.cuh @@ -175,7 +175,13 @@ class SketchContainer { template > size_t Unique(KeyComp key_comp = thrust::equal_to{}) { timer_.Start(__func__); + +#if defined(XGBOOST_USE_HIP) + dh::safe_cuda(hipSetDevice(device_)); +#else dh::safe_cuda(cudaSetDevice(device_)); +#endif + this->columns_ptr_.SetDevice(device_); Span d_column_scan = this->columns_ptr_.DeviceSpan(); CHECK_EQ(d_column_scan.size(), num_columns_ + 1); @@ -186,11 +192,21 @@ class SketchContainer { dh::XGBCachingDeviceAllocator alloc; d_column_scan = this->columns_ptr_.DeviceSpan(); + +#if defined(XGBOOST_USE_HIP) + size_t n_uniques = dh::SegmentedUnique( + thrust::hip::par(alloc), d_column_scan.data(), + d_column_scan.data() + d_column_scan.size(), entries.data(), + entries.data() + entries.size(), scan_out.DevicePointer(), + entries.data(), detail::SketchUnique{}, key_comp); +#else size_t n_uniques = dh::SegmentedUnique( thrust::cuda::par(alloc), d_column_scan.data(), d_column_scan.data() + d_column_scan.size(), entries.data(), entries.data() + entries.size(), scan_out.DevicePointer(), entries.data(), detail::SketchUnique{}, key_comp); +#endif + this->columns_ptr_.Copy(scan_out); CHECK(!this->columns_ptr_.HostCanRead());