Demo for experimental categorical data support. (#7213)
This commit is contained in:
parent
3515931305
commit
d997c967d5
75
demo/guide-python/categorical.py
Normal file
75
demo/guide-python/categorical.py
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
"""Experimental support for categorical data. After 1.5 XGBoost `gpu_hist` tree method
|
||||||
|
has experimental support for one-hot encoding based tree split.
|
||||||
|
|
||||||
|
In before, users need to run an encoder themselves before passing the data into XGBoost,
|
||||||
|
which creates a sparse matrix and potentially increase memory usage. This demo showcases
|
||||||
|
the experimental categorical data support, more advanced features are planned.
|
||||||
|
|
||||||
|
.. versionadded:: 1.5.0
|
||||||
|
|
||||||
|
"""
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
import xgboost as xgb
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
|
||||||
|
def make_categorical(
|
||||||
|
n_samples: int, n_features: int, n_categories: int, onehot: bool
|
||||||
|
) -> Tuple[pd.DataFrame, pd.Series]:
|
||||||
|
"""Make some random data for demo."""
|
||||||
|
rng = np.random.RandomState(1994)
|
||||||
|
|
||||||
|
pd_dict = {}
|
||||||
|
for i in range(n_features + 1):
|
||||||
|
c = rng.randint(low=0, high=n_categories, size=n_samples)
|
||||||
|
pd_dict[str(i)] = pd.Series(c, dtype=np.int64)
|
||||||
|
|
||||||
|
df = pd.DataFrame(pd_dict)
|
||||||
|
label = df.iloc[:, 0]
|
||||||
|
df = df.iloc[:, 1:]
|
||||||
|
for i in range(0, n_features):
|
||||||
|
label += df.iloc[:, i]
|
||||||
|
label += 1
|
||||||
|
|
||||||
|
df = df.astype("category")
|
||||||
|
categories = np.arange(0, n_categories)
|
||||||
|
for col in df.columns:
|
||||||
|
df[col] = df[col].cat.set_categories(categories)
|
||||||
|
|
||||||
|
if onehot:
|
||||||
|
return pd.get_dummies(df), label
|
||||||
|
return df, label
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
# Use builtin categorical data support
|
||||||
|
# Must be pandas DataFrame or cudf DataFrame with categorical data
|
||||||
|
X, y = make_categorical(100, 10, 4, False)
|
||||||
|
# Specify `enable_categorical` to True.
|
||||||
|
reg = xgb.XGBRegressor(tree_method="gpu_hist", enable_categorical=True)
|
||||||
|
reg.fit(X, y, eval_set=[(X, y)])
|
||||||
|
|
||||||
|
# Pass in already encoded data
|
||||||
|
X_enc, y_enc = make_categorical(100, 10, 4, True)
|
||||||
|
reg_enc = xgb.XGBRegressor(tree_method="gpu_hist")
|
||||||
|
reg_enc.fit(X_enc, y_enc, eval_set=[(X_enc, y_enc)])
|
||||||
|
|
||||||
|
reg_results = np.array(reg.evals_result()["validation_0"]["rmse"])
|
||||||
|
reg_enc_results = np.array(reg_enc.evals_result()["validation_0"]["rmse"])
|
||||||
|
|
||||||
|
# Check that they have same results
|
||||||
|
np.testing.assert_allclose(reg_results, reg_enc_results)
|
||||||
|
|
||||||
|
# Convert to DMatrix for SHAP value
|
||||||
|
booster: xgb.Booster = reg.get_booster()
|
||||||
|
m = xgb.DMatrix(X, enable_categorical=True) # specify categorical data support.
|
||||||
|
SHAP = booster.predict(m, pred_contribs=True)
|
||||||
|
margin = booster.predict(m, output_margin=True)
|
||||||
|
np.testing.assert_allclose(
|
||||||
|
np.sum(SHAP, axis=len(SHAP.shape) - 1), margin, rtol=1e-3
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@ -20,6 +20,12 @@ def test_update_process_demo():
|
|||||||
subprocess.check_call(cmd)
|
subprocess.check_call(cmd)
|
||||||
|
|
||||||
|
|
||||||
|
def test_categorical_demo():
|
||||||
|
script = os.path.join(td.PYTHON_DEMO_DIR, 'categorical.py')
|
||||||
|
cmd = ['python', script]
|
||||||
|
subprocess.check_call(cmd)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(**tm.no_dask())
|
@pytest.mark.skipif(**tm.no_dask())
|
||||||
@pytest.mark.skipif(**tm.no_dask_cuda())
|
@pytest.mark.skipif(**tm.no_dask_cuda())
|
||||||
@pytest.mark.skipif(**tm.no_cupy())
|
@pytest.mark.skipif(**tm.no_cupy())
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user