diff --git a/src/common/device_helpers.cuh b/src/common/device_helpers.cuh index 57201deb4..d9c4215db 100644 --- a/src/common/device_helpers.cuh +++ b/src/common/device_helpers.cuh @@ -430,38 +430,18 @@ struct XGBCachingDeviceAllocatorImpl : XGBBaseDeviceAllocator { return *allocator; } pointer allocate(size_t n) { // NOLINT - pointer ptr; - if (use_cub_allocator_) { - T* raw_ptr{nullptr}; - GetGlobalCachingAllocator().DeviceAllocate(reinterpret_cast(&raw_ptr), n * sizeof(T)); - ptr = pointer(raw_ptr); - } else { - ptr = SuperT::allocate(n); - } - GlobalMemoryLogger().RegisterAllocation(ptr.get(), n * sizeof(T)); - return ptr; + T* ptr; + GetGlobalCachingAllocator().DeviceAllocate(reinterpret_cast(&ptr), + n * sizeof(T)); + pointer thrust_ptr{ ptr }; + GlobalMemoryLogger().RegisterAllocation(thrust_ptr.get(), n * sizeof(T)); + return thrust_ptr; } void deallocate(pointer ptr, size_t n) { // NOLINT GlobalMemoryLogger().RegisterDeallocation(ptr.get(), n * sizeof(T)); - if (use_cub_allocator_) { - GetGlobalCachingAllocator().DeviceFree(ptr.get()); - } else { - SuperT::deallocate(ptr, n); - } + GetGlobalCachingAllocator().DeviceFree(ptr.get()); } -#if defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1 - XGBCachingDeviceAllocatorImpl() - : SuperT(rmm::mr::get_current_device_resource(), cudaStream_t{nullptr}) { - std::string symbol{typeid(*SuperT::resource()).name()}; - if (symbol.find("pool_memory_resource") != std::string::npos - || symbol.find("binning_memory_resource") != std::string::npos - || symbol.find("arena_memory_resource") != std::string::npos) { - use_cub_allocator_ = false; - } - } -#endif // defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1 - private: - bool use_cub_allocator_{true}; + XGBOOST_DEVICE void construct(T *) {} // NOLINT }; } // namespace detail diff --git a/src/common/hist_util.cu b/src/common/hist_util.cu index 376a64742..4aa1c38ea 100644 --- a/src/common/hist_util.cu +++ b/src/common/hist_util.cu @@ -106,16 +106,17 @@ size_t SketchBatchNumElements(size_t sketch_batch_num_elements, return sketch_batch_num_elements; } -void SortByWeight(dh::XGBCachingDeviceAllocator* alloc, - dh::caching_device_vector* weights, - dh::caching_device_vector* sorted_entries) { +void SortByWeight(dh::device_vector* weights, + dh::device_vector* sorted_entries) { // Sort both entries and wegihts. - thrust::sort_by_key(thrust::cuda::par(*alloc), sorted_entries->begin(), + dh::XGBDeviceAllocator alloc; + thrust::sort_by_key(thrust::cuda::par(alloc), sorted_entries->begin(), sorted_entries->end(), weights->begin(), detail::EntryCompareOp()); // Scan weights - thrust::inclusive_scan_by_key(thrust::cuda::par(*alloc), + dh::XGBCachingDeviceAllocator caching; + thrust::inclusive_scan_by_key(thrust::cuda::par(caching), sorted_entries->begin(), sorted_entries->end(), weights->begin(), weights->begin(), [=] __device__(const Entry& a, const Entry& b) { @@ -216,11 +217,11 @@ void ProcessWeightedBatch(int device, const SparsePage& page, bool is_ranking, Span d_group_ptr) { dh::XGBCachingDeviceAllocator alloc; const auto& host_data = page.data.ConstHostVector(); - dh::caching_device_vector sorted_entries(host_data.begin() + begin, - host_data.begin() + end); + dh::device_vector sorted_entries(host_data.begin() + begin, + host_data.begin() + end); // Binary search to assign weights to each element - dh::caching_device_vector temp_weights(sorted_entries.size()); + dh::device_vector temp_weights(sorted_entries.size()); auto d_temp_weights = temp_weights.data().get(); page.offset.SetDevice(device); auto row_ptrs = page.offset.ConstDeviceSpan(); @@ -243,7 +244,7 @@ void ProcessWeightedBatch(int device, const SparsePage& page, d_temp_weights[idx] = weights[ridx + base_rowid]; }); } - detail::SortByWeight(&alloc, &temp_weights, &sorted_entries); + detail::SortByWeight(&temp_weights, &sorted_entries); HostDeviceVector cuts_ptr; dh::caching_device_vector column_sizes_scan; diff --git a/src/common/hist_util.cuh b/src/common/hist_util.cuh index 7047716bf..922e75b95 100644 --- a/src/common/hist_util.cuh +++ b/src/common/hist_util.cuh @@ -102,7 +102,7 @@ void MakeEntriesFromAdapter(AdapterBatch const& batch, BatchIter batch_iter, size_t columns, size_t cuts_per_feature, int device, HostDeviceVector* cut_sizes_scan, dh::caching_device_vector* column_sizes_scan, - dh::caching_device_vector* sorted_entries) { + dh::device_vector* sorted_entries) { auto entry_iter = dh::MakeTransformIterator( thrust::make_counting_iterator(0llu), [=] __device__(size_t idx) { return Entry(batch.GetElement(idx).column_idx, @@ -123,9 +123,8 @@ void MakeEntriesFromAdapter(AdapterBatch const& batch, BatchIter batch_iter, entry_iter + range.end(), sorted_entries->begin(), is_valid); } -void SortByWeight(dh::XGBCachingDeviceAllocator* alloc, - dh::caching_device_vector* weights, - dh::caching_device_vector* sorted_entries); +void SortByWeight(dh::device_vector* weights, + dh::device_vector* sorted_entries); } // namespace detail // Compute sketch on DMatrix. @@ -138,7 +137,7 @@ void ProcessSlidingWindow(AdapterBatch const& batch, int device, size_t columns, size_t begin, size_t end, float missing, SketchContainer* sketch_container, int num_cuts) { // Copy current subset of valid elements into temporary storage and sort - dh::caching_device_vector sorted_entries; + dh::device_vector sorted_entries; dh::caching_device_vector column_sizes_scan; auto batch_iter = dh::MakeTransformIterator( thrust::make_counting_iterator(0llu), @@ -149,7 +148,7 @@ void ProcessSlidingWindow(AdapterBatch const& batch, int device, size_t columns, &cuts_ptr, &column_sizes_scan, &sorted_entries); - dh::XGBCachingDeviceAllocator alloc; + dh::XGBDeviceAllocator alloc; thrust::sort(thrust::cuda::par(alloc), sorted_entries.begin(), sorted_entries.end(), detail::EntryCompareOp()); @@ -179,7 +178,7 @@ void ProcessWeightedSlidingWindow(Batch batch, MetaInfo const& info, auto batch_iter = dh::MakeTransformIterator( thrust::make_counting_iterator(0llu), [=] __device__(size_t idx) { return batch.GetElement(idx); }); - dh::caching_device_vector sorted_entries; + dh::device_vector sorted_entries; dh::caching_device_vector column_sizes_scan; HostDeviceVector cuts_ptr; detail::MakeEntriesFromAdapter(batch, batch_iter, @@ -190,7 +189,7 @@ void ProcessWeightedSlidingWindow(Batch batch, MetaInfo const& info, &sorted_entries); data::IsValidFunctor is_valid(missing); - dh::caching_device_vector temp_weights(sorted_entries.size()); + dh::device_vector temp_weights(sorted_entries.size()); auto d_temp_weights = dh::ToSpan(temp_weights); if (is_ranking) { @@ -222,7 +221,7 @@ void ProcessWeightedSlidingWindow(Batch batch, MetaInfo const& info, CHECK_EQ(retit - d_temp_weights.data(), d_temp_weights.size()); } - detail::SortByWeight(&alloc, &temp_weights, &sorted_entries); + detail::SortByWeight(&temp_weights, &sorted_entries); auto const& h_cuts_ptr = cuts_ptr.ConstHostVector(); auto d_cuts_ptr = cuts_ptr.ConstDeviceSpan(); diff --git a/src/common/quantile.cu b/src/common/quantile.cu index 42dd8837a..aa9667c54 100644 --- a/src/common/quantile.cu +++ b/src/common/quantile.cu @@ -310,7 +310,7 @@ void SketchContainer::Push(Span entries, Span columns_ptr, common::Span cuts_ptr, size_t total_cuts, Span weights) { Span out; - dh::caching_device_vector cuts; + dh::device_vector cuts; bool first_window = this->Current().empty(); if (!first_window) { cuts.resize(total_cuts); diff --git a/src/common/quantile.cuh b/src/common/quantile.cuh index 81ba92de7..3a8b3fcb7 100644 --- a/src/common/quantile.cuh +++ b/src/common/quantile.cuh @@ -36,31 +36,31 @@ class SketchContainer { int32_t device_; // Double buffer as neither prune nor merge can be performed inplace. - dh::caching_device_vector entries_a_; - dh::caching_device_vector entries_b_; + dh::device_vector entries_a_; + dh::device_vector entries_b_; bool current_buffer_ {true}; // The container is just a CSC matrix. HostDeviceVector columns_ptr_; HostDeviceVector columns_ptr_b_; - dh::caching_device_vector& Current() { + dh::device_vector& Current() { if (current_buffer_) { return entries_a_; } else { return entries_b_; } } - dh::caching_device_vector& Other() { + dh::device_vector& Other() { if (!current_buffer_) { return entries_a_; } else { return entries_b_; } } - dh::caching_device_vector const& Current() const { + dh::device_vector const& Current() const { return const_cast(this)->Current(); } - dh::caching_device_vector const& Other() const { + dh::device_vector const& Other() const { return const_cast(this)->Other(); } void Alternate() {