Use thrust functions instead of custom functions (#5544)

This commit is contained in:
Rory Mitchell
2020-04-16 21:41:16 +12:00
committed by GitHub
parent 6a169cd41a
commit e268fb0093
6 changed files with 82 additions and 306 deletions

View File

@@ -1,6 +1,8 @@
/*!
* Copyright 2017-2019 XGBoost contributors
*/
#include <thrust/iterator/discard_iterator.h>
#include <thrust/iterator/transform_output_iterator.h>
#include <thrust/sequence.h>
#include <vector>
#include "../../common/device_helpers.cuh"
@@ -11,58 +13,74 @@ namespace tree {
struct IndicateLeftTransform {
bst_node_t left_nidx;
explicit IndicateLeftTransform(bst_node_t left_nidx)
: left_nidx(left_nidx) {}
__host__ __device__ __forceinline__ int operator()(const bst_node_t& x) const {
explicit IndicateLeftTransform(bst_node_t left_nidx) : left_nidx(left_nidx) {}
__host__ __device__ __forceinline__ size_t
operator()(const bst_node_t& x) const {
return x == left_nidx ? 1 : 0;
}
};
/*
* position: Position of rows belonged to current split node.
*/
struct IndexFlagTuple {
size_t idx;
size_t flag;
};
struct IndexFlagOp {
__device__ IndexFlagTuple operator()(const IndexFlagTuple& a,
const IndexFlagTuple& b) const {
return {b.idx, a.flag + b.flag};
}
};
struct WriteResultsFunctor {
bst_node_t left_nidx;
common::Span<bst_node_t> position_in;
common::Span<bst_node_t> position_out;
common::Span<RowPartitioner::RowIndexT> ridx_in;
common::Span<RowPartitioner::RowIndexT> ridx_out;
int64_t* d_left_count;
__device__ int operator()(const IndexFlagTuple& x) {
// the ex_scan_result represents how many rows have been assigned to left
// node so far during scan.
int scatter_address;
if (position_in[x.idx] == left_nidx) {
scatter_address = x.flag - 1; // -1 because inclusive scan
} else {
// current number of rows belong to right node + total number of rows
// belong to left node
scatter_address = (x.idx - x.flag) + *d_left_count;
}
// copy the node id to output
position_out[scatter_address] = position_in[x.idx];
ridx_out[scatter_address] = ridx_in[x.idx];
// Discard
return 0;
}
};
void RowPartitioner::SortPosition(common::Span<bst_node_t> position,
common::Span<bst_node_t> position_out,
common::Span<RowIndexT> ridx,
common::Span<RowIndexT> ridx_out,
bst_node_t left_nidx,
bst_node_t right_nidx,
bst_node_t left_nidx, bst_node_t right_nidx,
int64_t* d_left_count, cudaStream_t stream) {
// radix sort over 1 bit, see:
// https://developer.nvidia.com/gpugems/GPUGems3/gpugems3_ch39.html
auto d_position_out = position_out.data();
auto d_position_in = position.data();
auto d_ridx_out = ridx_out.data();
auto d_ridx_in = ridx.data();
auto write_results = [=] __device__(size_t idx, int ex_scan_result) {
// the ex_scan_result represents how many rows have been assigned to left node so far
// during scan.
int scatter_address;
if (d_position_in[idx] == left_nidx) {
scatter_address = ex_scan_result;
} else {
// current number of rows belong to right node + total number of rows belong to left
// node
scatter_address = (idx - ex_scan_result) + *d_left_count;
}
// copy the node id to output
d_position_out[scatter_address] = d_position_in[idx];
d_ridx_out[scatter_address] = d_ridx_in[idx];
}; // NOLINT
IndicateLeftTransform is_left(left_nidx);
// an iterator that given a old position returns whether it belongs to left or right
// node.
cub::TransformInputIterator<bst_node_t, IndicateLeftTransform,
bst_node_t*>
in_itr(d_position_in, is_left);
dh::DiscardLambdaItr<decltype(write_results)> out_itr(write_results);
size_t temp_storage_bytes = 0;
// position is of the same size with current split node's row segment
cub::DeviceScan::ExclusiveSum(nullptr, temp_storage_bytes, in_itr, out_itr,
position.size(), stream);
dh::caching_device_vector<uint8_t> temp_storage(temp_storage_bytes);
cub::DeviceScan::ExclusiveSum(temp_storage.data().get(), temp_storage_bytes,
in_itr, out_itr, position.size(), stream);
WriteResultsFunctor write_results{left_nidx, position, position_out,
ridx, ridx_out, d_left_count};
auto discard_write_iterator = thrust::make_transform_output_iterator(
thrust::discard_iterator<int>(), write_results);
auto input_iterator = dh::MakeTransformIterator<IndexFlagTuple>(
thrust::make_counting_iterator(0llu), [=] __device__(size_t idx) {
return IndexFlagTuple{idx, position[idx] == left_nidx};
});
dh::XGBCachingDeviceAllocator<char> alloc;
thrust::inclusive_scan(thrust::cuda::par(alloc).on(stream), input_iterator,
input_iterator + position.size(),
discard_write_iterator,
[=] __device__(IndexFlagTuple a, IndexFlagTuple b) {
return IndexFlagTuple{b.idx, a.flag + b.flag};
});
}
RowPartitioner::RowPartitioner(int device_idx, size_t num_rows)
@@ -137,7 +155,7 @@ void RowPartitioner::SortPositionAndCopy(const Segment& segment,
SortPosition(
// position_in
common::Span<bst_node_t>(position_.Current() + segment.begin,
segment.Size()),
segment.Size()),
// position_out
common::Span<bst_node_t>(position_.Other() + segment.begin,
segment.Size()),