Support Series and Python primitives in inplace_predict and QDM (#8547)

This commit is contained in:
Jiaming Yuan
2022-12-17 00:15:15 +08:00
committed by GitHub
parent a10e4cba4e
commit f6effa1734
5 changed files with 84 additions and 46 deletions

View File

@@ -298,22 +298,29 @@ class TestPandas:
assert 'auc' not in cv.columns[0]
assert 'error' in cv.columns[0]
def test_nullable_type(self) -> None:
@pytest.mark.parametrize("DMatrixT", [xgb.DMatrix, xgb.QuantileDMatrix])
def test_nullable_type(self, DMatrixT) -> None:
from pandas.api.types import is_categorical
for DMatrixT in (xgb.DMatrix, xgb.QuantileDMatrix):
for orig, df in pd_dtypes():
for orig, df in pd_dtypes():
if hasattr(df.dtypes, "__iter__"):
enable_categorical = any(is_categorical for dtype in df.dtypes)
else:
# series
enable_categorical = is_categorical(df.dtype)
m_orig = DMatrixT(orig, enable_categorical=enable_categorical)
# extension types
m_etype = DMatrixT(df, enable_categorical=enable_categorical)
# different from pd.BooleanDtype(), None is converted to False with bool
if any(dtype == "bool" for dtype in orig.dtypes):
assert not tm.predictor_equal(m_orig, m_etype)
else:
assert tm.predictor_equal(m_orig, m_etype)
m_orig = DMatrixT(orig, enable_categorical=enable_categorical)
# extension types
m_etype = DMatrixT(df, enable_categorical=enable_categorical)
# different from pd.BooleanDtype(), None is converted to False with bool
if hasattr(orig.dtypes, "__iter__") and any(
dtype == "bool" for dtype in orig.dtypes
):
assert not tm.predictor_equal(m_orig, m_etype)
else:
assert tm.predictor_equal(m_orig, m_etype)
if isinstance(df, pd.DataFrame):
f0 = df["f0"]
with pytest.raises(ValueError, match="Label contains NaN"):
xgb.DMatrix(df, f0, enable_categorical=enable_categorical)