Support Series and Python primitives in inplace_predict and QDM (#8547)

This commit is contained in:
Jiaming Yuan
2022-12-17 00:15:15 +08:00
committed by GitHub
parent a10e4cba4e
commit f6effa1734
5 changed files with 84 additions and 46 deletions

View File

@@ -5,7 +5,7 @@ import numpy as np
import pandas as pd
import pytest
from scipy import sparse
from xgboost.testing.data import np_dtypes
from xgboost.testing.data import np_dtypes, pd_dtypes
from xgboost.testing.shared import validate_leaf_output
import xgboost as xgb
@@ -231,6 +231,7 @@ class TestInplacePredict:
from_dmatrix = booster.predict(dtrain)
np.testing.assert_allclose(from_dmatrix, from_inplace)
@pytest.mark.skipif(**tm.no_pandas())
def test_dtypes(self) -> None:
for orig, x in np_dtypes(self.rows, self.cols):
predt_orig = self.booster.inplace_predict(orig)
@@ -246,3 +247,17 @@ class TestInplacePredict:
X: np.ndarray = np.array(orig, dtype=dtype)
with pytest.raises(ValueError):
self.booster.inplace_predict(X)
@pytest.mark.skipif(**tm.no_pandas())
def test_pd_dtypes(self) -> None:
from pandas.api.types import is_bool_dtype
for orig, x in pd_dtypes():
dtypes = orig.dtypes if isinstance(orig, pd.DataFrame) else [orig.dtypes]
if isinstance(orig, pd.DataFrame) and is_bool_dtype(dtypes[0]):
continue
y = np.arange(x.shape[0])
Xy = xgb.DMatrix(orig, y, enable_categorical=True)
booster = xgb.train({"tree_method": "hist"}, Xy, num_boost_round=1)
predt_orig = booster.inplace_predict(orig)
predt = booster.inplace_predict(x)
np.testing.assert_allclose(predt, predt_orig)