Turn warning messages into Python warnings. (#9387)
This commit is contained in:
parent
04aff3af8e
commit
9da5050643
@ -153,7 +153,11 @@ def _expect(expectations: Sequence[Type], got: Type) -> str:
|
|||||||
|
|
||||||
def _log_callback(msg: bytes) -> None:
|
def _log_callback(msg: bytes) -> None:
|
||||||
"""Redirect logs from native library into Python console"""
|
"""Redirect logs from native library into Python console"""
|
||||||
print(py_str(msg))
|
smsg = py_str(msg)
|
||||||
|
if smsg.find("WARNING:") != -1:
|
||||||
|
warnings.warn(smsg, UserWarning)
|
||||||
|
return
|
||||||
|
print(smsg)
|
||||||
|
|
||||||
|
|
||||||
def _get_log_callback_func() -> Callable:
|
def _get_log_callback_func() -> Callable:
|
||||||
|
|||||||
@ -61,9 +61,7 @@ class TestLoadPickle:
|
|||||||
rng = np.random.RandomState(1994)
|
rng = np.random.RandomState(1994)
|
||||||
X = rng.randn(10, 10)
|
X = rng.randn(10, 10)
|
||||||
y = rng.randn(10)
|
y = rng.randn(10)
|
||||||
with tm.captured_output() as (out, err):
|
with pytest.warns(UserWarning, match="No visible GPU is found"):
|
||||||
# Test no thrust exception is thrown
|
# Test no thrust exception is thrown
|
||||||
with pytest.raises(xgb.core.XGBoostError):
|
with pytest.raises(xgb.core.XGBoostError, match="have at least one device"):
|
||||||
xgb.train({"tree_method": "gpu_hist"}, xgb.DMatrix(X, y))
|
xgb.train({"tree_method": "gpu_hist"}, xgb.DMatrix(X, y))
|
||||||
|
|
||||||
assert out.getvalue().find("No visible GPU is found") != -1
|
|
||||||
|
|||||||
@ -3,6 +3,7 @@ import os
|
|||||||
import pickle
|
import pickle
|
||||||
import random
|
import random
|
||||||
import tempfile
|
import tempfile
|
||||||
|
import warnings
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -1091,25 +1092,22 @@ def test_constraint_parameters():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.filterwarnings("error")
|
||||||
def test_parameter_validation():
|
def test_parameter_validation():
|
||||||
reg = xgb.XGBRegressor(foo='bar', verbosity=1)
|
reg = xgb.XGBRegressor(foo="bar", verbosity=1)
|
||||||
X = np.random.randn(10, 10)
|
X = np.random.randn(10, 10)
|
||||||
y = np.random.randn(10)
|
y = np.random.randn(10)
|
||||||
with tm.captured_output() as (out, err):
|
with pytest.warns(Warning, match="foo"):
|
||||||
reg.fit(X, y)
|
reg.fit(X, y)
|
||||||
output = out.getvalue().strip()
|
|
||||||
|
|
||||||
assert output.find('foo') != -1
|
reg = xgb.XGBRegressor(
|
||||||
|
n_estimators=2, missing=3, importance_type="gain", verbosity=1
|
||||||
reg = xgb.XGBRegressor(n_estimators=2, missing=3,
|
)
|
||||||
importance_type='gain', verbosity=1)
|
|
||||||
X = np.random.randn(10, 10)
|
X = np.random.randn(10, 10)
|
||||||
y = np.random.randn(10)
|
y = np.random.randn(10)
|
||||||
with tm.captured_output() as (out, err):
|
|
||||||
reg.fit(X, y)
|
|
||||||
output = out.getvalue().strip()
|
|
||||||
|
|
||||||
assert len(output) == 0
|
with warnings.catch_warnings():
|
||||||
|
reg.fit(X, y)
|
||||||
|
|
||||||
|
|
||||||
def test_deprecate_position_arg():
|
def test_deprecate_position_arg():
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user