Use non-synchronising scan (#5560)
This commit is contained in:
parent
d6d1035950
commit
b2827a80e1
@ -40,7 +40,7 @@ struct WriteResultsFunctor {
|
|||||||
common::Span<RowPartitioner::RowIndexT> ridx_out;
|
common::Span<RowPartitioner::RowIndexT> ridx_out;
|
||||||
int64_t* d_left_count;
|
int64_t* d_left_count;
|
||||||
|
|
||||||
__device__ int operator()(const IndexFlagTuple& x) {
|
__device__ IndexFlagTuple operator()(const IndexFlagTuple& x) {
|
||||||
// the ex_scan_result represents how many rows have been assigned to left
|
// the ex_scan_result represents how many rows have been assigned to left
|
||||||
// node so far during scan.
|
// node so far during scan.
|
||||||
int scatter_address;
|
int scatter_address;
|
||||||
@ -56,10 +56,18 @@ struct WriteResultsFunctor {
|
|||||||
ridx_out[scatter_address] = ridx_in[x.idx];
|
ridx_out[scatter_address] = ridx_in[x.idx];
|
||||||
|
|
||||||
// Discard
|
// Discard
|
||||||
return 0;
|
return {};
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Change the value type of thrust discard iterator so we can use it with cub
|
||||||
|
class DiscardOverload : public thrust::discard_iterator<IndexFlagTuple> {
|
||||||
|
public:
|
||||||
|
using value_type = IndexFlagTuple; // NOLINT
|
||||||
|
};
|
||||||
|
|
||||||
|
// Implement partitioning via single scan operation using transform output to
|
||||||
|
// write the result
|
||||||
void RowPartitioner::SortPosition(common::Span<bst_node_t> position,
|
void RowPartitioner::SortPosition(common::Span<bst_node_t> position,
|
||||||
common::Span<bst_node_t> position_out,
|
common::Span<bst_node_t> position_out,
|
||||||
common::Span<RowIndexT> ridx,
|
common::Span<RowIndexT> ridx,
|
||||||
@ -68,19 +76,21 @@ void RowPartitioner::SortPosition(common::Span<bst_node_t> position,
|
|||||||
int64_t* d_left_count, cudaStream_t stream) {
|
int64_t* d_left_count, cudaStream_t stream) {
|
||||||
WriteResultsFunctor write_results{left_nidx, position, position_out,
|
WriteResultsFunctor write_results{left_nidx, position, position_out,
|
||||||
ridx, ridx_out, d_left_count};
|
ridx, ridx_out, d_left_count};
|
||||||
auto discard_write_iterator = thrust::make_transform_output_iterator(
|
auto discard_write_iterator =
|
||||||
thrust::discard_iterator<int>(), write_results);
|
thrust::make_transform_output_iterator(DiscardOverload(), write_results);
|
||||||
|
auto counting = thrust::make_counting_iterator(0llu);
|
||||||
auto input_iterator = dh::MakeTransformIterator<IndexFlagTuple>(
|
auto input_iterator = dh::MakeTransformIterator<IndexFlagTuple>(
|
||||||
thrust::make_counting_iterator(0llu), [=] __device__(size_t idx) {
|
counting, [=] __device__(size_t idx) {
|
||||||
return IndexFlagTuple{idx, position[idx] == left_nidx};
|
return IndexFlagTuple{idx, position[idx] == left_nidx};
|
||||||
});
|
});
|
||||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
size_t temp_bytes = 0;
|
||||||
thrust::inclusive_scan(thrust::cuda::par(alloc).on(stream), input_iterator,
|
cub::DeviceScan::InclusiveScan(nullptr, temp_bytes, input_iterator,
|
||||||
input_iterator + position.size(),
|
discard_write_iterator, IndexFlagOp(),
|
||||||
discard_write_iterator,
|
position.size(), stream);
|
||||||
[=] __device__(IndexFlagTuple a, IndexFlagTuple b) {
|
dh::TemporaryArray<int8_t> temp(temp_bytes);
|
||||||
return IndexFlagTuple{b.idx, a.flag + b.flag};
|
cub::DeviceScan::InclusiveScan(temp.data().get(), temp_bytes, input_iterator,
|
||||||
});
|
discard_write_iterator, IndexFlagOp(),
|
||||||
|
position.size(), stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
RowPartitioner::RowPartitioner(int device_idx, size_t num_rows)
|
RowPartitioner::RowPartitioner(int device_idx, size_t num_rows)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user