Use non-synchronising scan (#5560)

This commit is contained in:
Rory Mitchell 2020-04-20 15:51:34 +12:00 committed by GitHub
parent d6d1035950
commit b2827a80e1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -40,7 +40,7 @@ struct WriteResultsFunctor {
common::Span<RowPartitioner::RowIndexT> ridx_out; common::Span<RowPartitioner::RowIndexT> ridx_out;
int64_t* d_left_count; int64_t* d_left_count;
__device__ int operator()(const IndexFlagTuple& x) { __device__ IndexFlagTuple operator()(const IndexFlagTuple& x) {
// the ex_scan_result represents how many rows have been assigned to left // the ex_scan_result represents how many rows have been assigned to left
// node so far during scan. // node so far during scan.
int scatter_address; int scatter_address;
@ -56,10 +56,18 @@ struct WriteResultsFunctor {
ridx_out[scatter_address] = ridx_in[x.idx]; ridx_out[scatter_address] = ridx_in[x.idx];
// Discard // Discard
return 0; return {};
} }
}; };
// Change the value type of thrust discard iterator so we can use it with cub
class DiscardOverload : public thrust::discard_iterator<IndexFlagTuple> {
public:
using value_type = IndexFlagTuple; // NOLINT
};
// Implement partitioning via single scan operation using transform output to
// write the result
void RowPartitioner::SortPosition(common::Span<bst_node_t> position, void RowPartitioner::SortPosition(common::Span<bst_node_t> position,
common::Span<bst_node_t> position_out, common::Span<bst_node_t> position_out,
common::Span<RowIndexT> ridx, common::Span<RowIndexT> ridx,
@ -68,19 +76,21 @@ void RowPartitioner::SortPosition(common::Span<bst_node_t> position,
int64_t* d_left_count, cudaStream_t stream) { int64_t* d_left_count, cudaStream_t stream) {
WriteResultsFunctor write_results{left_nidx, position, position_out, WriteResultsFunctor write_results{left_nidx, position, position_out,
ridx, ridx_out, d_left_count}; ridx, ridx_out, d_left_count};
auto discard_write_iterator = thrust::make_transform_output_iterator( auto discard_write_iterator =
thrust::discard_iterator<int>(), write_results); thrust::make_transform_output_iterator(DiscardOverload(), write_results);
auto counting = thrust::make_counting_iterator(0llu);
auto input_iterator = dh::MakeTransformIterator<IndexFlagTuple>( auto input_iterator = dh::MakeTransformIterator<IndexFlagTuple>(
thrust::make_counting_iterator(0llu), [=] __device__(size_t idx) { counting, [=] __device__(size_t idx) {
return IndexFlagTuple{idx, position[idx] == left_nidx}; return IndexFlagTuple{idx, position[idx] == left_nidx};
}); });
dh::XGBCachingDeviceAllocator<char> alloc; size_t temp_bytes = 0;
thrust::inclusive_scan(thrust::cuda::par(alloc).on(stream), input_iterator, cub::DeviceScan::InclusiveScan(nullptr, temp_bytes, input_iterator,
input_iterator + position.size(), discard_write_iterator, IndexFlagOp(),
discard_write_iterator, position.size(), stream);
[=] __device__(IndexFlagTuple a, IndexFlagTuple b) { dh::TemporaryArray<int8_t> temp(temp_bytes);
return IndexFlagTuple{b.idx, a.flag + b.flag}; cub::DeviceScan::InclusiveScan(temp.data().get(), temp_bytes, input_iterator,
}); discard_write_iterator, IndexFlagOp(),
position.size(), stream);
} }
RowPartitioner::RowPartitioner(int device_idx, size_t num_rows) RowPartitioner::RowPartitioner(int device_idx, size_t num_rows)