Set enable_categorical to True in predict. (#8592)
This commit is contained in:
parent
7a07dcf651
commit
001e663d42
@ -278,22 +278,6 @@ def _check_call(ret: int) -> None:
|
|||||||
raise XGBoostError(py_str(_LIB.XGBGetLastError()))
|
raise XGBoostError(py_str(_LIB.XGBGetLastError()))
|
||||||
|
|
||||||
|
|
||||||
def _has_categorical(booster: "Booster", data: DataType) -> bool:
|
|
||||||
"""Check whether the booster and input data for prediction contain categorical data.
|
|
||||||
|
|
||||||
"""
|
|
||||||
from .data import _is_cudf_df, _is_pandas_df
|
|
||||||
if _is_pandas_df(data) or _is_cudf_df(data):
|
|
||||||
ft = booster.feature_types
|
|
||||||
if ft is None:
|
|
||||||
enable_categorical = False
|
|
||||||
else:
|
|
||||||
enable_categorical = any(f == "c" for f in ft)
|
|
||||||
else:
|
|
||||||
enable_categorical = False
|
|
||||||
return enable_categorical
|
|
||||||
|
|
||||||
|
|
||||||
def build_info() -> dict:
|
def build_info() -> dict:
|
||||||
"""Build information of XGBoost. The returned value format is not stable. Also, please
|
"""Build information of XGBoost. The returned value format is not stable. Also, please
|
||||||
note that build time dependency is not the same as runtime dependency. For instance,
|
note that build time dependency is not the same as runtime dependency. For instance,
|
||||||
@ -2278,7 +2262,7 @@ class Booster:
|
|||||||
_transform_pandas_df,
|
_transform_pandas_df,
|
||||||
)
|
)
|
||||||
|
|
||||||
enable_categorical = _has_categorical(self, data)
|
enable_categorical = True
|
||||||
if _is_pandas_df(data):
|
if _is_pandas_df(data):
|
||||||
data, fns, _ = _transform_pandas_df(data, enable_categorical)
|
data, fns, _ = _transform_pandas_df(data, enable_categorical)
|
||||||
if validate_features:
|
if validate_features:
|
||||||
|
|||||||
@ -72,7 +72,6 @@ from .core import (
|
|||||||
QuantileDMatrix,
|
QuantileDMatrix,
|
||||||
_deprecate_positional_args,
|
_deprecate_positional_args,
|
||||||
_expect,
|
_expect,
|
||||||
_has_categorical,
|
|
||||||
)
|
)
|
||||||
from .sklearn import (
|
from .sklearn import (
|
||||||
XGBClassifier,
|
XGBClassifier,
|
||||||
@ -1190,7 +1189,7 @@ def _infer_predict_output(
|
|||||||
kwargs = kwargs.copy()
|
kwargs = kwargs.copy()
|
||||||
if kwargs.pop("predict_type") == "margin":
|
if kwargs.pop("predict_type") == "margin":
|
||||||
kwargs["output_margin"] = True
|
kwargs["output_margin"] = True
|
||||||
m = DMatrix(test_sample)
|
m = DMatrix(test_sample, enable_categorical=True)
|
||||||
# generated DMatrix doesn't have feature name, so no validation.
|
# generated DMatrix doesn't have feature name, so no validation.
|
||||||
test_predt = booster.predict(m, validate_features=False, **kwargs)
|
test_predt = booster.predict(m, validate_features=False, **kwargs)
|
||||||
n_columns = test_predt.shape[1] if len(test_predt.shape) > 1 else 1
|
n_columns = test_predt.shape[1] if len(test_predt.shape) > 1 else 1
|
||||||
@ -1247,7 +1246,7 @@ async def _predict_async(
|
|||||||
m = DMatrix(
|
m = DMatrix(
|
||||||
data=partition,
|
data=partition,
|
||||||
missing=missing,
|
missing=missing,
|
||||||
enable_categorical=_has_categorical(booster, partition),
|
enable_categorical=True,
|
||||||
)
|
)
|
||||||
predt = booster.predict(
|
predt = booster.predict(
|
||||||
data=m,
|
data=m,
|
||||||
@ -1315,6 +1314,7 @@ async def _predict_async(
|
|||||||
base_margin=base_margin,
|
base_margin=base_margin,
|
||||||
feature_names=feature_names,
|
feature_names=feature_names,
|
||||||
feature_types=feature_types,
|
feature_types=feature_types,
|
||||||
|
enable_categorical=True,
|
||||||
)
|
)
|
||||||
predt = booster.predict(
|
predt = booster.predict(
|
||||||
m,
|
m,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user