Remove omp_get_max_threads in data. (#7588)

This commit is contained in:
Jiaming Yuan
2022-01-24 02:44:07 +08:00
committed by GitHub
parent f84291c1e1
commit 5817840858
18 changed files with 97 additions and 92 deletions

View File

@@ -1,5 +1,5 @@
/*!
* Copyright 2015-2021 by Contributors
* Copyright 2015-2022 by XGBoost Contributors
* \file data.cc
*/
#include <dmlc/registry.h>
@@ -1001,15 +1001,14 @@ DMatrix::Create(data::IteratorAdapter<DataIterHandle, XGBCallbackDataIterNext,
XGBoostBatchCSR> *adapter,
float missing, int nthread, const std::string &cache_prefix);
SparsePage SparsePage::GetTranspose(int num_columns) const {
SparsePage SparsePage::GetTranspose(int num_columns, int32_t n_threads) const {
SparsePage transpose;
common::ParallelGroupBuilder<Entry, bst_row_t> builder(&transpose.offset.HostVector(),
&transpose.data.HostVector());
const int nthread = omp_get_max_threads();
builder.InitBudget(num_columns, nthread);
builder.InitBudget(num_columns, n_threads);
long batch_size = static_cast<long>(this->Size()); // NOLINT(*)
auto page = this->GetView();
common::ParallelFor(batch_size, [&](long i) { // NOLINT(*)
common::ParallelFor(batch_size, n_threads, [&](long i) { // NOLINT(*)
int tid = omp_get_thread_num();
auto inst = page[i];
for (const auto& entry : inst) {
@@ -1017,7 +1016,7 @@ SparsePage SparsePage::GetTranspose(int num_columns) const {
}
});
builder.InitStorage();
common::ParallelFor(batch_size, [&](long i) { // NOLINT(*)
common::ParallelFor(batch_size, n_threads, [&](long i) { // NOLINT(*)
int tid = omp_get_thread_num();
auto inst = page[i];
for (const auto& entry : inst) {
@@ -1059,8 +1058,6 @@ uint64_t SparsePage::Push(const AdapterBatchT& batch, float missing, int nthread
constexpr bool kIsRowMajor = AdapterBatchT::kIsRowMajor;
// Allow threading only for row-major case as column-major requires O(nthread*batch_size) memory
nthread = kIsRowMajor ? nthread : 1;
// Set number of threads but keep old value so we can reset it after
int nthread_original = common::OmpSetNumThreadsWithoutHT(&nthread);
if (!kIsRowMajor) {
CHECK_EQ(nthread, 1);
}
@@ -1085,7 +1082,6 @@ uint64_t SparsePage::Push(const AdapterBatchT& batch, float missing, int nthread
expected_rows = kIsRowMajor ? batch_size : expected_rows;
uint64_t max_columns = 0;
if (batch_size == 0) {
omp_set_num_threads(nthread_original);
return max_columns;
}
const size_t thread_size = batch_size / nthread;
@@ -1154,7 +1150,6 @@ uint64_t SparsePage::Push(const AdapterBatchT& batch, float missing, int nthread
});
}
exec.Rethrow();
omp_set_num_threads(nthread_original);
return max_columns;
}