Support more scipy types. (#9881)

This commit is contained in:
Jiaming Yuan
2023-12-14 18:28:37 +08:00
committed by GitHub
parent cd473c9da3
commit 1aa8c8d9be
3 changed files with 143 additions and 50 deletions

View File

@@ -112,39 +112,6 @@ class TestDMatrix:
with pytest.raises(ValueError):
xgb.DMatrix(data)
def test_csr(self):
indptr = np.array([0, 2, 3, 6])
indices = np.array([0, 2, 2, 0, 1, 2])
data = np.array([1, 2, 3, 4, 5, 6])
X = scipy.sparse.csr_matrix((data, indices, indptr), shape=(3, 3))
dtrain = xgb.DMatrix(X)
assert dtrain.num_row() == 3
assert dtrain.num_col() == 3
def test_csc(self):
row = np.array([0, 2, 2, 0, 1, 2])
col = np.array([0, 0, 1, 2, 2, 2])
data = np.array([1, 2, 3, 4, 5, 6])
X = scipy.sparse.csc_matrix((data, (row, col)), shape=(3, 3))
dtrain = xgb.DMatrix(X)
assert dtrain.num_row() == 3
assert dtrain.num_col() == 3
indptr = np.array([0, 3, 5])
data = np.array([0, 1, 2, 3, 4])
row_idx = np.array([0, 1, 2, 0, 2])
X = scipy.sparse.csc_matrix((data, row_idx, indptr), shape=(3, 2))
assert tm.predictor_equal(xgb.DMatrix(X.tocsr()), xgb.DMatrix(X))
def test_coo(self):
row = np.array([0, 2, 2, 0, 1, 2])
col = np.array([0, 0, 1, 2, 2, 2])
data = np.array([1, 2, 3, 4, 5, 6])
X = scipy.sparse.coo_matrix((data, (row, col)), shape=(3, 3))
dtrain = xgb.DMatrix(X)
assert dtrain.num_row() == 3
assert dtrain.num_col() == 3
def test_np_view(self):
# Sliced Float32 array
y = np.array([12, 34, 56], np.float32)[::2]