Expose feature_types to sklearn interface. (#7821)
This commit is contained in:
parent
401d451569
commit
c70fa502a5
@ -1,7 +1,7 @@
|
||||
"""Shared typing definition."""
|
||||
import ctypes
|
||||
import os
|
||||
from typing import Optional, List, Any, TypeVar, Union
|
||||
from typing import Optional, Any, TypeVar, Union, Sequence
|
||||
|
||||
# os.PathLike/string/numpy.array/scipy.sparse/pd.DataFrame/dt.Frame/
|
||||
# cudf.DataFrame/cupy.array/dlpack
|
||||
@ -9,7 +9,8 @@ DataType = Any
|
||||
|
||||
# xgboost accepts some other possible types in practice due to historical reason, which is
|
||||
# lesser tested. For now we encourage users to pass a simple list of string.
|
||||
FeatureNames = Optional[List[str]]
|
||||
FeatureNames = Optional[Sequence[str]]
|
||||
FeatureTypes = Optional[Sequence[str]]
|
||||
|
||||
ArrayLike = Any
|
||||
PathLike = Union[str, os.PathLike]
|
||||
|
||||
@ -31,6 +31,7 @@ from ._typing import (
|
||||
CFloatPtr,
|
||||
NumpyOrCupy,
|
||||
FeatureNames,
|
||||
FeatureTypes,
|
||||
_T,
|
||||
CupyT,
|
||||
)
|
||||
@ -553,7 +554,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
||||
missing: Optional[float] = None,
|
||||
silent: bool = False,
|
||||
feature_names: FeatureNames = None,
|
||||
feature_types: Optional[List[str]] = None,
|
||||
feature_types: FeatureTypes = None,
|
||||
nthread: Optional[int] = None,
|
||||
group: Optional[ArrayLike] = None,
|
||||
qid: Optional[ArrayLike] = None,
|
||||
@ -594,10 +595,15 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
||||
Whether print messages during construction
|
||||
feature_names : list, optional
|
||||
Set names for features.
|
||||
feature_types :
|
||||
feature_types : FeatureTypes
|
||||
|
||||
Set types for features. When `enable_categorical` is set to `True`, string
|
||||
"c" represents categorical data type.
|
||||
"c" represents categorical data type while "q" represents numerical feature
|
||||
type. For categorical features, the input is assumed to be preprocessed and
|
||||
encoded by the users. The encoding can be done via
|
||||
:py:class:`sklearn.preprocessing.OrdinalEncoder` or pandas dataframe
|
||||
`.cat.codes` method. This is useful when users want to specify categorical
|
||||
features without having to construct a dataframe as input.
|
||||
|
||||
nthread : integer, optional
|
||||
Number of threads to use for loading data when parallelization is
|
||||
@ -1062,12 +1068,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
||||
|
||||
@property
|
||||
def feature_types(self) -> Optional[List[str]]:
|
||||
"""Get feature types (column types).
|
||||
|
||||
Returns
|
||||
-------
|
||||
feature_types : list or None
|
||||
"""
|
||||
"""Get feature types. See :py:class:`DMatrix` for details."""
|
||||
length = c_bst_ulong()
|
||||
sarr = ctypes.POINTER(ctypes.c_char_p)()
|
||||
_check_call(_LIB.XGDMatrixGetStrFeatureInfo(self.handle,
|
||||
@ -1083,8 +1084,8 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
||||
def feature_types(self, feature_types: Optional[Union[List[str], str]]) -> None:
|
||||
"""Set feature types (column types).
|
||||
|
||||
This is for displaying the results and categorical data support. See doc string
|
||||
of :py:obj:`xgboost.DMatrix` for details.
|
||||
This is for displaying the results and categorical data support. See
|
||||
:py:class:`DMatrix` for details.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@ -1647,7 +1648,7 @@ class Booster:
|
||||
feature_info = from_cstr_to_pystr(sarr, length)
|
||||
return feature_info if feature_info else None
|
||||
|
||||
def _set_feature_info(self, features: Optional[List[str]], field: str) -> None:
|
||||
def _set_feature_info(self, features: Optional[Sequence[str]], field: str) -> None:
|
||||
if features is not None:
|
||||
assert isinstance(features, list)
|
||||
feature_info_bytes = [bytes(f, encoding="utf-8") for f in features]
|
||||
@ -1667,7 +1668,7 @@ class Booster:
|
||||
@property
|
||||
def feature_types(self) -> Optional[List[str]]:
|
||||
"""Feature types for this booster. Can be directly set by input data or by
|
||||
assignment.
|
||||
assignment. See :py:class:`DMatrix` for details.
|
||||
|
||||
"""
|
||||
return self._get_feature_info("feature_type")
|
||||
|
||||
@ -54,10 +54,11 @@ from .compat import scipy_sparse
|
||||
from .compat import PANDAS_INSTALLED, DataFrame, Series, pandas_concat
|
||||
from .compat import lazy_isinstance
|
||||
|
||||
from ._typing import FeatureNames, FeatureTypes
|
||||
|
||||
from .core import DMatrix, DeviceQuantileDMatrix, Booster, _expect, DataIter
|
||||
from .core import Objective, Metric
|
||||
from .core import _deprecate_positional_args, _has_categorical
|
||||
from .data import FeatureNames
|
||||
from .training import train as worker_train
|
||||
from .tracker import RabitTracker, get_host_ip
|
||||
from .sklearn import XGBModel, XGBClassifier, XGBRegressorBase, XGBClassifierBase
|
||||
@ -327,7 +328,7 @@ class DaskDMatrix:
|
||||
missing: float = None,
|
||||
silent: bool = False, # pylint: disable=unused-argument
|
||||
feature_names: FeatureNames = None,
|
||||
feature_types: Optional[List[str]] = None,
|
||||
feature_types: FeatureTypes = None,
|
||||
group: Optional[_DaskCollection] = None,
|
||||
qid: Optional[_DaskCollection] = None,
|
||||
label_lower_bound: Optional[_DaskCollection] = None,
|
||||
@ -1601,7 +1602,11 @@ class DaskScikitLearnBase(XGBModel):
|
||||
predts = predts.to_dask_array()
|
||||
else:
|
||||
test_dmatrix = await DaskDMatrix(
|
||||
self.client, data=data, base_margin=base_margin, missing=self.missing
|
||||
self.client,
|
||||
data=data,
|
||||
base_margin=base_margin,
|
||||
missing=self.missing,
|
||||
feature_types=self.feature_types
|
||||
)
|
||||
predts = await predict(
|
||||
self.client,
|
||||
@ -1640,7 +1645,9 @@ class DaskScikitLearnBase(XGBModel):
|
||||
iteration_range: Optional[Tuple[int, int]] = None,
|
||||
) -> Any:
|
||||
iteration_range = self._get_iteration_range(iteration_range)
|
||||
test_dmatrix = await DaskDMatrix(self.client, data=X, missing=self.missing)
|
||||
test_dmatrix = await DaskDMatrix(
|
||||
self.client, data=X, missing=self.missing, feature_types=self.feature_types,
|
||||
)
|
||||
predts = await predict(
|
||||
self.client,
|
||||
model=self.get_booster(),
|
||||
@ -1755,6 +1762,7 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
|
||||
eval_qid=None,
|
||||
missing=self.missing,
|
||||
enable_categorical=self.enable_categorical,
|
||||
feature_types=self.feature_types,
|
||||
)
|
||||
|
||||
if callable(self.objective):
|
||||
@ -1849,6 +1857,7 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
||||
eval_qid=None,
|
||||
missing=self.missing,
|
||||
enable_categorical=self.enable_categorical,
|
||||
feature_types=self.feature_types,
|
||||
)
|
||||
|
||||
# pylint: disable=attribute-defined-outside-init
|
||||
@ -2054,6 +2063,7 @@ class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn):
|
||||
eval_qid=eval_qid,
|
||||
missing=self.missing,
|
||||
enable_categorical=self.enable_categorical,
|
||||
feature_types=self.feature_types,
|
||||
)
|
||||
if eval_metric is not None:
|
||||
if callable(eval_metric):
|
||||
|
||||
@ -13,6 +13,7 @@ import numpy as np
|
||||
from .core import c_array, _LIB, _check_call, c_str
|
||||
from .core import _cuda_array_interface
|
||||
from .core import DataIter, _ProxyDMatrix, DMatrix, FeatureNames
|
||||
from ._typing import FeatureTypes
|
||||
from .compat import lazy_isinstance, DataFrame
|
||||
|
||||
c_bst_ulong = ctypes.c_uint64 # pylint: disable=invalid-name
|
||||
@ -70,7 +71,7 @@ def _from_scipy_csr(
|
||||
missing,
|
||||
nthread,
|
||||
feature_names: FeatureNames,
|
||||
feature_types: Optional[List[str]],
|
||||
feature_types: FeatureTypes,
|
||||
):
|
||||
"""Initialize data from a CSR matrix."""
|
||||
if len(data.indices) != len(data.data):
|
||||
@ -109,7 +110,7 @@ def _from_scipy_csc(
|
||||
data,
|
||||
missing,
|
||||
feature_names: FeatureNames,
|
||||
feature_types: Optional[List[str]],
|
||||
feature_types: FeatureTypes,
|
||||
):
|
||||
if len(data.indices) != len(data.data):
|
||||
raise ValueError(f"length mismatch: {len(data.indices)} vs {len(data.data)}")
|
||||
@ -165,7 +166,7 @@ def _from_numpy_array(
|
||||
missing,
|
||||
nthread,
|
||||
feature_names: FeatureNames,
|
||||
feature_types: Optional[List[str]],
|
||||
feature_types: FeatureTypes,
|
||||
):
|
||||
"""Initialize data from a 2-D numpy matrix.
|
||||
|
||||
@ -228,6 +229,12 @@ _pandas_dtype_mapper = {
|
||||
}
|
||||
|
||||
|
||||
_ENABLE_CAT_ERR = (
|
||||
"When categorical type is supplied, DMatrix parameter `enable_categorical` must "
|
||||
"be set to `True`."
|
||||
)
|
||||
|
||||
|
||||
def _invalid_dataframe_dtype(data: Any) -> None:
|
||||
# pandas series has `dtypes` but it's just a single object
|
||||
# cudf series doesn't have `dtypes`.
|
||||
@ -241,9 +248,8 @@ def _invalid_dataframe_dtype(data: Any) -> None:
|
||||
else:
|
||||
err = ""
|
||||
|
||||
msg = """DataFrame.dtypes for data must be int, float, bool or category. When
|
||||
categorical type is supplied, DMatrix parameter `enable_categorical` must
|
||||
be set to `True`.""" + err
|
||||
type_err = "DataFrame.dtypes for data must be int, float, bool or category."
|
||||
msg = f"""{type_err} {_ENABLE_CAT_ERR} {err}"""
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
@ -340,8 +346,8 @@ def _from_pandas_df(
|
||||
missing: float,
|
||||
nthread: int,
|
||||
feature_names: FeatureNames,
|
||||
feature_types: Optional[List[str]],
|
||||
) -> Tuple[ctypes.c_void_p, FeatureNames, Optional[List[str]]]:
|
||||
feature_types: FeatureTypes,
|
||||
) -> Tuple[ctypes.c_void_p, FeatureNames, FeatureTypes]:
|
||||
data, feature_names, feature_types = _transform_pandas_df(
|
||||
data, enable_categorical, feature_names, feature_types
|
||||
)
|
||||
@ -382,7 +388,7 @@ def _from_pandas_series(
|
||||
nthread: int,
|
||||
enable_categorical: bool,
|
||||
feature_names: FeatureNames,
|
||||
feature_types: Optional[List[str]],
|
||||
feature_types: FeatureTypes,
|
||||
):
|
||||
from pandas.api.types import is_categorical_dtype
|
||||
|
||||
@ -413,7 +419,7 @@ _dt_type_mapper2 = {'bool': 'i', 'int': 'int', 'real': 'float'}
|
||||
def _transform_dt_df(
|
||||
data,
|
||||
feature_names: FeatureNames,
|
||||
feature_types: Optional[List[str]],
|
||||
feature_types: FeatureTypes,
|
||||
meta=None,
|
||||
meta_type=None,
|
||||
):
|
||||
@ -454,9 +460,9 @@ def _from_dt_df(
|
||||
missing,
|
||||
nthread,
|
||||
feature_names: FeatureNames,
|
||||
feature_types: Optional[List[str]],
|
||||
feature_types: FeatureTypes,
|
||||
enable_categorical: bool,
|
||||
) -> Tuple[ctypes.c_void_p, FeatureNames, Optional[List[str]]]:
|
||||
) -> Tuple[ctypes.c_void_p, FeatureNames, FeatureTypes]:
|
||||
if enable_categorical:
|
||||
raise ValueError("categorical data in datatable is not supported yet.")
|
||||
data, feature_names, feature_types = _transform_dt_df(
|
||||
@ -542,10 +548,10 @@ def _from_arrow(
|
||||
data,
|
||||
missing: float,
|
||||
nthread: int,
|
||||
feature_names: Optional[List[str]],
|
||||
feature_types: Optional[List[str]],
|
||||
feature_names: FeatureNames,
|
||||
feature_types: FeatureTypes,
|
||||
enable_categorical: bool,
|
||||
) -> Tuple[ctypes.c_void_p, Optional[List[str]], Optional[List[str]]]:
|
||||
) -> Tuple[ctypes.c_void_p, FeatureNames, FeatureTypes]:
|
||||
import pyarrow as pa
|
||||
|
||||
if not all(
|
||||
@ -621,7 +627,7 @@ def _cudf_array_interfaces(data, cat_codes: list) -> bytes:
|
||||
def _transform_cudf_df(
|
||||
data,
|
||||
feature_names: FeatureNames,
|
||||
feature_types: Optional[List[str]],
|
||||
feature_types: FeatureTypes,
|
||||
enable_categorical: bool,
|
||||
):
|
||||
try:
|
||||
@ -687,7 +693,7 @@ def _from_cudf_df(
|
||||
missing,
|
||||
nthread,
|
||||
feature_names: FeatureNames,
|
||||
feature_types: Optional[List[str]],
|
||||
feature_types: FeatureTypes,
|
||||
enable_categorical: bool,
|
||||
) -> Tuple[ctypes.c_void_p, Any, Any]:
|
||||
data, cat_codes, feature_names, feature_types = _transform_cudf_df(
|
||||
@ -735,7 +741,7 @@ def _from_cupy_array(
|
||||
missing,
|
||||
nthread,
|
||||
feature_names: FeatureNames,
|
||||
feature_types: Optional[List[str]],
|
||||
feature_types: FeatureTypes,
|
||||
):
|
||||
"""Initialize DMatrix from cupy ndarray."""
|
||||
data = _transform_cupy_array(data)
|
||||
@ -782,7 +788,7 @@ def _from_dlpack(
|
||||
missing,
|
||||
nthread,
|
||||
feature_names: FeatureNames,
|
||||
feature_types: Optional[List[str]],
|
||||
feature_types: FeatureTypes,
|
||||
):
|
||||
data = _transform_dlpack(data)
|
||||
return _from_cupy_array(data, missing, nthread, feature_names,
|
||||
@ -797,7 +803,7 @@ def _from_uri(
|
||||
data,
|
||||
missing,
|
||||
feature_names: FeatureNames,
|
||||
feature_types: Optional[List[str]],
|
||||
feature_types: FeatureTypes,
|
||||
):
|
||||
_warn_unused_missing(data, missing)
|
||||
handle = ctypes.c_void_p()
|
||||
@ -817,7 +823,7 @@ def _from_list(
|
||||
missing,
|
||||
n_threads,
|
||||
feature_names: FeatureNames,
|
||||
feature_types: Optional[List[str]],
|
||||
feature_types: FeatureTypes,
|
||||
):
|
||||
array = np.array(data)
|
||||
_check_data_shape(data)
|
||||
@ -833,7 +839,7 @@ def _from_tuple(
|
||||
missing,
|
||||
n_threads,
|
||||
feature_names: FeatureNames,
|
||||
feature_types: Optional[List[str]],
|
||||
feature_types: FeatureTypes,
|
||||
):
|
||||
return _from_list(data, missing, n_threads, feature_names, feature_types)
|
||||
|
||||
@ -869,7 +875,7 @@ def dispatch_data_backend(
|
||||
missing,
|
||||
threads,
|
||||
feature_names: FeatureNames,
|
||||
feature_types: Optional[List[str]],
|
||||
feature_types: FeatureTypes,
|
||||
enable_categorical: bool = False,
|
||||
):
|
||||
'''Dispatch data for DMatrix.'''
|
||||
@ -884,8 +890,7 @@ def dispatch_data_backend(
|
||||
data.tocsr(), missing, threads, feature_names, feature_types
|
||||
)
|
||||
if _is_numpy_array(data):
|
||||
return _from_numpy_array(data, missing, threads, feature_names,
|
||||
feature_types)
|
||||
return _from_numpy_array(data, missing, threads, feature_names, feature_types)
|
||||
if _is_uri(data):
|
||||
return _from_uri(data, missing, feature_names, feature_types)
|
||||
if _is_list(data):
|
||||
@ -1101,7 +1106,7 @@ class SingleBatchInternalIter(DataIter): # pylint: disable=R0902
|
||||
def _proxy_transform(
|
||||
data,
|
||||
feature_names: FeatureNames,
|
||||
feature_types: Optional[List[str]],
|
||||
feature_types: FeatureTypes,
|
||||
enable_categorical: bool,
|
||||
):
|
||||
if _is_cudf_df(data) or _is_cudf_ser(data):
|
||||
|
||||
@ -14,7 +14,7 @@ from .core import Metric
|
||||
from .training import train
|
||||
from .callback import TrainingCallback
|
||||
from .data import _is_cudf_df, _is_cudf_ser, _is_cupy_array
|
||||
from ._typing import ArrayLike
|
||||
from ._typing import ArrayLike, FeatureTypes
|
||||
|
||||
# Do not use class names on scikit-learn directly. Re-define the classes on
|
||||
# .compat to guarantee the behavior without scikit-learn
|
||||
@ -211,6 +211,13 @@ __model_doc = f'''
|
||||
should be used to specify categorical data type. Also, JSON/UBJSON
|
||||
serialization format is required.
|
||||
|
||||
feature_types : FeatureTypes
|
||||
|
||||
.. versionadded:: 2.0.0
|
||||
|
||||
Used for specifying feature types without constructing a dataframe. See
|
||||
:py:class:`DMatrix` for details.
|
||||
|
||||
max_cat_to_onehot : Optional[int]
|
||||
|
||||
.. versionadded:: 1.6.0
|
||||
@ -394,6 +401,7 @@ def _wrap_evaluation_matrices(
|
||||
eval_qid: Optional[Sequence[Any]],
|
||||
create_dmatrix: Callable,
|
||||
enable_categorical: bool,
|
||||
feature_types: FeatureTypes,
|
||||
) -> Tuple[Any, List[Tuple[Any, str]]]:
|
||||
"""Convert array_like evaluation matrices into DMatrix. Perform validation on the way.
|
||||
|
||||
@ -408,6 +416,7 @@ def _wrap_evaluation_matrices(
|
||||
feature_weights=feature_weights,
|
||||
missing=missing,
|
||||
enable_categorical=enable_categorical,
|
||||
feature_types=feature_types,
|
||||
)
|
||||
|
||||
n_validation = 0 if eval_set is None else len(eval_set)
|
||||
@ -455,6 +464,7 @@ def _wrap_evaluation_matrices(
|
||||
base_margin=base_margin_eval_set[i],
|
||||
missing=missing,
|
||||
enable_categorical=enable_categorical,
|
||||
feature_types=feature_types,
|
||||
)
|
||||
evals.append(m)
|
||||
nevals = len(evals)
|
||||
@ -518,6 +528,7 @@ class XGBModel(XGBModelBase):
|
||||
validate_parameters: Optional[bool] = None,
|
||||
predictor: Optional[str] = None,
|
||||
enable_categorical: bool = False,
|
||||
feature_types: FeatureTypes = None,
|
||||
max_cat_to_onehot: Optional[int] = None,
|
||||
eval_metric: Optional[Union[str, List[str], Callable]] = None,
|
||||
early_stopping_rounds: Optional[int] = None,
|
||||
@ -562,6 +573,7 @@ class XGBModel(XGBModelBase):
|
||||
self.validate_parameters = validate_parameters
|
||||
self.predictor = predictor
|
||||
self.enable_categorical = enable_categorical
|
||||
self.feature_types = feature_types
|
||||
self.max_cat_to_onehot = max_cat_to_onehot
|
||||
self.eval_metric = eval_metric
|
||||
self.early_stopping_rounds = early_stopping_rounds
|
||||
@ -684,6 +696,7 @@ class XGBModel(XGBModelBase):
|
||||
"enable_categorical",
|
||||
"early_stopping_rounds",
|
||||
"callbacks",
|
||||
"feature_types",
|
||||
}
|
||||
filtered = {}
|
||||
for k, v in params.items():
|
||||
@ -715,6 +728,10 @@ class XGBModel(XGBModelBase):
|
||||
# numpy array is not JSON serializable
|
||||
meta['classes_'] = self.classes_.tolist()
|
||||
continue
|
||||
if k == "feature_types":
|
||||
# Use the `feature_types` attribute from booster instead.
|
||||
meta["feature_types"] = None
|
||||
continue
|
||||
try:
|
||||
json.dumps({k: v})
|
||||
meta[k] = v
|
||||
@ -754,6 +771,9 @@ class XGBModel(XGBModelBase):
|
||||
if k == 'classes_':
|
||||
self.classes_ = np.array(v)
|
||||
continue
|
||||
if k == "feature_types":
|
||||
self.feature_types = self.get_booster().feature_types
|
||||
continue
|
||||
if k == "_estimator_type":
|
||||
if self._get_type() != v:
|
||||
raise TypeError(
|
||||
@ -944,6 +964,7 @@ class XGBModel(XGBModelBase):
|
||||
eval_qid=None,
|
||||
create_dmatrix=lambda **kwargs: DMatrix(nthread=self.n_jobs, **kwargs),
|
||||
enable_categorical=self.enable_categorical,
|
||||
feature_types=self.feature_types
|
||||
)
|
||||
params = self.get_xgb_params()
|
||||
|
||||
@ -1063,9 +1084,11 @@ class XGBModel(XGBModelBase):
|
||||
pass
|
||||
|
||||
test = DMatrix(
|
||||
X, base_margin=base_margin,
|
||||
X,
|
||||
base_margin=base_margin,
|
||||
missing=self.missing,
|
||||
nthread=self.n_jobs,
|
||||
feature_types=self.feature_types,
|
||||
enable_categorical=self.enable_categorical
|
||||
)
|
||||
return self.get_booster().predict(
|
||||
@ -1106,7 +1129,9 @@ class XGBModel(XGBModelBase):
|
||||
self.get_booster(), ntree_limit, iteration_range
|
||||
)
|
||||
iteration_range = self._get_iteration_range(iteration_range)
|
||||
test_dmatrix = DMatrix(X, missing=self.missing, nthread=self.n_jobs)
|
||||
test_dmatrix = DMatrix(
|
||||
X, missing=self.missing, feature_types=self.feature_types, nthread=self.n_jobs
|
||||
)
|
||||
return self.get_booster().predict(
|
||||
test_dmatrix,
|
||||
pred_leaf=True,
|
||||
@ -1397,6 +1422,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
||||
eval_qid=None,
|
||||
create_dmatrix=lambda **kwargs: DMatrix(nthread=self.n_jobs, **kwargs),
|
||||
enable_categorical=self.enable_categorical,
|
||||
feature_types=self.feature_types,
|
||||
)
|
||||
|
||||
self._Booster = train(
|
||||
@ -1828,6 +1854,7 @@ class XGBRanker(XGBModel, XGBRankerMixIn):
|
||||
eval_qid=eval_qid,
|
||||
create_dmatrix=lambda **kwargs: DMatrix(nthread=self.n_jobs, **kwargs),
|
||||
enable_categorical=self.enable_categorical,
|
||||
feature_types=self.feature_types,
|
||||
)
|
||||
|
||||
evals_result: TrainingCallback.EvalsLog = {}
|
||||
|
||||
@ -306,6 +306,13 @@ def test_categorical(client: "Client") -> None:
|
||||
run_categorical(client, "approx", X, X_onehot, y)
|
||||
run_categorical(client, "hist", X, X_onehot, y)
|
||||
|
||||
ft = ["c"] * X.shape[1]
|
||||
reg = xgb.dask.DaskXGBRegressor(
|
||||
tree_method="hist", feature_types=ft, enable_categorical=True
|
||||
)
|
||||
reg.fit(X, y)
|
||||
assert reg.get_booster().feature_types == ft
|
||||
|
||||
|
||||
def test_dask_predict_shape_infer(client: "Client") -> None:
|
||||
X, y = make_classification(n_samples=1000, n_informative=5, n_classes=3)
|
||||
|
||||
@ -1273,6 +1273,38 @@ def test_estimator_reg(estimator, check):
|
||||
check(estimator)
|
||||
|
||||
|
||||
def test_categorical():
|
||||
X, y = tm.make_categorical(n_samples=32, n_features=2, n_categories=3, onehot=False)
|
||||
ft = ["c"] * X.shape[1]
|
||||
reg = xgb.XGBRegressor(
|
||||
tree_method="hist",
|
||||
feature_types=ft,
|
||||
max_cat_to_onehot=1,
|
||||
enable_categorical=True,
|
||||
)
|
||||
reg.fit(X.values, y, eval_set=[(X.values, y)])
|
||||
from_cat = reg.evals_result()["validation_0"]["rmse"]
|
||||
predt_cat = reg.predict(X.values)
|
||||
assert reg.get_booster().feature_types == ft
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
path = os.path.join(tmpdir, "model.json")
|
||||
reg.save_model(path)
|
||||
reg = xgb.XGBRegressor()
|
||||
reg.load_model(path)
|
||||
assert reg.feature_types == ft
|
||||
|
||||
onehot, y = tm.make_categorical(
|
||||
n_samples=32, n_features=2, n_categories=3, onehot=True
|
||||
)
|
||||
reg = xgb.XGBRegressor(tree_method="hist")
|
||||
reg.fit(onehot, y, eval_set=[(onehot, y)])
|
||||
from_enc = reg.evals_result()["validation_0"]["rmse"]
|
||||
predt_enc = reg.predict(onehot)
|
||||
|
||||
np.testing.assert_allclose(from_cat, from_enc)
|
||||
np.testing.assert_allclose(predt_cat, predt_enc)
|
||||
|
||||
|
||||
def test_prediction_config():
|
||||
reg = xgb.XGBRegressor()
|
||||
assert reg._can_use_inplace_predict() is True
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user