Turn warning messages into Python warnings. (#9387)

This commit is contained in:
Jiaming Yuan 2023-07-15 07:46:43 +08:00 committed by GitHub
parent 04aff3af8e
commit 9da5050643
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 16 additions and 16 deletions

View File

@ -153,7 +153,11 @@ def _expect(expectations: Sequence[Type], got: Type) -> str:
def _log_callback(msg: bytes) -> None: def _log_callback(msg: bytes) -> None:
"""Redirect logs from native library into Python console""" """Redirect logs from native library into Python console"""
print(py_str(msg)) smsg = py_str(msg)
if smsg.find("WARNING:") != -1:
warnings.warn(smsg, UserWarning)
return
print(smsg)
def _get_log_callback_func() -> Callable: def _get_log_callback_func() -> Callable:

View File

@ -61,9 +61,7 @@ class TestLoadPickle:
rng = np.random.RandomState(1994) rng = np.random.RandomState(1994)
X = rng.randn(10, 10) X = rng.randn(10, 10)
y = rng.randn(10) y = rng.randn(10)
with tm.captured_output() as (out, err): with pytest.warns(UserWarning, match="No visible GPU is found"):
# Test no thrust exception is thrown # Test no thrust exception is thrown
with pytest.raises(xgb.core.XGBoostError): with pytest.raises(xgb.core.XGBoostError, match="have at least one device"):
xgb.train({"tree_method": "gpu_hist"}, xgb.DMatrix(X, y)) xgb.train({"tree_method": "gpu_hist"}, xgb.DMatrix(X, y))
assert out.getvalue().find("No visible GPU is found") != -1

View File

@ -3,6 +3,7 @@ import os
import pickle import pickle
import random import random
import tempfile import tempfile
import warnings
from typing import Callable, Optional from typing import Callable, Optional
import numpy as np import numpy as np
@ -1091,25 +1092,22 @@ def test_constraint_parameters():
) )
@pytest.mark.filterwarnings("error")
def test_parameter_validation(): def test_parameter_validation():
reg = xgb.XGBRegressor(foo='bar', verbosity=1) reg = xgb.XGBRegressor(foo="bar", verbosity=1)
X = np.random.randn(10, 10) X = np.random.randn(10, 10)
y = np.random.randn(10) y = np.random.randn(10)
with tm.captured_output() as (out, err): with pytest.warns(Warning, match="foo"):
reg.fit(X, y) reg.fit(X, y)
output = out.getvalue().strip()
assert output.find('foo') != -1 reg = xgb.XGBRegressor(
n_estimators=2, missing=3, importance_type="gain", verbosity=1
reg = xgb.XGBRegressor(n_estimators=2, missing=3, )
importance_type='gain', verbosity=1)
X = np.random.randn(10, 10) X = np.random.randn(10, 10)
y = np.random.randn(10) y = np.random.randn(10)
with tm.captured_output() as (out, err):
reg.fit(X, y)
output = out.getvalue().strip()
assert len(output) == 0 with warnings.catch_warnings():
reg.fit(X, y)
def test_deprecate_position_arg(): def test_deprecate_position_arg():