Handle OMP_THREAD_LIMIT. (#7390)
This commit is contained in:
@@ -1,6 +1,12 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import os
|
||||
import tempfile
|
||||
import subprocess
|
||||
|
||||
import xgboost as xgb
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import testing as tm
|
||||
|
||||
|
||||
class TestOMP:
|
||||
@@ -71,3 +77,31 @@ class TestOMP:
|
||||
assert auc_1 == auc_2 == auc_3
|
||||
assert np.array_equal(auc_1, auc_2)
|
||||
assert np.array_equal(auc_1, auc_3)
|
||||
|
||||
@pytest.mark.skipif(**tm.no_sklearn())
|
||||
def test_with_omp_thread_limit(self):
|
||||
args = [
|
||||
"python", os.path.join(
|
||||
tm.PROJECT_ROOT, "tests", "python", "with_omp_limit.py"
|
||||
)
|
||||
]
|
||||
results = []
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
for i in (1, 2, 16):
|
||||
path = os.path.join(tmpdir, str(i))
|
||||
with open(path, "w") as fd:
|
||||
fd.write("\n")
|
||||
cp = args.copy()
|
||||
cp.append(path)
|
||||
|
||||
env = os.environ.copy()
|
||||
env["OMP_THREAD_LIMIT"] = str(i)
|
||||
|
||||
status = subprocess.call(cp, env=env)
|
||||
assert status == 0
|
||||
|
||||
with open(path, "r") as fd:
|
||||
results.append(float(fd.read()))
|
||||
|
||||
for auc in results:
|
||||
np.testing.assert_allclose(auc, results[0])
|
||||
|
||||
Reference in New Issue
Block a user