Re-introduce double buffer in UpdatePosition, to fix perf regression in gpu_hist (#6757)
* Revert "gpu_hist performance tweaks (#5707)" This reverts commit f779980f7ea7f6f07e86229b8e78144e8a74e6b3. * Address reviewer's comment * Fix build error
This commit is contained in:
parent
e2d8a99413
commit
4230dcb614
@ -549,6 +549,36 @@ class TemporaryArray {
|
||||
size_t size_;
|
||||
};
|
||||
|
||||
/**
|
||||
* \brief A double buffer, useful for algorithms like sort.
|
||||
*/
|
||||
template <typename T>
|
||||
class DoubleBuffer {
|
||||
public:
|
||||
cub::DoubleBuffer<T> buff;
|
||||
xgboost::common::Span<T> a, b;
|
||||
DoubleBuffer() = default;
|
||||
template <typename VectorT>
|
||||
DoubleBuffer(VectorT *v1, VectorT *v2) {
|
||||
a = xgboost::common::Span<T>(v1->data().get(), v1->size());
|
||||
b = xgboost::common::Span<T>(v2->data().get(), v2->size());
|
||||
buff = cub::DoubleBuffer<T>(a.data(), b.data());
|
||||
}
|
||||
|
||||
size_t Size() const {
|
||||
CHECK_EQ(a.size(), b.size());
|
||||
return a.size();
|
||||
}
|
||||
cub::DoubleBuffer<T> &CubBuffer() { return buff; }
|
||||
|
||||
T *Current() { return buff.Current(); }
|
||||
xgboost::common::Span<T> CurrentSpan() {
|
||||
return xgboost::common::Span<T>{buff.Current(), Size()};
|
||||
}
|
||||
|
||||
T *Other() { return buff.Alternate(); }
|
||||
};
|
||||
|
||||
/**
|
||||
* \brief Copies device span to std::vector.
|
||||
*
|
||||
|
||||
@ -103,13 +103,17 @@ void Reset(int device_idx, common::Span<RowPartitioner::RowIndexT> ridx,
|
||||
}
|
||||
|
||||
RowPartitioner::RowPartitioner(int device_idx, size_t num_rows)
|
||||
: device_idx_(device_idx), ridx_a_(num_rows), position_a_(num_rows) {
|
||||
: device_idx_(device_idx), ridx_a_(num_rows), position_a_(num_rows),
|
||||
ridx_b_(num_rows), position_b_(num_rows) {
|
||||
dh::safe_cuda(cudaSetDevice(device_idx_));
|
||||
Reset(device_idx, dh::ToSpan(ridx_a_), dh::ToSpan(position_a_));
|
||||
ridx_ = dh::DoubleBuffer<RowIndexT>{&ridx_a_, &ridx_b_};
|
||||
position_ = dh::DoubleBuffer<bst_node_t>{&position_a_, &position_b_};
|
||||
ridx_segments_.emplace_back(Segment(0, num_rows));
|
||||
|
||||
Reset(device_idx, ridx_.CurrentSpan(), position_.CurrentSpan());
|
||||
left_counts_.resize(256);
|
||||
thrust::fill(left_counts_.begin(), left_counts_.end(), 0);
|
||||
streams_.resize(2);
|
||||
ridx_segments_.emplace_back(Segment(0, num_rows));
|
||||
for (auto& stream : streams_) {
|
||||
dh::safe_cuda(cudaStreamCreate(&stream));
|
||||
}
|
||||
@ -129,15 +133,15 @@ common::Span<const RowPartitioner::RowIndexT> RowPartitioner::GetRows(
|
||||
if (segment.Size() == 0) {
|
||||
return common::Span<const RowPartitioner::RowIndexT>();
|
||||
}
|
||||
return dh::ToSpan(ridx_a_).subspan(segment.begin, segment.Size());
|
||||
return ridx_.CurrentSpan().subspan(segment.begin, segment.Size());
|
||||
}
|
||||
|
||||
common::Span<const RowPartitioner::RowIndexT> RowPartitioner::GetRows() {
|
||||
return dh::ToSpan(ridx_a_);
|
||||
return ridx_.CurrentSpan();
|
||||
}
|
||||
|
||||
common::Span<const bst_node_t> RowPartitioner::GetPosition() {
|
||||
return dh::ToSpan(position_a_);
|
||||
return position_.CurrentSpan();
|
||||
}
|
||||
std::vector<RowPartitioner::RowIndexT> RowPartitioner::GetRowsHost(
|
||||
bst_node_t nidx) {
|
||||
@ -159,25 +163,23 @@ void RowPartitioner::SortPositionAndCopy(const Segment& segment,
|
||||
bst_node_t right_nidx,
|
||||
int64_t* d_left_count,
|
||||
cudaStream_t stream) {
|
||||
dh::TemporaryArray<bst_node_t> position_temp(position_a_.size());
|
||||
dh::TemporaryArray<RowIndexT> ridx_temp(ridx_a_.size());
|
||||
SortPosition(
|
||||
// position_in
|
||||
common::Span<bst_node_t>(position_a_.data().get() + segment.begin,
|
||||
common::Span<bst_node_t>(position_.Current() + segment.begin,
|
||||
segment.Size()),
|
||||
// position_out
|
||||
common::Span<bst_node_t>(position_temp.data().get() + segment.begin,
|
||||
common::Span<bst_node_t>(position_.Other() + segment.begin,
|
||||
segment.Size()),
|
||||
// row index in
|
||||
common::Span<RowIndexT>(ridx_a_.data().get() + segment.begin, segment.Size()),
|
||||
common::Span<RowIndexT>(ridx_.Current() + segment.begin, segment.Size()),
|
||||
// row index out
|
||||
common::Span<RowIndexT>(ridx_temp.data().get() + segment.begin, segment.Size()),
|
||||
common::Span<RowIndexT>(ridx_.Other() + segment.begin, segment.Size()),
|
||||
left_nidx, right_nidx, d_left_count, stream);
|
||||
// Copy back key/value
|
||||
const auto d_position_current = position_a_.data().get() + segment.begin;
|
||||
const auto d_position_other = position_temp.data().get() + segment.begin;
|
||||
const auto d_ridx_current = ridx_a_.data().get() + segment.begin;
|
||||
const auto d_ridx_other = ridx_temp.data().get() + segment.begin;
|
||||
const auto d_position_current = position_.Current() + segment.begin;
|
||||
const auto d_position_other = position_.Other() + segment.begin;
|
||||
const auto d_ridx_current = ridx_.Current() + segment.begin;
|
||||
const auto d_ridx_other = ridx_.Other() + segment.begin;
|
||||
dh::LaunchN(device_idx_, segment.Size(), stream, [=] __device__(size_t idx) {
|
||||
d_position_current[idx] = d_position_other[idx];
|
||||
d_ridx_current[idx] = d_ridx_other[idx];
|
||||
|
||||
@ -47,7 +47,17 @@ class RowPartitioner {
|
||||
/*! \brief Range of row index for each node, pointers into ridx below. */
|
||||
std::vector<Segment> ridx_segments_;
|
||||
dh::TemporaryArray<RowIndexT> ridx_a_;
|
||||
dh::TemporaryArray<RowIndexT> ridx_b_;
|
||||
dh::TemporaryArray<bst_node_t> position_a_;
|
||||
dh::TemporaryArray<bst_node_t> position_b_;
|
||||
/*! \brief mapping for node id -> rows.
|
||||
* This looks like:
|
||||
* node id | 1 | 2 |
|
||||
* rows idx | 3, 5, 1 | 13, 31 |
|
||||
*/
|
||||
dh::DoubleBuffer<RowIndexT> ridx_;
|
||||
/*! \brief mapping for row -> node id. */
|
||||
dh::DoubleBuffer<bst_node_t> position_;
|
||||
dh::caching_device_vector<int64_t>
|
||||
left_counts_; // Useful to keep a bunch of zeroed memory for sort position
|
||||
std::vector<cudaStream_t> streams_;
|
||||
@ -100,8 +110,8 @@ class RowPartitioner {
|
||||
void UpdatePosition(bst_node_t nidx, bst_node_t left_nidx,
|
||||
bst_node_t right_nidx, UpdatePositionOpT op) {
|
||||
Segment segment = ridx_segments_.at(nidx); // rows belongs to node nidx
|
||||
auto d_ridx = dh::ToSpan(ridx_a_);
|
||||
auto d_position = dh::ToSpan(position_a_);
|
||||
auto d_ridx = ridx_.CurrentSpan();
|
||||
auto d_position = position_.CurrentSpan();
|
||||
if (left_counts_.size() <= nidx) {
|
||||
left_counts_.resize((nidx * 2) + 1);
|
||||
thrust::fill(left_counts_.begin(), left_counts_.end(), 0);
|
||||
@ -148,9 +158,9 @@ class RowPartitioner {
|
||||
*/
|
||||
template <typename FinalisePositionOpT>
|
||||
void FinalisePosition(FinalisePositionOpT op) {
|
||||
auto d_position = position_a_.data().get();
|
||||
const auto d_ridx = ridx_a_.data().get();
|
||||
dh::LaunchN(device_idx_, position_a_.size(), [=] __device__(size_t idx) {
|
||||
auto d_position = position_.Current();
|
||||
const auto d_ridx = ridx_.Current();
|
||||
dh::LaunchN(device_idx_, position_.Size(), [=] __device__(size_t idx) {
|
||||
auto position = d_position[idx];
|
||||
RowIndexT ridx = d_ridx[idx];
|
||||
bst_node_t new_position = op(ridx, position);
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user