Handle OMP_THREAD_LIMIT. (#7390)
This commit is contained in:
@@ -7,11 +7,28 @@
|
||||
#define XGBOOST_COMMON_THREADING_UTILS_H_
|
||||
|
||||
#include <dmlc/common.h>
|
||||
#include <vector>
|
||||
#include <dmlc/omp.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <limits>
|
||||
#include <type_traits> // std::is_signed
|
||||
#include <vector>
|
||||
|
||||
#include "xgboost/logging.h"
|
||||
|
||||
#if !defined(_OPENMP)
|
||||
extern "C" {
|
||||
inline int32_t omp_get_thread_limit() __GOMP_NOTHROW { return 1; } // NOLINT
|
||||
}
|
||||
#endif // !defined(_OPENMP)
|
||||
|
||||
// MSVC doesn't implement the thread limit.
|
||||
#if defined(_OPENMP) && defined(_MSC_VER)
|
||||
extern "C" {
|
||||
inline int32_t omp_get_thread_limit() { return std::numeric_limits<int32_t>::max(); } // NOLINT
|
||||
}
|
||||
#endif // defined(_MSC_VER)
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
|
||||
@@ -153,7 +170,7 @@ struct Sched {
|
||||
};
|
||||
|
||||
template <typename Index, typename Func>
|
||||
void ParallelFor(Index size, size_t n_threads, Sched sched, Func fn) {
|
||||
void ParallelFor(Index size, int32_t n_threads, Sched sched, Func fn) {
|
||||
#if defined(_MSC_VER)
|
||||
// msvc doesn't support unsigned integer as openmp index.
|
||||
using OmpInd = std::conditional_t<std::is_signed<Index>::value, Index, omp_ulong>;
|
||||
@@ -220,6 +237,13 @@ void ParallelFor(Index size, size_t n_threads, Func fn) {
|
||||
template <typename Index, typename Func>
|
||||
void ParallelFor(Index size, Func fn) {
|
||||
ParallelFor(size, omp_get_max_threads(), Sched::Static(), fn);
|
||||
} // !defined(_OPENMP)
|
||||
|
||||
|
||||
inline int32_t OmpGetThreadLimit() {
|
||||
int32_t limit = omp_get_thread_limit();
|
||||
CHECK_GE(limit, 1) << "Invalid thread limit for OpenMP.";
|
||||
return limit;
|
||||
}
|
||||
|
||||
/* \brief Configure parallel threads.
|
||||
@@ -235,15 +259,18 @@ inline int32_t OmpSetNumThreads(int32_t* p_threads) {
|
||||
if (threads <= 0) {
|
||||
threads = omp_get_num_procs();
|
||||
}
|
||||
threads = std::min(threads, OmpGetThreadLimit());
|
||||
omp_set_num_threads(threads);
|
||||
return nthread_original;
|
||||
}
|
||||
|
||||
inline int32_t OmpSetNumThreadsWithoutHT(int32_t* p_threads) {
|
||||
auto& threads = *p_threads;
|
||||
int32_t nthread_original = omp_get_max_threads();
|
||||
if (threads <= 0) {
|
||||
threads = nthread_original;
|
||||
}
|
||||
threads = std::min(threads, OmpGetThreadLimit());
|
||||
omp_set_num_threads(threads);
|
||||
return nthread_original;
|
||||
}
|
||||
@@ -252,6 +279,7 @@ inline int32_t OmpGetNumThreads(int32_t n_threads) {
|
||||
if (n_threads <= 0) {
|
||||
n_threads = omp_get_num_procs();
|
||||
}
|
||||
n_threads = std::min(n_threads, OmpGetThreadLimit());
|
||||
return n_threads;
|
||||
}
|
||||
} // namespace common
|
||||
|
||||
Reference in New Issue
Block a user