Add back support for scipy.sparse.coo_matrix (#6162)
This commit is contained in:
parent
72ef553550
commit
bd2b1eabd0
@ -82,6 +82,15 @@ def _from_scipy_csc(data, missing, feature_names, feature_types):
|
|||||||
return handle, feature_names, feature_types
|
return handle, feature_names, feature_types
|
||||||
|
|
||||||
|
|
||||||
|
def _is_scipy_coo(data):
|
||||||
|
try:
|
||||||
|
import scipy
|
||||||
|
except ImportError:
|
||||||
|
scipy = None
|
||||||
|
return False
|
||||||
|
return isinstance(data, scipy.sparse.coo_matrix)
|
||||||
|
|
||||||
|
|
||||||
def _is_numpy_array(data):
|
def _is_numpy_array(data):
|
||||||
return isinstance(data, (np.ndarray, np.matrix))
|
return isinstance(data, (np.ndarray, np.matrix))
|
||||||
|
|
||||||
@ -504,6 +513,8 @@ def dispatch_data_backend(data, missing, threads,
|
|||||||
return _from_scipy_csr(data, missing, feature_names, feature_types)
|
return _from_scipy_csr(data, missing, feature_names, feature_types)
|
||||||
if _is_scipy_csc(data):
|
if _is_scipy_csc(data):
|
||||||
return _from_scipy_csc(data, missing, feature_names, feature_types)
|
return _from_scipy_csc(data, missing, feature_names, feature_types)
|
||||||
|
if _is_scipy_coo(data):
|
||||||
|
return _from_scipy_csr(data.tocsr(), missing, feature_names, feature_types)
|
||||||
if _is_numpy_array(data):
|
if _is_numpy_array(data):
|
||||||
return _from_numpy_array(data, missing, threads, feature_names,
|
return _from_numpy_array(data, missing, threads, feature_names,
|
||||||
feature_types)
|
feature_types)
|
||||||
|
|||||||
@ -76,6 +76,15 @@ class TestDMatrix(unittest.TestCase):
|
|||||||
assert dtrain.num_row() == 3
|
assert dtrain.num_row() == 3
|
||||||
assert dtrain.num_col() == 3
|
assert dtrain.num_col() == 3
|
||||||
|
|
||||||
|
def test_coo(self):
|
||||||
|
row = np.array([0, 2, 2, 0, 1, 2])
|
||||||
|
col = np.array([0, 0, 1, 2, 2, 2])
|
||||||
|
data = np.array([1, 2, 3, 4, 5, 6])
|
||||||
|
X = scipy.sparse.coo_matrix((data, (row, col)), shape=(3, 3))
|
||||||
|
dtrain = xgb.DMatrix(X)
|
||||||
|
assert dtrain.num_row() == 3
|
||||||
|
assert dtrain.num_col() == 3
|
||||||
|
|
||||||
def test_np_view(self):
|
def test_np_view(self):
|
||||||
# Sliced Float32 array
|
# Sliced Float32 array
|
||||||
y = np.array([12, 34, 56], np.float32)[::2]
|
y = np.array([12, 34, 56], np.float32)[::2]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user