Support categorical data with pandas Dataframe in inplace prediction (#7322)

This commit is contained in:
Jiaming Yuan 2021-10-17 14:32:06 +08:00 committed by GitHub
parent 8e619010d0
commit f56e2e9a66
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 40 additions and 21 deletions

View File

@ -1973,13 +1973,6 @@ class Booster(object):
preds = ctypes.POINTER(ctypes.c_float)()
# once caching is supported, we can pass id(data) as cache id.
try:
import pandas as pd
if isinstance(data, pd.DataFrame):
data = data.values
except ImportError:
pass
args = {
"type": 0,
"training": False,
@ -2014,7 +2007,15 @@ class Booster(object):
f"got {data.shape[1]}"
)
from .data import _is_pandas_df, _transform_pandas_df
from .data import _array_interface
if _is_pandas_df(data):
ft = self.feature_types
if ft is None:
enable_categorical = False
else:
enable_categorical = any(f == "c" for f in ft)
data, _, _ = _transform_pandas_df(data, enable_categorical)
if isinstance(data, np.ndarray):
from .data import _ensure_np_dtype
data, _ = _ensure_np_dtype(data, data.dtype)
@ -2068,7 +2069,6 @@ class Booster(object):
return _prediction_output(shape, dims, preds, True)
if lazy_isinstance(data, "cudf.core.dataframe", "DataFrame"):
from .data import _cudf_array_interfaces
_, interfaces_str = _cudf_array_interfaces(data)
_check_call(
_LIB.XGBoosterPredictFromCudaColumnar(

View File

@ -289,16 +289,15 @@ def _transform_pandas_df(
def _from_pandas_df(
data,
enable_categorical: bool,
missing,
nthread,
missing: float,
nthread: int,
feature_names: Optional[List[str]],
feature_types: Optional[List[str]],
):
) -> Tuple[ctypes.c_void_p, Optional[List[str]], Optional[List[str]]]:
data, feature_names, feature_types = _transform_pandas_df(
data, enable_categorical, feature_names, feature_types)
return _from_numpy_array(data, missing, nthread, feature_names,
feature_types)
data, enable_categorical, feature_names, feature_types
)
return _from_numpy_array(data, missing, nthread, feature_names, feature_types)
def _is_pandas_series(data):
try:

View File

@ -809,11 +809,7 @@ class XGBModel(XGBModelBase):
# Inplace predict doesn't handle as many data types as DMatrix, but it's
# sufficient for dask interface where input is simpiler.
predictor = self.get_params().get("predictor", None)
if (
not self.enable_categorical
and predictor in ("auto", None)
and self.booster != "gblinear"
):
if predictor in ("auto", None) and self.booster != "gblinear":
return True
return False

View File

@ -44,9 +44,12 @@ def test_num_parallel_tree():
@pytest.mark.skipif(**tm.no_pandas())
@pytest.mark.skipif(**tm.no_cudf())
@pytest.mark.skipif(**tm.no_sklearn())
def test_categorical():
import pandas as pd
import cudf
import cupy as cp
from sklearn.datasets import load_svmlight_file
data_dir = os.path.join(tm.PROJECT_ROOT, "demo", "data")
@ -59,7 +62,6 @@ def test_categorical():
)
X = pd.DataFrame(X.todense()).astype("category")
clf.fit(X, y)
assert not clf._can_use_inplace_predict()
with tempfile.TemporaryDirectory() as tempdir:
model = os.path.join(tempdir, "categorial.json")
@ -74,3 +76,25 @@ def test_categorical():
)
assert categories_sizes.shape[0] != 0
np.testing.assert_allclose(categories_sizes, 1)
def check_predt(X, y):
reg = xgb.XGBRegressor(
tree_method="gpu_hist", enable_categorical=True, n_estimators=64
)
reg.fit(X, y)
predts = reg.predict(X)
booster = reg.get_booster()
assert "c" in booster.feature_types
assert len(booster.feature_types) == 1
inp_predts = booster.inplace_predict(X)
if isinstance(inp_predts, cp.ndarray):
inp_predts = cp.asnumpy(inp_predts)
np.testing.assert_allclose(predts, inp_predts)
y = [1, 2, 3]
X = pd.DataFrame({"f0": ["a", "b", "c"]})
X["f0"] = X["f0"].astype("category")
check_predt(X, y)
X = cudf.DataFrame(X)
check_predt(X, y)