gpu_hist performance tweaks (#5707)

* Remove device vectors

* Remove allreduce synchronize

* Remove double buffer
This commit is contained in:
Rory Mitchell 2020-05-29 16:48:53 +12:00 committed by GitHub
parent ca0d605b34
commit f779980f7e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 33 additions and 75 deletions

View File

@ -427,36 +427,6 @@ class TemporaryArray {
size_t size_;
};
/**
* \brief A double buffer, useful for algorithms like sort.
*/
template <typename T>
class DoubleBuffer {
public:
cub::DoubleBuffer<T> buff;
xgboost::common::Span<T> a, b;
DoubleBuffer() = default;
template <typename VectorT>
DoubleBuffer(VectorT *v1, VectorT *v2) {
a = xgboost::common::Span<T>(v1->data().get(), v1->size());
b = xgboost::common::Span<T>(v2->data().get(), v2->size());
buff = cub::DoubleBuffer<T>(a.data(), b.data());
}
size_t Size() const {
CHECK_EQ(a.size(), b.size());
return a.size();
}
cub::DoubleBuffer<T> &CubBuffer() { return buff; }
T *Current() { return buff.Current(); }
xgboost::common::Span<T> CurrentSpan() {
return xgboost::common::Span<T>{buff.Current(), Size()};
}
T *Other() { return buff.Alternate(); }
};
/**
* \brief Copies device span to std::vector.
*

View File

@ -93,26 +93,23 @@ void RowPartitioner::SortPosition(common::Span<bst_node_t> position,
position.size(), stream);
}
RowPartitioner::RowPartitioner(int device_idx, size_t num_rows)
: device_idx_(device_idx) {
dh::safe_cuda(cudaSetDevice(device_idx_));
ridx_a_.resize(num_rows);
ridx_b_.resize(num_rows);
position_a_.resize(num_rows);
position_b_.resize(num_rows);
ridx_ = dh::DoubleBuffer<RowIndexT>{&ridx_a_, &ridx_b_};
position_ = dh::DoubleBuffer<bst_node_t>{&position_a_, &position_b_};
ridx_segments_.emplace_back(Segment(0, num_rows));
void Reset(int device_idx, common::Span<RowPartitioner::RowIndexT> ridx,
common::Span<bst_node_t> position) {
CHECK_EQ(ridx.size(), position.size());
dh::LaunchN(device_idx, ridx.size(), [=] __device__(size_t idx) {
ridx[idx] = idx;
position[idx] = 0;
});
}
thrust::sequence(
thrust::device_pointer_cast(ridx_.CurrentSpan().data()),
thrust::device_pointer_cast(ridx_.CurrentSpan().data() + ridx_.Size()));
thrust::fill(
thrust::device_pointer_cast(position_.Current()),
thrust::device_pointer_cast(position_.Current() + position_.Size()), 0);
RowPartitioner::RowPartitioner(int device_idx, size_t num_rows)
: device_idx_(device_idx), ridx_a_(num_rows), position_a_(num_rows) {
dh::safe_cuda(cudaSetDevice(device_idx_));
Reset(device_idx, dh::ToSpan(ridx_a_), dh::ToSpan(position_a_));
left_counts_.resize(256);
thrust::fill(left_counts_.begin(), left_counts_.end(), 0);
streams_.resize(2);
ridx_segments_.emplace_back(Segment(0, num_rows));
for (auto& stream : streams_) {
dh::safe_cuda(cudaStreamCreate(&stream));
}
@ -132,15 +129,15 @@ common::Span<const RowPartitioner::RowIndexT> RowPartitioner::GetRows(
if (segment.Size() == 0) {
return common::Span<const RowPartitioner::RowIndexT>();
}
return ridx_.CurrentSpan().subspan(segment.begin, segment.Size());
return dh::ToSpan(ridx_a_).subspan(segment.begin, segment.Size());
}
common::Span<const RowPartitioner::RowIndexT> RowPartitioner::GetRows() {
return ridx_.CurrentSpan();
return dh::ToSpan(ridx_a_);
}
common::Span<const bst_node_t> RowPartitioner::GetPosition() {
return position_.CurrentSpan();
return dh::ToSpan(position_a_);
}
std::vector<RowPartitioner::RowIndexT> RowPartitioner::GetRowsHost(
bst_node_t nidx) {
@ -162,23 +159,25 @@ void RowPartitioner::SortPositionAndCopy(const Segment& segment,
bst_node_t right_nidx,
int64_t* d_left_count,
cudaStream_t stream) {
dh::TemporaryArray<bst_node_t> position_temp(position_a_.size());
dh::TemporaryArray<RowIndexT> ridx_temp(ridx_a_.size());
SortPosition(
// position_in
common::Span<bst_node_t>(position_.Current() + segment.begin,
common::Span<bst_node_t>(position_a_.data().get() + segment.begin,
segment.Size()),
// position_out
common::Span<bst_node_t>(position_.Other() + segment.begin,
common::Span<bst_node_t>(position_temp.data().get() + segment.begin,
segment.Size()),
// row index in
common::Span<RowIndexT>(ridx_.Current() + segment.begin, segment.Size()),
common::Span<RowIndexT>(ridx_a_.data().get() + segment.begin, segment.Size()),
// row index out
common::Span<RowIndexT>(ridx_.Other() + segment.begin, segment.Size()),
common::Span<RowIndexT>(ridx_temp.data().get() + segment.begin, segment.Size()),
left_nidx, right_nidx, d_left_count, stream);
// Copy back key/value
const auto d_position_current = position_.Current() + segment.begin;
const auto d_position_other = position_.Other() + segment.begin;
const auto d_ridx_current = ridx_.Current() + segment.begin;
const auto d_ridx_other = ridx_.Other() + segment.begin;
const auto d_position_current = position_a_.data().get() + segment.begin;
const auto d_position_other = position_temp.data().get() + segment.begin;
const auto d_ridx_current = ridx_a_.data().get() + segment.begin;
const auto d_ridx_other = ridx_temp.data().get() + segment.begin;
dh::LaunchN(device_idx_, segment.Size(), stream, [=] __device__(size_t idx) {
d_position_current[idx] = d_position_other[idx];
d_ridx_current[idx] = d_ridx_other[idx];

View File

@ -46,18 +46,8 @@ class RowPartitioner {
*/
/*! \brief Range of row index for each node, pointers into ridx below. */
std::vector<Segment> ridx_segments_;
dh::caching_device_vector<RowIndexT> ridx_a_;
dh::caching_device_vector<RowIndexT> ridx_b_;
dh::caching_device_vector<bst_node_t> position_a_;
dh::caching_device_vector<bst_node_t> position_b_;
/*! \brief mapping for node id -> rows.
* This looks like:
* node id | 1 | 2 |
* rows idx | 3, 5, 1 | 13, 31 |
*/
dh::DoubleBuffer<RowIndexT> ridx_;
/*! \brief mapping for row -> node id. */
dh::DoubleBuffer<bst_node_t> position_;
dh::TemporaryArray<RowIndexT> ridx_a_;
dh::TemporaryArray<bst_node_t> position_a_;
dh::caching_device_vector<int64_t>
left_counts_; // Useful to keep a bunch of zeroed memory for sort position
std::vector<cudaStream_t> streams_;
@ -110,8 +100,8 @@ class RowPartitioner {
void UpdatePosition(bst_node_t nidx, bst_node_t left_nidx,
bst_node_t right_nidx, UpdatePositionOpT op) {
Segment segment = ridx_segments_.at(nidx); // rows belongs to node nidx
auto d_ridx = ridx_.CurrentSpan();
auto d_position = position_.CurrentSpan();
auto d_ridx = dh::ToSpan(ridx_a_);
auto d_position = dh::ToSpan(position_a_);
if (left_counts_.size() <= nidx) {
left_counts_.resize((nidx * 2) + 1);
thrust::fill(left_counts_.begin(), left_counts_.end(), 0);
@ -159,9 +149,9 @@ class RowPartitioner {
*/
template <typename FinalisePositionOpT>
void FinalisePosition(FinalisePositionOpT op) {
auto d_position = position_.Current();
const auto d_ridx = ridx_.Current();
dh::LaunchN(device_idx_, position_.Size(), [=] __device__(size_t idx) {
auto d_position = position_a_.data().get();
const auto d_ridx = ridx_a_.data().get();
dh::LaunchN(device_idx_, position_a_.size(), [=] __device__(size_t idx) {
auto position = d_position[idx];
RowIndexT ridx = d_ridx[idx];
bst_node_t new_position = op(ridx, position);

View File

@ -511,7 +511,6 @@ struct GPUHistMakerDevice {
reinterpret_cast<typename GradientSumT::ValueT*>(d_node_hist),
reinterpret_cast<typename GradientSumT::ValueT*>(d_node_hist),
page->Cuts().TotalBins() * (sizeof(GradientSumT) / sizeof(typename GradientSumT::ValueT)));
reducer->Synchronize();
monitor.Stop("AllReduce");
}