[doc] Mention data consistency for categorical features. (#9678)

This commit is contained in:
Jiaming Yuan 2023-10-24 10:11:33 +08:00 committed by GitHub
parent 5e6cb63a56
commit 3ca06ac51e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 293 additions and 96 deletions

View File

@ -11,10 +11,13 @@ https://www.kaggle.com/shahules/an-overview-of-encoding-techniques
And the data can be found at:
https://www.kaggle.com/shahules/an-overview-of-encoding-techniques/data
Also, see the tutorial for using XGBoost with categorical data:
:doc:`/tutorials/categorical`.
.. versionadded:: 1.6.0
.. versionadded 1.6.0
See Also
--------
- :doc:`Tutorial </tutorials/categorical>`
- :ref:`sphx_glr_python_examples_categorical.py`
- :ref:`sphx_glr_python_examples_cat_pipeline.py`
"""

View File

@ -0,0 +1,145 @@
"""
Feature engineering pipeline for categorical data
=================================================
The script showcases how to keep the categorical data encoding consistent across
training and inference. There are many ways to attain the same goal, this script can be
used as a starting point.
See Also
--------
- :doc:`Tutorial </tutorials/categorical>`
- :ref:`sphx_glr_python_examples_categorical.py`
- :ref:`sphx_glr_python_examples_cat_in_the_dat.py`
"""
from typing import List, Tuple
import numpy as np
import pandas as pd
from sklearn.compose import make_column_selector, make_column_transformer
from sklearn.model_selection import train_test_split
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import OrdinalEncoder
import xgboost as xgb
def make_example_data() -> Tuple[pd.DataFrame, pd.Series, List[str]]:
"""Generate data for demo."""
n_samples = 2048
rng = np.random.default_rng(1994)
# We have three categorical features, while the rest are numerical.
categorical_features = ["brand_id", "retailer_id", "category_id"]
df = pd.DataFrame(
np.random.randint(32, 96, size=(n_samples, 3)),
columns=categorical_features,
)
df["price"] = rng.integers(100, 200, size=(n_samples,))
df["stock_status"] = rng.choice([True, False], n_samples)
df["on_sale"] = rng.choice([True, False], n_samples)
df["label"] = rng.normal(loc=0.0, scale=1.0, size=n_samples)
X = df.drop(["label"], axis=1)
y = df["label"]
return X, y, categorical_features
def native() -> None:
"""Using the native XGBoost interface."""
X, y, cat_feats = make_example_data()
X_train, X_test, y_train, y_test = train_test_split(
X, y, random_state=1994, test_size=0.2
)
# Create an encoder based on training data.
enc = OrdinalEncoder(handle_unknown="use_encoded_value", unknown_value=np.nan)
enc.set_output(transform="pandas")
enc = enc.fit(X_train[cat_feats])
def enc_transform(X: pd.DataFrame) -> pd.DataFrame:
# don't make change inplace so that we can have demonstrations for encoding
X = X.copy()
cat_cols = enc.transform(X[cat_feats])
for i, name in enumerate(cat_feats):
# create pd.Series based on the encoder
cat_cols[name] = pd.Categorical.from_codes(
codes=cat_cols[name].astype(np.int32), categories=enc.categories_[i]
)
X[cat_feats] = cat_cols
return X
# Encode the data based on fitted encoder.
X_train_enc = enc_transform(X_train)
X_test_enc = enc_transform(X_test)
# Train XGBoost model using the native interface.
Xy_train = xgb.QuantileDMatrix(X_train_enc, y_train, enable_categorical=True)
Xy_test = xgb.QuantileDMatrix(
X_test_enc, y_test, enable_categorical=True, ref=Xy_train
)
booster = xgb.train({}, Xy_train)
booster.predict(Xy_test)
# Following shows that data are encoded consistently.
# We first obtain result from newly encoded data
predt0 = booster.inplace_predict(enc_transform(X_train.head(16)))
# then we obtain result from already encoded data from training.
predt1 = booster.inplace_predict(X_train_enc.head(16))
np.testing.assert_allclose(predt0, predt1)
def pipeline() -> None:
"""Using the sklearn pipeline."""
X, y, cat_feats = make_example_data()
X_train, X_test, y_train, y_test = train_test_split(
X, y, random_state=3, test_size=0.2
)
enc = make_column_transformer(
(
OrdinalEncoder(handle_unknown="use_encoded_value", unknown_value=np.nan),
# all categorical feature names end with "_id"
make_column_selector(pattern=".*_id"),
),
remainder="passthrough",
verbose_feature_names_out=False,
)
# No need to set pandas output, we use `feature_types` to indicate the type of
# features.
# enc.set_output(transform="pandas")
feature_types = ["c" if fn in cat_feats else "q" for fn in X_train.columns]
reg = xgb.XGBRegressor(
feature_types=feature_types, enable_categorical=True, n_estimators=10
)
p = make_pipeline(enc, reg)
p.fit(X_train, y_train)
# check XGBoost is using the feature type correctly.
model_types = reg.get_booster().feature_types
assert model_types is not None
for a, b in zip(model_types, feature_types):
assert a == b
# Following shows that data are encoded consistently.
# We first create a slice of data that doesn't contain all the categories
predt0 = p.predict(X_train.iloc[:16, :])
# Then we use the dataframe that contains all the categories
predt1 = p.predict(X_train)[:16]
# The resulting encoding is the same
np.testing.assert_allclose(predt0, predt1)
if __name__ == "__main__":
pipeline()
native()

View File

@ -8,10 +8,13 @@ In before, users need to run an encoder themselves before passing the data into
which creates a sparse matrix and potentially increase memory usage. This demo
showcases the experimental categorical data support, more advanced features are planned.
Also, see :doc:`the tutorial </tutorials/categorical>` for using XGBoost with
categorical data.
.. versionadded:: 1.5.0
.. versionadded:: 1.5.0
See Also
--------
- :doc:`Tutorial </tutorials/categorical>`
- :ref:`sphx_glr_python_examples_cat_in_the_dat.py`
- :ref:`sphx_glr_python_examples_cat_pipeline.py`
"""
from typing import Tuple
@ -52,11 +55,13 @@ def make_categorical(
def main() -> None:
# Use builtin categorical data support
# For scikit-learn interface, the input data must be pandas DataFrame or cudf
# DataFrame with categorical features
# For scikit-learn interface, the input data should be pandas DataFrame or cudf
# DataFrame with categorical features. If an numpy/cupy array is used instead, the
# `feature_types` for `XGBRegressor` should be set accordingly.
X, y = make_categorical(100, 10, 4, False)
# Specify `enable_categorical` to True, also we use onehot encoding based split
# here for demonstration. For details see the document of `max_cat_to_onehot`.
# Specify `enable_categorical` to True, also we use onehot-encoding-based split here
# for demonstration. For details see the document of `max_cat_to_onehot`.
reg = xgb.XGBRegressor(
tree_method="hist", enable_categorical=True, max_cat_to_onehot=5, device="cuda"
)

View File

@ -137,7 +137,7 @@ To build and run C++ unit tests enable tests while running CMake:
./testxgboost
Flags like ``USE_CUDA``, ``USE_DMLC_GTEST`` are optional. For more info about how to build
XGBoost from source, see :doc:`</build>`. One can also run all unit test using ctest tool
XGBoost from source, see :doc:`/build`. One can also run all unit tests using ctest tool
which provides higher flexibility. For example:
.. code-block:: bash

View File

@ -94,11 +94,11 @@ Using native interface
**********************
The ``scikit-learn`` interface is user friendly, but lacks some features that are only
available in native interface. For instance users cannot compute SHAP value directly or
use quantized :class:`DMatrix <xgboost.DMatrix>`. Also native interface supports data
types other than dataframe, like ``numpy/cupy array``. To use the native interface with
categorical data, we need to pass the similar parameter to :class:`DMatrix
<xgboost.DMatrix>` and the :func:`train <xgboost.train>` function. For dataframe input:
available in native interface. For instance users cannot compute SHAP value directly.
Also native interface supports more data types. To use the native interface with
categorical data, we need to pass the similar parameter to :class:`~xgboost.DMatrix` or
:py:class:`~xgboost.QuantileDMatrix` and the :func:`train <xgboost.train>` function. For
dataframe input:
.. code:: python
@ -117,7 +117,6 @@ SHAP value computation:
# categorical features are listed as "c"
print(booster.feature_types)
For other types of input, like ``numpy array``, we can tell XGBoost about the feature
types by using the ``feature_types`` parameter in :class:`DMatrix <xgboost.DMatrix>`:
@ -131,7 +130,31 @@ types by using the ``feature_types`` parameter in :class:`DMatrix <xgboost.DMatr
For numerical data, the feature type can be ``"q"`` or ``"float"``, while for categorical
feature it's specified as ``"c"``. The Dask module in XGBoost has the same interface so
:class:`dask.Array <dask.Array>` can also be used for categorical data.
:class:`dask.Array <dask.Array>` can also be used for categorical data. Lastly, the
sklearn interface :py:class:`~xgboost.XGBRegressor` has the same parameter.
****************
Data Consistency
****************
XGBoost accepts parameters to indicate which feature is considered categorical, either through the ``dtypes`` of a dataframe or through the ``feature_types`` parameter. However, XGBoost by itself doesn't store information on how categories are encoded in the first place. For instance, given an encoding schema that maps music genres to integer codes:
.. code-block:: python
{"acoustic": 0, "indie": 1, "blues": 2, "country": 3}
XGBoost doesn't know this mapping from the input and hence cannot store it in the model. The mapping usually happens in the users' data engineering pipeline with column transformers like :py:class:`sklearn.preprocessing.OrdinalEncoder`. To make sure correct result from XGBoost, users need to keep the pipeline for transforming data consistent across training and testing data. One should watch out for errors like:
.. code-block:: python
X_train["genre"] = X_train["genre"].astype("category")
reg = xgb.XGBRegressor(enable_categorical=True).fit(X_train, y_train)
# invalid encoding
X_test["genre"] = X_test["genre"].astype("category")
reg.predict(X_test)
In the above snippet, training data and test data are encoded separately, resulting in two different encoding schemas and invalid prediction result. See :ref:`sphx_glr_python_examples_cat_pipeline.py` for a worked example using ordinal encoder.
*************
Miscellaneous

View File

@ -821,7 +821,7 @@ class DirectoryExcursion:
"""
def __init__(self, path: os.PathLike, cleanup: bool = False):
def __init__(self, path: Union[os.PathLike, str], cleanup: bool = False):
self.path = path
self.curdir = os.path.normpath(os.path.abspath(os.path.curdir))
self.cleanup = cleanup

View File

@ -21,6 +21,7 @@ class LintersPaths:
"tests/python/test_data_iterator.py",
"tests/python/test_dmatrix.py",
"tests/python/test_dt.py",
"tests/python/test_demos.py",
"tests/python/test_predict.py",
"tests/python/test_quantile_dmatrix.py",
"tests/python/test_tree_regularization.py",
@ -41,6 +42,7 @@ class LintersPaths:
"demo/guide-python/cat_in_the_dat.py",
"demo/guide-python/callbacks.py",
"demo/guide-python/categorical.py",
"demo/guide-python/cat_pipeline.py",
"demo/guide-python/feature_weights.py",
"demo/guide-python/sklearn_parallel.py",
"demo/guide-python/spark_estimator_examples.py",
@ -79,6 +81,7 @@ class LintersPaths:
"python-package/",
# tests
"tests/python/test_dt.py",
"tests/python/test_demos.py",
"tests/python/test_data_iterator.py",
"tests/python-gpu/test_gpu_data_iterator.py",
"tests/python-gpu/load_pickle.py",
@ -89,6 +92,8 @@ class LintersPaths:
"demo/json-model/json_parser.py",
"demo/guide-python/external_memory.py",
"demo/guide-python/cat_in_the_dat.py",
"demo/guide-python/categorical.py",
"demo/guide-python/cat_pipeline.py",
"demo/guide-python/feature_weights.py",
"demo/guide-python/individual_trees.py",
"demo/guide-python/quantile_regression.py",

View File

@ -11,129 +11,143 @@ from xgboost import testing as tm
pytestmark = tm.timeout(30)
DEMO_DIR = tm.demo_dir(__file__)
PYTHON_DEMO_DIR = os.path.join(DEMO_DIR, 'guide-python')
CLI_DEMO_DIR = os.path.join(DEMO_DIR, 'CLI')
PYTHON_DEMO_DIR = os.path.join(DEMO_DIR, "guide-python")
CLI_DEMO_DIR = os.path.join(DEMO_DIR, "CLI")
def test_basic_walkthrough():
script = os.path.join(PYTHON_DEMO_DIR, 'basic_walkthrough.py')
cmd = ['python', script]
def test_basic_walkthrough() -> None:
script = os.path.join(PYTHON_DEMO_DIR, "basic_walkthrough.py")
cmd = ["python", script]
with tempfile.TemporaryDirectory() as tmpdir:
subprocess.check_call(cmd, cwd=tmpdir)
@pytest.mark.skipif(**tm.no_matplotlib())
def test_custom_multiclass_objective():
script = os.path.join(PYTHON_DEMO_DIR, 'custom_softmax.py')
cmd = ['python', script, '--plot=0']
@pytest.mark.skipif(**tm.no_pandas())
def test_categorical() -> None:
script = os.path.join(PYTHON_DEMO_DIR, "categorical.py")
cmd = ["python", script]
subprocess.check_call(cmd)
@pytest.mark.skipif(**tm.no_pandas())
def test_cat_pipeline() -> None:
script = os.path.join(PYTHON_DEMO_DIR, "cat_pipeline.py")
cmd = ["python", script]
subprocess.check_call(cmd)
@pytest.mark.skipif(**tm.no_matplotlib())
def test_custom_rmsle_objective():
script = os.path.join(PYTHON_DEMO_DIR, 'custom_rmsle.py')
cmd = ['python', script, '--plot=0']
def test_custom_multiclass_objective() -> None:
script = os.path.join(PYTHON_DEMO_DIR, "custom_softmax.py")
cmd = ["python", script, "--plot=0"]
subprocess.check_call(cmd)
@pytest.mark.skipif(**tm.no_matplotlib())
def test_feature_weights_demo():
script = os.path.join(PYTHON_DEMO_DIR, 'feature_weights.py')
cmd = ['python', script, '--plot=0']
def test_custom_rmsle_objective() -> None:
script = os.path.join(PYTHON_DEMO_DIR, "custom_rmsle.py")
cmd = ["python", script, "--plot=0"]
subprocess.check_call(cmd)
@pytest.mark.skipif(**tm.no_matplotlib())
def test_feature_weights_demo() -> None:
script = os.path.join(PYTHON_DEMO_DIR, "feature_weights.py")
cmd = ["python", script, "--plot=0"]
subprocess.check_call(cmd)
@pytest.mark.skipif(**tm.no_sklearn())
def test_sklearn_demo():
script = os.path.join(PYTHON_DEMO_DIR, 'sklearn_examples.py')
cmd = ['python', script]
def test_sklearn_demo() -> None:
script = os.path.join(PYTHON_DEMO_DIR, "sklearn_examples.py")
cmd = ["python", script]
subprocess.check_call(cmd)
assert os.path.exists('best_calif.pkl')
os.remove('best_calif.pkl')
assert os.path.exists("best_calif.pkl")
os.remove("best_calif.pkl")
@pytest.mark.skipif(**tm.no_sklearn())
def test_sklearn_parallel_demo():
script = os.path.join(PYTHON_DEMO_DIR, 'sklearn_parallel.py')
cmd = ['python', script]
def test_sklearn_parallel_demo() -> None:
script = os.path.join(PYTHON_DEMO_DIR, "sklearn_parallel.py")
cmd = ["python", script]
subprocess.check_call(cmd)
@pytest.mark.skipif(**tm.no_sklearn())
def test_sklearn_evals_result_demo():
script = os.path.join(PYTHON_DEMO_DIR, 'sklearn_evals_result.py')
cmd = ['python', script]
def test_sklearn_evals_result_demo() -> None:
script = os.path.join(PYTHON_DEMO_DIR, "sklearn_evals_result.py")
cmd = ["python", script]
subprocess.check_call(cmd)
def test_boost_from_prediction_demo():
script = os.path.join(PYTHON_DEMO_DIR, 'boost_from_prediction.py')
cmd = ['python', script]
def test_boost_from_prediction_demo() -> None:
script = os.path.join(PYTHON_DEMO_DIR, "boost_from_prediction.py")
cmd = ["python", script]
subprocess.check_call(cmd)
def test_predict_first_ntree_demo():
script = os.path.join(PYTHON_DEMO_DIR, 'predict_first_ntree.py')
cmd = ['python', script]
def test_predict_first_ntree_demo() -> None:
script = os.path.join(PYTHON_DEMO_DIR, "predict_first_ntree.py")
cmd = ["python", script]
subprocess.check_call(cmd)
def test_individual_trees():
script = os.path.join(PYTHON_DEMO_DIR, 'individual_trees.py')
cmd = ['python', script]
def test_individual_trees() -> None:
script = os.path.join(PYTHON_DEMO_DIR, "individual_trees.py")
cmd = ["python", script]
subprocess.check_call(cmd)
def test_predict_leaf_indices_demo():
script = os.path.join(PYTHON_DEMO_DIR, 'predict_leaf_indices.py')
cmd = ['python', script]
def test_predict_leaf_indices_demo() -> None:
script = os.path.join(PYTHON_DEMO_DIR, "predict_leaf_indices.py")
cmd = ["python", script]
subprocess.check_call(cmd)
def test_generalized_linear_model_demo():
script = os.path.join(PYTHON_DEMO_DIR, 'generalized_linear_model.py')
cmd = ['python', script]
def test_generalized_linear_model_demo() -> None:
script = os.path.join(PYTHON_DEMO_DIR, "generalized_linear_model.py")
cmd = ["python", script]
subprocess.check_call(cmd)
def test_cross_validation_demo():
script = os.path.join(PYTHON_DEMO_DIR, 'cross_validation.py')
cmd = ['python', script]
def test_cross_validation_demo() -> None:
script = os.path.join(PYTHON_DEMO_DIR, "cross_validation.py")
cmd = ["python", script]
subprocess.check_call(cmd)
def test_external_memory_demo():
script = os.path.join(PYTHON_DEMO_DIR, 'external_memory.py')
cmd = ['python', script]
def test_external_memory_demo() -> None:
script = os.path.join(PYTHON_DEMO_DIR, "external_memory.py")
cmd = ["python", script]
subprocess.check_call(cmd)
def test_evals_result_demo():
script = os.path.join(PYTHON_DEMO_DIR, 'evals_result.py')
cmd = ['python', script]
def test_evals_result_demo() -> None:
script = os.path.join(PYTHON_DEMO_DIR, "evals_result.py")
cmd = ["python", script]
subprocess.check_call(cmd)
@pytest.mark.skipif(**tm.no_sklearn())
@pytest.mark.skipif(**tm.no_pandas())
def test_aft_demo():
script = os.path.join(DEMO_DIR, 'aft_survival', 'aft_survival_demo.py')
cmd = ['python', script]
def test_aft_demo() -> None:
script = os.path.join(DEMO_DIR, "aft_survival", "aft_survival_demo.py")
cmd = ["python", script]
subprocess.check_call(cmd)
assert os.path.exists('aft_model.json')
os.remove('aft_model.json')
assert os.path.exists("aft_model.json")
os.remove("aft_model.json")
@pytest.mark.skipif(**tm.no_matplotlib())
def test_callbacks_demo():
script = os.path.join(PYTHON_DEMO_DIR, 'callbacks.py')
cmd = ['python', script, '--plot=0']
def test_callbacks_demo() -> None:
script = os.path.join(PYTHON_DEMO_DIR, "callbacks.py")
cmd = ["python", script, "--plot=0"]
subprocess.check_call(cmd)
def test_continuation_demo():
script = os.path.join(PYTHON_DEMO_DIR, 'continuation.py')
cmd = ['python', script]
def test_continuation_demo() -> None:
script = os.path.join(PYTHON_DEMO_DIR, "continuation.py")
cmd = ["python", script]
subprocess.check_call(cmd)
@ -141,14 +155,14 @@ def test_continuation_demo():
@pytest.mark.skipif(**tm.no_matplotlib())
def test_multioutput_reg() -> None:
script = os.path.join(PYTHON_DEMO_DIR, "multioutput_regression.py")
cmd = ['python', script, "--plot=0"]
cmd = ["python", script, "--plot=0"]
subprocess.check_call(cmd)
@pytest.mark.skipif(**tm.no_sklearn())
def test_quantile_reg() -> None:
script = os.path.join(PYTHON_DEMO_DIR, "quantile_regression.py")
cmd = ['python', script]
cmd = ["python", script]
subprocess.check_call(cmd)
@ -197,28 +211,30 @@ def test_json_model() -> None:
# - aft tunning is not tested due to extra dependency.
def test_cli_regression_demo():
reg_dir = os.path.join(CLI_DEMO_DIR, 'regression')
script = os.path.join(reg_dir, 'mapfeat.py')
cmd = ['python', script]
def test_cli_regression_demo() -> None:
reg_dir = os.path.join(CLI_DEMO_DIR, "regression")
script = os.path.join(reg_dir, "mapfeat.py")
cmd = ["python", script]
subprocess.check_call(cmd, cwd=reg_dir)
script = os.path.join(reg_dir, 'mknfold.py')
cmd = ['python', script, 'machine.txt', '1']
script = os.path.join(reg_dir, "mknfold.py")
cmd = ["python", script, "machine.txt", "1"]
subprocess.check_call(cmd, cwd=reg_dir)
exe = os.path.join(DEMO_DIR, os.path.pardir, 'xgboost')
conf = os.path.join(reg_dir, 'machine.conf')
exe = os.path.join(DEMO_DIR, os.path.pardir, "xgboost")
conf = os.path.join(reg_dir, "machine.conf")
subprocess.check_call([exe, conf], cwd=reg_dir)
@pytest.mark.skipif(condition=sys.platform.startswith("win"),
reason='Test requires sh execution.')
def test_cli_binary_classification():
cls_dir = os.path.join(CLI_DEMO_DIR, 'binary_classification')
@pytest.mark.skipif(
condition=sys.platform.startswith("win"), reason="Test requires sh execution."
)
def test_cli_binary_classification() -> None:
cls_dir = os.path.join(CLI_DEMO_DIR, "binary_classification")
with tm.DirectoryExcursion(cls_dir, cleanup=True):
subprocess.check_call(['./runexp.sh'])
os.remove('0002.model')
subprocess.check_call(["./runexp.sh"])
os.remove("0002.model")
# year prediction is not tested due to data size being too large.
# rank is not tested as it requires unrar command.