Support half type for pandas. (#8481)
This commit is contained in:
@@ -11,6 +11,7 @@ from xgboost.testing import (
|
||||
make_categorical,
|
||||
make_sparse_regression,
|
||||
)
|
||||
from xgboost.testing.data import np_dtypes
|
||||
|
||||
import xgboost as xgb
|
||||
|
||||
@@ -238,3 +239,25 @@ class TestQuantileDMatrix:
|
||||
np.testing.assert_allclose(
|
||||
booster.predict(qdm), booster.predict(xgb.DMatrix(qdm.get_data()))
|
||||
)
|
||||
|
||||
def test_dtypes(self) -> None:
|
||||
n_samples = 128
|
||||
n_features = 16
|
||||
for orig, x in np_dtypes(n_samples, n_features):
|
||||
m0 = xgb.QuantileDMatrix(orig)
|
||||
m1 = xgb.QuantileDMatrix(x)
|
||||
csr0 = m0.get_data()
|
||||
csr1 = m1.get_data()
|
||||
np.testing.assert_allclose(csr0.data, csr1.data)
|
||||
np.testing.assert_allclose(csr0.indptr, csr1.indptr)
|
||||
np.testing.assert_allclose(csr0.indices, csr1.indices)
|
||||
|
||||
# unsupported types
|
||||
for dtype in [
|
||||
np.string_,
|
||||
np.complex64,
|
||||
np.complex128,
|
||||
]:
|
||||
X: np.ndarray = np.array(orig, dtype=dtype)
|
||||
with pytest.raises(ValueError):
|
||||
xgb.QuantileDMatrix(X)
|
||||
|
||||
Reference in New Issue
Block a user