Use batched copy if. (#6826)
This commit is contained in:
@@ -1290,6 +1290,21 @@ void InclusiveScan(InputIteratorT d_in, OutputIteratorT d_out, ScanOpT scan_op,
|
||||
num_items, nullptr, false)));
|
||||
}
|
||||
|
||||
template <typename InIt, typename OutIt, typename Predicate>
|
||||
void CopyIf(InIt in_first, InIt in_second, OutIt out_first, Predicate pred) {
|
||||
// We loop over batches because thrust::copy_if cant deal with sizes > 2^31
|
||||
// See thrust issue #1302, #6822
|
||||
size_t max_copy_size = std::numeric_limits<int>::max() / 2;
|
||||
size_t length = std::distance(in_first, in_second);
|
||||
XGBCachingDeviceAllocator<char> alloc;
|
||||
for (size_t offset = 0; offset < length; offset += max_copy_size) {
|
||||
auto begin_input = in_first + offset;
|
||||
auto end_input = in_first + std::min(offset + max_copy_size, length);
|
||||
out_first = thrust::copy_if(thrust::cuda::par(alloc), begin_input,
|
||||
end_input, out_first, pred);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename InputIteratorT, typename OutputIteratorT, typename OffsetT>
|
||||
void InclusiveSum(InputIteratorT d_in, OutputIteratorT d_out, OffsetT num_items) {
|
||||
InclusiveScan(d_in, d_out, cub::Sum(), num_items);
|
||||
@@ -1311,14 +1326,14 @@ void ArgSort(xgboost::common::Span<U> keys, xgboost::common::Span<IdxT> sorted_i
|
||||
|
||||
if (accending) {
|
||||
void *d_temp_storage = nullptr;
|
||||
cub::DispatchRadixSort<false, KeyT, ValueT, size_t>::Dispatch(
|
||||
safe_cuda((cub::DispatchRadixSort<false, KeyT, ValueT, size_t>::Dispatch(
|
||||
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
|
||||
sizeof(KeyT) * 8, false, nullptr, false);
|
||||
sizeof(KeyT) * 8, false, nullptr, false)));
|
||||
dh::TemporaryArray<char> storage(bytes);
|
||||
d_temp_storage = storage.data().get();
|
||||
cub::DispatchRadixSort<false, KeyT, ValueT, size_t>::Dispatch(
|
||||
safe_cuda((cub::DispatchRadixSort<false, KeyT, ValueT, size_t>::Dispatch(
|
||||
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
|
||||
sizeof(KeyT) * 8, false, nullptr, false);
|
||||
sizeof(KeyT) * 8, false, nullptr, false)));
|
||||
} else {
|
||||
void *d_temp_storage = nullptr;
|
||||
safe_cuda((cub::DispatchRadixSort<true, KeyT, ValueT, size_t>::Dispatch(
|
||||
|
||||
@@ -118,9 +118,8 @@ void MakeEntriesFromAdapter(AdapterBatch const& batch, BatchIter batch_iter,
|
||||
size_t num_valid = column_sizes_scan->back();
|
||||
// Copy current subset of valid elements into temporary storage and sort
|
||||
sorted_entries->resize(num_valid);
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
thrust::copy_if(thrust::cuda::par(alloc), entry_iter + range.begin(),
|
||||
entry_iter + range.end(), sorted_entries->begin(), is_valid);
|
||||
dh::CopyIf(entry_iter + range.begin(), entry_iter + range.end(),
|
||||
sorted_entries->begin(), is_valid);
|
||||
}
|
||||
|
||||
void SortByWeight(dh::device_vector<float>* weights,
|
||||
|
||||
Reference in New Issue
Block a user