Remove omp_get_max_threads in data. (#7588)
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2015-2021 by Contributors
|
||||
* Copyright 2015-2022 by XGBoost Contributors
|
||||
* \file data.cc
|
||||
*/
|
||||
#include <dmlc/registry.h>
|
||||
@@ -1001,15 +1001,14 @@ DMatrix::Create(data::IteratorAdapter<DataIterHandle, XGBCallbackDataIterNext,
|
||||
XGBoostBatchCSR> *adapter,
|
||||
float missing, int nthread, const std::string &cache_prefix);
|
||||
|
||||
SparsePage SparsePage::GetTranspose(int num_columns) const {
|
||||
SparsePage SparsePage::GetTranspose(int num_columns, int32_t n_threads) const {
|
||||
SparsePage transpose;
|
||||
common::ParallelGroupBuilder<Entry, bst_row_t> builder(&transpose.offset.HostVector(),
|
||||
&transpose.data.HostVector());
|
||||
const int nthread = omp_get_max_threads();
|
||||
builder.InitBudget(num_columns, nthread);
|
||||
builder.InitBudget(num_columns, n_threads);
|
||||
long batch_size = static_cast<long>(this->Size()); // NOLINT(*)
|
||||
auto page = this->GetView();
|
||||
common::ParallelFor(batch_size, [&](long i) { // NOLINT(*)
|
||||
common::ParallelFor(batch_size, n_threads, [&](long i) { // NOLINT(*)
|
||||
int tid = omp_get_thread_num();
|
||||
auto inst = page[i];
|
||||
for (const auto& entry : inst) {
|
||||
@@ -1017,7 +1016,7 @@ SparsePage SparsePage::GetTranspose(int num_columns) const {
|
||||
}
|
||||
});
|
||||
builder.InitStorage();
|
||||
common::ParallelFor(batch_size, [&](long i) { // NOLINT(*)
|
||||
common::ParallelFor(batch_size, n_threads, [&](long i) { // NOLINT(*)
|
||||
int tid = omp_get_thread_num();
|
||||
auto inst = page[i];
|
||||
for (const auto& entry : inst) {
|
||||
@@ -1059,8 +1058,6 @@ uint64_t SparsePage::Push(const AdapterBatchT& batch, float missing, int nthread
|
||||
constexpr bool kIsRowMajor = AdapterBatchT::kIsRowMajor;
|
||||
// Allow threading only for row-major case as column-major requires O(nthread*batch_size) memory
|
||||
nthread = kIsRowMajor ? nthread : 1;
|
||||
// Set number of threads but keep old value so we can reset it after
|
||||
int nthread_original = common::OmpSetNumThreadsWithoutHT(&nthread);
|
||||
if (!kIsRowMajor) {
|
||||
CHECK_EQ(nthread, 1);
|
||||
}
|
||||
@@ -1085,7 +1082,6 @@ uint64_t SparsePage::Push(const AdapterBatchT& batch, float missing, int nthread
|
||||
expected_rows = kIsRowMajor ? batch_size : expected_rows;
|
||||
uint64_t max_columns = 0;
|
||||
if (batch_size == 0) {
|
||||
omp_set_num_threads(nthread_original);
|
||||
return max_columns;
|
||||
}
|
||||
const size_t thread_size = batch_size / nthread;
|
||||
@@ -1154,7 +1150,6 @@ uint64_t SparsePage::Push(const AdapterBatchT& batch, float missing, int nthread
|
||||
});
|
||||
}
|
||||
exec.Rethrow();
|
||||
omp_set_num_threads(nthread_original);
|
||||
|
||||
return max_columns;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user