Fix missing data warning. (#5969)
* Fix data warning. * Add numpy/scipy test.
This commit is contained in:
parent
8599f87597
commit
dde9c5aaff
@ -15,10 +15,10 @@ c_bst_ulong = ctypes.c_uint64 # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def _warn_unused_missing(data, missing):
|
||||
if (not np.isnan(missing)) or (missing is None):
|
||||
if (missing is not None) and (not np.isnan(missing)):
|
||||
warnings.warn(
|
||||
'`missing` is not used for current input data type:' +
|
||||
str(type(data)))
|
||||
str(type(data)), UserWarning)
|
||||
|
||||
|
||||
def _check_complex(data):
|
||||
|
||||
@ -3,7 +3,8 @@ import numpy as np
|
||||
import xgboost as xgb
|
||||
import unittest
|
||||
import scipy.sparse
|
||||
from scipy.sparse import rand
|
||||
import pytest
|
||||
from scipy.sparse import rand, csr_matrix
|
||||
|
||||
rng = np.random.RandomState(1)
|
||||
|
||||
@ -12,6 +13,29 @@ rng = np.random.RandomState(1994)
|
||||
|
||||
|
||||
class TestDMatrix(unittest.TestCase):
|
||||
def test_warn_missing(self):
|
||||
from xgboost import data
|
||||
with pytest.warns(UserWarning):
|
||||
data._warn_unused_missing('uri', 4)
|
||||
|
||||
with pytest.warns(None) as record:
|
||||
data._warn_unused_missing('uri', None)
|
||||
data._warn_unused_missing('uri', np.nan)
|
||||
|
||||
assert len(record) == 0
|
||||
|
||||
with pytest.warns(None) as record:
|
||||
x = rng.randn(10, 10)
|
||||
y = rng.randn(10)
|
||||
|
||||
xgb.DMatrix(x, y, missing=4)
|
||||
|
||||
assert len(record) == 0
|
||||
|
||||
with pytest.warns(UserWarning):
|
||||
csr = csr_matrix(x)
|
||||
xgb.DMatrix(csr, y, missing=4)
|
||||
|
||||
def test_dmatrix_numpy_init(self):
|
||||
data = np.random.randn(5, 5)
|
||||
dm = xgb.DMatrix(data)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user