diff --git a/python-package/xgboost/data.py b/python-package/xgboost/data.py index 0caab3ec1..b29eac795 100644 --- a/python-package/xgboost/data.py +++ b/python-package/xgboost/data.py @@ -15,10 +15,10 @@ c_bst_ulong = ctypes.c_uint64 # pylint: disable=invalid-name def _warn_unused_missing(data, missing): - if (not np.isnan(missing)) or (missing is None): + if (missing is not None) and (not np.isnan(missing)): warnings.warn( '`missing` is not used for current input data type:' + - str(type(data))) + str(type(data)), UserWarning) def _check_complex(data): diff --git a/tests/python/test_dmatrix.py b/tests/python/test_dmatrix.py index 8daf7f357..ecf5f6041 100644 --- a/tests/python/test_dmatrix.py +++ b/tests/python/test_dmatrix.py @@ -3,7 +3,8 @@ import numpy as np import xgboost as xgb import unittest import scipy.sparse -from scipy.sparse import rand +import pytest +from scipy.sparse import rand, csr_matrix rng = np.random.RandomState(1) @@ -12,6 +13,29 @@ rng = np.random.RandomState(1994) class TestDMatrix(unittest.TestCase): + def test_warn_missing(self): + from xgboost import data + with pytest.warns(UserWarning): + data._warn_unused_missing('uri', 4) + + with pytest.warns(None) as record: + data._warn_unused_missing('uri', None) + data._warn_unused_missing('uri', np.nan) + + assert len(record) == 0 + + with pytest.warns(None) as record: + x = rng.randn(10, 10) + y = rng.randn(10) + + xgb.DMatrix(x, y, missing=4) + + assert len(record) == 0 + + with pytest.warns(UserWarning): + csr = csr_matrix(x) + xgb.DMatrix(csr, y, missing=4) + def test_dmatrix_numpy_init(self): data = np.random.randn(5, 5) dm = xgb.DMatrix(data)