Unify thread configuration. (#6186)
This commit is contained in:
parent
7f6ed5780c
commit
ddf37cca30
@ -144,6 +144,22 @@ void ParallelFor(size_t size, size_t nthreads, Func fn) {
|
||||
omp_exc.Rethrow();
|
||||
}
|
||||
|
||||
/* \brief Configure parallel threads.
|
||||
*
|
||||
* \param p_threads Number of threads, when it's less than or equal to 0, this function
|
||||
* will change it to number of process on system.
|
||||
*
|
||||
* \return Global openmp max threads before configuration.
|
||||
*/
|
||||
inline int32_t OmpSetNumThreads(int32_t* p_threads) {
|
||||
auto& threads = *p_threads;
|
||||
int32_t nthread_original = omp_get_max_threads();
|
||||
if (threads <= 0) {
|
||||
threads = omp_get_num_procs();
|
||||
}
|
||||
omp_set_num_threads(threads);
|
||||
return nthread_original;
|
||||
}
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
|
||||
|
||||
@ -19,6 +19,7 @@
|
||||
#include "../common/math.h"
|
||||
#include "../common/version.h"
|
||||
#include "../common/group_data.h"
|
||||
#include "../common/threading_utils.h"
|
||||
#include "../data/adapter.h"
|
||||
#include "../data/iterative_device_dmatrix.h"
|
||||
|
||||
@ -843,10 +844,7 @@ void SparsePage::Push(const SparsePage &batch) {
|
||||
template <typename AdapterBatchT>
|
||||
uint64_t SparsePage::Push(const AdapterBatchT& batch, float missing, int nthread) {
|
||||
// Set number of threads but keep old value so we can reset it after
|
||||
const int nthreadmax = omp_get_max_threads();
|
||||
if (nthread <= 0) nthread = nthreadmax;
|
||||
const int nthread_original = omp_get_max_threads();
|
||||
omp_set_num_threads(nthread);
|
||||
int nthread_original = common::OmpSetNumThreads(&nthread);
|
||||
auto& offset_vec = offset.HostVector();
|
||||
auto& data_vec = data.HostVector();
|
||||
|
||||
@ -865,7 +863,7 @@ uint64_t SparsePage::Push(const AdapterBatchT& batch, float missing, int nthread
|
||||
}
|
||||
}
|
||||
size_t batch_size = batch.Size();
|
||||
const size_t thread_size = batch_size/nthread;
|
||||
const size_t thread_size = batch_size / nthread;
|
||||
builder.InitBudget(expected_rows+1, nthread);
|
||||
uint64_t max_columns = 0;
|
||||
if (batch_size == 0) {
|
||||
|
||||
@ -15,6 +15,7 @@
|
||||
#include "simple_dmatrix.h"
|
||||
#include "./simple_batch_iterator.h"
|
||||
#include "../common/random.h"
|
||||
#include "../common/threading_utils.h"
|
||||
#include "adapter.h"
|
||||
|
||||
namespace xgboost {
|
||||
@ -92,10 +93,7 @@ BatchSet<EllpackPage> SimpleDMatrix::GetEllpackBatches(const BatchParam& param)
|
||||
template <typename AdapterT>
|
||||
SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) {
|
||||
// Set number of threads but keep old value so we can reset it after
|
||||
const int nthreadmax = omp_get_max_threads();
|
||||
if (nthread <= 0) nthread = nthreadmax;
|
||||
int nthread_original = omp_get_max_threads();
|
||||
omp_set_num_threads(nthread);
|
||||
int nthread_original = common::OmpSetNumThreads(&nthread);
|
||||
|
||||
std::vector<uint64_t> qids;
|
||||
uint64_t default_max = std::numeric_limits<uint64_t>::max();
|
||||
|
||||
@ -43,6 +43,7 @@
|
||||
#include "common/timer.h"
|
||||
#include "common/charconv.h"
|
||||
#include "common/version.h"
|
||||
#include "common/threading_utils.h"
|
||||
|
||||
namespace {
|
||||
|
||||
@ -287,9 +288,7 @@ class LearnerConfiguration : public Learner {
|
||||
generic_parameters_.CheckDeprecated();
|
||||
|
||||
ConsoleLogger::Configure(args);
|
||||
if (generic_parameters_.nthread != 0) {
|
||||
omp_set_num_threads(generic_parameters_.nthread);
|
||||
}
|
||||
common::OmpSetNumThreads(&generic_parameters_.nthread);
|
||||
|
||||
// add additional parameters
|
||||
// These are cosntraints that need to be satisfied.
|
||||
|
||||
@ -88,6 +88,22 @@ TEST(ParallelFor2dNonUniform, Test) {
|
||||
|
||||
omp_set_num_threads(old);
|
||||
}
|
||||
#if defined(_OPENMP)
|
||||
TEST(OmpSetNumThreads, Basic) {
|
||||
auto nthreads = 2;
|
||||
auto orgi = OmpSetNumThreads(&nthreads);
|
||||
ASSERT_EQ(omp_get_max_threads(), 2);
|
||||
nthreads = 0;
|
||||
OmpSetNumThreads(&nthreads);
|
||||
ASSERT_EQ(omp_get_max_threads(), omp_get_num_procs());
|
||||
nthreads = 1;
|
||||
OmpSetNumThreads(&nthreads);
|
||||
nthreads = 0;
|
||||
OmpSetNumThreads(&nthreads);
|
||||
ASSERT_EQ(omp_get_max_threads(), omp_get_num_procs());
|
||||
|
||||
omp_set_num_threads(orgi);
|
||||
}
|
||||
#endif // defined(_OPENMP)
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user