Handle OMP_THREAD_LIMIT. (#7390)

This commit is contained in:
Jiaming Yuan
2021-11-03 15:44:38 +08:00
committed by GitHub
parent e6ab594e14
commit 57a4b4ff64
4 changed files with 93 additions and 5 deletions

View File

@@ -7,11 +7,28 @@
#define XGBOOST_COMMON_THREADING_UTILS_H_
#include <dmlc/common.h>
#include <vector>
#include <dmlc/omp.h>
#include <algorithm>
#include <limits>
#include <type_traits> // std::is_signed
#include <vector>
#include "xgboost/logging.h"
#if !defined(_OPENMP)
extern "C" {
inline int32_t omp_get_thread_limit() __GOMP_NOTHROW { return 1; } // NOLINT
}
#endif // !defined(_OPENMP)
// MSVC doesn't implement the thread limit.
#if defined(_OPENMP) && defined(_MSC_VER)
extern "C" {
inline int32_t omp_get_thread_limit() { return std::numeric_limits<int32_t>::max(); } // NOLINT
}
#endif // defined(_MSC_VER)
namespace xgboost {
namespace common {
@@ -153,7 +170,7 @@ struct Sched {
};
template <typename Index, typename Func>
void ParallelFor(Index size, size_t n_threads, Sched sched, Func fn) {
void ParallelFor(Index size, int32_t n_threads, Sched sched, Func fn) {
#if defined(_MSC_VER)
// msvc doesn't support unsigned integer as openmp index.
using OmpInd = std::conditional_t<std::is_signed<Index>::value, Index, omp_ulong>;
@@ -220,6 +237,13 @@ void ParallelFor(Index size, size_t n_threads, Func fn) {
template <typename Index, typename Func>
void ParallelFor(Index size, Func fn) {
ParallelFor(size, omp_get_max_threads(), Sched::Static(), fn);
} // !defined(_OPENMP)
inline int32_t OmpGetThreadLimit() {
int32_t limit = omp_get_thread_limit();
CHECK_GE(limit, 1) << "Invalid thread limit for OpenMP.";
return limit;
}
/* \brief Configure parallel threads.
@@ -235,15 +259,18 @@ inline int32_t OmpSetNumThreads(int32_t* p_threads) {
if (threads <= 0) {
threads = omp_get_num_procs();
}
threads = std::min(threads, OmpGetThreadLimit());
omp_set_num_threads(threads);
return nthread_original;
}
inline int32_t OmpSetNumThreadsWithoutHT(int32_t* p_threads) {
auto& threads = *p_threads;
int32_t nthread_original = omp_get_max_threads();
if (threads <= 0) {
threads = nthread_original;
}
threads = std::min(threads, OmpGetThreadLimit());
omp_set_num_threads(threads);
return nthread_original;
}
@@ -252,6 +279,7 @@ inline int32_t OmpGetNumThreads(int32_t n_threads) {
if (n_threads <= 0) {
n_threads = omp_get_num_procs();
}
n_threads = std::min(n_threads, OmpGetThreadLimit());
return n_threads;
}
} // namespace common