Update parameter for categorical feature. (#8285)

This commit is contained in:
Jiaming Yuan 2022-10-10 19:48:29 +08:00 committed by GitHub
parent 5545c49cfc
commit c68684ff4c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 18 additions and 4 deletions

View File

@ -235,7 +235,7 @@ These parameters are only used for training with categorical data. See
* ``max_cat_to_onehot`` * ``max_cat_to_onehot``
.. versionadded:: 1.6 .. versionadded:: 1.6.0
.. note:: This parameter is experimental. ``exact`` tree method is not yet supported. .. note:: This parameter is experimental. ``exact`` tree method is not yet supported.

View File

@ -84,7 +84,7 @@ values are categories, and the measure is the output leaf value. Intuitively, w
group the categories that output similar leaf values. During split finding, we first sort group the categories that output similar leaf values. During split finding, we first sort
the gradient histogram to prepare the contiguous partitions then enumerate the splits the gradient histogram to prepare the contiguous partitions then enumerate the splits
according to these sorted values. One of the related parameters for XGBoost is according to these sorted values. One of the related parameters for XGBoost is
``max_cat_to_one_hot``, which controls whether one-hot encoding or partitioning should be ``max_cat_to_onehot``, which controls whether one-hot encoding or partitioning should be
used for each feature, see :ref:`cat-param` for details. used for each feature, see :ref:`cat-param` for details.

View File

@ -249,8 +249,20 @@ __model_doc = f"""
A threshold for deciding whether XGBoost should use one-hot encoding based split A threshold for deciding whether XGBoost should use one-hot encoding based split
for categorical data. When number of categories is lesser than the threshold for categorical data. When number of categories is lesser than the threshold
then one-hot encoding is chosen, otherwise the categories will be partitioned then one-hot encoding is chosen, otherwise the categories will be partitioned
into children nodes. Only relevant for regression and binary classification. into children nodes. Also, `enable_categorical` needs to be set to have
See :doc:`Categorical Data </tutorials/categorical>` for details. categorical feature support. See :doc:`Categorical Data
</tutorials/categorical>` and :ref:`cat-param` for details.
max_cat_threshold : Optional[int]
.. versionadded:: 1.7.0
.. note:: This parameter is experimental
Maximum number of categories considered for each split. Used only by
partition-based splits for preventing over-fitting. Also, `enable_categorical`
needs to be set to have categorical feature support. See :doc:`Categorical Data
</tutorials/categorical>` and :ref:`cat-param` for details.
eval_metric : Optional[Union[str, List[str], Callable]] eval_metric : Optional[Union[str, List[str], Callable]]
@ -562,6 +574,7 @@ class XGBModel(XGBModelBase):
enable_categorical: bool = False, enable_categorical: bool = False,
feature_types: FeatureTypes = None, feature_types: FeatureTypes = None,
max_cat_to_onehot: Optional[int] = None, max_cat_to_onehot: Optional[int] = None,
max_cat_threshold: Optional[int] = None,
eval_metric: Optional[Union[str, List[str], Callable]] = None, eval_metric: Optional[Union[str, List[str], Callable]] = None,
early_stopping_rounds: Optional[int] = None, early_stopping_rounds: Optional[int] = None,
callbacks: Optional[List[TrainingCallback]] = None, callbacks: Optional[List[TrainingCallback]] = None,
@ -607,6 +620,7 @@ class XGBModel(XGBModelBase):
self.enable_categorical = enable_categorical self.enable_categorical = enable_categorical
self.feature_types = feature_types self.feature_types = feature_types
self.max_cat_to_onehot = max_cat_to_onehot self.max_cat_to_onehot = max_cat_to_onehot
self.max_cat_threshold = max_cat_threshold
self.eval_metric = eval_metric self.eval_metric = eval_metric
self.early_stopping_rounds = early_stopping_rounds self.early_stopping_rounds = early_stopping_rounds
self.callbacks = callbacks self.callbacks = callbacks