Support categorical data with pandas Dataframe in inplace prediction (#7322)
This commit is contained in:
parent
8e619010d0
commit
f56e2e9a66
@ -1973,13 +1973,6 @@ class Booster(object):
|
||||
preds = ctypes.POINTER(ctypes.c_float)()
|
||||
|
||||
# once caching is supported, we can pass id(data) as cache id.
|
||||
try:
|
||||
import pandas as pd
|
||||
|
||||
if isinstance(data, pd.DataFrame):
|
||||
data = data.values
|
||||
except ImportError:
|
||||
pass
|
||||
args = {
|
||||
"type": 0,
|
||||
"training": False,
|
||||
@ -2014,7 +2007,15 @@ class Booster(object):
|
||||
f"got {data.shape[1]}"
|
||||
)
|
||||
|
||||
from .data import _is_pandas_df, _transform_pandas_df
|
||||
from .data import _array_interface
|
||||
if _is_pandas_df(data):
|
||||
ft = self.feature_types
|
||||
if ft is None:
|
||||
enable_categorical = False
|
||||
else:
|
||||
enable_categorical = any(f == "c" for f in ft)
|
||||
data, _, _ = _transform_pandas_df(data, enable_categorical)
|
||||
if isinstance(data, np.ndarray):
|
||||
from .data import _ensure_np_dtype
|
||||
data, _ = _ensure_np_dtype(data, data.dtype)
|
||||
@ -2068,7 +2069,6 @@ class Booster(object):
|
||||
return _prediction_output(shape, dims, preds, True)
|
||||
if lazy_isinstance(data, "cudf.core.dataframe", "DataFrame"):
|
||||
from .data import _cudf_array_interfaces
|
||||
|
||||
_, interfaces_str = _cudf_array_interfaces(data)
|
||||
_check_call(
|
||||
_LIB.XGBoosterPredictFromCudaColumnar(
|
||||
|
||||
@ -289,16 +289,15 @@ def _transform_pandas_df(
|
||||
def _from_pandas_df(
|
||||
data,
|
||||
enable_categorical: bool,
|
||||
missing,
|
||||
nthread,
|
||||
missing: float,
|
||||
nthread: int,
|
||||
feature_names: Optional[List[str]],
|
||||
feature_types: Optional[List[str]],
|
||||
):
|
||||
) -> Tuple[ctypes.c_void_p, Optional[List[str]], Optional[List[str]]]:
|
||||
data, feature_names, feature_types = _transform_pandas_df(
|
||||
data, enable_categorical, feature_names, feature_types)
|
||||
return _from_numpy_array(data, missing, nthread, feature_names,
|
||||
feature_types)
|
||||
|
||||
data, enable_categorical, feature_names, feature_types
|
||||
)
|
||||
return _from_numpy_array(data, missing, nthread, feature_names, feature_types)
|
||||
|
||||
def _is_pandas_series(data):
|
||||
try:
|
||||
|
||||
@ -809,11 +809,7 @@ class XGBModel(XGBModelBase):
|
||||
# Inplace predict doesn't handle as many data types as DMatrix, but it's
|
||||
# sufficient for dask interface where input is simpiler.
|
||||
predictor = self.get_params().get("predictor", None)
|
||||
if (
|
||||
not self.enable_categorical
|
||||
and predictor in ("auto", None)
|
||||
and self.booster != "gblinear"
|
||||
):
|
||||
if predictor in ("auto", None) and self.booster != "gblinear":
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@ -44,9 +44,12 @@ def test_num_parallel_tree():
|
||||
|
||||
|
||||
@pytest.mark.skipif(**tm.no_pandas())
|
||||
@pytest.mark.skipif(**tm.no_cudf())
|
||||
@pytest.mark.skipif(**tm.no_sklearn())
|
||||
def test_categorical():
|
||||
import pandas as pd
|
||||
import cudf
|
||||
import cupy as cp
|
||||
from sklearn.datasets import load_svmlight_file
|
||||
|
||||
data_dir = os.path.join(tm.PROJECT_ROOT, "demo", "data")
|
||||
@ -59,7 +62,6 @@ def test_categorical():
|
||||
)
|
||||
X = pd.DataFrame(X.todense()).astype("category")
|
||||
clf.fit(X, y)
|
||||
assert not clf._can_use_inplace_predict()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
model = os.path.join(tempdir, "categorial.json")
|
||||
@ -74,3 +76,25 @@ def test_categorical():
|
||||
)
|
||||
assert categories_sizes.shape[0] != 0
|
||||
np.testing.assert_allclose(categories_sizes, 1)
|
||||
|
||||
def check_predt(X, y):
|
||||
reg = xgb.XGBRegressor(
|
||||
tree_method="gpu_hist", enable_categorical=True, n_estimators=64
|
||||
)
|
||||
reg.fit(X, y)
|
||||
predts = reg.predict(X)
|
||||
booster = reg.get_booster()
|
||||
assert "c" in booster.feature_types
|
||||
assert len(booster.feature_types) == 1
|
||||
inp_predts = booster.inplace_predict(X)
|
||||
if isinstance(inp_predts, cp.ndarray):
|
||||
inp_predts = cp.asnumpy(inp_predts)
|
||||
np.testing.assert_allclose(predts, inp_predts)
|
||||
|
||||
y = [1, 2, 3]
|
||||
X = pd.DataFrame({"f0": ["a", "b", "c"]})
|
||||
X["f0"] = X["f0"].astype("category")
|
||||
check_predt(X, y)
|
||||
|
||||
X = cudf.DataFrame(X)
|
||||
check_predt(X, y)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user