Fix missing data warning. (#5969)

* Fix data warning.

* Add numpy/scipy test.
This commit is contained in:
Jiaming Yuan 2020-08-05 16:19:12 +08:00 committed by GitHub
parent 8599f87597
commit dde9c5aaff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 27 additions and 3 deletions

View File

@ -15,10 +15,10 @@ c_bst_ulong = ctypes.c_uint64 # pylint: disable=invalid-name
def _warn_unused_missing(data, missing): def _warn_unused_missing(data, missing):
if (not np.isnan(missing)) or (missing is None): if (missing is not None) and (not np.isnan(missing)):
warnings.warn( warnings.warn(
'`missing` is not used for current input data type:' + '`missing` is not used for current input data type:' +
str(type(data))) str(type(data)), UserWarning)
def _check_complex(data): def _check_complex(data):

View File

@ -3,7 +3,8 @@ import numpy as np
import xgboost as xgb import xgboost as xgb
import unittest import unittest
import scipy.sparse import scipy.sparse
from scipy.sparse import rand import pytest
from scipy.sparse import rand, csr_matrix
rng = np.random.RandomState(1) rng = np.random.RandomState(1)
@ -12,6 +13,29 @@ rng = np.random.RandomState(1994)
class TestDMatrix(unittest.TestCase): class TestDMatrix(unittest.TestCase):
def test_warn_missing(self):
from xgboost import data
with pytest.warns(UserWarning):
data._warn_unused_missing('uri', 4)
with pytest.warns(None) as record:
data._warn_unused_missing('uri', None)
data._warn_unused_missing('uri', np.nan)
assert len(record) == 0
with pytest.warns(None) as record:
x = rng.randn(10, 10)
y = rng.randn(10)
xgb.DMatrix(x, y, missing=4)
assert len(record) == 0
with pytest.warns(UserWarning):
csr = csr_matrix(x)
xgb.DMatrix(csr, y, missing=4)
def test_dmatrix_numpy_init(self): def test_dmatrix_numpy_init(self):
data = np.random.randn(5, 5) data = np.random.randn(5, 5)
dm = xgb.DMatrix(data) dm = xgb.DMatrix(data)