Set enable_categorical to True in predict. (#8592)

This commit is contained in:
Jiaming Yuan 2022-12-15 05:27:06 +08:00 committed by GitHub
parent 7a07dcf651
commit 001e663d42
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 4 additions and 20 deletions

View File

@ -278,22 +278,6 @@ def _check_call(ret: int) -> None:
raise XGBoostError(py_str(_LIB.XGBGetLastError())) raise XGBoostError(py_str(_LIB.XGBGetLastError()))
def _has_categorical(booster: "Booster", data: DataType) -> bool:
"""Check whether the booster and input data for prediction contain categorical data.
"""
from .data import _is_cudf_df, _is_pandas_df
if _is_pandas_df(data) or _is_cudf_df(data):
ft = booster.feature_types
if ft is None:
enable_categorical = False
else:
enable_categorical = any(f == "c" for f in ft)
else:
enable_categorical = False
return enable_categorical
def build_info() -> dict: def build_info() -> dict:
"""Build information of XGBoost. The returned value format is not stable. Also, please """Build information of XGBoost. The returned value format is not stable. Also, please
note that build time dependency is not the same as runtime dependency. For instance, note that build time dependency is not the same as runtime dependency. For instance,
@ -2278,7 +2262,7 @@ class Booster:
_transform_pandas_df, _transform_pandas_df,
) )
enable_categorical = _has_categorical(self, data) enable_categorical = True
if _is_pandas_df(data): if _is_pandas_df(data):
data, fns, _ = _transform_pandas_df(data, enable_categorical) data, fns, _ = _transform_pandas_df(data, enable_categorical)
if validate_features: if validate_features:

View File

@ -72,7 +72,6 @@ from .core import (
QuantileDMatrix, QuantileDMatrix,
_deprecate_positional_args, _deprecate_positional_args,
_expect, _expect,
_has_categorical,
) )
from .sklearn import ( from .sklearn import (
XGBClassifier, XGBClassifier,
@ -1190,7 +1189,7 @@ def _infer_predict_output(
kwargs = kwargs.copy() kwargs = kwargs.copy()
if kwargs.pop("predict_type") == "margin": if kwargs.pop("predict_type") == "margin":
kwargs["output_margin"] = True kwargs["output_margin"] = True
m = DMatrix(test_sample) m = DMatrix(test_sample, enable_categorical=True)
# generated DMatrix doesn't have feature name, so no validation. # generated DMatrix doesn't have feature name, so no validation.
test_predt = booster.predict(m, validate_features=False, **kwargs) test_predt = booster.predict(m, validate_features=False, **kwargs)
n_columns = test_predt.shape[1] if len(test_predt.shape) > 1 else 1 n_columns = test_predt.shape[1] if len(test_predt.shape) > 1 else 1
@ -1247,7 +1246,7 @@ async def _predict_async(
m = DMatrix( m = DMatrix(
data=partition, data=partition,
missing=missing, missing=missing,
enable_categorical=_has_categorical(booster, partition), enable_categorical=True,
) )
predt = booster.predict( predt = booster.predict(
data=m, data=m,
@ -1315,6 +1314,7 @@ async def _predict_async(
base_margin=base_margin, base_margin=base_margin,
feature_names=feature_names, feature_names=feature_names,
feature_types=feature_types, feature_types=feature_types,
enable_categorical=True,
) )
predt = booster.predict( predt = booster.predict(
m, m,