Use default allocator in sketching. (#6182)
This commit is contained in:
@@ -106,16 +106,17 @@ size_t SketchBatchNumElements(size_t sketch_batch_num_elements,
|
||||
return sketch_batch_num_elements;
|
||||
}
|
||||
|
||||
void SortByWeight(dh::XGBCachingDeviceAllocator<char>* alloc,
|
||||
dh::caching_device_vector<float>* weights,
|
||||
dh::caching_device_vector<Entry>* sorted_entries) {
|
||||
void SortByWeight(dh::device_vector<float>* weights,
|
||||
dh::device_vector<Entry>* sorted_entries) {
|
||||
// Sort both entries and wegihts.
|
||||
thrust::sort_by_key(thrust::cuda::par(*alloc), sorted_entries->begin(),
|
||||
dh::XGBDeviceAllocator<char> alloc;
|
||||
thrust::sort_by_key(thrust::cuda::par(alloc), sorted_entries->begin(),
|
||||
sorted_entries->end(), weights->begin(),
|
||||
detail::EntryCompareOp());
|
||||
|
||||
// Scan weights
|
||||
thrust::inclusive_scan_by_key(thrust::cuda::par(*alloc),
|
||||
dh::XGBCachingDeviceAllocator<char> caching;
|
||||
thrust::inclusive_scan_by_key(thrust::cuda::par(caching),
|
||||
sorted_entries->begin(), sorted_entries->end(),
|
||||
weights->begin(), weights->begin(),
|
||||
[=] __device__(const Entry& a, const Entry& b) {
|
||||
@@ -216,11 +217,11 @@ void ProcessWeightedBatch(int device, const SparsePage& page,
|
||||
bool is_ranking, Span<bst_group_t const> d_group_ptr) {
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
const auto& host_data = page.data.ConstHostVector();
|
||||
dh::caching_device_vector<Entry> sorted_entries(host_data.begin() + begin,
|
||||
host_data.begin() + end);
|
||||
dh::device_vector<Entry> sorted_entries(host_data.begin() + begin,
|
||||
host_data.begin() + end);
|
||||
|
||||
// Binary search to assign weights to each element
|
||||
dh::caching_device_vector<float> temp_weights(sorted_entries.size());
|
||||
dh::device_vector<float> temp_weights(sorted_entries.size());
|
||||
auto d_temp_weights = temp_weights.data().get();
|
||||
page.offset.SetDevice(device);
|
||||
auto row_ptrs = page.offset.ConstDeviceSpan();
|
||||
@@ -243,7 +244,7 @@ void ProcessWeightedBatch(int device, const SparsePage& page,
|
||||
d_temp_weights[idx] = weights[ridx + base_rowid];
|
||||
});
|
||||
}
|
||||
detail::SortByWeight(&alloc, &temp_weights, &sorted_entries);
|
||||
detail::SortByWeight(&temp_weights, &sorted_entries);
|
||||
|
||||
HostDeviceVector<SketchContainer::OffsetT> cuts_ptr;
|
||||
dh::caching_device_vector<size_t> column_sizes_scan;
|
||||
|
||||
Reference in New Issue
Block a user