Support Series and Python primitives in inplace_predict and QDM (#8547)
This commit is contained in:
@@ -5,7 +5,7 @@ import numpy as np
|
||||
import pandas as pd
|
||||
import pytest
|
||||
from scipy import sparse
|
||||
from xgboost.testing.data import np_dtypes
|
||||
from xgboost.testing.data import np_dtypes, pd_dtypes
|
||||
from xgboost.testing.shared import validate_leaf_output
|
||||
|
||||
import xgboost as xgb
|
||||
@@ -231,6 +231,7 @@ class TestInplacePredict:
|
||||
from_dmatrix = booster.predict(dtrain)
|
||||
np.testing.assert_allclose(from_dmatrix, from_inplace)
|
||||
|
||||
@pytest.mark.skipif(**tm.no_pandas())
|
||||
def test_dtypes(self) -> None:
|
||||
for orig, x in np_dtypes(self.rows, self.cols):
|
||||
predt_orig = self.booster.inplace_predict(orig)
|
||||
@@ -246,3 +247,17 @@ class TestInplacePredict:
|
||||
X: np.ndarray = np.array(orig, dtype=dtype)
|
||||
with pytest.raises(ValueError):
|
||||
self.booster.inplace_predict(X)
|
||||
|
||||
@pytest.mark.skipif(**tm.no_pandas())
|
||||
def test_pd_dtypes(self) -> None:
|
||||
from pandas.api.types import is_bool_dtype
|
||||
for orig, x in pd_dtypes():
|
||||
dtypes = orig.dtypes if isinstance(orig, pd.DataFrame) else [orig.dtypes]
|
||||
if isinstance(orig, pd.DataFrame) and is_bool_dtype(dtypes[0]):
|
||||
continue
|
||||
y = np.arange(x.shape[0])
|
||||
Xy = xgb.DMatrix(orig, y, enable_categorical=True)
|
||||
booster = xgb.train({"tree_method": "hist"}, Xy, num_boost_round=1)
|
||||
predt_orig = booster.inplace_predict(orig)
|
||||
predt = booster.inplace_predict(x)
|
||||
np.testing.assert_allclose(predt, predt_orig)
|
||||
|
||||
@@ -298,22 +298,29 @@ class TestPandas:
|
||||
assert 'auc' not in cv.columns[0]
|
||||
assert 'error' in cv.columns[0]
|
||||
|
||||
def test_nullable_type(self) -> None:
|
||||
@pytest.mark.parametrize("DMatrixT", [xgb.DMatrix, xgb.QuantileDMatrix])
|
||||
def test_nullable_type(self, DMatrixT) -> None:
|
||||
from pandas.api.types import is_categorical
|
||||
|
||||
for DMatrixT in (xgb.DMatrix, xgb.QuantileDMatrix):
|
||||
for orig, df in pd_dtypes():
|
||||
for orig, df in pd_dtypes():
|
||||
if hasattr(df.dtypes, "__iter__"):
|
||||
enable_categorical = any(is_categorical for dtype in df.dtypes)
|
||||
else:
|
||||
# series
|
||||
enable_categorical = is_categorical(df.dtype)
|
||||
|
||||
m_orig = DMatrixT(orig, enable_categorical=enable_categorical)
|
||||
# extension types
|
||||
m_etype = DMatrixT(df, enable_categorical=enable_categorical)
|
||||
# different from pd.BooleanDtype(), None is converted to False with bool
|
||||
if any(dtype == "bool" for dtype in orig.dtypes):
|
||||
assert not tm.predictor_equal(m_orig, m_etype)
|
||||
else:
|
||||
assert tm.predictor_equal(m_orig, m_etype)
|
||||
m_orig = DMatrixT(orig, enable_categorical=enable_categorical)
|
||||
# extension types
|
||||
m_etype = DMatrixT(df, enable_categorical=enable_categorical)
|
||||
# different from pd.BooleanDtype(), None is converted to False with bool
|
||||
if hasattr(orig.dtypes, "__iter__") and any(
|
||||
dtype == "bool" for dtype in orig.dtypes
|
||||
):
|
||||
assert not tm.predictor_equal(m_orig, m_etype)
|
||||
else:
|
||||
assert tm.predictor_equal(m_orig, m_etype)
|
||||
|
||||
if isinstance(df, pd.DataFrame):
|
||||
f0 = df["f0"]
|
||||
with pytest.raises(ValueError, match="Label contains NaN"):
|
||||
xgb.DMatrix(df, f0, enable_categorical=enable_categorical)
|
||||
|
||||
Reference in New Issue
Block a user