xgboost/src/tree/gpu_hist/row_partitioner.cuh

342 lines
13 KiB
Plaintext

/*!
* Copyright 2017-2022 XGBoost contributors
*/
#pragma once
#include <thrust/execution_policy.h>
#include <limits>
#include <vector>
#include "../../common/device_helpers.cuh"
#include "xgboost/base.h"
#include "xgboost/context.h"
#include "xgboost/task.h"
#include "xgboost/tree_model.h"
namespace xgboost {
namespace tree {
/** \brief Used to demarcate a contiguous set of row indices associated with
* some tree node. */
struct Segment {
bst_uint begin{0};
bst_uint end{0};
Segment() = default;
Segment(bst_uint begin, bst_uint end) : begin(begin), end(end) { CHECK_GE(end, begin); }
__host__ __device__ size_t Size() const { return end - begin; }
};
// TODO(Rory): Can be larger. To be tuned alongside other batch operations.
static const int kMaxUpdatePositionBatchSize = 32;
template <typename OpDataT>
struct PerNodeData {
Segment segment;
OpDataT data;
};
template <typename BatchIterT>
__device__ __forceinline__ void AssignBatch(BatchIterT batch_info, std::size_t global_thread_idx,
int* batch_idx, std::size_t* item_idx) {
bst_uint sum = 0;
for (int i = 0; i < kMaxUpdatePositionBatchSize; i++) {
if (sum + batch_info[i].segment.Size() > global_thread_idx) {
*batch_idx = i;
*item_idx = (global_thread_idx - sum) + batch_info[i].segment.begin;
break;
}
sum += batch_info[i].segment.Size();
}
}
template <int kBlockSize, typename RowIndexT, typename OpDataT>
__global__ __launch_bounds__(kBlockSize) void SortPositionCopyKernel(
dh::LDGIterator<PerNodeData<OpDataT>> batch_info, common::Span<RowIndexT> d_ridx,
const common::Span<const RowIndexT> ridx_tmp, std::size_t total_rows) {
for (auto idx : dh::GridStrideRange<std::size_t>(0, total_rows)) {
int batch_idx;
std::size_t item_idx;
AssignBatch(batch_info, idx, &batch_idx, &item_idx);
d_ridx[item_idx] = ridx_tmp[item_idx];
}
}
// We can scan over this tuple, where the scan gives us information on how to partition inputs
// according to the flag
struct IndexFlagTuple {
bst_uint idx; // The location of the item we are working on in ridx_
bst_uint flag_scan; // This gets populated after scanning
int batch_idx; // Which node in the batch does this item belong to
bool flag; // Result of op (is this item going left?)
};
struct IndexFlagOp {
__device__ IndexFlagTuple operator()(const IndexFlagTuple& a, const IndexFlagTuple& b) const {
// Segmented scan - resets if we cross batch boundaries
if (a.batch_idx == b.batch_idx) {
// Accumulate the flags, everything else stays the same
return {b.idx, a.flag_scan + b.flag_scan, b.batch_idx, b.flag};
} else {
return b;
}
}
};
template <typename OpDataT>
struct WriteResultsFunctor {
dh::LDGIterator<PerNodeData<OpDataT>> batch_info;
const bst_uint* ridx_in;
bst_uint* ridx_out;
bst_uint* counts;
__device__ IndexFlagTuple operator()(const IndexFlagTuple& x) {
std::size_t scatter_address;
const Segment& segment = batch_info[x.batch_idx].segment;
if (x.flag) {
bst_uint num_previous_flagged = x.flag_scan - 1; // -1 because inclusive scan
scatter_address = segment.begin + num_previous_flagged;
} else {
bst_uint num_previous_unflagged = (x.idx - segment.begin) - x.flag_scan;
scatter_address = segment.end - num_previous_unflagged - 1;
}
ridx_out[scatter_address] = ridx_in[x.idx];
if (x.idx == (segment.end - 1)) {
// Write out counts
counts[x.batch_idx] = x.flag_scan;
}
// Discard
return {};
}
};
template <typename RowIndexT, typename OpT, typename OpDataT>
void SortPositionBatch(common::Span<const PerNodeData<OpDataT>> d_batch_info,
common::Span<RowIndexT> ridx, common::Span<RowIndexT> ridx_tmp,
common::Span<bst_uint> d_counts, std::size_t total_rows, OpT op,
dh::device_vector<int8_t>* tmp) {
dh::LDGIterator<PerNodeData<OpDataT>> batch_info_itr(d_batch_info.data());
WriteResultsFunctor<OpDataT> write_results{batch_info_itr, ridx.data(), ridx_tmp.data(),
d_counts.data()};
auto discard_write_iterator =
thrust::make_transform_output_iterator(dh::TypedDiscard<IndexFlagTuple>(), write_results);
auto counting = thrust::make_counting_iterator(0llu);
auto input_iterator =
dh::MakeTransformIterator<IndexFlagTuple>(counting, [=] __device__(size_t idx) {
int batch_idx;
std::size_t item_idx;
AssignBatch(batch_info_itr, idx, &batch_idx, &item_idx);
auto op_res = op(ridx[item_idx], batch_info_itr[batch_idx].data);
return IndexFlagTuple{static_cast<bst_uint>(item_idx), op_res, batch_idx, op_res};
});
size_t temp_bytes = 0;
if (tmp->empty()) {
cub::DeviceScan::InclusiveScan(nullptr, temp_bytes, input_iterator, discard_write_iterator,
IndexFlagOp(), total_rows);
tmp->resize(temp_bytes);
}
temp_bytes = tmp->size();
cub::DeviceScan::InclusiveScan(tmp->data().get(), temp_bytes, input_iterator,
discard_write_iterator, IndexFlagOp(), total_rows);
constexpr int kBlockSize = 256;
// Value found by experimentation
const int kItemsThread = 12;
const int grid_size = xgboost::common::DivRoundUp(total_rows, kBlockSize * kItemsThread);
SortPositionCopyKernel<kBlockSize, RowIndexT, OpDataT>
<<<grid_size, kBlockSize, 0>>>(batch_info_itr, ridx, ridx_tmp, total_rows);
}
struct NodePositionInfo {
Segment segment;
bst_node_t left_child = -1;
bst_node_t right_child = -1;
__device__ bool IsLeaf() { return left_child == -1; }
};
__device__ __forceinline__ int GetPositionFromSegments(std::size_t idx,
const NodePositionInfo* d_node_info) {
int position = 0;
NodePositionInfo node = d_node_info[position];
while (!node.IsLeaf()) {
NodePositionInfo left = d_node_info[node.left_child];
NodePositionInfo right = d_node_info[node.right_child];
if (idx >= left.segment.begin && idx < left.segment.end) {
position = node.left_child;
node = left;
} else if (idx >= right.segment.begin && idx < right.segment.end) {
position = node.right_child;
node = right;
} else {
KERNEL_CHECK(false);
}
}
return position;
}
template <int kBlockSize, typename RowIndexT, typename OpT>
__global__ __launch_bounds__(kBlockSize) void FinalisePositionKernel(
const common::Span<const NodePositionInfo> d_node_info,
const common::Span<const RowIndexT> d_ridx, common::Span<bst_node_t> d_out_position, OpT op) {
for (auto idx : dh::GridStrideRange<std::size_t>(0, d_ridx.size())) {
auto position = GetPositionFromSegments(idx, d_node_info.data());
RowIndexT ridx = d_ridx[idx];
bst_node_t new_position = op(ridx, position);
d_out_position[ridx] = new_position;
}
}
/** \brief Class responsible for tracking subsets of rows as we add splits and
* partition training rows into different leaf nodes. */
class RowPartitioner {
public:
using RowIndexT = bst_uint;
static constexpr bst_node_t kIgnoredTreePosition = -1;
private:
int device_idx_;
/*! \brief In here if you want to find the rows belong to a node nid, first you need to
* get the indices segment from ridx_segments[nid], then get the row index that
* represents position of row in input data X. `RowPartitioner::GetRows` would be a
* good starting place to get a sense what are these vector storing.
*
* node id -> segment -> indices of rows belonging to node
*/
/*! \brief Range of row index for each node, pointers into ridx below. */
std::vector<NodePositionInfo> ridx_segments_;
/*! \brief mapping for node id -> rows.
* This looks like:
* node id | 1 | 2 |
* rows idx | 3, 5, 1 | 13, 31 |
*/
dh::TemporaryArray<RowIndexT> ridx_;
// Staging area for sorting ridx
dh::TemporaryArray<RowIndexT> ridx_tmp_;
dh::device_vector<int8_t> tmp_;
dh::PinnedMemory pinned_;
dh::PinnedMemory pinned2_;
public:
RowPartitioner(int device_idx, size_t num_rows);
~RowPartitioner();
RowPartitioner(const RowPartitioner&) = delete;
RowPartitioner& operator=(const RowPartitioner&) = delete;
/**
* \brief Gets the row indices of training instances in a given node.
*/
common::Span<const RowIndexT> GetRows(bst_node_t nidx);
/**
* \brief Gets all training rows in the set.
*/
common::Span<const RowIndexT> GetRows();
/**
* \brief Convenience method for testing
*/
std::vector<RowIndexT> GetRowsHost(bst_node_t nidx);
/**
* \brief Updates the tree position for set of training instances being split
* into left and right child nodes. Accepts a user-defined lambda specifying
* which branch each training instance should go down.
*
* \tparam UpdatePositionOpT
* \tparam OpDataT
* \param nidx The index of the nodes being split.
* \param left_nidx The left child indices.
* \param right_nidx The right child indices.
* \param op_data User-defined data provided as the second argument to op
* \param op Device lambda with the row index as the first argument and op_data as the
* second. Returns true if this training instance goes on the left partition.
*/
template <typename UpdatePositionOpT, typename OpDataT>
void UpdatePositionBatch(const std::vector<bst_node_t>& nidx,
const std::vector<bst_node_t>& left_nidx,
const std::vector<bst_node_t>& right_nidx,
const std::vector<OpDataT>& op_data, UpdatePositionOpT op) {
if (nidx.empty()) return;
CHECK_EQ(nidx.size(), left_nidx.size());
CHECK_EQ(nidx.size(), right_nidx.size());
CHECK_EQ(nidx.size(), op_data.size());
auto h_batch_info = pinned2_.GetSpan<PerNodeData<OpDataT>>(nidx.size());
dh::TemporaryArray<PerNodeData<OpDataT>> d_batch_info(nidx.size());
std::size_t total_rows = 0;
for (size_t i = 0; i < nidx.size(); i++) {
h_batch_info[i] = {ridx_segments_.at(nidx.at(i)).segment, op_data.at(i)};
total_rows += ridx_segments_.at(nidx.at(i)).segment.Size();
}
dh::safe_cuda(cudaMemcpyAsync(d_batch_info.data().get(), h_batch_info.data(),
h_batch_info.size() * sizeof(PerNodeData<OpDataT>),
cudaMemcpyDefault));
// Temporary arrays
auto h_counts = pinned_.GetSpan<bst_uint>(nidx.size(), 0);
dh::TemporaryArray<bst_uint> d_counts(nidx.size(), 0);
// Partition the rows according to the operator
SortPositionBatch<RowIndexT, UpdatePositionOpT, OpDataT>(
dh::ToSpan(d_batch_info), dh::ToSpan(ridx_), dh::ToSpan(ridx_tmp_), dh::ToSpan(d_counts),
total_rows, op, &tmp_);
dh::safe_cuda(cudaMemcpyAsync(h_counts.data(), d_counts.data().get(), h_counts.size_bytes(),
cudaMemcpyDefault));
// TODO(Rory): this synchronisation hurts performance a lot
// Future optimisation should find a way to skip this
dh::DefaultStream().Sync();
// Update segments
for (size_t i = 0; i < nidx.size(); i++) {
auto segment = ridx_segments_.at(nidx[i]).segment;
auto left_count = h_counts[i];
CHECK_LE(left_count, segment.Size());
ridx_segments_.resize(std::max(static_cast<bst_node_t>(ridx_segments_.size()),
std::max(left_nidx[i], right_nidx[i]) + 1));
ridx_segments_[nidx[i]] = NodePositionInfo{segment, left_nidx[i], right_nidx[i]};
ridx_segments_[left_nidx[i]] =
NodePositionInfo{Segment(segment.begin, segment.begin + left_count)};
ridx_segments_[right_nidx[i]] =
NodePositionInfo{Segment(segment.begin + left_count, segment.end)};
}
}
/**
* \brief Finalise the position of all training instances after tree construction is
* complete. Does not update any other meta information in this data structure, so
* should only be used at the end of training.
*
* When the task requires update leaf, this function will copy the node index into
* p_out_position. The index is negated if it's being sampled in current iteration.
*
* \param p_out_position Node index for each row.
* \param op Device lambda. Should provide the row index and current position as an
* argument and return the new position for this training instance.
* \param sampled A device lambda to inform the partitioner whether a row is sampled.
*/
template <typename FinalisePositionOpT>
void FinalisePosition(common::Span<bst_node_t> d_out_position, FinalisePositionOpT op) {
dh::TemporaryArray<NodePositionInfo> d_node_info_storage(ridx_segments_.size());
dh::safe_cuda(cudaMemcpyAsync(d_node_info_storage.data().get(), ridx_segments_.data(),
sizeof(NodePositionInfo) * ridx_segments_.size(),
cudaMemcpyDefault));
constexpr int kBlockSize = 512;
const int kItemsThread = 8;
const int grid_size = xgboost::common::DivRoundUp(ridx_.size(), kBlockSize * kItemsThread);
common::Span<const RowIndexT> d_ridx(ridx_.data().get(), ridx_.size());
FinalisePositionKernel<kBlockSize><<<grid_size, kBlockSize, 0>>>(
dh::ToSpan(d_node_info_storage), d_ridx, d_out_position, op);
}
};
}; // namespace tree
}; // namespace xgboost