From 04fedefd4d2815e5b3b7ab24054ecb58de24f210 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Wed, 7 Apr 2021 04:50:52 +0800 Subject: [PATCH] [back port] Use batched copy if. (#6826) (#6834) --- src/common/device_helpers.cuh | 23 +++++++++++++++++++---- src/common/hist_util.cuh | 5 ++--- src/data/simple_dmatrix.cu | 13 ++----------- 3 files changed, 23 insertions(+), 18 deletions(-) diff --git a/src/common/device_helpers.cuh b/src/common/device_helpers.cuh index b1ddfdb20..a66711a78 100644 --- a/src/common/device_helpers.cuh +++ b/src/common/device_helpers.cuh @@ -1290,6 +1290,21 @@ void InclusiveScan(InputIteratorT d_in, OutputIteratorT d_out, ScanOpT scan_op, num_items, nullptr, false))); } +template +void CopyIf(InIt in_first, InIt in_second, OutIt out_first, Predicate pred) { + // We loop over batches because thrust::copy_if cant deal with sizes > 2^31 + // See thrust issue #1302, #6822 + size_t max_copy_size = std::numeric_limits::max() / 2; + size_t length = std::distance(in_first, in_second); + XGBCachingDeviceAllocator alloc; + for (size_t offset = 0; offset < length; offset += max_copy_size) { + auto begin_input = in_first + offset; + auto end_input = in_first + std::min(offset + max_copy_size, length); + out_first = thrust::copy_if(thrust::cuda::par(alloc), begin_input, + end_input, out_first, pred); + } +} + template void InclusiveSum(InputIteratorT d_in, OutputIteratorT d_out, OffsetT num_items) { InclusiveScan(d_in, d_out, cub::Sum(), num_items); @@ -1311,14 +1326,14 @@ void ArgSort(xgboost::common::Span keys, xgboost::common::Span sorted_i if (accending) { void *d_temp_storage = nullptr; - cub::DispatchRadixSort::Dispatch( + safe_cuda((cub::DispatchRadixSort::Dispatch( d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0, - sizeof(KeyT) * 8, false, nullptr, false); + sizeof(KeyT) * 8, false, nullptr, false))); dh::TemporaryArray storage(bytes); d_temp_storage = storage.data().get(); - cub::DispatchRadixSort::Dispatch( + safe_cuda((cub::DispatchRadixSort::Dispatch( d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0, - sizeof(KeyT) * 8, false, nullptr, false); + sizeof(KeyT) * 8, false, nullptr, false))); } else { void *d_temp_storage = nullptr; safe_cuda((cub::DispatchRadixSort::Dispatch( diff --git a/src/common/hist_util.cuh b/src/common/hist_util.cuh index 022d92a80..898b198f8 100644 --- a/src/common/hist_util.cuh +++ b/src/common/hist_util.cuh @@ -118,9 +118,8 @@ void MakeEntriesFromAdapter(AdapterBatch const& batch, BatchIter batch_iter, size_t num_valid = column_sizes_scan->back(); // Copy current subset of valid elements into temporary storage and sort sorted_entries->resize(num_valid); - dh::XGBCachingDeviceAllocator alloc; - thrust::copy_if(thrust::cuda::par(alloc), entry_iter + range.begin(), - entry_iter + range.end(), sorted_entries->begin(), is_valid); + dh::CopyIf(entry_iter + range.begin(), entry_iter + range.end(), + sorted_entries->begin(), is_valid); } void SortByWeight(dh::device_vector* weights, diff --git a/src/data/simple_dmatrix.cu b/src/data/simple_dmatrix.cu index 43e75bb21..ff58c6bad 100644 --- a/src/data/simple_dmatrix.cu +++ b/src/data/simple_dmatrix.cu @@ -55,18 +55,9 @@ void CopyDataToDMatrix(AdapterT* adapter, common::Span data, COOToEntryOp transform_op{batch}; thrust::transform_iterator transform_iter(counting, transform_op); - // We loop over batches because thrust::copy_if cant deal with sizes > 2^31 - // See thrust issue #1302 - size_t max_copy_size = std::numeric_limits::max() / 2; auto begin_output = thrust::device_pointer_cast(data.data()); - for (size_t offset = 0; offset < batch.Size(); offset += max_copy_size) { - auto begin_input = transform_iter + offset; - auto end_input = - transform_iter + std::min(offset + max_copy_size, batch.Size()); - begin_output = - thrust::copy_if(thrust::cuda::par(alloc), begin_input, end_input, - begin_output, IsValidFunctor(missing)); - } + dh::CopyIf(transform_iter, transform_iter + batch.Size(), begin_output, + IsValidFunctor(missing)); } // Does not currently support metainfo as no on-device data source contains this