Unify thread configuration. (#6186)

This commit is contained in:
Jiaming Yuan 2020-10-19 16:05:42 +08:00 committed by GitHub
parent 7f6ed5780c
commit ddf37cca30
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 39 additions and 12 deletions

View File

@ -144,6 +144,22 @@ void ParallelFor(size_t size, size_t nthreads, Func fn) {
omp_exc.Rethrow();
}
/* \brief Configure parallel threads.
*
* \param p_threads Number of threads, when it's less than or equal to 0, this function
* will change it to number of process on system.
*
* \return Global openmp max threads before configuration.
*/
inline int32_t OmpSetNumThreads(int32_t* p_threads) {
auto& threads = *p_threads;
int32_t nthread_original = omp_get_max_threads();
if (threads <= 0) {
threads = omp_get_num_procs();
}
omp_set_num_threads(threads);
return nthread_original;
}
} // namespace common
} // namespace xgboost

View File

@ -19,6 +19,7 @@
#include "../common/math.h"
#include "../common/version.h"
#include "../common/group_data.h"
#include "../common/threading_utils.h"
#include "../data/adapter.h"
#include "../data/iterative_device_dmatrix.h"
@ -843,10 +844,7 @@ void SparsePage::Push(const SparsePage &batch) {
template <typename AdapterBatchT>
uint64_t SparsePage::Push(const AdapterBatchT& batch, float missing, int nthread) {
// Set number of threads but keep old value so we can reset it after
const int nthreadmax = omp_get_max_threads();
if (nthread <= 0) nthread = nthreadmax;
const int nthread_original = omp_get_max_threads();
omp_set_num_threads(nthread);
int nthread_original = common::OmpSetNumThreads(&nthread);
auto& offset_vec = offset.HostVector();
auto& data_vec = data.HostVector();
@ -865,7 +863,7 @@ uint64_t SparsePage::Push(const AdapterBatchT& batch, float missing, int nthread
}
}
size_t batch_size = batch.Size();
const size_t thread_size = batch_size/nthread;
const size_t thread_size = batch_size / nthread;
builder.InitBudget(expected_rows+1, nthread);
uint64_t max_columns = 0;
if (batch_size == 0) {

View File

@ -15,6 +15,7 @@
#include "simple_dmatrix.h"
#include "./simple_batch_iterator.h"
#include "../common/random.h"
#include "../common/threading_utils.h"
#include "adapter.h"
namespace xgboost {
@ -92,10 +93,7 @@ BatchSet<EllpackPage> SimpleDMatrix::GetEllpackBatches(const BatchParam& param)
template <typename AdapterT>
SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) {
// Set number of threads but keep old value so we can reset it after
const int nthreadmax = omp_get_max_threads();
if (nthread <= 0) nthread = nthreadmax;
int nthread_original = omp_get_max_threads();
omp_set_num_threads(nthread);
int nthread_original = common::OmpSetNumThreads(&nthread);
std::vector<uint64_t> qids;
uint64_t default_max = std::numeric_limits<uint64_t>::max();

View File

@ -43,6 +43,7 @@
#include "common/timer.h"
#include "common/charconv.h"
#include "common/version.h"
#include "common/threading_utils.h"
namespace {
@ -287,9 +288,7 @@ class LearnerConfiguration : public Learner {
generic_parameters_.CheckDeprecated();
ConsoleLogger::Configure(args);
if (generic_parameters_.nthread != 0) {
omp_set_num_threads(generic_parameters_.nthread);
}
common::OmpSetNumThreads(&generic_parameters_.nthread);
// add additional parameters
// These are cosntraints that need to be satisfied.

View File

@ -88,6 +88,22 @@ TEST(ParallelFor2dNonUniform, Test) {
omp_set_num_threads(old);
}
#if defined(_OPENMP)
TEST(OmpSetNumThreads, Basic) {
auto nthreads = 2;
auto orgi = OmpSetNumThreads(&nthreads);
ASSERT_EQ(omp_get_max_threads(), 2);
nthreads = 0;
OmpSetNumThreads(&nthreads);
ASSERT_EQ(omp_get_max_threads(), omp_get_num_procs());
nthreads = 1;
OmpSetNumThreads(&nthreads);
nthreads = 0;
OmpSetNumThreads(&nthreads);
ASSERT_EQ(omp_get_max_threads(), omp_get_num_procs());
omp_set_num_threads(orgi);
}
#endif // defined(_OPENMP)
} // namespace common
} // namespace xgboost