Use default allocator in sketching. (#6182)

This commit is contained in:
Jiaming Yuan
2020-09-30 14:55:59 +08:00
committed by GitHub
parent 444131a2e6
commit f0c63902ff
5 changed files with 33 additions and 53 deletions

View File

@@ -102,7 +102,7 @@ void MakeEntriesFromAdapter(AdapterBatch const& batch, BatchIter batch_iter,
size_t columns, size_t cuts_per_feature, int device,
HostDeviceVector<SketchContainer::OffsetT>* cut_sizes_scan,
dh::caching_device_vector<size_t>* column_sizes_scan,
dh::caching_device_vector<Entry>* sorted_entries) {
dh::device_vector<Entry>* sorted_entries) {
auto entry_iter = dh::MakeTransformIterator<Entry>(
thrust::make_counting_iterator(0llu), [=] __device__(size_t idx) {
return Entry(batch.GetElement(idx).column_idx,
@@ -123,9 +123,8 @@ void MakeEntriesFromAdapter(AdapterBatch const& batch, BatchIter batch_iter,
entry_iter + range.end(), sorted_entries->begin(), is_valid);
}
void SortByWeight(dh::XGBCachingDeviceAllocator<char>* alloc,
dh::caching_device_vector<float>* weights,
dh::caching_device_vector<Entry>* sorted_entries);
void SortByWeight(dh::device_vector<float>* weights,
dh::device_vector<Entry>* sorted_entries);
} // namespace detail
// Compute sketch on DMatrix.
@@ -138,7 +137,7 @@ void ProcessSlidingWindow(AdapterBatch const& batch, int device, size_t columns,
size_t begin, size_t end, float missing,
SketchContainer* sketch_container, int num_cuts) {
// Copy current subset of valid elements into temporary storage and sort
dh::caching_device_vector<Entry> sorted_entries;
dh::device_vector<Entry> sorted_entries;
dh::caching_device_vector<size_t> column_sizes_scan;
auto batch_iter = dh::MakeTransformIterator<data::COOTuple>(
thrust::make_counting_iterator(0llu),
@@ -149,7 +148,7 @@ void ProcessSlidingWindow(AdapterBatch const& batch, int device, size_t columns,
&cuts_ptr,
&column_sizes_scan,
&sorted_entries);
dh::XGBCachingDeviceAllocator<char> alloc;
dh::XGBDeviceAllocator<char> alloc;
thrust::sort(thrust::cuda::par(alloc), sorted_entries.begin(),
sorted_entries.end(), detail::EntryCompareOp());
@@ -179,7 +178,7 @@ void ProcessWeightedSlidingWindow(Batch batch, MetaInfo const& info,
auto batch_iter = dh::MakeTransformIterator<data::COOTuple>(
thrust::make_counting_iterator(0llu),
[=] __device__(size_t idx) { return batch.GetElement(idx); });
dh::caching_device_vector<Entry> sorted_entries;
dh::device_vector<Entry> sorted_entries;
dh::caching_device_vector<size_t> column_sizes_scan;
HostDeviceVector<SketchContainer::OffsetT> cuts_ptr;
detail::MakeEntriesFromAdapter(batch, batch_iter,
@@ -190,7 +189,7 @@ void ProcessWeightedSlidingWindow(Batch batch, MetaInfo const& info,
&sorted_entries);
data::IsValidFunctor is_valid(missing);
dh::caching_device_vector<float> temp_weights(sorted_entries.size());
dh::device_vector<float> temp_weights(sorted_entries.size());
auto d_temp_weights = dh::ToSpan(temp_weights);
if (is_ranking) {
@@ -222,7 +221,7 @@ void ProcessWeightedSlidingWindow(Batch batch, MetaInfo const& info,
CHECK_EQ(retit - d_temp_weights.data(), d_temp_weights.size());
}
detail::SortByWeight(&alloc, &temp_weights, &sorted_entries);
detail::SortByWeight(&temp_weights, &sorted_entries);
auto const& h_cuts_ptr = cuts_ptr.ConstHostVector();
auto d_cuts_ptr = cuts_ptr.ConstDeviceSpan();