Handle OMP_THREAD_LIMIT. (#7390) (#7391)

This commit is contained in:
Jiaming Yuan 2021-11-03 20:25:51 +08:00 committed by GitHub
parent fab3c05ced
commit a3d195e73e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 93 additions and 5 deletions

View File

@ -9,8 +9,8 @@ xgboost <- function(data = NULL, label = NULL, missing = NA, weight = NULL,
early_stopping_rounds = NULL, maximize = NULL, early_stopping_rounds = NULL, maximize = NULL,
save_period = NULL, save_name = "xgboost.model", save_period = NULL, save_name = "xgboost.model",
xgb_model = NULL, callbacks = list(), ...) { xgb_model = NULL, callbacks = list(), ...) {
merged <- check.booster.params(params, ...)
dtrain <- xgb.get.DMatrix(data, label, missing, weight, nthread = params$nthread) dtrain <- xgb.get.DMatrix(data, label, missing, weight, nthread = merged$nthread)
watchlist <- list(train = dtrain) watchlist <- list(train = dtrain)

View File

@ -7,11 +7,28 @@
#define XGBOOST_COMMON_THREADING_UTILS_H_ #define XGBOOST_COMMON_THREADING_UTILS_H_
#include <dmlc/common.h> #include <dmlc/common.h>
#include <vector> #include <dmlc/omp.h>
#include <algorithm> #include <algorithm>
#include <limits>
#include <type_traits> // std::is_signed #include <type_traits> // std::is_signed
#include <vector>
#include "xgboost/logging.h" #include "xgboost/logging.h"
#if !defined(_OPENMP)
extern "C" {
inline int32_t omp_get_thread_limit() __GOMP_NOTHROW { return 1; } // NOLINT
}
#endif // !defined(_OPENMP)
// MSVC doesn't implement the thread limit.
#if defined(_OPENMP) && defined(_MSC_VER)
extern "C" {
inline int32_t omp_get_thread_limit() { return std::numeric_limits<int32_t>::max(); } // NOLINT
}
#endif // defined(_MSC_VER)
namespace xgboost { namespace xgboost {
namespace common { namespace common {
@ -153,7 +170,7 @@ struct Sched {
}; };
template <typename Index, typename Func> template <typename Index, typename Func>
void ParallelFor(Index size, size_t n_threads, Sched sched, Func fn) { void ParallelFor(Index size, int32_t n_threads, Sched sched, Func fn) {
#if defined(_MSC_VER) #if defined(_MSC_VER)
// msvc doesn't support unsigned integer as openmp index. // msvc doesn't support unsigned integer as openmp index.
using OmpInd = std::conditional_t<std::is_signed<Index>::value, Index, omp_ulong>; using OmpInd = std::conditional_t<std::is_signed<Index>::value, Index, omp_ulong>;
@ -220,6 +237,13 @@ void ParallelFor(Index size, size_t n_threads, Func fn) {
template <typename Index, typename Func> template <typename Index, typename Func>
void ParallelFor(Index size, Func fn) { void ParallelFor(Index size, Func fn) {
ParallelFor(size, omp_get_max_threads(), Sched::Static(), fn); ParallelFor(size, omp_get_max_threads(), Sched::Static(), fn);
} // !defined(_OPENMP)
inline int32_t OmpGetThreadLimit() {
int32_t limit = omp_get_thread_limit();
CHECK_GE(limit, 1) << "Invalid thread limit for OpenMP.";
return limit;
} }
/* \brief Configure parallel threads. /* \brief Configure parallel threads.
@ -235,15 +259,18 @@ inline int32_t OmpSetNumThreads(int32_t* p_threads) {
if (threads <= 0) { if (threads <= 0) {
threads = omp_get_num_procs(); threads = omp_get_num_procs();
} }
threads = std::min(threads, OmpGetThreadLimit());
omp_set_num_threads(threads); omp_set_num_threads(threads);
return nthread_original; return nthread_original;
} }
inline int32_t OmpSetNumThreadsWithoutHT(int32_t* p_threads) { inline int32_t OmpSetNumThreadsWithoutHT(int32_t* p_threads) {
auto& threads = *p_threads; auto& threads = *p_threads;
int32_t nthread_original = omp_get_max_threads(); int32_t nthread_original = omp_get_max_threads();
if (threads <= 0) { if (threads <= 0) {
threads = nthread_original; threads = nthread_original;
} }
threads = std::min(threads, OmpGetThreadLimit());
omp_set_num_threads(threads); omp_set_num_threads(threads);
return nthread_original; return nthread_original;
} }
@ -252,6 +279,7 @@ inline int32_t OmpGetNumThreads(int32_t n_threads) {
if (n_threads <= 0) { if (n_threads <= 0) {
n_threads = omp_get_num_procs(); n_threads = omp_get_num_procs();
} }
n_threads = std::min(n_threads, OmpGetThreadLimit());
return n_threads; return n_threads;
} }
} // namespace common } // namespace common

View File

@ -1,6 +1,12 @@
# -*- coding: utf-8 -*- import os
import tempfile
import subprocess
import xgboost as xgb import xgboost as xgb
import numpy as np import numpy as np
import pytest
import testing as tm
class TestOMP: class TestOMP:
@ -71,3 +77,31 @@ class TestOMP:
assert auc_1 == auc_2 == auc_3 assert auc_1 == auc_2 == auc_3
assert np.array_equal(auc_1, auc_2) assert np.array_equal(auc_1, auc_2)
assert np.array_equal(auc_1, auc_3) assert np.array_equal(auc_1, auc_3)
@pytest.mark.skipif(**tm.no_sklearn())
def test_with_omp_thread_limit(self):
args = [
"python", os.path.join(
tm.PROJECT_ROOT, "tests", "python", "with_omp_limit.py"
)
]
results = []
with tempfile.TemporaryDirectory() as tmpdir:
for i in (1, 2, 16):
path = os.path.join(tmpdir, str(i))
with open(path, "w") as fd:
fd.write("\n")
cp = args.copy()
cp.append(path)
env = os.environ.copy()
env["OMP_THREAD_LIMIT"] = str(i)
status = subprocess.call(cp, env=env)
assert status == 0
with open(path, "r") as fd:
results.append(float(fd.read()))
for auc in results:
np.testing.assert_allclose(auc, results[0])

View File

@ -0,0 +1,26 @@
import os
import xgboost as xgb
from sklearn.datasets import make_classification
from sklearn.metrics import roc_auc_score
import sys
def run_omp(output_path: str):
X, y = make_classification(
n_samples=200, n_features=32, n_classes=3, n_informative=8
)
Xy = xgb.DMatrix(X, y, nthread=16)
booster = xgb.train(
{"num_class": 3, "objective": "multi:softprob", "n_jobs": 16},
Xy,
num_boost_round=8,
)
score = booster.predict(Xy)
auc = roc_auc_score(y, score, average="weighted", multi_class="ovr")
with open(output_path, "w") as fd:
fd.write(str(auc))
if __name__ == "__main__":
out = sys.argv[1]
run_omp(out)