Obtain CSR matrix from DMatrix. (#8269)

This commit is contained in:
Jiaming Yuan
2022-09-29 20:41:43 +08:00
committed by GitHub
parent b14c44ee5e
commit 55cf24cc32
22 changed files with 400 additions and 74 deletions

View File

@@ -1,13 +1,14 @@
# -*- coding: utf-8 -*-
import os
import tempfile
import numpy as np
import xgboost as xgb
import scipy.sparse
import pytest
from scipy.sparse import rand, csr_matrix
import numpy as np
import pytest
import scipy.sparse
import testing as tm
from hypothesis import given, settings, strategies
from scipy.sparse import csr_matrix, rand
import xgboost as xgb
rng = np.random.RandomState(1)
@@ -433,3 +434,22 @@ class TestDMatrix:
def test_base_margin(self):
set_base_margin_info(np.asarray, xgb.DMatrix, "hist")
@given(
strategies.integers(0, 1000),
strategies.integers(0, 100),
strategies.fractions(0, 1),
)
@settings(deadline=None, print_blob=True)
def test_to_csr(self, n_samples, n_features, sparsity) -> None:
if n_samples == 0 or n_features == 0 or sparsity == 1.0:
csr = scipy.sparse.csr_matrix(np.empty((0, 0)))
else:
csr = tm.make_sparse_regression(n_samples, n_features, sparsity, False)[
0
].astype(np.float32)
m = xgb.DMatrix(data=csr)
ret = m.get_data()
np.testing.assert_equal(csr.indptr, ret.indptr)
np.testing.assert_equal(csr.data, ret.data)
np.testing.assert_equal(csr.indices, ret.indices)