Expose feature_types to sklearn interface. (#7821)

This commit is contained in:
Jiaming Yuan 2022-04-21 20:23:35 +08:00 committed by GitHub
parent 401d451569
commit c70fa502a5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 131 additions and 48 deletions

View File

@ -1,7 +1,7 @@
"""Shared typing definition."""
import ctypes
import os
from typing import Optional, List, Any, TypeVar, Union
from typing import Optional, Any, TypeVar, Union, Sequence
# os.PathLike/string/numpy.array/scipy.sparse/pd.DataFrame/dt.Frame/
# cudf.DataFrame/cupy.array/dlpack
@ -9,7 +9,8 @@ DataType = Any
# xgboost accepts some other possible types in practice due to historical reason, which is
# lesser tested. For now we encourage users to pass a simple list of string.
FeatureNames = Optional[List[str]]
FeatureNames = Optional[Sequence[str]]
FeatureTypes = Optional[Sequence[str]]
ArrayLike = Any
PathLike = Union[str, os.PathLike]

View File

@ -31,6 +31,7 @@ from ._typing import (
CFloatPtr,
NumpyOrCupy,
FeatureNames,
FeatureTypes,
_T,
CupyT,
)
@ -553,7 +554,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
missing: Optional[float] = None,
silent: bool = False,
feature_names: FeatureNames = None,
feature_types: Optional[List[str]] = None,
feature_types: FeatureTypes = None,
nthread: Optional[int] = None,
group: Optional[ArrayLike] = None,
qid: Optional[ArrayLike] = None,
@ -594,10 +595,15 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
Whether print messages during construction
feature_names : list, optional
Set names for features.
feature_types :
feature_types : FeatureTypes
Set types for features. When `enable_categorical` is set to `True`, string
"c" represents categorical data type.
"c" represents categorical data type while "q" represents numerical feature
type. For categorical features, the input is assumed to be preprocessed and
encoded by the users. The encoding can be done via
:py:class:`sklearn.preprocessing.OrdinalEncoder` or pandas dataframe
`.cat.codes` method. This is useful when users want to specify categorical
features without having to construct a dataframe as input.
nthread : integer, optional
Number of threads to use for loading data when parallelization is
@ -1062,12 +1068,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
@property
def feature_types(self) -> Optional[List[str]]:
"""Get feature types (column types).
Returns
-------
feature_types : list or None
"""
"""Get feature types. See :py:class:`DMatrix` for details."""
length = c_bst_ulong()
sarr = ctypes.POINTER(ctypes.c_char_p)()
_check_call(_LIB.XGDMatrixGetStrFeatureInfo(self.handle,
@ -1083,8 +1084,8 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
def feature_types(self, feature_types: Optional[Union[List[str], str]]) -> None:
"""Set feature types (column types).
This is for displaying the results and categorical data support. See doc string
of :py:obj:`xgboost.DMatrix` for details.
This is for displaying the results and categorical data support. See
:py:class:`DMatrix` for details.
Parameters
----------
@ -1647,7 +1648,7 @@ class Booster:
feature_info = from_cstr_to_pystr(sarr, length)
return feature_info if feature_info else None
def _set_feature_info(self, features: Optional[List[str]], field: str) -> None:
def _set_feature_info(self, features: Optional[Sequence[str]], field: str) -> None:
if features is not None:
assert isinstance(features, list)
feature_info_bytes = [bytes(f, encoding="utf-8") for f in features]
@ -1667,7 +1668,7 @@ class Booster:
@property
def feature_types(self) -> Optional[List[str]]:
"""Feature types for this booster. Can be directly set by input data or by
assignment.
assignment. See :py:class:`DMatrix` for details.
"""
return self._get_feature_info("feature_type")

View File

@ -54,10 +54,11 @@ from .compat import scipy_sparse
from .compat import PANDAS_INSTALLED, DataFrame, Series, pandas_concat
from .compat import lazy_isinstance
from ._typing import FeatureNames, FeatureTypes
from .core import DMatrix, DeviceQuantileDMatrix, Booster, _expect, DataIter
from .core import Objective, Metric
from .core import _deprecate_positional_args, _has_categorical
from .data import FeatureNames
from .training import train as worker_train
from .tracker import RabitTracker, get_host_ip
from .sklearn import XGBModel, XGBClassifier, XGBRegressorBase, XGBClassifierBase
@ -327,7 +328,7 @@ class DaskDMatrix:
missing: float = None,
silent: bool = False, # pylint: disable=unused-argument
feature_names: FeatureNames = None,
feature_types: Optional[List[str]] = None,
feature_types: FeatureTypes = None,
group: Optional[_DaskCollection] = None,
qid: Optional[_DaskCollection] = None,
label_lower_bound: Optional[_DaskCollection] = None,
@ -1601,7 +1602,11 @@ class DaskScikitLearnBase(XGBModel):
predts = predts.to_dask_array()
else:
test_dmatrix = await DaskDMatrix(
self.client, data=data, base_margin=base_margin, missing=self.missing
self.client,
data=data,
base_margin=base_margin,
missing=self.missing,
feature_types=self.feature_types
)
predts = await predict(
self.client,
@ -1640,7 +1645,9 @@ class DaskScikitLearnBase(XGBModel):
iteration_range: Optional[Tuple[int, int]] = None,
) -> Any:
iteration_range = self._get_iteration_range(iteration_range)
test_dmatrix = await DaskDMatrix(self.client, data=X, missing=self.missing)
test_dmatrix = await DaskDMatrix(
self.client, data=X, missing=self.missing, feature_types=self.feature_types,
)
predts = await predict(
self.client,
model=self.get_booster(),
@ -1755,6 +1762,7 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
eval_qid=None,
missing=self.missing,
enable_categorical=self.enable_categorical,
feature_types=self.feature_types,
)
if callable(self.objective):
@ -1849,6 +1857,7 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
eval_qid=None,
missing=self.missing,
enable_categorical=self.enable_categorical,
feature_types=self.feature_types,
)
# pylint: disable=attribute-defined-outside-init
@ -2054,6 +2063,7 @@ class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn):
eval_qid=eval_qid,
missing=self.missing,
enable_categorical=self.enable_categorical,
feature_types=self.feature_types,
)
if eval_metric is not None:
if callable(eval_metric):

View File

@ -13,6 +13,7 @@ import numpy as np
from .core import c_array, _LIB, _check_call, c_str
from .core import _cuda_array_interface
from .core import DataIter, _ProxyDMatrix, DMatrix, FeatureNames
from ._typing import FeatureTypes
from .compat import lazy_isinstance, DataFrame
c_bst_ulong = ctypes.c_uint64 # pylint: disable=invalid-name
@ -70,7 +71,7 @@ def _from_scipy_csr(
missing,
nthread,
feature_names: FeatureNames,
feature_types: Optional[List[str]],
feature_types: FeatureTypes,
):
"""Initialize data from a CSR matrix."""
if len(data.indices) != len(data.data):
@ -109,7 +110,7 @@ def _from_scipy_csc(
data,
missing,
feature_names: FeatureNames,
feature_types: Optional[List[str]],
feature_types: FeatureTypes,
):
if len(data.indices) != len(data.data):
raise ValueError(f"length mismatch: {len(data.indices)} vs {len(data.data)}")
@ -165,7 +166,7 @@ def _from_numpy_array(
missing,
nthread,
feature_names: FeatureNames,
feature_types: Optional[List[str]],
feature_types: FeatureTypes,
):
"""Initialize data from a 2-D numpy matrix.
@ -228,6 +229,12 @@ _pandas_dtype_mapper = {
}
_ENABLE_CAT_ERR = (
"When categorical type is supplied, DMatrix parameter `enable_categorical` must "
"be set to `True`."
)
def _invalid_dataframe_dtype(data: Any) -> None:
# pandas series has `dtypes` but it's just a single object
# cudf series doesn't have `dtypes`.
@ -241,9 +248,8 @@ def _invalid_dataframe_dtype(data: Any) -> None:
else:
err = ""
msg = """DataFrame.dtypes for data must be int, float, bool or category. When
categorical type is supplied, DMatrix parameter `enable_categorical` must
be set to `True`.""" + err
type_err = "DataFrame.dtypes for data must be int, float, bool or category."
msg = f"""{type_err} {_ENABLE_CAT_ERR} {err}"""
raise ValueError(msg)
@ -340,8 +346,8 @@ def _from_pandas_df(
missing: float,
nthread: int,
feature_names: FeatureNames,
feature_types: Optional[List[str]],
) -> Tuple[ctypes.c_void_p, FeatureNames, Optional[List[str]]]:
feature_types: FeatureTypes,
) -> Tuple[ctypes.c_void_p, FeatureNames, FeatureTypes]:
data, feature_names, feature_types = _transform_pandas_df(
data, enable_categorical, feature_names, feature_types
)
@ -382,7 +388,7 @@ def _from_pandas_series(
nthread: int,
enable_categorical: bool,
feature_names: FeatureNames,
feature_types: Optional[List[str]],
feature_types: FeatureTypes,
):
from pandas.api.types import is_categorical_dtype
@ -413,7 +419,7 @@ _dt_type_mapper2 = {'bool': 'i', 'int': 'int', 'real': 'float'}
def _transform_dt_df(
data,
feature_names: FeatureNames,
feature_types: Optional[List[str]],
feature_types: FeatureTypes,
meta=None,
meta_type=None,
):
@ -454,9 +460,9 @@ def _from_dt_df(
missing,
nthread,
feature_names: FeatureNames,
feature_types: Optional[List[str]],
feature_types: FeatureTypes,
enable_categorical: bool,
) -> Tuple[ctypes.c_void_p, FeatureNames, Optional[List[str]]]:
) -> Tuple[ctypes.c_void_p, FeatureNames, FeatureTypes]:
if enable_categorical:
raise ValueError("categorical data in datatable is not supported yet.")
data, feature_names, feature_types = _transform_dt_df(
@ -542,10 +548,10 @@ def _from_arrow(
data,
missing: float,
nthread: int,
feature_names: Optional[List[str]],
feature_types: Optional[List[str]],
feature_names: FeatureNames,
feature_types: FeatureTypes,
enable_categorical: bool,
) -> Tuple[ctypes.c_void_p, Optional[List[str]], Optional[List[str]]]:
) -> Tuple[ctypes.c_void_p, FeatureNames, FeatureTypes]:
import pyarrow as pa
if not all(
@ -621,7 +627,7 @@ def _cudf_array_interfaces(data, cat_codes: list) -> bytes:
def _transform_cudf_df(
data,
feature_names: FeatureNames,
feature_types: Optional[List[str]],
feature_types: FeatureTypes,
enable_categorical: bool,
):
try:
@ -687,7 +693,7 @@ def _from_cudf_df(
missing,
nthread,
feature_names: FeatureNames,
feature_types: Optional[List[str]],
feature_types: FeatureTypes,
enable_categorical: bool,
) -> Tuple[ctypes.c_void_p, Any, Any]:
data, cat_codes, feature_names, feature_types = _transform_cudf_df(
@ -735,7 +741,7 @@ def _from_cupy_array(
missing,
nthread,
feature_names: FeatureNames,
feature_types: Optional[List[str]],
feature_types: FeatureTypes,
):
"""Initialize DMatrix from cupy ndarray."""
data = _transform_cupy_array(data)
@ -782,7 +788,7 @@ def _from_dlpack(
missing,
nthread,
feature_names: FeatureNames,
feature_types: Optional[List[str]],
feature_types: FeatureTypes,
):
data = _transform_dlpack(data)
return _from_cupy_array(data, missing, nthread, feature_names,
@ -797,7 +803,7 @@ def _from_uri(
data,
missing,
feature_names: FeatureNames,
feature_types: Optional[List[str]],
feature_types: FeatureTypes,
):
_warn_unused_missing(data, missing)
handle = ctypes.c_void_p()
@ -817,7 +823,7 @@ def _from_list(
missing,
n_threads,
feature_names: FeatureNames,
feature_types: Optional[List[str]],
feature_types: FeatureTypes,
):
array = np.array(data)
_check_data_shape(data)
@ -833,7 +839,7 @@ def _from_tuple(
missing,
n_threads,
feature_names: FeatureNames,
feature_types: Optional[List[str]],
feature_types: FeatureTypes,
):
return _from_list(data, missing, n_threads, feature_names, feature_types)
@ -869,7 +875,7 @@ def dispatch_data_backend(
missing,
threads,
feature_names: FeatureNames,
feature_types: Optional[List[str]],
feature_types: FeatureTypes,
enable_categorical: bool = False,
):
'''Dispatch data for DMatrix.'''
@ -884,8 +890,7 @@ def dispatch_data_backend(
data.tocsr(), missing, threads, feature_names, feature_types
)
if _is_numpy_array(data):
return _from_numpy_array(data, missing, threads, feature_names,
feature_types)
return _from_numpy_array(data, missing, threads, feature_names, feature_types)
if _is_uri(data):
return _from_uri(data, missing, feature_names, feature_types)
if _is_list(data):
@ -1101,7 +1106,7 @@ class SingleBatchInternalIter(DataIter): # pylint: disable=R0902
def _proxy_transform(
data,
feature_names: FeatureNames,
feature_types: Optional[List[str]],
feature_types: FeatureTypes,
enable_categorical: bool,
):
if _is_cudf_df(data) or _is_cudf_ser(data):

View File

@ -14,7 +14,7 @@ from .core import Metric
from .training import train
from .callback import TrainingCallback
from .data import _is_cudf_df, _is_cudf_ser, _is_cupy_array
from ._typing import ArrayLike
from ._typing import ArrayLike, FeatureTypes
# Do not use class names on scikit-learn directly. Re-define the classes on
# .compat to guarantee the behavior without scikit-learn
@ -211,6 +211,13 @@ __model_doc = f'''
should be used to specify categorical data type. Also, JSON/UBJSON
serialization format is required.
feature_types : FeatureTypes
.. versionadded:: 2.0.0
Used for specifying feature types without constructing a dataframe. See
:py:class:`DMatrix` for details.
max_cat_to_onehot : Optional[int]
.. versionadded:: 1.6.0
@ -394,6 +401,7 @@ def _wrap_evaluation_matrices(
eval_qid: Optional[Sequence[Any]],
create_dmatrix: Callable,
enable_categorical: bool,
feature_types: FeatureTypes,
) -> Tuple[Any, List[Tuple[Any, str]]]:
"""Convert array_like evaluation matrices into DMatrix. Perform validation on the way.
@ -408,6 +416,7 @@ def _wrap_evaluation_matrices(
feature_weights=feature_weights,
missing=missing,
enable_categorical=enable_categorical,
feature_types=feature_types,
)
n_validation = 0 if eval_set is None else len(eval_set)
@ -455,6 +464,7 @@ def _wrap_evaluation_matrices(
base_margin=base_margin_eval_set[i],
missing=missing,
enable_categorical=enable_categorical,
feature_types=feature_types,
)
evals.append(m)
nevals = len(evals)
@ -518,6 +528,7 @@ class XGBModel(XGBModelBase):
validate_parameters: Optional[bool] = None,
predictor: Optional[str] = None,
enable_categorical: bool = False,
feature_types: FeatureTypes = None,
max_cat_to_onehot: Optional[int] = None,
eval_metric: Optional[Union[str, List[str], Callable]] = None,
early_stopping_rounds: Optional[int] = None,
@ -562,6 +573,7 @@ class XGBModel(XGBModelBase):
self.validate_parameters = validate_parameters
self.predictor = predictor
self.enable_categorical = enable_categorical
self.feature_types = feature_types
self.max_cat_to_onehot = max_cat_to_onehot
self.eval_metric = eval_metric
self.early_stopping_rounds = early_stopping_rounds
@ -684,6 +696,7 @@ class XGBModel(XGBModelBase):
"enable_categorical",
"early_stopping_rounds",
"callbacks",
"feature_types",
}
filtered = {}
for k, v in params.items():
@ -715,6 +728,10 @@ class XGBModel(XGBModelBase):
# numpy array is not JSON serializable
meta['classes_'] = self.classes_.tolist()
continue
if k == "feature_types":
# Use the `feature_types` attribute from booster instead.
meta["feature_types"] = None
continue
try:
json.dumps({k: v})
meta[k] = v
@ -754,6 +771,9 @@ class XGBModel(XGBModelBase):
if k == 'classes_':
self.classes_ = np.array(v)
continue
if k == "feature_types":
self.feature_types = self.get_booster().feature_types
continue
if k == "_estimator_type":
if self._get_type() != v:
raise TypeError(
@ -944,6 +964,7 @@ class XGBModel(XGBModelBase):
eval_qid=None,
create_dmatrix=lambda **kwargs: DMatrix(nthread=self.n_jobs, **kwargs),
enable_categorical=self.enable_categorical,
feature_types=self.feature_types
)
params = self.get_xgb_params()
@ -1063,9 +1084,11 @@ class XGBModel(XGBModelBase):
pass
test = DMatrix(
X, base_margin=base_margin,
X,
base_margin=base_margin,
missing=self.missing,
nthread=self.n_jobs,
feature_types=self.feature_types,
enable_categorical=self.enable_categorical
)
return self.get_booster().predict(
@ -1106,7 +1129,9 @@ class XGBModel(XGBModelBase):
self.get_booster(), ntree_limit, iteration_range
)
iteration_range = self._get_iteration_range(iteration_range)
test_dmatrix = DMatrix(X, missing=self.missing, nthread=self.n_jobs)
test_dmatrix = DMatrix(
X, missing=self.missing, feature_types=self.feature_types, nthread=self.n_jobs
)
return self.get_booster().predict(
test_dmatrix,
pred_leaf=True,
@ -1397,6 +1422,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
eval_qid=None,
create_dmatrix=lambda **kwargs: DMatrix(nthread=self.n_jobs, **kwargs),
enable_categorical=self.enable_categorical,
feature_types=self.feature_types,
)
self._Booster = train(
@ -1828,6 +1854,7 @@ class XGBRanker(XGBModel, XGBRankerMixIn):
eval_qid=eval_qid,
create_dmatrix=lambda **kwargs: DMatrix(nthread=self.n_jobs, **kwargs),
enable_categorical=self.enable_categorical,
feature_types=self.feature_types,
)
evals_result: TrainingCallback.EvalsLog = {}

View File

@ -306,6 +306,13 @@ def test_categorical(client: "Client") -> None:
run_categorical(client, "approx", X, X_onehot, y)
run_categorical(client, "hist", X, X_onehot, y)
ft = ["c"] * X.shape[1]
reg = xgb.dask.DaskXGBRegressor(
tree_method="hist", feature_types=ft, enable_categorical=True
)
reg.fit(X, y)
assert reg.get_booster().feature_types == ft
def test_dask_predict_shape_infer(client: "Client") -> None:
X, y = make_classification(n_samples=1000, n_informative=5, n_classes=3)

View File

@ -1273,6 +1273,38 @@ def test_estimator_reg(estimator, check):
check(estimator)
def test_categorical():
X, y = tm.make_categorical(n_samples=32, n_features=2, n_categories=3, onehot=False)
ft = ["c"] * X.shape[1]
reg = xgb.XGBRegressor(
tree_method="hist",
feature_types=ft,
max_cat_to_onehot=1,
enable_categorical=True,
)
reg.fit(X.values, y, eval_set=[(X.values, y)])
from_cat = reg.evals_result()["validation_0"]["rmse"]
predt_cat = reg.predict(X.values)
assert reg.get_booster().feature_types == ft
with tempfile.TemporaryDirectory() as tmpdir:
path = os.path.join(tmpdir, "model.json")
reg.save_model(path)
reg = xgb.XGBRegressor()
reg.load_model(path)
assert reg.feature_types == ft
onehot, y = tm.make_categorical(
n_samples=32, n_features=2, n_categories=3, onehot=True
)
reg = xgb.XGBRegressor(tree_method="hist")
reg.fit(onehot, y, eval_set=[(onehot, y)])
from_enc = reg.evals_result()["validation_0"]["rmse"]
predt_enc = reg.predict(onehot)
np.testing.assert_allclose(from_cat, from_enc)
np.testing.assert_allclose(predt_cat, predt_enc)
def test_prediction_config():
reg = xgb.XGBRegressor()
assert reg._can_use_inplace_predict() is True