Turn warning messages into Python warnings. (#9387)

This commit is contained in:
Jiaming Yuan 2023-07-15 07:46:43 +08:00 committed by GitHub
parent 04aff3af8e
commit 9da5050643
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 16 additions and 16 deletions

View File

@ -153,7 +153,11 @@ def _expect(expectations: Sequence[Type], got: Type) -> str:
def _log_callback(msg: bytes) -> None:
"""Redirect logs from native library into Python console"""
print(py_str(msg))
smsg = py_str(msg)
if smsg.find("WARNING:") != -1:
warnings.warn(smsg, UserWarning)
return
print(smsg)
def _get_log_callback_func() -> Callable:

View File

@ -61,9 +61,7 @@ class TestLoadPickle:
rng = np.random.RandomState(1994)
X = rng.randn(10, 10)
y = rng.randn(10)
with tm.captured_output() as (out, err):
with pytest.warns(UserWarning, match="No visible GPU is found"):
# Test no thrust exception is thrown
with pytest.raises(xgb.core.XGBoostError):
with pytest.raises(xgb.core.XGBoostError, match="have at least one device"):
xgb.train({"tree_method": "gpu_hist"}, xgb.DMatrix(X, y))
assert out.getvalue().find("No visible GPU is found") != -1

View File

@ -3,6 +3,7 @@ import os
import pickle
import random
import tempfile
import warnings
from typing import Callable, Optional
import numpy as np
@ -1091,25 +1092,22 @@ def test_constraint_parameters():
)
@pytest.mark.filterwarnings("error")
def test_parameter_validation():
reg = xgb.XGBRegressor(foo='bar', verbosity=1)
reg = xgb.XGBRegressor(foo="bar", verbosity=1)
X = np.random.randn(10, 10)
y = np.random.randn(10)
with tm.captured_output() as (out, err):
with pytest.warns(Warning, match="foo"):
reg.fit(X, y)
output = out.getvalue().strip()
assert output.find('foo') != -1
reg = xgb.XGBRegressor(n_estimators=2, missing=3,
importance_type='gain', verbosity=1)
reg = xgb.XGBRegressor(
n_estimators=2, missing=3, importance_type="gain", verbosity=1
)
X = np.random.randn(10, 10)
y = np.random.randn(10)
with tm.captured_output() as (out, err):
reg.fit(X, y)
output = out.getvalue().strip()
assert len(output) == 0
with warnings.catch_warnings():
reg.fit(X, y)
def test_deprecate_position_arg():