Fix missing data warning. (#5969)
* Fix data warning. * Add numpy/scipy test.
This commit is contained in:
@@ -3,7 +3,8 @@ import numpy as np
|
||||
import xgboost as xgb
|
||||
import unittest
|
||||
import scipy.sparse
|
||||
from scipy.sparse import rand
|
||||
import pytest
|
||||
from scipy.sparse import rand, csr_matrix
|
||||
|
||||
rng = np.random.RandomState(1)
|
||||
|
||||
@@ -12,6 +13,29 @@ rng = np.random.RandomState(1994)
|
||||
|
||||
|
||||
class TestDMatrix(unittest.TestCase):
|
||||
def test_warn_missing(self):
|
||||
from xgboost import data
|
||||
with pytest.warns(UserWarning):
|
||||
data._warn_unused_missing('uri', 4)
|
||||
|
||||
with pytest.warns(None) as record:
|
||||
data._warn_unused_missing('uri', None)
|
||||
data._warn_unused_missing('uri', np.nan)
|
||||
|
||||
assert len(record) == 0
|
||||
|
||||
with pytest.warns(None) as record:
|
||||
x = rng.randn(10, 10)
|
||||
y = rng.randn(10)
|
||||
|
||||
xgb.DMatrix(x, y, missing=4)
|
||||
|
||||
assert len(record) == 0
|
||||
|
||||
with pytest.warns(UserWarning):
|
||||
csr = csr_matrix(x)
|
||||
xgb.DMatrix(csr, y, missing=4)
|
||||
|
||||
def test_dmatrix_numpy_init(self):
|
||||
data = np.random.randn(5, 5)
|
||||
dm = xgb.DMatrix(data)
|
||||
|
||||
Reference in New Issue
Block a user