[back port] Use batched copy if. (#6826) (#6834)

This commit is contained in:
Jiaming Yuan 2021-04-07 04:50:52 +08:00 committed by GitHub
parent f814d4027a
commit 04fedefd4d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 23 additions and 18 deletions

View File

@ -1290,6 +1290,21 @@ void InclusiveScan(InputIteratorT d_in, OutputIteratorT d_out, ScanOpT scan_op,
num_items, nullptr, false)));
}
template <typename InIt, typename OutIt, typename Predicate>
void CopyIf(InIt in_first, InIt in_second, OutIt out_first, Predicate pred) {
// We loop over batches because thrust::copy_if cant deal with sizes > 2^31
// See thrust issue #1302, #6822
size_t max_copy_size = std::numeric_limits<int>::max() / 2;
size_t length = std::distance(in_first, in_second);
XGBCachingDeviceAllocator<char> alloc;
for (size_t offset = 0; offset < length; offset += max_copy_size) {
auto begin_input = in_first + offset;
auto end_input = in_first + std::min(offset + max_copy_size, length);
out_first = thrust::copy_if(thrust::cuda::par(alloc), begin_input,
end_input, out_first, pred);
}
}
template <typename InputIteratorT, typename OutputIteratorT, typename OffsetT>
void InclusiveSum(InputIteratorT d_in, OutputIteratorT d_out, OffsetT num_items) {
InclusiveScan(d_in, d_out, cub::Sum(), num_items);
@ -1311,14 +1326,14 @@ void ArgSort(xgboost::common::Span<U> keys, xgboost::common::Span<IdxT> sorted_i
if (accending) {
void *d_temp_storage = nullptr;
cub::DispatchRadixSort<false, KeyT, ValueT, size_t>::Dispatch(
safe_cuda((cub::DispatchRadixSort<false, KeyT, ValueT, size_t>::Dispatch(
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
sizeof(KeyT) * 8, false, nullptr, false);
sizeof(KeyT) * 8, false, nullptr, false)));
dh::TemporaryArray<char> storage(bytes);
d_temp_storage = storage.data().get();
cub::DispatchRadixSort<false, KeyT, ValueT, size_t>::Dispatch(
safe_cuda((cub::DispatchRadixSort<false, KeyT, ValueT, size_t>::Dispatch(
d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
sizeof(KeyT) * 8, false, nullptr, false);
sizeof(KeyT) * 8, false, nullptr, false)));
} else {
void *d_temp_storage = nullptr;
safe_cuda((cub::DispatchRadixSort<true, KeyT, ValueT, size_t>::Dispatch(

View File

@ -118,9 +118,8 @@ void MakeEntriesFromAdapter(AdapterBatch const& batch, BatchIter batch_iter,
size_t num_valid = column_sizes_scan->back();
// Copy current subset of valid elements into temporary storage and sort
sorted_entries->resize(num_valid);
dh::XGBCachingDeviceAllocator<char> alloc;
thrust::copy_if(thrust::cuda::par(alloc), entry_iter + range.begin(),
entry_iter + range.end(), sorted_entries->begin(), is_valid);
dh::CopyIf(entry_iter + range.begin(), entry_iter + range.end(),
sorted_entries->begin(), is_valid);
}
void SortByWeight(dh::device_vector<float>* weights,

View File

@ -55,18 +55,9 @@ void CopyDataToDMatrix(AdapterT* adapter, common::Span<Entry> data,
COOToEntryOp<decltype(batch)> transform_op{batch};
thrust::transform_iterator<decltype(transform_op), decltype(counting)>
transform_iter(counting, transform_op);
// We loop over batches because thrust::copy_if cant deal with sizes > 2^31
// See thrust issue #1302
size_t max_copy_size = std::numeric_limits<int>::max() / 2;
auto begin_output = thrust::device_pointer_cast(data.data());
for (size_t offset = 0; offset < batch.Size(); offset += max_copy_size) {
auto begin_input = transform_iter + offset;
auto end_input =
transform_iter + std::min(offset + max_copy_size, batch.Size());
begin_output =
thrust::copy_if(thrust::cuda::par(alloc), begin_input, end_input,
begin_output, IsValidFunctor(missing));
}
dh::CopyIf(transform_iter, transform_iter + batch.Size(), begin_output,
IsValidFunctor(missing));
}
// Does not currently support metainfo as no on-device data source contains this