Support half type for pandas. (#8481)

This commit is contained in:
Jiaming Yuan
2022-11-24 12:47:40 +08:00
committed by GitHub
parent e07245f110
commit 8f97c92541
5 changed files with 109 additions and 53 deletions

View File

@@ -6,6 +6,7 @@ import pytest
import scipy.sparse
from hypothesis import given, settings, strategies
from scipy.sparse import csr_matrix, rand
from xgboost.testing.data import np_dtypes
import xgboost as xgb
from xgboost import testing as tm
@@ -453,3 +454,15 @@ class TestDMatrix:
np.testing.assert_equal(csr.indptr, ret.indptr)
np.testing.assert_equal(csr.data, ret.data)
np.testing.assert_equal(csr.indices, ret.indices)
def test_dtypes(self) -> None:
n_samples = 128
n_features = 16
for orig, x in np_dtypes(n_samples, n_features):
m0 = xgb.DMatrix(orig)
m1 = xgb.DMatrix(x)
csr0 = m0.get_data()
csr1 = m1.get_data()
np.testing.assert_allclose(csr0.data, csr1.data)
np.testing.assert_allclose(csr0.indptr, csr1.indptr)
np.testing.assert_allclose(csr0.indices, csr1.indices)