Use non-synchronising scan (#5560)

This commit is contained in:
Rory Mitchell 2020-04-20 15:51:34 +12:00 committed by GitHub
parent d6d1035950
commit b2827a80e1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -40,7 +40,7 @@ struct WriteResultsFunctor {
common::Span<RowPartitioner::RowIndexT> ridx_out;
int64_t* d_left_count;
__device__ int operator()(const IndexFlagTuple& x) {
__device__ IndexFlagTuple operator()(const IndexFlagTuple& x) {
// the ex_scan_result represents how many rows have been assigned to left
// node so far during scan.
int scatter_address;
@ -56,10 +56,18 @@ struct WriteResultsFunctor {
ridx_out[scatter_address] = ridx_in[x.idx];
// Discard
return 0;
return {};
}
};
// Change the value type of thrust discard iterator so we can use it with cub
class DiscardOverload : public thrust::discard_iterator<IndexFlagTuple> {
public:
using value_type = IndexFlagTuple; // NOLINT
};
// Implement partitioning via single scan operation using transform output to
// write the result
void RowPartitioner::SortPosition(common::Span<bst_node_t> position,
common::Span<bst_node_t> position_out,
common::Span<RowIndexT> ridx,
@ -68,19 +76,21 @@ void RowPartitioner::SortPosition(common::Span<bst_node_t> position,
int64_t* d_left_count, cudaStream_t stream) {
WriteResultsFunctor write_results{left_nidx, position, position_out,
ridx, ridx_out, d_left_count};
auto discard_write_iterator = thrust::make_transform_output_iterator(
thrust::discard_iterator<int>(), write_results);
auto discard_write_iterator =
thrust::make_transform_output_iterator(DiscardOverload(), write_results);
auto counting = thrust::make_counting_iterator(0llu);
auto input_iterator = dh::MakeTransformIterator<IndexFlagTuple>(
thrust::make_counting_iterator(0llu), [=] __device__(size_t idx) {
counting, [=] __device__(size_t idx) {
return IndexFlagTuple{idx, position[idx] == left_nidx};
});
dh::XGBCachingDeviceAllocator<char> alloc;
thrust::inclusive_scan(thrust::cuda::par(alloc).on(stream), input_iterator,
input_iterator + position.size(),
discard_write_iterator,
[=] __device__(IndexFlagTuple a, IndexFlagTuple b) {
return IndexFlagTuple{b.idx, a.flag + b.flag};
});
size_t temp_bytes = 0;
cub::DeviceScan::InclusiveScan(nullptr, temp_bytes, input_iterator,
discard_write_iterator, IndexFlagOp(),
position.size(), stream);
dh::TemporaryArray<int8_t> temp(temp_bytes);
cub::DeviceScan::InclusiveScan(temp.data().get(), temp_bytes, input_iterator,
discard_write_iterator, IndexFlagOp(),
position.size(), stream);
}
RowPartitioner::RowPartitioner(int device_idx, size_t num_rows)