Set enable_categorical to True in predict. (#8592)

This commit is contained in:
Jiaming Yuan 2022-12-15 05:27:06 +08:00 committed by GitHub
parent 7a07dcf651
commit 001e663d42
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 4 additions and 20 deletions

View File

@ -278,22 +278,6 @@ def _check_call(ret: int) -> None:
raise XGBoostError(py_str(_LIB.XGBGetLastError()))
def _has_categorical(booster: "Booster", data: DataType) -> bool:
"""Check whether the booster and input data for prediction contain categorical data.
"""
from .data import _is_cudf_df, _is_pandas_df
if _is_pandas_df(data) or _is_cudf_df(data):
ft = booster.feature_types
if ft is None:
enable_categorical = False
else:
enable_categorical = any(f == "c" for f in ft)
else:
enable_categorical = False
return enable_categorical
def build_info() -> dict:
"""Build information of XGBoost. The returned value format is not stable. Also, please
note that build time dependency is not the same as runtime dependency. For instance,
@ -2278,7 +2262,7 @@ class Booster:
_transform_pandas_df,
)
enable_categorical = _has_categorical(self, data)
enable_categorical = True
if _is_pandas_df(data):
data, fns, _ = _transform_pandas_df(data, enable_categorical)
if validate_features:

View File

@ -72,7 +72,6 @@ from .core import (
QuantileDMatrix,
_deprecate_positional_args,
_expect,
_has_categorical,
)
from .sklearn import (
XGBClassifier,
@ -1190,7 +1189,7 @@ def _infer_predict_output(
kwargs = kwargs.copy()
if kwargs.pop("predict_type") == "margin":
kwargs["output_margin"] = True
m = DMatrix(test_sample)
m = DMatrix(test_sample, enable_categorical=True)
# generated DMatrix doesn't have feature name, so no validation.
test_predt = booster.predict(m, validate_features=False, **kwargs)
n_columns = test_predt.shape[1] if len(test_predt.shape) > 1 else 1
@ -1247,7 +1246,7 @@ async def _predict_async(
m = DMatrix(
data=partition,
missing=missing,
enable_categorical=_has_categorical(booster, partition),
enable_categorical=True,
)
predt = booster.predict(
data=m,
@ -1315,6 +1314,7 @@ async def _predict_async(
base_margin=base_margin,
feature_names=feature_names,
feature_types=feature_types,
enable_categorical=True,
)
predt = booster.predict(
m,