Some comments for row partitioner. (#4832)

This commit is contained in:
Jiaming Yuan 2019-09-06 03:01:42 -04:00 committed by GitHub
parent a5f232feb8
commit f90e7f9aa8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 45 additions and 13 deletions

View File

@ -19,7 +19,9 @@ struct IndicateLeftTransform {
return x == left_nidx ? 1 : 0;
}
};
/*
* position: Position of rows belonged to current split node.
*/
void RowPartitioner::SortPosition(common::Span<TreePositionT> position,
common::Span<TreePositionT> position_out,
common::Span<RowIndexT> ridx,
@ -27,27 +29,37 @@ void RowPartitioner::SortPosition(common::Span<TreePositionT> position,
TreePositionT left_nidx,
TreePositionT right_nidx,
int64_t* d_left_count, cudaStream_t stream) {
// radix sort over 1 bit, see:
// https://developer.nvidia.com/gpugems/GPUGems3/gpugems3_ch39.html
auto d_position_out = position_out.data();
auto d_position_in = position.data();
auto d_ridx_out = ridx_out.data();
auto d_ridx_in = ridx.data();
auto write_results = [=] __device__(size_t idx, int ex_scan_result) {
// the ex_scan_result represents how many rows have been assigned to left node so far
// during scan.
int scatter_address;
if (d_position_in[idx] == left_nidx) {
scatter_address = ex_scan_result;
} else {
// current number of rows belong to right node + total number of rows belong to left
// node
scatter_address = (idx - ex_scan_result) + *d_left_count;
}
// copy the node id to output
d_position_out[scatter_address] = d_position_in[idx];
d_ridx_out[scatter_address] = d_ridx_in[idx];
}; // NOLINT
IndicateLeftTransform conversion_op(left_nidx);
IndicateLeftTransform is_left(left_nidx);
// an iterator that given a old position returns whether it belongs to left or right
// node.
cub::TransformInputIterator<TreePositionT, IndicateLeftTransform,
TreePositionT*>
in_itr(d_position_in, conversion_op);
in_itr(d_position_in, is_left);
dh::DiscardLambdaItr<decltype(write_results)> out_itr(write_results);
size_t temp_storage_bytes = 0;
// position is of the same size with current split node's row segment
cub::DeviceScan::ExclusiveSum(nullptr, temp_storage_bytes, in_itr, out_itr,
position.size(), stream);
dh::caching_device_vector<uint8_t> temp_storage(temp_storage_bytes);
@ -125,11 +137,15 @@ void RowPartitioner::SortPositionAndCopy(const Segment& segment,
int64_t* d_left_count,
cudaStream_t stream) {
SortPosition(
// position_in
common::Span<TreePositionT>(position.Current() + segment.begin,
segment.Size()),
// position_out
common::Span<TreePositionT>(position.other() + segment.begin,
segment.Size()),
// row index in
common::Span<RowIndexT>(ridx.Current() + segment.begin, segment.Size()),
// row index out
common::Span<RowIndexT>(ridx.other() + segment.begin, segment.Size()),
left_nidx, right_nidx, d_left_count, stream);
// Copy back key/value

View File

@ -30,19 +30,32 @@ __forceinline__ __device__ void AtomicIncrement(int64_t* d_count, bool increment
* partition training rows into different leaf nodes. */
class RowPartitioner {
public:
using TreePositionT = int;
using TreePositionT = int32_t;
using RowIndexT = bst_uint;
struct Segment;
private:
int device_idx;
/*! \brief Range of rows for each node. */
/*! \brief In here if you want to find the rows belong to a node nid, first you need to
* get the indices segment from ridx_segments[nid], then get the row index that
* represents position of row in input data X. `RowPartitioner::GetRows` would be a
* good starting place to get a sense what are these vector storing.
*
* node id -> segment -> indices of rows belonging to node
*/
/*! \brief Range of row index for each node, pointers into ridx below. */
std::vector<Segment> ridx_segments;
dh::caching_device_vector<RowIndexT> ridx_a;
dh::caching_device_vector<RowIndexT> ridx_b;
dh::caching_device_vector<TreePositionT> position_a;
dh::caching_device_vector<TreePositionT> position_b;
/*! \brief mapping for node id -> rows.
* This looks like:
* node id | 1 | 2 |
* rows idx | 3, 5, 1 | 13, 31 |
*/
dh::DoubleBuffer<RowIndexT> ridx;
/*! \brief mapping for row -> node id. */
dh::DoubleBuffer<TreePositionT> position;
dh::caching_device_vector<int64_t>
left_counts; // Useful to keep a bunch of zeroed memory for sort position
@ -95,20 +108,22 @@ class RowPartitioner {
void UpdatePosition(TreePositionT nidx, TreePositionT left_nidx,
TreePositionT right_nidx, UpdatePositionOpT op) {
dh::safe_cuda(cudaSetDevice(device_idx));
Segment segment = ridx_segments.at(nidx);
Segment segment = ridx_segments.at(nidx); // rows belongs to node nidx
auto d_ridx = ridx.CurrentSpan();
auto d_position = position.CurrentSpan();
if (left_counts.size() <= nidx) {
left_counts.resize((nidx * 2) + 1);
thrust::fill(left_counts.begin(), left_counts.end(), 0);
}
// Now we divide the row segment into left and right node.
int64_t* d_left_count = left_counts.data().get() + nidx;
// Launch 1 thread for each row
dh::LaunchN<1, 128>(device_idx, segment.Size(), [=] __device__(size_t idx) {
// LaunchN starts from zero, so we restore the row index by adding segment.begin
idx += segment.begin;
RowIndexT ridx = d_ridx[idx];
// Missing value
TreePositionT new_position = op(ridx);
TreePositionT new_position = op(ridx); // new node id
KERNEL_CHECK(new_position == left_nidx || new_position == right_nidx);
AtomicIncrement(d_left_count, new_position == left_nidx);
d_position[idx] = new_position;

View File

@ -152,8 +152,8 @@ struct ELLPackMatrix {
XGBOOST_DEVICE size_t BinCount() const { return gidx_fvalue_map.size(); }
// Get a matrix element, uses binary search for look up
// Return NaN if missing
// Get a matrix element, uses binary search for look up Return NaN if missing
// Given a row index and a feature index, returns the corresponding cut value
__device__ bst_float GetElement(size_t ridx, size_t fidx) const {
auto row_begin = row_stride * ridx;
auto row_end = row_begin + row_stride;
@ -832,14 +832,15 @@ struct DeviceShard {
row_partitioner->UpdatePosition(
nidx, split_node.LeftChild(), split_node.RightChild(),
[=] __device__(bst_uint ridx) {
bst_float element =
// given a row index, returns the node id it belongs to
bst_float cut_value =
d_matrix.GetElement(ridx, split_node.SplitIndex());
// Missing value
int new_position = 0;
if (isnan(element)) {
if (isnan(cut_value)) {
new_position = split_node.DefaultChild();
} else {
if (element <= split_node.SplitCond()) {
if (cut_value <= split_node.SplitCond()) {
new_position = split_node.LeftChild();
} else {
new_position = split_node.RightChild();