From b2827a80e112cdb9e28a56da0612810693ea4405 Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Mon, 20 Apr 2020 15:51:34 +1200 Subject: [PATCH] Use non-synchronising scan (#5560) --- src/tree/gpu_hist/row_partitioner.cu | 34 ++++++++++++++++++---------- 1 file changed, 22 insertions(+), 12 deletions(-) diff --git a/src/tree/gpu_hist/row_partitioner.cu b/src/tree/gpu_hist/row_partitioner.cu index e8f55fee2..25a72940c 100644 --- a/src/tree/gpu_hist/row_partitioner.cu +++ b/src/tree/gpu_hist/row_partitioner.cu @@ -40,7 +40,7 @@ struct WriteResultsFunctor { common::Span ridx_out; int64_t* d_left_count; - __device__ int operator()(const IndexFlagTuple& x) { + __device__ IndexFlagTuple operator()(const IndexFlagTuple& x) { // the ex_scan_result represents how many rows have been assigned to left // node so far during scan. int scatter_address; @@ -56,10 +56,18 @@ struct WriteResultsFunctor { ridx_out[scatter_address] = ridx_in[x.idx]; // Discard - return 0; + return {}; } }; +// Change the value type of thrust discard iterator so we can use it with cub +class DiscardOverload : public thrust::discard_iterator { + public: + using value_type = IndexFlagTuple; // NOLINT +}; + +// Implement partitioning via single scan operation using transform output to +// write the result void RowPartitioner::SortPosition(common::Span position, common::Span position_out, common::Span ridx, @@ -68,19 +76,21 @@ void RowPartitioner::SortPosition(common::Span position, int64_t* d_left_count, cudaStream_t stream) { WriteResultsFunctor write_results{left_nidx, position, position_out, ridx, ridx_out, d_left_count}; - auto discard_write_iterator = thrust::make_transform_output_iterator( - thrust::discard_iterator(), write_results); + auto discard_write_iterator = + thrust::make_transform_output_iterator(DiscardOverload(), write_results); + auto counting = thrust::make_counting_iterator(0llu); auto input_iterator = dh::MakeTransformIterator( - thrust::make_counting_iterator(0llu), [=] __device__(size_t idx) { + counting, [=] __device__(size_t idx) { return IndexFlagTuple{idx, position[idx] == left_nidx}; }); - dh::XGBCachingDeviceAllocator alloc; - thrust::inclusive_scan(thrust::cuda::par(alloc).on(stream), input_iterator, - input_iterator + position.size(), - discard_write_iterator, - [=] __device__(IndexFlagTuple a, IndexFlagTuple b) { - return IndexFlagTuple{b.idx, a.flag + b.flag}; - }); + size_t temp_bytes = 0; + cub::DeviceScan::InclusiveScan(nullptr, temp_bytes, input_iterator, + discard_write_iterator, IndexFlagOp(), + position.size(), stream); + dh::TemporaryArray temp(temp_bytes); + cub::DeviceScan::InclusiveScan(temp.data().get(), temp_bytes, input_iterator, + discard_write_iterator, IndexFlagOp(), + position.size(), stream); } RowPartitioner::RowPartitioner(int device_idx, size_t num_rows)