diff --git a/src/common/threading_utils.h b/src/common/threading_utils.h index 6f2a953a7..896098aac 100644 --- a/src/common/threading_utils.h +++ b/src/common/threading_utils.h @@ -160,6 +160,16 @@ inline int32_t OmpSetNumThreads(int32_t* p_threads) { omp_set_num_threads(threads); return nthread_original; } +inline int32_t OmpSetNumThreadsWithoutHT(int32_t* p_threads) { + auto& threads = *p_threads; + int32_t nthread_original = omp_get_max_threads(); + if (threads <= 0) { + threads = nthread_original; + } + omp_set_num_threads(threads); + return nthread_original; +} + } // namespace common } // namespace xgboost diff --git a/src/data/data.cc b/src/data/data.cc index 2d56a6f29..f203fd3dc 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -844,7 +844,7 @@ void SparsePage::Push(const SparsePage &batch) { template uint64_t SparsePage::Push(const AdapterBatchT& batch, float missing, int nthread) { // Set number of threads but keep old value so we can reset it after - int nthread_original = common::OmpSetNumThreads(&nthread); + int nthread_original = common::OmpSetNumThreadsWithoutHT(&nthread); auto& offset_vec = offset.HostVector(); auto& data_vec = data.HostVector(); diff --git a/src/data/simple_dmatrix.cc b/src/data/simple_dmatrix.cc index 437da9754..c39909489 100644 --- a/src/data/simple_dmatrix.cc +++ b/src/data/simple_dmatrix.cc @@ -93,7 +93,7 @@ BatchSet SimpleDMatrix::GetEllpackBatches(const BatchParam& param) template SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) { // Set number of threads but keep old value so we can reset it after - int nthread_original = common::OmpSetNumThreads(&nthread); + int nthread_original = common::OmpSetNumThreadsWithoutHT(&nthread); std::vector qids; uint64_t default_max = std::numeric_limits::max();