Remove omp_get_max_threads (#7608)

This is the one last PR for removing omp global variable.

* Add context object to the `DMatrix`.  This bridges `DMatrix` with https://github.com/dmlc/xgboost/issues/7308 .
* Require context to be available at the construction time of booster.
* Add `n_threads` support for R csc DMatrix constructor.
* Remove `omp_get_max_threads` in R glue code.
* Remove threading utilities that rely on omp global variable.
This commit is contained in:
Jiaming Yuan
2022-01-28 16:09:22 +08:00
committed by GitHub
parent 028bdc1740
commit 81210420c6
31 changed files with 195 additions and 211 deletions

View File

@@ -177,6 +177,7 @@ void ParallelFor(Index size, int32_t n_threads, Sched sched, Func fn) {
using OmpInd = Index;
#endif
OmpInd length = static_cast<OmpInd>(size);
CHECK_GE(n_threads, 1);
dmlc::OMPException exc;
switch (sched.sched) {
@@ -227,42 +228,16 @@ void ParallelFor(Index size, int32_t n_threads, Sched sched, Func fn) {
}
template <typename Index, typename Func>
void ParallelFor(Index size, size_t n_threads, Func fn) {
void ParallelFor(Index size, int32_t n_threads, Func fn) {
ParallelFor(size, n_threads, Sched::Static(), fn);
}
// FIXME(jiamingy): Remove this function to get rid of `omp_set_num_threads`, which sets a
// global variable in runtime and affects other programs in the same process.
template <typename Index, typename Func>
void ParallelFor(Index size, Func fn) {
ParallelFor(size, omp_get_max_threads(), Sched::Static(), fn);
} // !defined(_OPENMP)
inline int32_t OmpGetThreadLimit() {
int32_t limit = omp_get_thread_limit();
CHECK_GE(limit, 1) << "Invalid thread limit for OpenMP.";
return limit;
}
/* \brief Configure parallel threads.
*
* \param p_threads Number of threads, when it's less than or equal to 0, this function
* will change it to number of process on system.
*
* \return Global openmp max threads before configuration.
*/
inline int32_t OmpSetNumThreads(int32_t* p_threads) {
auto& threads = *p_threads;
int32_t nthread_original = omp_get_max_threads();
if (threads <= 0) {
threads = omp_get_num_procs();
}
threads = std::min(threads, OmpGetThreadLimit());
omp_set_num_threads(threads);
return nthread_original;
}
inline int32_t OmpGetNumThreads(int32_t n_threads) {
if (n_threads <= 0) {
n_threads = std::min(omp_get_num_procs(), omp_get_max_threads());