Use thrust functions instead of custom functions (#5544)
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
/*!
|
||||
* Copyright 2017-2019 XGBoost contributors
|
||||
*/
|
||||
#include <thrust/iterator/discard_iterator.h>
|
||||
#include <thrust/iterator/transform_output_iterator.h>
|
||||
#include <thrust/sequence.h>
|
||||
#include <vector>
|
||||
#include "../../common/device_helpers.cuh"
|
||||
@@ -11,58 +13,74 @@ namespace tree {
|
||||
|
||||
struct IndicateLeftTransform {
|
||||
bst_node_t left_nidx;
|
||||
explicit IndicateLeftTransform(bst_node_t left_nidx)
|
||||
: left_nidx(left_nidx) {}
|
||||
__host__ __device__ __forceinline__ int operator()(const bst_node_t& x) const {
|
||||
explicit IndicateLeftTransform(bst_node_t left_nidx) : left_nidx(left_nidx) {}
|
||||
__host__ __device__ __forceinline__ size_t
|
||||
operator()(const bst_node_t& x) const {
|
||||
return x == left_nidx ? 1 : 0;
|
||||
}
|
||||
};
|
||||
/*
|
||||
* position: Position of rows belonged to current split node.
|
||||
*/
|
||||
|
||||
struct IndexFlagTuple {
|
||||
size_t idx;
|
||||
size_t flag;
|
||||
};
|
||||
|
||||
struct IndexFlagOp {
|
||||
__device__ IndexFlagTuple operator()(const IndexFlagTuple& a,
|
||||
const IndexFlagTuple& b) const {
|
||||
return {b.idx, a.flag + b.flag};
|
||||
}
|
||||
};
|
||||
|
||||
struct WriteResultsFunctor {
|
||||
bst_node_t left_nidx;
|
||||
common::Span<bst_node_t> position_in;
|
||||
common::Span<bst_node_t> position_out;
|
||||
common::Span<RowPartitioner::RowIndexT> ridx_in;
|
||||
common::Span<RowPartitioner::RowIndexT> ridx_out;
|
||||
int64_t* d_left_count;
|
||||
|
||||
__device__ int operator()(const IndexFlagTuple& x) {
|
||||
// the ex_scan_result represents how many rows have been assigned to left
|
||||
// node so far during scan.
|
||||
int scatter_address;
|
||||
if (position_in[x.idx] == left_nidx) {
|
||||
scatter_address = x.flag - 1; // -1 because inclusive scan
|
||||
} else {
|
||||
// current number of rows belong to right node + total number of rows
|
||||
// belong to left node
|
||||
scatter_address = (x.idx - x.flag) + *d_left_count;
|
||||
}
|
||||
// copy the node id to output
|
||||
position_out[scatter_address] = position_in[x.idx];
|
||||
ridx_out[scatter_address] = ridx_in[x.idx];
|
||||
|
||||
// Discard
|
||||
return 0;
|
||||
}
|
||||
};
|
||||
|
||||
void RowPartitioner::SortPosition(common::Span<bst_node_t> position,
|
||||
common::Span<bst_node_t> position_out,
|
||||
common::Span<RowIndexT> ridx,
|
||||
common::Span<RowIndexT> ridx_out,
|
||||
bst_node_t left_nidx,
|
||||
bst_node_t right_nidx,
|
||||
bst_node_t left_nidx, bst_node_t right_nidx,
|
||||
int64_t* d_left_count, cudaStream_t stream) {
|
||||
// radix sort over 1 bit, see:
|
||||
// https://developer.nvidia.com/gpugems/GPUGems3/gpugems3_ch39.html
|
||||
auto d_position_out = position_out.data();
|
||||
auto d_position_in = position.data();
|
||||
auto d_ridx_out = ridx_out.data();
|
||||
auto d_ridx_in = ridx.data();
|
||||
auto write_results = [=] __device__(size_t idx, int ex_scan_result) {
|
||||
// the ex_scan_result represents how many rows have been assigned to left node so far
|
||||
// during scan.
|
||||
int scatter_address;
|
||||
if (d_position_in[idx] == left_nidx) {
|
||||
scatter_address = ex_scan_result;
|
||||
} else {
|
||||
// current number of rows belong to right node + total number of rows belong to left
|
||||
// node
|
||||
scatter_address = (idx - ex_scan_result) + *d_left_count;
|
||||
}
|
||||
// copy the node id to output
|
||||
d_position_out[scatter_address] = d_position_in[idx];
|
||||
d_ridx_out[scatter_address] = d_ridx_in[idx];
|
||||
}; // NOLINT
|
||||
|
||||
IndicateLeftTransform is_left(left_nidx);
|
||||
// an iterator that given a old position returns whether it belongs to left or right
|
||||
// node.
|
||||
cub::TransformInputIterator<bst_node_t, IndicateLeftTransform,
|
||||
bst_node_t*>
|
||||
in_itr(d_position_in, is_left);
|
||||
dh::DiscardLambdaItr<decltype(write_results)> out_itr(write_results);
|
||||
size_t temp_storage_bytes = 0;
|
||||
// position is of the same size with current split node's row segment
|
||||
cub::DeviceScan::ExclusiveSum(nullptr, temp_storage_bytes, in_itr, out_itr,
|
||||
position.size(), stream);
|
||||
dh::caching_device_vector<uint8_t> temp_storage(temp_storage_bytes);
|
||||
cub::DeviceScan::ExclusiveSum(temp_storage.data().get(), temp_storage_bytes,
|
||||
in_itr, out_itr, position.size(), stream);
|
||||
WriteResultsFunctor write_results{left_nidx, position, position_out,
|
||||
ridx, ridx_out, d_left_count};
|
||||
auto discard_write_iterator = thrust::make_transform_output_iterator(
|
||||
thrust::discard_iterator<int>(), write_results);
|
||||
auto input_iterator = dh::MakeTransformIterator<IndexFlagTuple>(
|
||||
thrust::make_counting_iterator(0llu), [=] __device__(size_t idx) {
|
||||
return IndexFlagTuple{idx, position[idx] == left_nidx};
|
||||
});
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
thrust::inclusive_scan(thrust::cuda::par(alloc).on(stream), input_iterator,
|
||||
input_iterator + position.size(),
|
||||
discard_write_iterator,
|
||||
[=] __device__(IndexFlagTuple a, IndexFlagTuple b) {
|
||||
return IndexFlagTuple{b.idx, a.flag + b.flag};
|
||||
});
|
||||
}
|
||||
|
||||
RowPartitioner::RowPartitioner(int device_idx, size_t num_rows)
|
||||
@@ -137,7 +155,7 @@ void RowPartitioner::SortPositionAndCopy(const Segment& segment,
|
||||
SortPosition(
|
||||
// position_in
|
||||
common::Span<bst_node_t>(position_.Current() + segment.begin,
|
||||
segment.Size()),
|
||||
segment.Size()),
|
||||
// position_out
|
||||
common::Span<bst_node_t>(position_.Other() + segment.begin,
|
||||
segment.Size()),
|
||||
|
||||
Reference in New Issue
Block a user