enable rocm, fix quantile.cuh
This commit is contained in:
parent
2eb0b6aae4
commit
d3be67ad8e
@ -175,7 +175,13 @@ class SketchContainer {
|
||||
template <typename KeyComp = thrust::equal_to<size_t>>
|
||||
size_t Unique(KeyComp key_comp = thrust::equal_to<size_t>{}) {
|
||||
timer_.Start(__func__);
|
||||
|
||||
#if defined(XGBOOST_USE_HIP)
|
||||
dh::safe_cuda(hipSetDevice(device_));
|
||||
#else
|
||||
dh::safe_cuda(cudaSetDevice(device_));
|
||||
#endif
|
||||
|
||||
this->columns_ptr_.SetDevice(device_);
|
||||
Span<OffsetT> d_column_scan = this->columns_ptr_.DeviceSpan();
|
||||
CHECK_EQ(d_column_scan.size(), num_columns_ + 1);
|
||||
@ -186,11 +192,21 @@ class SketchContainer {
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
|
||||
d_column_scan = this->columns_ptr_.DeviceSpan();
|
||||
|
||||
#if defined(XGBOOST_USE_HIP)
|
||||
size_t n_uniques = dh::SegmentedUnique(
|
||||
thrust::hip::par(alloc), d_column_scan.data(),
|
||||
d_column_scan.data() + d_column_scan.size(), entries.data(),
|
||||
entries.data() + entries.size(), scan_out.DevicePointer(),
|
||||
entries.data(), detail::SketchUnique{}, key_comp);
|
||||
#else
|
||||
size_t n_uniques = dh::SegmentedUnique(
|
||||
thrust::cuda::par(alloc), d_column_scan.data(),
|
||||
d_column_scan.data() + d_column_scan.size(), entries.data(),
|
||||
entries.data() + entries.size(), scan_out.DevicePointer(),
|
||||
entries.data(), detail::SketchUnique{}, key_comp);
|
||||
#endif
|
||||
|
||||
this->columns_ptr_.Copy(scan_out);
|
||||
CHECK(!this->columns_ptr_.HostCanRead());
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user