Add enable_categorical to sklearn. (#7011)
This commit is contained in:
parent
655e6992f6
commit
c4b9f4f622
@ -1642,6 +1642,7 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
|
||||
eval_group=None,
|
||||
eval_qid=None,
|
||||
missing=self.missing,
|
||||
enable_categorical=self.enable_categorical,
|
||||
)
|
||||
|
||||
if callable(self.objective):
|
||||
@ -1730,6 +1731,7 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
||||
eval_group=None,
|
||||
eval_qid=None,
|
||||
missing=self.missing,
|
||||
enable_categorical=self.enable_categorical,
|
||||
)
|
||||
|
||||
# pylint: disable=attribute-defined-outside-init
|
||||
@ -1927,6 +1929,7 @@ class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn):
|
||||
eval_group=None,
|
||||
eval_qid=eval_qid,
|
||||
missing=self.missing,
|
||||
enable_categorical=self.enable_categorical,
|
||||
)
|
||||
if eval_metric is not None:
|
||||
if callable(eval_metric):
|
||||
|
||||
@ -164,6 +164,14 @@ __model_doc = f'''
|
||||
validate_parameters : Optional[bool]
|
||||
Give warnings for unknown parameter.
|
||||
|
||||
enable_categorical : bool
|
||||
|
||||
.. versionadded:: 1.5.0
|
||||
|
||||
Experimental support for categorical data. Do not set to true unless you are
|
||||
interested in development. Only valid when `gpu_hist` and pandas dataframe are
|
||||
used.
|
||||
|
||||
kwargs : dict, optional
|
||||
Keyword arguments for XGBoost Booster object. Full documentation of
|
||||
parameters can be found here:
|
||||
@ -257,6 +265,7 @@ def _wrap_evaluation_matrices(
|
||||
eval_group: Optional[List[Any]],
|
||||
eval_qid: Optional[List[Any]],
|
||||
create_dmatrix: Callable,
|
||||
enable_categorical: bool,
|
||||
label_transform: Callable = lambda x: x,
|
||||
) -> Tuple[Any, Optional[List[Tuple[Any, str]]]]:
|
||||
"""Convert array_like evaluation matrices into DMatrix. Perform validation on the way.
|
||||
@ -271,6 +280,7 @@ def _wrap_evaluation_matrices(
|
||||
base_margin=base_margin,
|
||||
feature_weights=feature_weights,
|
||||
missing=missing,
|
||||
enable_categorical=enable_categorical,
|
||||
)
|
||||
|
||||
n_validation = 0 if eval_set is None else len(eval_set)
|
||||
@ -317,6 +327,7 @@ def _wrap_evaluation_matrices(
|
||||
qid=eval_qid[i],
|
||||
base_margin=base_margin_eval_set[i],
|
||||
missing=missing,
|
||||
enable_categorical=enable_categorical,
|
||||
)
|
||||
evals.append(m)
|
||||
nevals = len(evals)
|
||||
@ -375,6 +386,7 @@ class XGBModel(XGBModelBase):
|
||||
gpu_id: Optional[int] = None,
|
||||
validate_parameters: Optional[bool] = None,
|
||||
predictor: Optional[str] = None,
|
||||
enable_categorical: bool = False,
|
||||
**kwargs: Any
|
||||
) -> None:
|
||||
if not SKLEARN_INSTALLED:
|
||||
@ -411,6 +423,7 @@ class XGBModel(XGBModelBase):
|
||||
self.gpu_id = gpu_id
|
||||
self.validate_parameters = validate_parameters
|
||||
self.predictor = predictor
|
||||
self.enable_categorical = enable_categorical
|
||||
|
||||
def _more_tags(self) -> Dict[str, bool]:
|
||||
'''Tags used for scikit-learn data validation.'''
|
||||
@ -514,7 +527,9 @@ class XGBModel(XGBModelBase):
|
||||
params = self.get_params()
|
||||
# Parameters that should not go into native learner.
|
||||
wrapper_specific = {
|
||||
'importance_type', 'kwargs', 'missing', 'n_estimators', 'use_label_encoder'}
|
||||
'importance_type', 'kwargs', 'missing', 'n_estimators', 'use_label_encoder',
|
||||
"enable_categorical"
|
||||
}
|
||||
filtered = dict()
|
||||
for k, v in params.items():
|
||||
if k not in wrapper_specific and not callable(v):
|
||||
@ -735,6 +750,7 @@ class XGBModel(XGBModelBase):
|
||||
eval_group=None,
|
||||
eval_qid=None,
|
||||
create_dmatrix=lambda **kwargs: DMatrix(nthread=self.n_jobs, **kwargs),
|
||||
enable_categorical=self.enable_categorical,
|
||||
)
|
||||
params = self.get_xgb_params()
|
||||
|
||||
@ -1202,6 +1218,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
||||
eval_group=None,
|
||||
eval_qid=None,
|
||||
create_dmatrix=lambda **kwargs: DMatrix(nthread=self.n_jobs, **kwargs),
|
||||
enable_categorical=self.enable_categorical,
|
||||
label_transform=label_transform,
|
||||
)
|
||||
|
||||
@ -1628,6 +1645,7 @@ class XGBRanker(XGBModel, XGBRankerMixIn):
|
||||
eval_group=eval_group,
|
||||
eval_qid=eval_qid,
|
||||
create_dmatrix=lambda **kwargs: DMatrix(nthread=self.n_jobs, **kwargs),
|
||||
enable_categorical=self.enable_categorical,
|
||||
)
|
||||
|
||||
evals_result: TrainingCallback.EvalsLog = {}
|
||||
|
||||
@ -1,7 +1,10 @@
|
||||
import json
|
||||
import xgboost as xgb
|
||||
import pytest
|
||||
import tempfile
|
||||
import sys
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
sys.path.append("tests/python")
|
||||
import testing as tm # noqa
|
||||
@ -38,3 +41,36 @@ def test_boost_from_prediction_gpu_hist():
|
||||
|
||||
def test_num_parallel_tree():
|
||||
twskl.run_boston_housing_rf_regression("gpu_hist")
|
||||
|
||||
|
||||
@pytest.mark.skipif(**tm.no_pandas())
|
||||
@pytest.mark.skipif(**tm.no_sklearn())
|
||||
def test_categorical():
|
||||
import pandas as pd
|
||||
from sklearn.datasets import load_svmlight_file
|
||||
|
||||
data_dir = os.path.join(tm.PROJECT_ROOT, "demo", "data")
|
||||
X, y = load_svmlight_file(os.path.join(data_dir, "agaricus.txt.train"))
|
||||
clf = xgb.XGBClassifier(
|
||||
tree_method="gpu_hist",
|
||||
use_label_encoder=False,
|
||||
enable_categorical=True,
|
||||
predictor="gpu_predictor",
|
||||
n_estimators=10,
|
||||
)
|
||||
X = pd.DataFrame(X.todense()).astype("category")
|
||||
clf.fit(X, y)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
model = os.path.join(tempdir, "categorial.json")
|
||||
clf.save_model(model)
|
||||
|
||||
with open(model) as fd:
|
||||
categorical = json.load(fd)
|
||||
categories_sizes = np.array(
|
||||
categorical["learner"]["gradient_booster"]["model"]["trees"][0][
|
||||
"categories_sizes"
|
||||
]
|
||||
)
|
||||
assert categories_sizes.shape[0] != 0
|
||||
np.testing.assert_allclose(categories_sizes, 1)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user