enable rocm, fix quantile.cuh
This commit is contained in:
parent
2eb0b6aae4
commit
d3be67ad8e
@ -175,7 +175,13 @@ class SketchContainer {
|
|||||||
template <typename KeyComp = thrust::equal_to<size_t>>
|
template <typename KeyComp = thrust::equal_to<size_t>>
|
||||||
size_t Unique(KeyComp key_comp = thrust::equal_to<size_t>{}) {
|
size_t Unique(KeyComp key_comp = thrust::equal_to<size_t>{}) {
|
||||||
timer_.Start(__func__);
|
timer_.Start(__func__);
|
||||||
|
|
||||||
|
#if defined(XGBOOST_USE_HIP)
|
||||||
|
dh::safe_cuda(hipSetDevice(device_));
|
||||||
|
#else
|
||||||
dh::safe_cuda(cudaSetDevice(device_));
|
dh::safe_cuda(cudaSetDevice(device_));
|
||||||
|
#endif
|
||||||
|
|
||||||
this->columns_ptr_.SetDevice(device_);
|
this->columns_ptr_.SetDevice(device_);
|
||||||
Span<OffsetT> d_column_scan = this->columns_ptr_.DeviceSpan();
|
Span<OffsetT> d_column_scan = this->columns_ptr_.DeviceSpan();
|
||||||
CHECK_EQ(d_column_scan.size(), num_columns_ + 1);
|
CHECK_EQ(d_column_scan.size(), num_columns_ + 1);
|
||||||
@ -186,11 +192,21 @@ class SketchContainer {
|
|||||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||||
|
|
||||||
d_column_scan = this->columns_ptr_.DeviceSpan();
|
d_column_scan = this->columns_ptr_.DeviceSpan();
|
||||||
|
|
||||||
|
#if defined(XGBOOST_USE_HIP)
|
||||||
|
size_t n_uniques = dh::SegmentedUnique(
|
||||||
|
thrust::hip::par(alloc), d_column_scan.data(),
|
||||||
|
d_column_scan.data() + d_column_scan.size(), entries.data(),
|
||||||
|
entries.data() + entries.size(), scan_out.DevicePointer(),
|
||||||
|
entries.data(), detail::SketchUnique{}, key_comp);
|
||||||
|
#else
|
||||||
size_t n_uniques = dh::SegmentedUnique(
|
size_t n_uniques = dh::SegmentedUnique(
|
||||||
thrust::cuda::par(alloc), d_column_scan.data(),
|
thrust::cuda::par(alloc), d_column_scan.data(),
|
||||||
d_column_scan.data() + d_column_scan.size(), entries.data(),
|
d_column_scan.data() + d_column_scan.size(), entries.data(),
|
||||||
entries.data() + entries.size(), scan_out.DevicePointer(),
|
entries.data() + entries.size(), scan_out.DevicePointer(),
|
||||||
entries.data(), detail::SketchUnique{}, key_comp);
|
entries.data(), detail::SketchUnique{}, key_comp);
|
||||||
|
#endif
|
||||||
|
|
||||||
this->columns_ptr_.Copy(scan_out);
|
this->columns_ptr_.Copy(scan_out);
|
||||||
CHECK(!this->columns_ptr_.HostCanRead());
|
CHECK(!this->columns_ptr_.HostCanRead());
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user