Support half type for pandas. (#8481)

This commit is contained in:
Jiaming Yuan
2022-11-24 12:47:40 +08:00
committed by GitHub
parent e07245f110
commit 8f97c92541
5 changed files with 109 additions and 53 deletions

View File

@@ -5,6 +5,7 @@ import numpy as np
import pandas as pd
import pytest
from scipy import sparse
from xgboost.testing.data import np_dtypes
from xgboost.testing.shared import validate_leaf_output
import xgboost as xgb
@@ -230,46 +231,10 @@ class TestInplacePredict:
from_dmatrix = booster.predict(dtrain)
np.testing.assert_allclose(from_dmatrix, from_inplace)
def test_dtypes(self):
orig = self.rng.randint(low=0, high=127, size=self.rows * self.cols).reshape(
self.rows, self.cols
)
predt_orig = self.booster.inplace_predict(orig)
# all primitive types in numpy
for dtype in [
np.int32,
np.int64,
np.byte,
np.short,
np.intc,
np.int_,
np.longlong,
np.uint32,
np.uint64,
np.ubyte,
np.ushort,
np.uintc,
np.uint,
np.ulonglong,
np.float16,
np.float32,
np.float64,
np.half,
np.single,
np.double,
]:
X = np.array(orig, dtype=dtype)
predt = self.booster.inplace_predict(X)
np.testing.assert_allclose(predt, predt_orig)
# boolean
orig = self.rng.binomial(1, 0.5, size=self.rows * self.cols).reshape(
self.rows, self.cols
)
predt_orig = self.booster.inplace_predict(orig)
for dtype in [np.bool8, np.bool_]:
X = np.array(orig, dtype=dtype)
predt = self.booster.inplace_predict(X)
def test_dtypes(self) -> None:
for orig, x in np_dtypes(self.rows, self.cols):
predt_orig = self.booster.inplace_predict(orig)
predt = self.booster.inplace_predict(x)
np.testing.assert_allclose(predt, predt_orig)
# unsupported types
@@ -278,6 +243,6 @@ class TestInplacePredict:
np.complex64,
np.complex128,
]:
X = np.array(orig, dtype=dtype)
X: np.ndarray = np.array(orig, dtype=dtype)
with pytest.raises(ValueError):
self.booster.inplace_predict(X)