Obtain CSR matrix from DMatrix. (#8269)
This commit is contained in:
@@ -1,13 +1,14 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import os
|
||||
import tempfile
|
||||
import numpy as np
|
||||
import xgboost as xgb
|
||||
import scipy.sparse
|
||||
import pytest
|
||||
from scipy.sparse import rand, csr_matrix
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import scipy.sparse
|
||||
import testing as tm
|
||||
from hypothesis import given, settings, strategies
|
||||
from scipy.sparse import csr_matrix, rand
|
||||
|
||||
import xgboost as xgb
|
||||
|
||||
rng = np.random.RandomState(1)
|
||||
|
||||
@@ -433,3 +434,22 @@ class TestDMatrix:
|
||||
|
||||
def test_base_margin(self):
|
||||
set_base_margin_info(np.asarray, xgb.DMatrix, "hist")
|
||||
|
||||
@given(
|
||||
strategies.integers(0, 1000),
|
||||
strategies.integers(0, 100),
|
||||
strategies.fractions(0, 1),
|
||||
)
|
||||
@settings(deadline=None, print_blob=True)
|
||||
def test_to_csr(self, n_samples, n_features, sparsity) -> None:
|
||||
if n_samples == 0 or n_features == 0 or sparsity == 1.0:
|
||||
csr = scipy.sparse.csr_matrix(np.empty((0, 0)))
|
||||
else:
|
||||
csr = tm.make_sparse_regression(n_samples, n_features, sparsity, False)[
|
||||
0
|
||||
].astype(np.float32)
|
||||
m = xgb.DMatrix(data=csr)
|
||||
ret = m.get_data()
|
||||
np.testing.assert_equal(csr.indptr, ret.indptr)
|
||||
np.testing.assert_equal(csr.data, ret.data)
|
||||
np.testing.assert_equal(csr.indices, ret.indices)
|
||||
|
||||
Reference in New Issue
Block a user