Support half type for pandas. (#8481)
This commit is contained in:
@@ -5,6 +5,7 @@ import numpy as np
|
||||
import pandas as pd
|
||||
import pytest
|
||||
from scipy import sparse
|
||||
from xgboost.testing.data import np_dtypes
|
||||
from xgboost.testing.shared import validate_leaf_output
|
||||
|
||||
import xgboost as xgb
|
||||
@@ -230,46 +231,10 @@ class TestInplacePredict:
|
||||
from_dmatrix = booster.predict(dtrain)
|
||||
np.testing.assert_allclose(from_dmatrix, from_inplace)
|
||||
|
||||
def test_dtypes(self):
|
||||
orig = self.rng.randint(low=0, high=127, size=self.rows * self.cols).reshape(
|
||||
self.rows, self.cols
|
||||
)
|
||||
predt_orig = self.booster.inplace_predict(orig)
|
||||
# all primitive types in numpy
|
||||
for dtype in [
|
||||
np.int32,
|
||||
np.int64,
|
||||
np.byte,
|
||||
np.short,
|
||||
np.intc,
|
||||
np.int_,
|
||||
np.longlong,
|
||||
np.uint32,
|
||||
np.uint64,
|
||||
np.ubyte,
|
||||
np.ushort,
|
||||
np.uintc,
|
||||
np.uint,
|
||||
np.ulonglong,
|
||||
np.float16,
|
||||
np.float32,
|
||||
np.float64,
|
||||
np.half,
|
||||
np.single,
|
||||
np.double,
|
||||
]:
|
||||
X = np.array(orig, dtype=dtype)
|
||||
predt = self.booster.inplace_predict(X)
|
||||
np.testing.assert_allclose(predt, predt_orig)
|
||||
|
||||
# boolean
|
||||
orig = self.rng.binomial(1, 0.5, size=self.rows * self.cols).reshape(
|
||||
self.rows, self.cols
|
||||
)
|
||||
predt_orig = self.booster.inplace_predict(orig)
|
||||
for dtype in [np.bool8, np.bool_]:
|
||||
X = np.array(orig, dtype=dtype)
|
||||
predt = self.booster.inplace_predict(X)
|
||||
def test_dtypes(self) -> None:
|
||||
for orig, x in np_dtypes(self.rows, self.cols):
|
||||
predt_orig = self.booster.inplace_predict(orig)
|
||||
predt = self.booster.inplace_predict(x)
|
||||
np.testing.assert_allclose(predt, predt_orig)
|
||||
|
||||
# unsupported types
|
||||
@@ -278,6 +243,6 @@ class TestInplacePredict:
|
||||
np.complex64,
|
||||
np.complex128,
|
||||
]:
|
||||
X = np.array(orig, dtype=dtype)
|
||||
X: np.ndarray = np.array(orig, dtype=dtype)
|
||||
with pytest.raises(ValueError):
|
||||
self.booster.inplace_predict(X)
|
||||
|
||||
Reference in New Issue
Block a user