parent
fab3c05ced
commit
a3d195e73e
@ -9,8 +9,8 @@ xgboost <- function(data = NULL, label = NULL, missing = NA, weight = NULL,
|
|||||||
early_stopping_rounds = NULL, maximize = NULL,
|
early_stopping_rounds = NULL, maximize = NULL,
|
||||||
save_period = NULL, save_name = "xgboost.model",
|
save_period = NULL, save_name = "xgboost.model",
|
||||||
xgb_model = NULL, callbacks = list(), ...) {
|
xgb_model = NULL, callbacks = list(), ...) {
|
||||||
|
merged <- check.booster.params(params, ...)
|
||||||
dtrain <- xgb.get.DMatrix(data, label, missing, weight, nthread = params$nthread)
|
dtrain <- xgb.get.DMatrix(data, label, missing, weight, nthread = merged$nthread)
|
||||||
|
|
||||||
watchlist <- list(train = dtrain)
|
watchlist <- list(train = dtrain)
|
||||||
|
|
||||||
|
|||||||
@ -7,11 +7,28 @@
|
|||||||
#define XGBOOST_COMMON_THREADING_UTILS_H_
|
#define XGBOOST_COMMON_THREADING_UTILS_H_
|
||||||
|
|
||||||
#include <dmlc/common.h>
|
#include <dmlc/common.h>
|
||||||
#include <vector>
|
#include <dmlc/omp.h>
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
#include <limits>
|
||||||
#include <type_traits> // std::is_signed
|
#include <type_traits> // std::is_signed
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
#include "xgboost/logging.h"
|
#include "xgboost/logging.h"
|
||||||
|
|
||||||
|
#if !defined(_OPENMP)
|
||||||
|
extern "C" {
|
||||||
|
inline int32_t omp_get_thread_limit() __GOMP_NOTHROW { return 1; } // NOLINT
|
||||||
|
}
|
||||||
|
#endif // !defined(_OPENMP)
|
||||||
|
|
||||||
|
// MSVC doesn't implement the thread limit.
|
||||||
|
#if defined(_OPENMP) && defined(_MSC_VER)
|
||||||
|
extern "C" {
|
||||||
|
inline int32_t omp_get_thread_limit() { return std::numeric_limits<int32_t>::max(); } // NOLINT
|
||||||
|
}
|
||||||
|
#endif // defined(_MSC_VER)
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace common {
|
namespace common {
|
||||||
|
|
||||||
@ -153,7 +170,7 @@ struct Sched {
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <typename Index, typename Func>
|
template <typename Index, typename Func>
|
||||||
void ParallelFor(Index size, size_t n_threads, Sched sched, Func fn) {
|
void ParallelFor(Index size, int32_t n_threads, Sched sched, Func fn) {
|
||||||
#if defined(_MSC_VER)
|
#if defined(_MSC_VER)
|
||||||
// msvc doesn't support unsigned integer as openmp index.
|
// msvc doesn't support unsigned integer as openmp index.
|
||||||
using OmpInd = std::conditional_t<std::is_signed<Index>::value, Index, omp_ulong>;
|
using OmpInd = std::conditional_t<std::is_signed<Index>::value, Index, omp_ulong>;
|
||||||
@ -220,6 +237,13 @@ void ParallelFor(Index size, size_t n_threads, Func fn) {
|
|||||||
template <typename Index, typename Func>
|
template <typename Index, typename Func>
|
||||||
void ParallelFor(Index size, Func fn) {
|
void ParallelFor(Index size, Func fn) {
|
||||||
ParallelFor(size, omp_get_max_threads(), Sched::Static(), fn);
|
ParallelFor(size, omp_get_max_threads(), Sched::Static(), fn);
|
||||||
|
} // !defined(_OPENMP)
|
||||||
|
|
||||||
|
|
||||||
|
inline int32_t OmpGetThreadLimit() {
|
||||||
|
int32_t limit = omp_get_thread_limit();
|
||||||
|
CHECK_GE(limit, 1) << "Invalid thread limit for OpenMP.";
|
||||||
|
return limit;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* \brief Configure parallel threads.
|
/* \brief Configure parallel threads.
|
||||||
@ -235,15 +259,18 @@ inline int32_t OmpSetNumThreads(int32_t* p_threads) {
|
|||||||
if (threads <= 0) {
|
if (threads <= 0) {
|
||||||
threads = omp_get_num_procs();
|
threads = omp_get_num_procs();
|
||||||
}
|
}
|
||||||
|
threads = std::min(threads, OmpGetThreadLimit());
|
||||||
omp_set_num_threads(threads);
|
omp_set_num_threads(threads);
|
||||||
return nthread_original;
|
return nthread_original;
|
||||||
}
|
}
|
||||||
|
|
||||||
inline int32_t OmpSetNumThreadsWithoutHT(int32_t* p_threads) {
|
inline int32_t OmpSetNumThreadsWithoutHT(int32_t* p_threads) {
|
||||||
auto& threads = *p_threads;
|
auto& threads = *p_threads;
|
||||||
int32_t nthread_original = omp_get_max_threads();
|
int32_t nthread_original = omp_get_max_threads();
|
||||||
if (threads <= 0) {
|
if (threads <= 0) {
|
||||||
threads = nthread_original;
|
threads = nthread_original;
|
||||||
}
|
}
|
||||||
|
threads = std::min(threads, OmpGetThreadLimit());
|
||||||
omp_set_num_threads(threads);
|
omp_set_num_threads(threads);
|
||||||
return nthread_original;
|
return nthread_original;
|
||||||
}
|
}
|
||||||
@ -252,6 +279,7 @@ inline int32_t OmpGetNumThreads(int32_t n_threads) {
|
|||||||
if (n_threads <= 0) {
|
if (n_threads <= 0) {
|
||||||
n_threads = omp_get_num_procs();
|
n_threads = omp_get_num_procs();
|
||||||
}
|
}
|
||||||
|
n_threads = std::min(n_threads, OmpGetThreadLimit());
|
||||||
return n_threads;
|
return n_threads;
|
||||||
}
|
}
|
||||||
} // namespace common
|
} // namespace common
|
||||||
|
|||||||
@ -1,6 +1,12 @@
|
|||||||
# -*- coding: utf-8 -*-
|
import os
|
||||||
|
import tempfile
|
||||||
|
import subprocess
|
||||||
|
|
||||||
import xgboost as xgb
|
import xgboost as xgb
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import testing as tm
|
||||||
|
|
||||||
|
|
||||||
class TestOMP:
|
class TestOMP:
|
||||||
@ -71,3 +77,31 @@ class TestOMP:
|
|||||||
assert auc_1 == auc_2 == auc_3
|
assert auc_1 == auc_2 == auc_3
|
||||||
assert np.array_equal(auc_1, auc_2)
|
assert np.array_equal(auc_1, auc_2)
|
||||||
assert np.array_equal(auc_1, auc_3)
|
assert np.array_equal(auc_1, auc_3)
|
||||||
|
|
||||||
|
@pytest.mark.skipif(**tm.no_sklearn())
|
||||||
|
def test_with_omp_thread_limit(self):
|
||||||
|
args = [
|
||||||
|
"python", os.path.join(
|
||||||
|
tm.PROJECT_ROOT, "tests", "python", "with_omp_limit.py"
|
||||||
|
)
|
||||||
|
]
|
||||||
|
results = []
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
for i in (1, 2, 16):
|
||||||
|
path = os.path.join(tmpdir, str(i))
|
||||||
|
with open(path, "w") as fd:
|
||||||
|
fd.write("\n")
|
||||||
|
cp = args.copy()
|
||||||
|
cp.append(path)
|
||||||
|
|
||||||
|
env = os.environ.copy()
|
||||||
|
env["OMP_THREAD_LIMIT"] = str(i)
|
||||||
|
|
||||||
|
status = subprocess.call(cp, env=env)
|
||||||
|
assert status == 0
|
||||||
|
|
||||||
|
with open(path, "r") as fd:
|
||||||
|
results.append(float(fd.read()))
|
||||||
|
|
||||||
|
for auc in results:
|
||||||
|
np.testing.assert_allclose(auc, results[0])
|
||||||
|
|||||||
26
tests/python/with_omp_limit.py
Normal file
26
tests/python/with_omp_limit.py
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
import os
|
||||||
|
import xgboost as xgb
|
||||||
|
from sklearn.datasets import make_classification
|
||||||
|
from sklearn.metrics import roc_auc_score
|
||||||
|
import sys
|
||||||
|
|
||||||
|
|
||||||
|
def run_omp(output_path: str):
|
||||||
|
X, y = make_classification(
|
||||||
|
n_samples=200, n_features=32, n_classes=3, n_informative=8
|
||||||
|
)
|
||||||
|
Xy = xgb.DMatrix(X, y, nthread=16)
|
||||||
|
booster = xgb.train(
|
||||||
|
{"num_class": 3, "objective": "multi:softprob", "n_jobs": 16},
|
||||||
|
Xy,
|
||||||
|
num_boost_round=8,
|
||||||
|
)
|
||||||
|
score = booster.predict(Xy)
|
||||||
|
auc = roc_auc_score(y, score, average="weighted", multi_class="ovr")
|
||||||
|
with open(output_path, "w") as fd:
|
||||||
|
fd.write(str(auc))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
out = sys.argv[1]
|
||||||
|
run_omp(out)
|
||||||
Loading…
x
Reference in New Issue
Block a user