Use non-synchronising scan (#5560)
This commit is contained in:
parent
d6d1035950
commit
b2827a80e1
@ -40,7 +40,7 @@ struct WriteResultsFunctor {
|
||||
common::Span<RowPartitioner::RowIndexT> ridx_out;
|
||||
int64_t* d_left_count;
|
||||
|
||||
__device__ int operator()(const IndexFlagTuple& x) {
|
||||
__device__ IndexFlagTuple operator()(const IndexFlagTuple& x) {
|
||||
// the ex_scan_result represents how many rows have been assigned to left
|
||||
// node so far during scan.
|
||||
int scatter_address;
|
||||
@ -56,10 +56,18 @@ struct WriteResultsFunctor {
|
||||
ridx_out[scatter_address] = ridx_in[x.idx];
|
||||
|
||||
// Discard
|
||||
return 0;
|
||||
return {};
|
||||
}
|
||||
};
|
||||
|
||||
// Change the value type of thrust discard iterator so we can use it with cub
|
||||
class DiscardOverload : public thrust::discard_iterator<IndexFlagTuple> {
|
||||
public:
|
||||
using value_type = IndexFlagTuple; // NOLINT
|
||||
};
|
||||
|
||||
// Implement partitioning via single scan operation using transform output to
|
||||
// write the result
|
||||
void RowPartitioner::SortPosition(common::Span<bst_node_t> position,
|
||||
common::Span<bst_node_t> position_out,
|
||||
common::Span<RowIndexT> ridx,
|
||||
@ -68,19 +76,21 @@ void RowPartitioner::SortPosition(common::Span<bst_node_t> position,
|
||||
int64_t* d_left_count, cudaStream_t stream) {
|
||||
WriteResultsFunctor write_results{left_nidx, position, position_out,
|
||||
ridx, ridx_out, d_left_count};
|
||||
auto discard_write_iterator = thrust::make_transform_output_iterator(
|
||||
thrust::discard_iterator<int>(), write_results);
|
||||
auto discard_write_iterator =
|
||||
thrust::make_transform_output_iterator(DiscardOverload(), write_results);
|
||||
auto counting = thrust::make_counting_iterator(0llu);
|
||||
auto input_iterator = dh::MakeTransformIterator<IndexFlagTuple>(
|
||||
thrust::make_counting_iterator(0llu), [=] __device__(size_t idx) {
|
||||
counting, [=] __device__(size_t idx) {
|
||||
return IndexFlagTuple{idx, position[idx] == left_nidx};
|
||||
});
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
thrust::inclusive_scan(thrust::cuda::par(alloc).on(stream), input_iterator,
|
||||
input_iterator + position.size(),
|
||||
discard_write_iterator,
|
||||
[=] __device__(IndexFlagTuple a, IndexFlagTuple b) {
|
||||
return IndexFlagTuple{b.idx, a.flag + b.flag};
|
||||
});
|
||||
size_t temp_bytes = 0;
|
||||
cub::DeviceScan::InclusiveScan(nullptr, temp_bytes, input_iterator,
|
||||
discard_write_iterator, IndexFlagOp(),
|
||||
position.size(), stream);
|
||||
dh::TemporaryArray<int8_t> temp(temp_bytes);
|
||||
cub::DeviceScan::InclusiveScan(temp.data().get(), temp_bytes, input_iterator,
|
||||
discard_write_iterator, IndexFlagOp(),
|
||||
position.size(), stream);
|
||||
}
|
||||
|
||||
RowPartitioner::RowPartitioner(int device_idx, size_t num_rows)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user