Support half type for pandas. (#8481)

This commit is contained in:
Jiaming Yuan
2022-11-24 12:47:40 +08:00
committed by GitHub
parent e07245f110
commit 8f97c92541
5 changed files with 109 additions and 53 deletions

View File

@@ -11,6 +11,7 @@ from xgboost.testing import (
make_categorical,
make_sparse_regression,
)
from xgboost.testing.data import np_dtypes
import xgboost as xgb
@@ -238,3 +239,25 @@ class TestQuantileDMatrix:
np.testing.assert_allclose(
booster.predict(qdm), booster.predict(xgb.DMatrix(qdm.get_data()))
)
def test_dtypes(self) -> None:
n_samples = 128
n_features = 16
for orig, x in np_dtypes(n_samples, n_features):
m0 = xgb.QuantileDMatrix(orig)
m1 = xgb.QuantileDMatrix(x)
csr0 = m0.get_data()
csr1 = m1.get_data()
np.testing.assert_allclose(csr0.data, csr1.data)
np.testing.assert_allclose(csr0.indptr, csr1.indptr)
np.testing.assert_allclose(csr0.indices, csr1.indices)
# unsupported types
for dtype in [
np.string_,
np.complex64,
np.complex128,
]:
X: np.ndarray = np.array(orig, dtype=dtype)
with pytest.raises(ValueError):
xgb.QuantileDMatrix(X)