Unify thread configuration. (#6186)
This commit is contained in:
parent
7f6ed5780c
commit
ddf37cca30
@ -144,6 +144,22 @@ void ParallelFor(size_t size, size_t nthreads, Func fn) {
|
|||||||
omp_exc.Rethrow();
|
omp_exc.Rethrow();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* \brief Configure parallel threads.
|
||||||
|
*
|
||||||
|
* \param p_threads Number of threads, when it's less than or equal to 0, this function
|
||||||
|
* will change it to number of process on system.
|
||||||
|
*
|
||||||
|
* \return Global openmp max threads before configuration.
|
||||||
|
*/
|
||||||
|
inline int32_t OmpSetNumThreads(int32_t* p_threads) {
|
||||||
|
auto& threads = *p_threads;
|
||||||
|
int32_t nthread_original = omp_get_max_threads();
|
||||||
|
if (threads <= 0) {
|
||||||
|
threads = omp_get_num_procs();
|
||||||
|
}
|
||||||
|
omp_set_num_threads(threads);
|
||||||
|
return nthread_original;
|
||||||
|
}
|
||||||
} // namespace common
|
} // namespace common
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|
||||||
|
|||||||
@ -19,6 +19,7 @@
|
|||||||
#include "../common/math.h"
|
#include "../common/math.h"
|
||||||
#include "../common/version.h"
|
#include "../common/version.h"
|
||||||
#include "../common/group_data.h"
|
#include "../common/group_data.h"
|
||||||
|
#include "../common/threading_utils.h"
|
||||||
#include "../data/adapter.h"
|
#include "../data/adapter.h"
|
||||||
#include "../data/iterative_device_dmatrix.h"
|
#include "../data/iterative_device_dmatrix.h"
|
||||||
|
|
||||||
@ -843,10 +844,7 @@ void SparsePage::Push(const SparsePage &batch) {
|
|||||||
template <typename AdapterBatchT>
|
template <typename AdapterBatchT>
|
||||||
uint64_t SparsePage::Push(const AdapterBatchT& batch, float missing, int nthread) {
|
uint64_t SparsePage::Push(const AdapterBatchT& batch, float missing, int nthread) {
|
||||||
// Set number of threads but keep old value so we can reset it after
|
// Set number of threads but keep old value so we can reset it after
|
||||||
const int nthreadmax = omp_get_max_threads();
|
int nthread_original = common::OmpSetNumThreads(&nthread);
|
||||||
if (nthread <= 0) nthread = nthreadmax;
|
|
||||||
const int nthread_original = omp_get_max_threads();
|
|
||||||
omp_set_num_threads(nthread);
|
|
||||||
auto& offset_vec = offset.HostVector();
|
auto& offset_vec = offset.HostVector();
|
||||||
auto& data_vec = data.HostVector();
|
auto& data_vec = data.HostVector();
|
||||||
|
|
||||||
@ -865,7 +863,7 @@ uint64_t SparsePage::Push(const AdapterBatchT& batch, float missing, int nthread
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
size_t batch_size = batch.Size();
|
size_t batch_size = batch.Size();
|
||||||
const size_t thread_size = batch_size/nthread;
|
const size_t thread_size = batch_size / nthread;
|
||||||
builder.InitBudget(expected_rows+1, nthread);
|
builder.InitBudget(expected_rows+1, nthread);
|
||||||
uint64_t max_columns = 0;
|
uint64_t max_columns = 0;
|
||||||
if (batch_size == 0) {
|
if (batch_size == 0) {
|
||||||
|
|||||||
@ -15,6 +15,7 @@
|
|||||||
#include "simple_dmatrix.h"
|
#include "simple_dmatrix.h"
|
||||||
#include "./simple_batch_iterator.h"
|
#include "./simple_batch_iterator.h"
|
||||||
#include "../common/random.h"
|
#include "../common/random.h"
|
||||||
|
#include "../common/threading_utils.h"
|
||||||
#include "adapter.h"
|
#include "adapter.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
@ -92,10 +93,7 @@ BatchSet<EllpackPage> SimpleDMatrix::GetEllpackBatches(const BatchParam& param)
|
|||||||
template <typename AdapterT>
|
template <typename AdapterT>
|
||||||
SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) {
|
SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) {
|
||||||
// Set number of threads but keep old value so we can reset it after
|
// Set number of threads but keep old value so we can reset it after
|
||||||
const int nthreadmax = omp_get_max_threads();
|
int nthread_original = common::OmpSetNumThreads(&nthread);
|
||||||
if (nthread <= 0) nthread = nthreadmax;
|
|
||||||
int nthread_original = omp_get_max_threads();
|
|
||||||
omp_set_num_threads(nthread);
|
|
||||||
|
|
||||||
std::vector<uint64_t> qids;
|
std::vector<uint64_t> qids;
|
||||||
uint64_t default_max = std::numeric_limits<uint64_t>::max();
|
uint64_t default_max = std::numeric_limits<uint64_t>::max();
|
||||||
|
|||||||
@ -43,6 +43,7 @@
|
|||||||
#include "common/timer.h"
|
#include "common/timer.h"
|
||||||
#include "common/charconv.h"
|
#include "common/charconv.h"
|
||||||
#include "common/version.h"
|
#include "common/version.h"
|
||||||
|
#include "common/threading_utils.h"
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
@ -287,9 +288,7 @@ class LearnerConfiguration : public Learner {
|
|||||||
generic_parameters_.CheckDeprecated();
|
generic_parameters_.CheckDeprecated();
|
||||||
|
|
||||||
ConsoleLogger::Configure(args);
|
ConsoleLogger::Configure(args);
|
||||||
if (generic_parameters_.nthread != 0) {
|
common::OmpSetNumThreads(&generic_parameters_.nthread);
|
||||||
omp_set_num_threads(generic_parameters_.nthread);
|
|
||||||
}
|
|
||||||
|
|
||||||
// add additional parameters
|
// add additional parameters
|
||||||
// These are cosntraints that need to be satisfied.
|
// These are cosntraints that need to be satisfied.
|
||||||
|
|||||||
@ -88,6 +88,22 @@ TEST(ParallelFor2dNonUniform, Test) {
|
|||||||
|
|
||||||
omp_set_num_threads(old);
|
omp_set_num_threads(old);
|
||||||
}
|
}
|
||||||
|
#if defined(_OPENMP)
|
||||||
|
TEST(OmpSetNumThreads, Basic) {
|
||||||
|
auto nthreads = 2;
|
||||||
|
auto orgi = OmpSetNumThreads(&nthreads);
|
||||||
|
ASSERT_EQ(omp_get_max_threads(), 2);
|
||||||
|
nthreads = 0;
|
||||||
|
OmpSetNumThreads(&nthreads);
|
||||||
|
ASSERT_EQ(omp_get_max_threads(), omp_get_num_procs());
|
||||||
|
nthreads = 1;
|
||||||
|
OmpSetNumThreads(&nthreads);
|
||||||
|
nthreads = 0;
|
||||||
|
OmpSetNumThreads(&nthreads);
|
||||||
|
ASSERT_EQ(omp_get_max_threads(), omp_get_num_procs());
|
||||||
|
|
||||||
|
omp_set_num_threads(orgi);
|
||||||
|
}
|
||||||
|
#endif // defined(_OPENMP)
|
||||||
} // namespace common
|
} // namespace common
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user