Extract partial sum into an independent function. (#7889)
This commit is contained in:
@@ -10,6 +10,7 @@
|
||||
|
||||
#include "../common/column_matrix.h"
|
||||
#include "../common/hist_util.h"
|
||||
#include "../common/numeric.h"
|
||||
#include "../common/threading_utils.h"
|
||||
|
||||
namespace xgboost {
|
||||
@@ -28,58 +29,13 @@ void GHistIndexMatrix::PushBatch(SparsePage const &batch,
|
||||
common::Span<FeatureType const> ft,
|
||||
size_t rbegin, size_t prev_sum, uint32_t nbins,
|
||||
int32_t n_threads) {
|
||||
// The number of threads is pegged to the batch size. If the OMP
|
||||
// block is parallelized on anything other than the batch/block size,
|
||||
// it should be reassigned
|
||||
auto page = batch.GetView();
|
||||
auto it = common::MakeIndexTransformIter([&](size_t ridx) { return page[ridx].size(); });
|
||||
common::PartialSum(n_threads, it, it + page.Size(), prev_sum, row_ptr.begin() + rbegin);
|
||||
// The number of threads is pegged to the batch size. If the OMP block is parallelized
|
||||
// on anything other than the batch/block size, it should be reassigned
|
||||
const size_t batch_threads =
|
||||
std::max(static_cast<size_t>(1), std::min(batch.Size(), static_cast<size_t>(n_threads)));
|
||||
auto page = batch.GetView();
|
||||
common::MemStackAllocator<size_t, 128> partial_sums(batch_threads);
|
||||
|
||||
size_t block_size = batch.Size() / batch_threads;
|
||||
|
||||
dmlc::OMPException exc;
|
||||
#pragma omp parallel num_threads(batch_threads)
|
||||
{
|
||||
#pragma omp for
|
||||
for (omp_ulong tid = 0; tid < batch_threads; ++tid) {
|
||||
exc.Run([&]() {
|
||||
size_t ibegin = block_size * tid;
|
||||
size_t iend = (tid == (batch_threads - 1) ? batch.Size()
|
||||
: (block_size * (tid + 1)));
|
||||
|
||||
size_t running_sum = 0;
|
||||
for (size_t ridx = ibegin; ridx < iend; ++ridx) {
|
||||
running_sum += page[ridx].size();
|
||||
row_ptr[rbegin + 1 + ridx] = running_sum;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
#pragma omp single
|
||||
{
|
||||
exc.Run([&]() {
|
||||
partial_sums[0] = prev_sum;
|
||||
for (size_t i = 1; i < batch_threads; ++i) {
|
||||
partial_sums[i] = partial_sums[i - 1] + row_ptr[rbegin + i * block_size];
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
#pragma omp for
|
||||
for (omp_ulong tid = 0; tid < batch_threads; ++tid) {
|
||||
exc.Run([&]() {
|
||||
size_t ibegin = block_size * tid;
|
||||
size_t iend = (tid == (batch_threads - 1) ? batch.Size()
|
||||
: (block_size * (tid + 1)));
|
||||
|
||||
for (size_t i = ibegin; i < iend; ++i) {
|
||||
row_ptr[rbegin + 1 + i] += partial_sums[tid];
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
exc.Rethrow();
|
||||
|
||||
const size_t n_index = row_ptr[rbegin + batch.Size()]; // number of entries in this page
|
||||
ResizeIndex(n_index, isDense_);
|
||||
|
||||
Reference in New Issue
Block a user