Use default allocator in sketching. (#6182)
This commit is contained in:
parent
444131a2e6
commit
f0c63902ff
@ -430,38 +430,18 @@ struct XGBCachingDeviceAllocatorImpl : XGBBaseDeviceAllocator<T> {
|
||||
return *allocator;
|
||||
}
|
||||
pointer allocate(size_t n) { // NOLINT
|
||||
pointer ptr;
|
||||
if (use_cub_allocator_) {
|
||||
T* raw_ptr{nullptr};
|
||||
GetGlobalCachingAllocator().DeviceAllocate(reinterpret_cast<void**>(&raw_ptr), n * sizeof(T));
|
||||
ptr = pointer(raw_ptr);
|
||||
} else {
|
||||
ptr = SuperT::allocate(n);
|
||||
}
|
||||
GlobalMemoryLogger().RegisterAllocation(ptr.get(), n * sizeof(T));
|
||||
return ptr;
|
||||
T* ptr;
|
||||
GetGlobalCachingAllocator().DeviceAllocate(reinterpret_cast<void **>(&ptr),
|
||||
n * sizeof(T));
|
||||
pointer thrust_ptr{ ptr };
|
||||
GlobalMemoryLogger().RegisterAllocation(thrust_ptr.get(), n * sizeof(T));
|
||||
return thrust_ptr;
|
||||
}
|
||||
void deallocate(pointer ptr, size_t n) { // NOLINT
|
||||
GlobalMemoryLogger().RegisterDeallocation(ptr.get(), n * sizeof(T));
|
||||
if (use_cub_allocator_) {
|
||||
GetGlobalCachingAllocator().DeviceFree(ptr.get());
|
||||
} else {
|
||||
SuperT::deallocate(ptr, n);
|
||||
}
|
||||
GetGlobalCachingAllocator().DeviceFree(ptr.get());
|
||||
}
|
||||
#if defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
|
||||
XGBCachingDeviceAllocatorImpl()
|
||||
: SuperT(rmm::mr::get_current_device_resource(), cudaStream_t{nullptr}) {
|
||||
std::string symbol{typeid(*SuperT::resource()).name()};
|
||||
if (symbol.find("pool_memory_resource") != std::string::npos
|
||||
|| symbol.find("binning_memory_resource") != std::string::npos
|
||||
|| symbol.find("arena_memory_resource") != std::string::npos) {
|
||||
use_cub_allocator_ = false;
|
||||
}
|
||||
}
|
||||
#endif // defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
|
||||
private:
|
||||
bool use_cub_allocator_{true};
|
||||
XGBOOST_DEVICE void construct(T *) {} // NOLINT
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
|
||||
@ -106,16 +106,17 @@ size_t SketchBatchNumElements(size_t sketch_batch_num_elements,
|
||||
return sketch_batch_num_elements;
|
||||
}
|
||||
|
||||
void SortByWeight(dh::XGBCachingDeviceAllocator<char>* alloc,
|
||||
dh::caching_device_vector<float>* weights,
|
||||
dh::caching_device_vector<Entry>* sorted_entries) {
|
||||
void SortByWeight(dh::device_vector<float>* weights,
|
||||
dh::device_vector<Entry>* sorted_entries) {
|
||||
// Sort both entries and wegihts.
|
||||
thrust::sort_by_key(thrust::cuda::par(*alloc), sorted_entries->begin(),
|
||||
dh::XGBDeviceAllocator<char> alloc;
|
||||
thrust::sort_by_key(thrust::cuda::par(alloc), sorted_entries->begin(),
|
||||
sorted_entries->end(), weights->begin(),
|
||||
detail::EntryCompareOp());
|
||||
|
||||
// Scan weights
|
||||
thrust::inclusive_scan_by_key(thrust::cuda::par(*alloc),
|
||||
dh::XGBCachingDeviceAllocator<char> caching;
|
||||
thrust::inclusive_scan_by_key(thrust::cuda::par(caching),
|
||||
sorted_entries->begin(), sorted_entries->end(),
|
||||
weights->begin(), weights->begin(),
|
||||
[=] __device__(const Entry& a, const Entry& b) {
|
||||
@ -216,11 +217,11 @@ void ProcessWeightedBatch(int device, const SparsePage& page,
|
||||
bool is_ranking, Span<bst_group_t const> d_group_ptr) {
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
const auto& host_data = page.data.ConstHostVector();
|
||||
dh::caching_device_vector<Entry> sorted_entries(host_data.begin() + begin,
|
||||
host_data.begin() + end);
|
||||
dh::device_vector<Entry> sorted_entries(host_data.begin() + begin,
|
||||
host_data.begin() + end);
|
||||
|
||||
// Binary search to assign weights to each element
|
||||
dh::caching_device_vector<float> temp_weights(sorted_entries.size());
|
||||
dh::device_vector<float> temp_weights(sorted_entries.size());
|
||||
auto d_temp_weights = temp_weights.data().get();
|
||||
page.offset.SetDevice(device);
|
||||
auto row_ptrs = page.offset.ConstDeviceSpan();
|
||||
@ -243,7 +244,7 @@ void ProcessWeightedBatch(int device, const SparsePage& page,
|
||||
d_temp_weights[idx] = weights[ridx + base_rowid];
|
||||
});
|
||||
}
|
||||
detail::SortByWeight(&alloc, &temp_weights, &sorted_entries);
|
||||
detail::SortByWeight(&temp_weights, &sorted_entries);
|
||||
|
||||
HostDeviceVector<SketchContainer::OffsetT> cuts_ptr;
|
||||
dh::caching_device_vector<size_t> column_sizes_scan;
|
||||
|
||||
@ -102,7 +102,7 @@ void MakeEntriesFromAdapter(AdapterBatch const& batch, BatchIter batch_iter,
|
||||
size_t columns, size_t cuts_per_feature, int device,
|
||||
HostDeviceVector<SketchContainer::OffsetT>* cut_sizes_scan,
|
||||
dh::caching_device_vector<size_t>* column_sizes_scan,
|
||||
dh::caching_device_vector<Entry>* sorted_entries) {
|
||||
dh::device_vector<Entry>* sorted_entries) {
|
||||
auto entry_iter = dh::MakeTransformIterator<Entry>(
|
||||
thrust::make_counting_iterator(0llu), [=] __device__(size_t idx) {
|
||||
return Entry(batch.GetElement(idx).column_idx,
|
||||
@ -123,9 +123,8 @@ void MakeEntriesFromAdapter(AdapterBatch const& batch, BatchIter batch_iter,
|
||||
entry_iter + range.end(), sorted_entries->begin(), is_valid);
|
||||
}
|
||||
|
||||
void SortByWeight(dh::XGBCachingDeviceAllocator<char>* alloc,
|
||||
dh::caching_device_vector<float>* weights,
|
||||
dh::caching_device_vector<Entry>* sorted_entries);
|
||||
void SortByWeight(dh::device_vector<float>* weights,
|
||||
dh::device_vector<Entry>* sorted_entries);
|
||||
} // namespace detail
|
||||
|
||||
// Compute sketch on DMatrix.
|
||||
@ -138,7 +137,7 @@ void ProcessSlidingWindow(AdapterBatch const& batch, int device, size_t columns,
|
||||
size_t begin, size_t end, float missing,
|
||||
SketchContainer* sketch_container, int num_cuts) {
|
||||
// Copy current subset of valid elements into temporary storage and sort
|
||||
dh::caching_device_vector<Entry> sorted_entries;
|
||||
dh::device_vector<Entry> sorted_entries;
|
||||
dh::caching_device_vector<size_t> column_sizes_scan;
|
||||
auto batch_iter = dh::MakeTransformIterator<data::COOTuple>(
|
||||
thrust::make_counting_iterator(0llu),
|
||||
@ -149,7 +148,7 @@ void ProcessSlidingWindow(AdapterBatch const& batch, int device, size_t columns,
|
||||
&cuts_ptr,
|
||||
&column_sizes_scan,
|
||||
&sorted_entries);
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
dh::XGBDeviceAllocator<char> alloc;
|
||||
thrust::sort(thrust::cuda::par(alloc), sorted_entries.begin(),
|
||||
sorted_entries.end(), detail::EntryCompareOp());
|
||||
|
||||
@ -179,7 +178,7 @@ void ProcessWeightedSlidingWindow(Batch batch, MetaInfo const& info,
|
||||
auto batch_iter = dh::MakeTransformIterator<data::COOTuple>(
|
||||
thrust::make_counting_iterator(0llu),
|
||||
[=] __device__(size_t idx) { return batch.GetElement(idx); });
|
||||
dh::caching_device_vector<Entry> sorted_entries;
|
||||
dh::device_vector<Entry> sorted_entries;
|
||||
dh::caching_device_vector<size_t> column_sizes_scan;
|
||||
HostDeviceVector<SketchContainer::OffsetT> cuts_ptr;
|
||||
detail::MakeEntriesFromAdapter(batch, batch_iter,
|
||||
@ -190,7 +189,7 @@ void ProcessWeightedSlidingWindow(Batch batch, MetaInfo const& info,
|
||||
&sorted_entries);
|
||||
data::IsValidFunctor is_valid(missing);
|
||||
|
||||
dh::caching_device_vector<float> temp_weights(sorted_entries.size());
|
||||
dh::device_vector<float> temp_weights(sorted_entries.size());
|
||||
auto d_temp_weights = dh::ToSpan(temp_weights);
|
||||
|
||||
if (is_ranking) {
|
||||
@ -222,7 +221,7 @@ void ProcessWeightedSlidingWindow(Batch batch, MetaInfo const& info,
|
||||
CHECK_EQ(retit - d_temp_weights.data(), d_temp_weights.size());
|
||||
}
|
||||
|
||||
detail::SortByWeight(&alloc, &temp_weights, &sorted_entries);
|
||||
detail::SortByWeight(&temp_weights, &sorted_entries);
|
||||
|
||||
auto const& h_cuts_ptr = cuts_ptr.ConstHostVector();
|
||||
auto d_cuts_ptr = cuts_ptr.ConstDeviceSpan();
|
||||
|
||||
@ -310,7 +310,7 @@ void SketchContainer::Push(Span<Entry const> entries, Span<size_t> columns_ptr,
|
||||
common::Span<OffsetT const> cuts_ptr,
|
||||
size_t total_cuts, Span<float> weights) {
|
||||
Span<SketchEntry> out;
|
||||
dh::caching_device_vector<SketchEntry> cuts;
|
||||
dh::device_vector<SketchEntry> cuts;
|
||||
bool first_window = this->Current().empty();
|
||||
if (!first_window) {
|
||||
cuts.resize(total_cuts);
|
||||
|
||||
@ -36,31 +36,31 @@ class SketchContainer {
|
||||
int32_t device_;
|
||||
|
||||
// Double buffer as neither prune nor merge can be performed inplace.
|
||||
dh::caching_device_vector<SketchEntry> entries_a_;
|
||||
dh::caching_device_vector<SketchEntry> entries_b_;
|
||||
dh::device_vector<SketchEntry> entries_a_;
|
||||
dh::device_vector<SketchEntry> entries_b_;
|
||||
bool current_buffer_ {true};
|
||||
// The container is just a CSC matrix.
|
||||
HostDeviceVector<OffsetT> columns_ptr_;
|
||||
HostDeviceVector<OffsetT> columns_ptr_b_;
|
||||
|
||||
dh::caching_device_vector<SketchEntry>& Current() {
|
||||
dh::device_vector<SketchEntry>& Current() {
|
||||
if (current_buffer_) {
|
||||
return entries_a_;
|
||||
} else {
|
||||
return entries_b_;
|
||||
}
|
||||
}
|
||||
dh::caching_device_vector<SketchEntry>& Other() {
|
||||
dh::device_vector<SketchEntry>& Other() {
|
||||
if (!current_buffer_) {
|
||||
return entries_a_;
|
||||
} else {
|
||||
return entries_b_;
|
||||
}
|
||||
}
|
||||
dh::caching_device_vector<SketchEntry> const& Current() const {
|
||||
dh::device_vector<SketchEntry> const& Current() const {
|
||||
return const_cast<SketchContainer*>(this)->Current();
|
||||
}
|
||||
dh::caching_device_vector<SketchEntry> const& Other() const {
|
||||
dh::device_vector<SketchEntry> const& Other() const {
|
||||
return const_cast<SketchContainer*>(this)->Other();
|
||||
}
|
||||
void Alternate() {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user