Use default allocator in sketching. (#6182)

This commit is contained in:
Jiaming Yuan
2020-09-30 14:55:59 +08:00
committed by GitHub
parent 444131a2e6
commit f0c63902ff
5 changed files with 33 additions and 53 deletions

View File

@@ -106,16 +106,17 @@ size_t SketchBatchNumElements(size_t sketch_batch_num_elements,
return sketch_batch_num_elements;
}
void SortByWeight(dh::XGBCachingDeviceAllocator<char>* alloc,
dh::caching_device_vector<float>* weights,
dh::caching_device_vector<Entry>* sorted_entries) {
void SortByWeight(dh::device_vector<float>* weights,
dh::device_vector<Entry>* sorted_entries) {
// Sort both entries and wegihts.
thrust::sort_by_key(thrust::cuda::par(*alloc), sorted_entries->begin(),
dh::XGBDeviceAllocator<char> alloc;
thrust::sort_by_key(thrust::cuda::par(alloc), sorted_entries->begin(),
sorted_entries->end(), weights->begin(),
detail::EntryCompareOp());
// Scan weights
thrust::inclusive_scan_by_key(thrust::cuda::par(*alloc),
dh::XGBCachingDeviceAllocator<char> caching;
thrust::inclusive_scan_by_key(thrust::cuda::par(caching),
sorted_entries->begin(), sorted_entries->end(),
weights->begin(), weights->begin(),
[=] __device__(const Entry& a, const Entry& b) {
@@ -216,11 +217,11 @@ void ProcessWeightedBatch(int device, const SparsePage& page,
bool is_ranking, Span<bst_group_t const> d_group_ptr) {
dh::XGBCachingDeviceAllocator<char> alloc;
const auto& host_data = page.data.ConstHostVector();
dh::caching_device_vector<Entry> sorted_entries(host_data.begin() + begin,
host_data.begin() + end);
dh::device_vector<Entry> sorted_entries(host_data.begin() + begin,
host_data.begin() + end);
// Binary search to assign weights to each element
dh::caching_device_vector<float> temp_weights(sorted_entries.size());
dh::device_vector<float> temp_weights(sorted_entries.size());
auto d_temp_weights = temp_weights.data().get();
page.offset.SetDevice(device);
auto row_ptrs = page.offset.ConstDeviceSpan();
@@ -243,7 +244,7 @@ void ProcessWeightedBatch(int device, const SparsePage& page,
d_temp_weights[idx] = weights[ridx + base_rowid];
});
}
detail::SortByWeight(&alloc, &temp_weights, &sorted_entries);
detail::SortByWeight(&temp_weights, &sorted_entries);
HostDeviceVector<SketchContainer::OffsetT> cuts_ptr;
dh::caching_device_vector<size_t> column_sizes_scan;