Batch UpdatePosition using cudaMemcpy (#7964)

This commit is contained in:
Rory Mitchell 2022-06-30 17:52:40 +02:00 committed by GitHub
parent 2407381c3d
commit bc4f802b17
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 441 additions and 516 deletions

View File

@ -1939,4 +1939,25 @@ class CUDAStream {
CUDAStreamView View() const { return CUDAStreamView{stream_}; }
void Sync() { this->View().Sync(); }
};
// Force nvcc to load data as constant
template <typename T>
class LDGIterator {
using DeviceWordT = typename cub::UnitWord<T>::DeviceWord;
static constexpr std::size_t kNumWords = sizeof(T) / sizeof(DeviceWordT);
const T *ptr_;
public:
explicit LDGIterator(const T *ptr) : ptr_(ptr) {}
__device__ T operator[](std::size_t idx) const {
DeviceWordT tmp[kNumWords];
static_assert(sizeof(tmp) == sizeof(T), "Expect sizes to be equal.");
#pragma unroll
for (int i = 0; i < kNumWords; i++) {
tmp[i] = __ldg(reinterpret_cast<const DeviceWordT *>(ptr_ + idx) + i);
}
return *reinterpret_cast<const T *>(tmp);
}
};
} // namespace dh

View File

@ -1,174 +1,46 @@
/*!
* Copyright 2017-2021 XGBoost contributors
* Copyright 2017-2022 XGBoost contributors
*/
#include <thrust/iterator/discard_iterator.h>
#include <thrust/iterator/transform_output_iterator.h>
#include <thrust/sequence.h>
#include <vector>
#include "../../common/device_helpers.cuh"
#include "row_partitioner.cuh"
namespace xgboost {
namespace tree {
struct IndexFlagTuple {
size_t idx;
size_t flag;
};
struct IndexFlagOp {
__device__ IndexFlagTuple operator()(const IndexFlagTuple& a,
const IndexFlagTuple& b) const {
return {b.idx, a.flag + b.flag};
}
};
struct WriteResultsFunctor {
bst_node_t left_nidx;
common::Span<bst_node_t> position_in;
common::Span<bst_node_t> position_out;
common::Span<RowPartitioner::RowIndexT> ridx_in;
common::Span<RowPartitioner::RowIndexT> ridx_out;
int64_t* d_left_count;
__device__ IndexFlagTuple operator()(const IndexFlagTuple& x) {
// the ex_scan_result represents how many rows have been assigned to left
// node so far during scan.
int scatter_address;
if (position_in[x.idx] == left_nidx) {
scatter_address = x.flag - 1; // -1 because inclusive scan
} else {
// current number of rows belong to right node + total number of rows
// belong to left node
scatter_address = (x.idx - x.flag) + *d_left_count;
}
// copy the node id to output
position_out[scatter_address] = position_in[x.idx];
ridx_out[scatter_address] = ridx_in[x.idx];
// Discard
return {};
}
};
// Implement partitioning via single scan operation using transform output to
// write the result
void RowPartitioner::SortPosition(common::Span<bst_node_t> position,
common::Span<bst_node_t> position_out,
common::Span<RowIndexT> ridx,
common::Span<RowIndexT> ridx_out,
bst_node_t left_nidx, bst_node_t,
int64_t* d_left_count, cudaStream_t stream) {
WriteResultsFunctor write_results{left_nidx, position, position_out,
ridx, ridx_out, d_left_count};
auto discard_write_iterator =
thrust::make_transform_output_iterator(dh::TypedDiscard<IndexFlagTuple>(), write_results);
auto counting = thrust::make_counting_iterator(0llu);
auto input_iterator = dh::MakeTransformIterator<IndexFlagTuple>(
counting, [=] __device__(size_t idx) {
return IndexFlagTuple{idx, static_cast<size_t>(position[idx] == left_nidx)};
});
size_t temp_bytes = 0;
cub::DeviceScan::InclusiveScan(nullptr, temp_bytes, input_iterator,
discard_write_iterator, IndexFlagOp(),
position.size(), stream);
dh::TemporaryArray<int8_t> temp(temp_bytes);
cub::DeviceScan::InclusiveScan(temp.data().get(), temp_bytes, input_iterator,
discard_write_iterator, IndexFlagOp(),
position.size(), stream);
}
void Reset(int device_idx, common::Span<RowPartitioner::RowIndexT> ridx,
common::Span<bst_node_t> position) {
dh::safe_cuda(cudaSetDevice(device_idx));
CHECK_EQ(ridx.size(), position.size());
dh::LaunchN(ridx.size(), [=] __device__(size_t idx) {
ridx[idx] = idx;
position[idx] = 0;
});
}
RowPartitioner::RowPartitioner(int device_idx, size_t num_rows)
: device_idx_(device_idx), ridx_a_(num_rows), position_a_(num_rows),
ridx_b_(num_rows), position_b_(num_rows) {
: device_idx_(device_idx), ridx_(num_rows), ridx_tmp_(num_rows) {
dh::safe_cuda(cudaSetDevice(device_idx_));
ridx_ = dh::DoubleBuffer<RowIndexT>{&ridx_a_, &ridx_b_};
position_ = dh::DoubleBuffer<bst_node_t>{&position_a_, &position_b_};
ridx_segments_.emplace_back(static_cast<size_t>(0), num_rows);
Reset(device_idx, ridx_.CurrentSpan(), position_.CurrentSpan());
left_counts_.resize(256);
thrust::fill(left_counts_.begin(), left_counts_.end(), 0);
streams_.resize(2);
for (auto& stream : streams_) {
dh::safe_cuda(cudaStreamCreate(&stream));
}
ridx_segments_.emplace_back(NodePositionInfo{Segment(0, num_rows)});
thrust::sequence(thrust::device, ridx_.data(), ridx_.data() + ridx_.size());
dh::safe_cuda(cudaStreamCreate(&stream_));
}
RowPartitioner::~RowPartitioner() {
dh::safe_cuda(cudaSetDevice(device_idx_));
for (auto& stream : streams_) {
dh::safe_cuda(cudaStreamDestroy(stream));
}
dh::safe_cuda(cudaStreamDestroy(stream_));
}
common::Span<const RowPartitioner::RowIndexT> RowPartitioner::GetRows(
bst_node_t nidx) {
auto segment = ridx_segments_.at(nidx);
// Return empty span here as a valid result
// Will error if we try to construct a span from a pointer with size 0
if (segment.Size() == 0) {
return {};
}
return ridx_.CurrentSpan().subspan(segment.begin, segment.Size());
common::Span<const RowPartitioner::RowIndexT> RowPartitioner::GetRows(bst_node_t nidx) {
auto segment = ridx_segments_.at(nidx).segment;
return dh::ToSpan(ridx_).subspan(segment.begin, segment.Size());
}
common::Span<const RowPartitioner::RowIndexT> RowPartitioner::GetRows() {
return ridx_.CurrentSpan();
return dh::ToSpan(ridx_);
}
common::Span<const bst_node_t> RowPartitioner::GetPosition() {
return position_.CurrentSpan();
}
std::vector<RowPartitioner::RowIndexT> RowPartitioner::GetRowsHost(
bst_node_t nidx) {
std::vector<RowPartitioner::RowIndexT> RowPartitioner::GetRowsHost(bst_node_t nidx) {
auto span = GetRows(nidx);
std::vector<RowIndexT> rows(span.size());
dh::CopyDeviceSpanToVector(&rows, span);
return rows;
}
std::vector<bst_node_t> RowPartitioner::GetPositionHost() {
auto span = GetPosition();
std::vector<bst_node_t> position(span.size());
dh::CopyDeviceSpanToVector(&position, span);
return position;
}
void RowPartitioner::SortPositionAndCopy(const Segment& segment,
bst_node_t left_nidx,
bst_node_t right_nidx,
int64_t* d_left_count,
cudaStream_t stream) {
SortPosition(
// position_in
common::Span<bst_node_t>(position_.Current() + segment.begin,
segment.Size()),
// position_out
common::Span<bst_node_t>(position_.Other() + segment.begin,
segment.Size()),
// row index in
common::Span<RowIndexT>(ridx_.Current() + segment.begin, segment.Size()),
// row index out
common::Span<RowIndexT>(ridx_.Other() + segment.begin, segment.Size()),
left_nidx, right_nidx, d_left_count, stream);
// Copy back key/value
const auto d_position_current = position_.Current() + segment.begin;
const auto d_position_other = position_.Other() + segment.begin;
const auto d_ridx_current = ridx_.Current() + segment.begin;
const auto d_ridx_other = ridx_.Other() + segment.begin;
dh::LaunchN(segment.Size(), stream, [=] __device__(size_t idx) {
d_position_current[idx] = d_position_other[idx];
d_ridx_current[idx] = d_ridx_other[idx];
});
}
}; // namespace tree
}; // namespace xgboost

View File

@ -2,33 +2,193 @@
* Copyright 2017-2022 XGBoost contributors
*/
#pragma once
#include <thrust/execution_policy.h>
#include <limits>
#include <vector>
#include "xgboost/base.h"
#include "../../common/device_helpers.cuh"
#include "xgboost/base.h"
#include "xgboost/generic_parameters.h"
#include "xgboost/task.h"
#include "xgboost/tree_model.h"
namespace xgboost {
namespace tree {
/*! \brief Count how many rows are assigned to left node. */
__forceinline__ __device__ void AtomicIncrement(int64_t* d_count, bool increment) {
#if __CUDACC_VER_MAJOR__ > 8
int mask = __activemask();
unsigned ballot = __ballot_sync(mask, increment);
int leader = __ffs(mask) - 1;
if (threadIdx.x % 32 == leader) {
atomicAdd(reinterpret_cast<unsigned long long*>(d_count), // NOLINT
static_cast<unsigned long long>(__popc(ballot))); // NOLINT
/** \brief Used to demarcate a contiguous set of row indices associated with
* some tree node. */
struct Segment {
bst_uint begin{0};
bst_uint end{0};
Segment() = default;
Segment(bst_uint begin, bst_uint end) : begin(begin), end(end) { CHECK_GE(end, begin); }
__host__ __device__ size_t Size() const { return end - begin; }
};
// TODO(Rory): Can be larger. To be tuned alongside other batch operations.
static const int kMaxUpdatePositionBatchSize = 32;
template <typename OpDataT>
struct PerNodeData {
Segment segment;
OpDataT data;
};
template <typename BatchIterT>
__device__ __forceinline__ void AssignBatch(BatchIterT batch_info, std::size_t global_thread_idx,
int* batch_idx, std::size_t* item_idx) {
bst_uint sum = 0;
for (int i = 0; i < kMaxUpdatePositionBatchSize; i++) {
if (sum + batch_info[i].segment.Size() > global_thread_idx) {
*batch_idx = i;
*item_idx = (global_thread_idx - sum) + batch_info[i].segment.begin;
break;
}
sum += batch_info[i].segment.Size();
}
#else
unsigned ballot = __ballot(increment);
if (threadIdx.x % 32 == 0) {
atomicAdd(reinterpret_cast<unsigned long long*>(d_count), // NOLINT
static_cast<unsigned long long>(__popc(ballot))); // NOLINT
}
template <int kBlockSize, typename RowIndexT, typename OpDataT>
__global__ __launch_bounds__(kBlockSize) void SortPositionCopyKernel(
dh::LDGIterator<PerNodeData<OpDataT>> batch_info, common::Span<RowIndexT> d_ridx,
const common::Span<const RowIndexT> ridx_tmp, std::size_t total_rows) {
for (auto idx : dh::GridStrideRange<std::size_t>(0, total_rows)) {
int batch_idx;
std::size_t item_idx;
AssignBatch(batch_info, idx, &batch_idx, &item_idx);
d_ridx[item_idx] = ridx_tmp[item_idx];
}
}
// We can scan over this tuple, where the scan gives us information on how to partition inputs
// according to the flag
struct IndexFlagTuple {
bst_uint idx; // The location of the item we are working on in ridx_
bst_uint flag_scan; // This gets populated after scanning
int batch_idx; // Which node in the batch does this item belong to
bool flag; // Result of op (is this item going left?)
};
struct IndexFlagOp {
__device__ IndexFlagTuple operator()(const IndexFlagTuple& a, const IndexFlagTuple& b) const {
// Segmented scan - resets if we cross batch boundaries
if (a.batch_idx == b.batch_idx) {
// Accumulate the flags, everything else stays the same
return {b.idx, a.flag_scan + b.flag_scan, b.batch_idx, b.flag};
} else {
return b;
}
}
};
template <typename OpDataT>
struct WriteResultsFunctor {
dh::LDGIterator<PerNodeData<OpDataT>> batch_info;
const bst_uint* ridx_in;
bst_uint* ridx_out;
bst_uint* counts;
__device__ IndexFlagTuple operator()(const IndexFlagTuple& x) {
std::size_t scatter_address;
const Segment& segment = batch_info[x.batch_idx].segment;
if (x.flag) {
bst_uint num_previous_flagged = x.flag_scan - 1; // -1 because inclusive scan
scatter_address = segment.begin + num_previous_flagged;
} else {
bst_uint num_previous_unflagged = (x.idx - segment.begin) - x.flag_scan;
scatter_address = segment.end - num_previous_unflagged - 1;
}
ridx_out[scatter_address] = ridx_in[x.idx];
if (x.idx == (segment.end - 1)) {
// Write out counts
counts[x.batch_idx] = x.flag_scan;
}
// Discard
return {};
}
};
template <typename RowIndexT, typename OpT, typename OpDataT>
void SortPositionBatch(common::Span<const PerNodeData<OpDataT>> d_batch_info,
common::Span<RowIndexT> ridx, common::Span<RowIndexT> ridx_tmp,
common::Span<bst_uint> d_counts, std::size_t total_rows, OpT op,
dh::device_vector<int8_t>* tmp, cudaStream_t stream) {
dh::LDGIterator<PerNodeData<OpDataT>> batch_info_itr(d_batch_info.data());
WriteResultsFunctor<OpDataT> write_results{batch_info_itr, ridx.data(), ridx_tmp.data(),
d_counts.data()};
auto discard_write_iterator =
thrust::make_transform_output_iterator(dh::TypedDiscard<IndexFlagTuple>(), write_results);
auto counting = thrust::make_counting_iterator(0llu);
auto input_iterator =
dh::MakeTransformIterator<IndexFlagTuple>(counting, [=] __device__(size_t idx) {
int batch_idx;
std::size_t item_idx;
AssignBatch(batch_info_itr, idx, &batch_idx, &item_idx);
auto op_res = op(ridx[item_idx], batch_info_itr[batch_idx].data);
return IndexFlagTuple{bst_uint(item_idx), op_res, batch_idx, op_res};
});
size_t temp_bytes = 0;
if (tmp->empty()) {
cub::DeviceScan::InclusiveScan(nullptr, temp_bytes, input_iterator, discard_write_iterator,
IndexFlagOp(), total_rows, stream);
tmp->resize(temp_bytes);
}
temp_bytes = tmp->size();
cub::DeviceScan::InclusiveScan(tmp->data().get(), temp_bytes, input_iterator,
discard_write_iterator, IndexFlagOp(), total_rows, stream);
constexpr int kBlockSize = 256;
// Value found by experimentation
const int kItemsThread = 12;
const int grid_size = xgboost::common::DivRoundUp(total_rows, kBlockSize * kItemsThread);
SortPositionCopyKernel<kBlockSize, RowIndexT, OpDataT>
<<<grid_size, kBlockSize, 0, stream>>>(batch_info_itr, ridx, ridx_tmp, total_rows);
}
struct NodePositionInfo {
Segment segment;
bst_node_t left_child = -1;
bst_node_t right_child = -1;
__device__ bool IsLeaf() { return left_child == -1; }
};
__device__ __forceinline__ int GetPositionFromSegments(std::size_t idx,
const NodePositionInfo* d_node_info) {
int position = 0;
NodePositionInfo node = d_node_info[position];
while (!node.IsLeaf()) {
NodePositionInfo left = d_node_info[node.left_child];
NodePositionInfo right = d_node_info[node.right_child];
if (idx >= left.segment.begin && idx < left.segment.end) {
position = node.left_child;
node = left;
} else if (idx >= right.segment.begin && idx < right.segment.end) {
position = node.right_child;
node = right;
} else {
KERNEL_CHECK(false);
}
}
return position;
}
template <int kBlockSize, typename RowIndexT, typename OpT>
__global__ __launch_bounds__(kBlockSize) void FinalisePositionKernel(
const common::Span<const NodePositionInfo> d_node_info,
const common::Span<const RowIndexT> d_ridx, common::Span<bst_node_t> d_out_position, OpT op) {
for (auto idx : dh::GridStrideRange<std::size_t>(0, d_ridx.size())) {
auto position = GetPositionFromSegments(idx, d_node_info.data());
RowIndexT ridx = d_ridx[idx];
bst_node_t new_position = op(ridx, position);
d_out_position[ridx] = new_position;
}
#endif
}
/** \brief Class responsible for tracking subsets of rows as we add splits and
@ -36,7 +196,6 @@ __forceinline__ __device__ void AtomicIncrement(int64_t* d_count, bool increment
class RowPartitioner {
public:
using RowIndexT = bst_uint;
struct Segment;
static constexpr bst_node_t kIgnoredTreePosition = -1;
private:
@ -49,23 +208,20 @@ class RowPartitioner {
* node id -> segment -> indices of rows belonging to node
*/
/*! \brief Range of row index for each node, pointers into ridx below. */
std::vector<Segment> ridx_segments_;
dh::TemporaryArray<RowIndexT> ridx_a_;
dh::TemporaryArray<RowIndexT> ridx_b_;
dh::TemporaryArray<bst_node_t> position_a_;
dh::TemporaryArray<bst_node_t> position_b_;
std::vector<NodePositionInfo> ridx_segments_;
/*! \brief mapping for node id -> rows.
* This looks like:
* node id | 1 | 2 |
* rows idx | 3, 5, 1 | 13, 31 |
*/
dh::DoubleBuffer<RowIndexT> ridx_;
/*! \brief mapping for row -> node id. */
dh::DoubleBuffer<bst_node_t> position_;
dh::caching_device_vector<int64_t>
left_counts_; // Useful to keep a bunch of zeroed memory for sort position
std::vector<cudaStream_t> streams_;
dh::TemporaryArray<RowIndexT> ridx_;
// Staging area for sorting ridx
dh::TemporaryArray<RowIndexT> ridx_tmp_;
dh::device_vector<int8_t> tmp_;
dh::PinnedMemory pinned_;
dh::PinnedMemory pinned2_;
cudaStream_t stream_;
public:
RowPartitioner(int device_idx, size_t num_rows);
@ -83,73 +239,74 @@ class RowPartitioner {
*/
common::Span<const RowIndexT> GetRows();
/**
* \brief Gets the tree position of all training instances.
*/
common::Span<const bst_node_t> GetPosition();
/**
* \brief Convenience method for testing
*/
std::vector<RowIndexT> GetRowsHost(bst_node_t nidx);
/**
* \brief Convenience method for testing
*/
std::vector<bst_node_t> GetPositionHost();
/**
* \brief Updates the tree position for set of training instances being split
* into left and right child nodes. Accepts a user-defined lambda specifying
* which branch each training instance should go down.
*
* \tparam UpdatePositionOpT
* \param nidx The index of the node being split.
* \param left_nidx The left child index.
* \param right_nidx The right child index.
* \param op Device lambda. Should provide the row index as an
* argument and return the new position for this training instance.
* \tparam OpDataT
* \param nidx The index of the nodes being split.
* \param left_nidx The left child indices.
* \param right_nidx The right child indices.
* \param op_data User-defined data provided as the second argument to op
* \param op Device lambda with the row index as the first argument and op_data as the
* second. Returns true if this training instance goes on the left partition.
*/
template <typename UpdatePositionOpT>
void UpdatePosition(bst_node_t nidx, bst_node_t left_nidx,
bst_node_t right_nidx, UpdatePositionOpT op) {
Segment segment = ridx_segments_.at(nidx); // rows belongs to node nidx
auto d_ridx = ridx_.CurrentSpan();
auto d_position = position_.CurrentSpan();
if (left_counts_.size() <= static_cast<size_t>(nidx)) {
left_counts_.resize((nidx * 2) + 1);
thrust::fill(left_counts_.begin(), left_counts_.end(), 0);
template <typename UpdatePositionOpT, typename OpDataT>
void UpdatePositionBatch(const std::vector<bst_node_t>& nidx,
const std::vector<bst_node_t>& left_nidx,
const std::vector<bst_node_t>& right_nidx,
const std::vector<OpDataT>& op_data, UpdatePositionOpT op) {
if (nidx.empty()) return;
CHECK_EQ(nidx.size(), left_nidx.size());
CHECK_EQ(nidx.size(), right_nidx.size());
CHECK_EQ(nidx.size(), op_data.size());
auto h_batch_info = pinned2_.GetSpan<PerNodeData<OpDataT>>(nidx.size());
dh::TemporaryArray<PerNodeData<OpDataT>> d_batch_info(nidx.size());
std::size_t total_rows = 0;
for (int i = 0; i < nidx.size(); i++) {
h_batch_info[i] = {ridx_segments_.at(nidx.at(i)).segment, op_data.at(i)};
total_rows += ridx_segments_.at(nidx.at(i)).segment.Size();
}
// Now we divide the row segment into left and right node.
dh::safe_cuda(cudaMemcpyAsync(d_batch_info.data().get(), h_batch_info.data(),
h_batch_info.size() * sizeof(PerNodeData<OpDataT>),
cudaMemcpyDefault, stream_));
int64_t* d_left_count = left_counts_.data().get() + nidx;
// Launch 1 thread for each row
dh::LaunchN<1, 128>(segment.Size(), [segment, op, left_nidx, right_nidx, d_ridx, d_left_count,
d_position] __device__(size_t idx) {
// LaunchN starts from zero, so we restore the row index by adding segment.begin
idx += segment.begin;
RowIndexT ridx = d_ridx[idx];
bst_node_t new_position = op(ridx); // new node id
KERNEL_CHECK(new_position == left_nidx || new_position == right_nidx);
AtomicIncrement(d_left_count, new_position == left_nidx);
d_position[idx] = new_position;
});
// Overlap device to host memory copy (left_count) with sort
int64_t &left_count = pinned_.GetSpan<int64_t>(1)[0];
dh::safe_cuda(cudaMemcpyAsync(&left_count, d_left_count, sizeof(int64_t),
cudaMemcpyDeviceToHost, streams_[0]));
// Temporary arrays
auto h_counts = pinned_.GetSpan<bst_uint>(nidx.size(), 0);
dh::TemporaryArray<bst_uint> d_counts(nidx.size(), 0);
SortPositionAndCopy(segment, left_nidx, right_nidx, d_left_count, streams_[1]);
// Partition the rows according to the operator
SortPositionBatch<RowIndexT, UpdatePositionOpT, OpDataT>(
dh::ToSpan(d_batch_info), dh::ToSpan(ridx_), dh::ToSpan(ridx_tmp_), dh::ToSpan(d_counts),
total_rows, op, &tmp_, stream_);
dh::safe_cuda(cudaMemcpyAsync(h_counts.data(), d_counts.data().get(), h_counts.size_bytes(),
cudaMemcpyDefault, stream_));
// TODO(Rory): this synchronisation hurts performance a lot
// Future optimisation should find a way to skip this
dh::safe_cuda(cudaStreamSynchronize(stream_));
dh::safe_cuda(cudaStreamSynchronize(streams_[0]));
CHECK_LE(left_count, segment.Size());
CHECK_GE(left_count, 0);
ridx_segments_.resize(std::max(static_cast<bst_node_t>(ridx_segments_.size()),
std::max(left_nidx, right_nidx) + 1));
ridx_segments_[left_nidx] =
Segment(segment.begin, segment.begin + left_count);
ridx_segments_[right_nidx] =
Segment(segment.begin + left_count, segment.end);
// Update segments
for (int i = 0; i < nidx.size(); i++) {
auto segment = ridx_segments_.at(nidx[i]).segment;
auto left_count = h_counts[i];
CHECK_LE(left_count, segment.Size());
ridx_segments_.resize(std::max(static_cast<bst_node_t>(ridx_segments_.size()),
std::max(left_nidx[i], right_nidx[i]) + 1));
ridx_segments_[nidx[i]] = NodePositionInfo{segment, left_nidx[i], right_nidx[i]};
ridx_segments_[left_nidx[i]] =
NodePositionInfo{Segment(segment.begin, segment.begin + left_count)};
ridx_segments_[right_nidx[i]] =
NodePositionInfo{Segment(segment.begin + left_count, segment.end)};
}
}
/**
@ -165,69 +322,21 @@ class RowPartitioner {
* argument and return the new position for this training instance.
* \param sampled A device lambda to inform the partitioner whether a row is sampled.
*/
template <typename FinalisePositionOpT, typename Sampledp>
void FinalisePosition(Context const* ctx, ObjInfo task,
HostDeviceVector<bst_node_t>* p_out_position, FinalisePositionOpT op,
Sampledp sampledp) {
auto d_position = position_.Current();
const auto d_ridx = ridx_.Current();
if (!task.UpdateTreeLeaf()) {
dh::LaunchN(position_.Size(), [=] __device__(size_t idx) {
auto position = d_position[idx];
RowIndexT ridx = d_ridx[idx];
bst_node_t new_position = op(ridx, position);
if (new_position == kIgnoredTreePosition) {
return;
}
d_position[idx] = new_position;
});
return;
}
template <typename FinalisePositionOpT>
void FinalisePosition(common::Span<bst_node_t> d_out_position, FinalisePositionOpT op) {
dh::TemporaryArray<NodePositionInfo> d_node_info_storage(ridx_segments_.size());
dh::safe_cuda(cudaMemcpyAsync(d_node_info_storage.data().get(), ridx_segments_.data(),
sizeof(NodePositionInfo) * ridx_segments_.size(),
cudaMemcpyDefault, stream_));
p_out_position->SetDevice(ctx->gpu_id);
p_out_position->Resize(position_.Size());
auto sorted_position = p_out_position->DevicePointer();
dh::LaunchN(position_.Size(), [=] __device__(size_t idx) {
auto position = d_position[idx];
RowIndexT ridx = d_ridx[idx];
bst_node_t new_position = op(ridx, position);
sorted_position[ridx] = sampledp(ridx) ? ~new_position : new_position;
if (new_position == kIgnoredTreePosition) {
return;
}
d_position[idx] = new_position;
});
constexpr int kBlockSize = 512;
const int kItemsThread = 8;
const int grid_size = xgboost::common::DivRoundUp(ridx_.size(), kBlockSize * kItemsThread);
common::Span<const RowIndexT> d_ridx(ridx_.data().get(), ridx_.size());
FinalisePositionKernel<kBlockSize><<<grid_size, kBlockSize, 0, stream_>>>(
dh::ToSpan(d_node_info_storage), d_ridx, d_out_position, op);
}
/**
* \brief Optimised routine for sorting key value pairs into left and right
* segments. Based on a single pass of exclusive scan, uses iterators to
* redirect inputs and outputs.
*/
void SortPosition(common::Span<bst_node_t> position,
common::Span<bst_node_t> position_out,
common::Span<RowIndexT> ridx,
common::Span<RowIndexT> ridx_out, bst_node_t left_nidx,
bst_node_t right_nidx, int64_t* d_left_count,
cudaStream_t stream = nullptr);
/*! \brief Sort row indices according to position. */
void SortPositionAndCopy(const Segment& segment, bst_node_t left_nidx,
bst_node_t right_nidx, int64_t* d_left_count,
cudaStream_t stream);
/** \brief Used to demarcate a contiguous set of row indices associated with
* some tree node. */
struct Segment {
size_t begin { 0 };
size_t end { 0 };
Segment() = default;
Segment(size_t begin, size_t end) : begin(begin), end(end) {
CHECK_GE(end, begin);
}
size_t Size() const { return end - begin; }
};
};
}; // namespace tree
}; // namespace xgboost

View File

@ -182,10 +182,11 @@ struct GPUHistMakerDevice {
std::unique_ptr<RowPartitioner> row_partitioner;
DeviceHistogramStorage<GradientSumT> hist{};
dh::caching_device_vector<GradientPair> d_gpair; // storage for gpair;
dh::device_vector<GradientPair> d_gpair; // storage for gpair;
common::Span<GradientPair> gpair;
dh::caching_device_vector<int> monotone_constraints;
dh::device_vector<int> monotone_constraints;
dh::device_vector<float> update_predictions;
/*! \brief Sum gradient for each node. */
std::vector<GradientPairPrecise> node_sum_gradients;
@ -356,36 +357,49 @@ struct GPUHistMakerDevice {
return true;
}
void UpdatePosition(const GPUExpandEntry &e, RegTree* p_tree) {
RegTree::Node split_node = (*p_tree)[e.nid];
auto split_type = p_tree->NodeSplitType(e.nid);
auto d_matrix = page->GetDeviceAccessor(ctx_->gpu_id);
auto node_cats = e.split.split_cats.Bits();
// Extra data for each node that is passed
// to the update position function
struct NodeSplitData {
RegTree::Node split_node;
FeatureType split_type;
common::CatBitField node_cats;
};
row_partitioner->UpdatePosition(
e.nid, split_node.LeftChild(), split_node.RightChild(),
[=] __device__(bst_uint ridx) {
void UpdatePosition(const std::vector<GPUExpandEntry>& candidates, RegTree* p_tree) {
if (candidates.empty()) return;
std::vector<int> nidx(candidates.size());
std::vector<int> left_nidx(candidates.size());
std::vector<int> right_nidx(candidates.size());
std::vector<NodeSplitData> split_data(candidates.size());
for (int i = 0; i < candidates.size(); i++) {
auto& e = candidates[i];
RegTree::Node split_node = (*p_tree)[e.nid];
auto split_type = p_tree->NodeSplitType(e.nid);
nidx.at(i) = e.nid;
left_nidx.at(i) = split_node.LeftChild();
right_nidx.at(i) = split_node.RightChild();
split_data.at(i) = NodeSplitData{split_node, split_type, e.split.split_cats};
}
auto d_matrix = page->GetDeviceAccessor(ctx_->gpu_id);
row_partitioner->UpdatePositionBatch(
nidx, left_nidx, right_nidx, split_data,
[=] __device__(bst_uint ridx, const NodeSplitData& data) {
// given a row index, returns the node id it belongs to
bst_float cut_value =
d_matrix.GetFvalue(ridx, split_node.SplitIndex());
bst_float cut_value = d_matrix.GetFvalue(ridx, data.split_node.SplitIndex());
// Missing value
bst_node_t new_position = 0;
bool go_left = true;
if (isnan(cut_value)) {
new_position = split_node.DefaultChild();
go_left = data.split_node.DefaultLeft();
} else {
bool go_left = true;
if (split_type == FeatureType::kCategorical) {
go_left = common::Decision<false>(node_cats, cut_value, split_node.DefaultLeft());
if (data.split_type == FeatureType::kCategorical) {
go_left = common::Decision<false>(data.node_cats.Bits(), cut_value,
data.split_node.DefaultLeft());
} else {
go_left = cut_value <= split_node.SplitCond();
}
if (go_left) {
new_position = split_node.LeftChild();
} else {
new_position = split_node.RightChild();
go_left = cut_value <= data.split_node.SplitCond();
}
}
return new_position;
return go_left;
});
}
@ -394,6 +408,16 @@ struct GPUHistMakerDevice {
// prediction cache
void FinalisePosition(RegTree const* p_tree, DMatrix* p_fmat, ObjInfo task,
HostDeviceVector<bst_node_t>* p_out_position) {
// Prediction cache will not be used with external memory
if (!p_fmat->SingleColBlock()) {
if (task.UpdateTreeLeaf()) {
LOG(FATAL) << "Current objective function can not be used with external memory.";
}
p_out_position->Resize(0);
update_predictions.clear();
return;
}
dh::TemporaryArray<RegTree::Node> d_nodes(p_tree->GetNodes().size());
dh::safe_cuda(cudaMemcpyAsync(d_nodes.data().get(), p_tree->GetNodes().data(),
d_nodes.size() * sizeof(RegTree::Node),
@ -412,25 +436,9 @@ struct GPUHistMakerDevice {
dh::CopyToD(categories_segments, &d_categories_segments);
}
if (row_partitioner->GetRows().size() != p_fmat->Info().num_row_) {
row_partitioner.reset(); // Release the device memory first before reallocating
row_partitioner.reset(new RowPartitioner(ctx_->gpu_id, p_fmat->Info().num_row_));
}
if (task.UpdateTreeLeaf() && !p_fmat->SingleColBlock() && param.subsample != 1.0) {
// see comment in the `FinalisePositionInPage`.
LOG(FATAL) << "Current objective function can not be used with subsampled external memory.";
}
if (page->n_rows == p_fmat->Info().num_row_) {
FinalisePositionInPage(page, dh::ToSpan(d_nodes), dh::ToSpan(d_split_types),
dh::ToSpan(d_categories), dh::ToSpan(d_categories_segments), task,
p_out_position);
} else {
for (auto const& batch : p_fmat->GetBatches<EllpackPage>(batch_param)) {
FinalisePositionInPage(batch.Impl(), dh::ToSpan(d_nodes), dh::ToSpan(d_split_types),
dh::ToSpan(d_categories), dh::ToSpan(d_categories_segments), task,
p_out_position);
}
}
FinalisePositionInPage(page, dh::ToSpan(d_nodes), dh::ToSpan(d_split_types),
dh::ToSpan(d_categories), dh::ToSpan(d_categories_segments),
p_out_position);
}
void FinalisePositionInPage(EllpackPageImpl const *page,
@ -438,79 +446,73 @@ struct GPUHistMakerDevice {
common::Span<FeatureType const> d_feature_types,
common::Span<uint32_t const> categories,
common::Span<RegTree::Segment> categories_segments,
ObjInfo task,
HostDeviceVector<bst_node_t>* p_out_position) {
auto d_matrix = page->GetDeviceAccessor(ctx_->gpu_id);
auto d_gpair = this->gpair;
row_partitioner->FinalisePosition(
ctx_, task, p_out_position,
[=] __device__(size_t row_id, int position) {
// What happens if user prune the tree?
if (!d_matrix.IsInRange(row_id)) {
return RowPartitioner::kIgnoredTreePosition;
}
auto node = d_nodes[position];
update_predictions.resize(row_partitioner->GetRows().size());
auto d_update_predictions = dh::ToSpan(update_predictions);
p_out_position->SetDevice(ctx_->gpu_id);
p_out_position->Resize(row_partitioner->GetRows().size());
while (!node.IsLeaf()) {
bst_float element = d_matrix.GetFvalue(row_id, node.SplitIndex());
// Missing value
if (isnan(element)) {
position = node.DefaultChild();
} else {
bool go_left = true;
if (common::IsCat(d_feature_types, position)) {
auto node_cats =
categories.subspan(categories_segments[position].beg,
categories_segments[position].size);
go_left = common::Decision<false>(node_cats, element, node.DefaultLeft());
} else {
go_left = element <= node.SplitCond();
}
if (go_left) {
position = node.LeftChild();
} else {
position = node.RightChild();
}
}
node = d_nodes[position];
}
auto new_position_op = [=] __device__(size_t row_id, int position) {
// What happens if user prune the tree?
if (!d_matrix.IsInRange(row_id)) {
return RowPartitioner::kIgnoredTreePosition;
}
auto node = d_nodes[position];
return position;
},
[d_gpair] __device__(size_t ridx) {
// FIXME(jiamingy): Doesn't work when sampling is used with external memory as
// the sampler compacts the gradient vector.
return d_gpair[ridx].GetHess() - .0f == 0.f;
});
while (!node.IsLeaf()) {
bst_float element = d_matrix.GetFvalue(row_id, node.SplitIndex());
// Missing value
if (isnan(element)) {
position = node.DefaultChild();
} else {
bool go_left = true;
if (common::IsCat(d_feature_types, position)) {
auto node_cats = categories.subspan(categories_segments[position].beg,
categories_segments[position].size);
go_left = common::Decision<false>(node_cats, element, node.DefaultLeft());
} else {
go_left = element <= node.SplitCond();
}
if (go_left) {
position = node.LeftChild();
} else {
position = node.RightChild();
}
}
node = d_nodes[position];
}
d_update_predictions[row_id] = node.LeafValue();
return position;
}; // NOLINT
auto d_out_position = p_out_position->DeviceSpan();
row_partitioner->FinalisePosition(d_out_position, new_position_op);
dh::LaunchN(row_partitioner->GetRows().size(), [=] __device__(size_t idx) {
bst_node_t position = d_out_position[idx];
d_update_predictions[idx] = d_nodes[position].LeafValue();
bool is_row_sampled = d_gpair[idx].GetHess() - .0f == 0.f;
d_out_position[idx] = is_row_sampled ? ~position : position;
});
}
void UpdatePredictionCache(linalg::VectorView<float> out_preds_d, RegTree const* p_tree) {
bool UpdatePredictionCache(linalg::VectorView<float> out_preds_d, RegTree const* p_tree) {
if (update_predictions.empty()) {
return false;
}
CHECK(p_tree);
dh::safe_cuda(cudaSetDevice(ctx_->gpu_id));
CHECK_EQ(out_preds_d.DeviceIdx(), ctx_->gpu_id);
auto d_ridx = row_partitioner->GetRows();
GPUTrainingParam param_d(param);
dh::TemporaryArray<GradientPairPrecise> device_node_sum_gradients(node_sum_gradients.size());
dh::safe_cuda(cudaMemcpyAsync(device_node_sum_gradients.data().get(), node_sum_gradients.data(),
sizeof(GradientPairPrecise) * node_sum_gradients.size(),
cudaMemcpyHostToDevice));
auto d_position = row_partitioner->GetPosition();
auto d_node_sum_gradients = device_node_sum_gradients.data().get();
auto tree_evaluator = evaluator_.GetEvaluator();
auto const& h_nodes = p_tree->GetNodes();
dh::caching_device_vector<RegTree::Node> nodes(h_nodes.size());
dh::safe_cuda(cudaMemcpyAsync(nodes.data().get(), h_nodes.data(),
h_nodes.size() * sizeof(RegTree::Node), cudaMemcpyHostToDevice));
auto d_nodes = dh::ToSpan(nodes);
dh::LaunchN(d_ridx.size(), [=] XGBOOST_DEVICE(size_t idx) mutable {
bst_node_t nidx = d_position[idx];
auto weight = d_nodes[nidx].LeafValue();
out_preds_d(d_ridx[idx]) += weight;
auto d_update_predictions = dh::ToSpan(update_predictions);
CHECK_EQ(out_preds_d.Size(), d_update_predictions.size());
dh::LaunchN(out_preds_d.Size(), [=] XGBOOST_DEVICE(size_t idx) mutable {
out_preds_d(idx) += d_update_predictions[idx];
});
row_partitioner.reset();
return true;
}
// num histograms is the number of contiguous histograms in memory to reduce over
@ -684,14 +686,12 @@ struct GPUHistMakerDevice {
auto new_candidates =
pinned.GetSpan<GPUExpandEntry>(filtered_expand_set.size() * 2, GPUExpandEntry());
for (const auto& e : filtered_expand_set) {
monitor.Start("UpdatePosition");
// Update position is only run when child is valid, instead of right after apply
// split (as in approx tree method). Hense we have the finalise position call
// in GPU Hist.
this->UpdatePosition(e, p_tree);
monitor.Stop("UpdatePosition");
}
monitor.Start("UpdatePosition");
// Update position is only run when child is valid, instead of right after apply
// split (as in approx tree method). Hense we have the finalise position call
// in GPU Hist.
this->UpdatePosition(filtered_expand_set, p_tree);
monitor.Stop("UpdatePosition");
monitor.Start("BuildHist");
this->BuildHistLeftRight(filtered_expand_set, reducer, tree);
@ -844,9 +844,9 @@ class GPUHistMaker : public TreeUpdater {
return false;
}
monitor_.Start("UpdatePredictionCache");
maker->UpdatePredictionCache(p_out_preds, p_last_tree_);
bool result = maker->UpdatePredictionCache(p_out_preds, p_last_tree_);
monitor_.Stop("UpdatePredictionCache");
return true;
return result;
}
TrainParam param_; // NOLINT

View File

@ -19,49 +19,7 @@
namespace xgboost {
namespace tree {
void TestSortPosition(const std::vector<int>& position_in, int left_idx,
int right_idx) {
dh::safe_cuda(cudaSetDevice(0));
std::vector<int64_t> left_count = {
std::count(position_in.begin(), position_in.end(), left_idx)};
dh::caching_device_vector<int64_t> d_left_count = left_count;
dh::caching_device_vector<int> position = position_in;
dh::caching_device_vector<int> position_out(position.size());
dh::caching_device_vector<RowPartitioner::RowIndexT> ridx(position.size());
thrust::sequence(ridx.begin(), ridx.end());
dh::caching_device_vector<RowPartitioner::RowIndexT> ridx_out(ridx.size());
RowPartitioner rp(0,10);
rp.SortPosition(
common::Span<int>(position.data().get(), position.size()),
common::Span<int>(position_out.data().get(), position_out.size()),
common::Span<RowPartitioner::RowIndexT>(ridx.data().get(), ridx.size()),
common::Span<RowPartitioner::RowIndexT>(ridx_out.data().get(), ridx_out.size()), left_idx,
right_idx, d_left_count.data().get(), nullptr);
thrust::host_vector<int> position_result = position_out;
thrust::host_vector<int> ridx_result = ridx_out;
// Check position is sorted
EXPECT_TRUE(std::is_sorted(position_result.begin(), position_result.end()));
// Check row indices are sorted inside left and right segment
EXPECT_TRUE(
std::is_sorted(ridx_result.begin(), ridx_result.begin() + left_count[0]));
EXPECT_TRUE(
std::is_sorted(ridx_result.begin() + left_count[0], ridx_result.end()));
// Check key value pairs are the same
for (auto i = 0ull; i < ridx_result.size(); i++) {
EXPECT_EQ(position_result[i], position_in[ridx_result[i]]);
}
}
TEST(GpuHist, SortPosition) {
TestSortPosition({1, 2, 1, 2, 1}, 1, 2);
TestSortPosition({1, 1, 1, 1}, 1, 2);
TestSortPosition({2, 2, 2, 2}, 1, 2);
TestSortPosition({1, 2, 1, 2, 3}, 1, 2);
}
void TestUpdatePosition() {
void TestUpdatePositionBatch() {
const int kNumRows = 10;
RowPartitioner rp(0, kNumRows);
auto rows = rp.GetRowsHost(0);
@ -69,16 +27,11 @@ void TestUpdatePosition() {
for (auto i = 0ull; i < kNumRows; i++) {
EXPECT_EQ(rows[i], i);
}
std::vector<int> extra_data = {0};
// Send the first five training instances to the right node
// and the second 5 to the left node
rp.UpdatePosition(0, 1, 2,
[=] __device__(RowPartitioner::RowIndexT ridx) {
if (ridx > 4) {
return 1;
}
else {
return 2;
}
rp.UpdatePositionBatch({0}, {1}, {2}, extra_data, [=] __device__(RowPartitioner::RowIndexT ridx, int) {
return ridx > 4;
});
rows = rp.GetRowsHost(1);
for (auto r : rows) {
@ -90,88 +43,58 @@ void TestUpdatePosition() {
}
// Split the left node again
rp.UpdatePosition(1, 3, 4, [=]__device__(RowPartitioner::RowIndexT ridx)
{
if (ridx < 7) {
return 3
;
}
return 4;
rp.UpdatePositionBatch({1}, {3}, {4}, extra_data,[=] __device__(RowPartitioner::RowIndexT ridx, int) {
return ridx < 7;
});
EXPECT_EQ(rp.GetRows(3).size(), 2);
EXPECT_EQ(rp.GetRows(4).size(), 3);
// Check position is as expected
EXPECT_EQ(rp.GetPositionHost(), std::vector<bst_node_t>({3,3,4,4,4,2,2,2,2,2}));
}
TEST(RowPartitioner, Basic) { TestUpdatePosition(); }
TEST(RowPartitioner, Batch) { TestUpdatePositionBatch(); }
void TestFinalise() {
const int kNumRows = 10;
void TestSortPositionBatch(const std::vector<int>& ridx_in, const std::vector<Segment>& segments) {
thrust::device_vector<uint32_t> ridx = ridx_in;
thrust::device_vector<uint32_t> ridx_tmp(ridx_in.size());
thrust::device_vector<bst_uint> counts(segments.size());
ObjInfo task{ObjInfo::kRegression, false, false};
HostDeviceVector<bst_node_t> position;
Context ctx;
ctx.gpu_id = 0;
auto op = [=] __device__(auto ridx, int data) { return ridx % 2 == 0; };
std::vector<int> op_data(segments.size());
std::vector<PerNodeData<int>> h_batch_info(segments.size());
dh::TemporaryArray<PerNodeData<int>> d_batch_info(segments.size());
{
RowPartitioner rp(0, kNumRows);
rp.FinalisePosition(
&ctx, task, &position,
[=] __device__(RowPartitioner::RowIndexT ridx, int position) { return 7; },
[] XGBOOST_DEVICE(size_t) { return false; });
auto position = rp.GetPositionHost();
for (auto p : position) {
EXPECT_EQ(p, 7);
}
std::size_t total_rows = 0;
for (int i = 0; i < segments.size(); i++) {
h_batch_info[i] = {segments.at(i), 0};
total_rows += segments.at(i).Size();
}
dh::safe_cuda(cudaMemcpyAsync(d_batch_info.data().get(), h_batch_info.data(),
h_batch_info.size() * sizeof(PerNodeData<int>), cudaMemcpyDefault,
nullptr));
dh::device_vector<int8_t> tmp;
SortPositionBatch<uint32_t, decltype(op), int>(dh::ToSpan(d_batch_info), dh::ToSpan(ridx),
dh::ToSpan(ridx_tmp), dh::ToSpan(counts),
total_rows, op, &tmp, nullptr);
/**
* Test for sampling.
*/
dh::device_vector<float> hess(kNumRows);
for (size_t i = 0; i < hess.size(); ++i) {
// removed rows, 0, 3, 6, 9
if (i % 3 == 0) {
hess[i] = 0;
} else {
hess[i] = i;
}
}
auto d_hess = dh::ToSpan(hess);
RowPartitioner rp(0, kNumRows);
rp.FinalisePosition(
&ctx, task, &position,
[] __device__(RowPartitioner::RowIndexT ridx, bst_node_t position) {
return ridx % 2 == 0 ? 1 : 2;
},
[d_hess] __device__(size_t ridx) { return d_hess[ridx] - 0.f == 0.f; });
auto const& h_position = position.ConstHostVector();
for (size_t ridx = 0; ridx < h_position.size(); ++ridx) {
if (ridx % 3 == 0) {
ASSERT_LT(h_position[ridx], 0);
} else {
ASSERT_EQ(h_position[ridx], ridx % 2 == 0 ? 1 : 2);
}
auto op_without_data = [=] __device__(auto ridx) { return ridx % 2 == 0; };
for (int i = 0; i < segments.size(); i++) {
auto begin = ridx.begin() + segments[i].begin;
auto end = ridx.begin() + segments[i].end;
bst_uint count = counts[i];
auto left_partition_count =
thrust::count_if(thrust::device, begin, begin + count, op_without_data);
EXPECT_EQ(left_partition_count, count);
auto right_partition_count =
thrust::count_if(thrust::device, begin + count, end, op_without_data);
EXPECT_EQ(right_partition_count, 0);
}
}
TEST(RowPartitioner, Finalise) { TestFinalise(); }
void TestIncorrectRow() {
RowPartitioner rp(0, 1);
rp.UpdatePosition(0, 1, 2, [=]__device__ (RowPartitioner::RowIndexT ridx)
{
return 4; // This is not the left branch or the right branch
});
TEST(GpuHist, SortPositionBatch) {
TestSortPositionBatch({0, 1, 2, 3, 4, 5}, {{0, 3}, {3, 6}});
TestSortPositionBatch({0, 1, 2, 3, 4, 5}, {{0, 1}, {3, 6}});
TestSortPositionBatch({0, 1, 2, 3, 4, 5}, {{0, 6}});
TestSortPositionBatch({0, 1, 2, 3, 4, 5}, {{3, 6}, {0, 2}});
}
TEST(RowPartitionerDeathTest, IncorrectRow) {
ASSERT_DEATH({ TestIncorrectRow(); },".*");
}
} // namespace tree
} // namespace xgboost