Handle OMP_THREAD_LIMIT. (#7390)

This commit is contained in:
Jiaming Yuan
2021-11-03 15:44:38 +08:00
committed by GitHub
parent e6ab594e14
commit 57a4b4ff64
4 changed files with 93 additions and 5 deletions

View File

@@ -1,6 +1,12 @@
# -*- coding: utf-8 -*-
import os
import tempfile
import subprocess
import xgboost as xgb
import numpy as np
import pytest
import testing as tm
class TestOMP:
@@ -71,3 +77,31 @@ class TestOMP:
assert auc_1 == auc_2 == auc_3
assert np.array_equal(auc_1, auc_2)
assert np.array_equal(auc_1, auc_3)
@pytest.mark.skipif(**tm.no_sklearn())
def test_with_omp_thread_limit(self):
args = [
"python", os.path.join(
tm.PROJECT_ROOT, "tests", "python", "with_omp_limit.py"
)
]
results = []
with tempfile.TemporaryDirectory() as tmpdir:
for i in (1, 2, 16):
path = os.path.join(tmpdir, str(i))
with open(path, "w") as fd:
fd.write("\n")
cp = args.copy()
cp.append(path)
env = os.environ.copy()
env["OMP_THREAD_LIMIT"] = str(i)
status = subprocess.call(cp, env=env)
assert status == 0
with open(path, "r") as fd:
results.append(float(fd.read()))
for auc in results:
np.testing.assert_allclose(auc, results[0])