[doc] Mention data consistency for categorical features. (#9678)
This commit is contained in:
parent
5e6cb63a56
commit
3ca06ac51e
@ -11,10 +11,13 @@ https://www.kaggle.com/shahules/an-overview-of-encoding-techniques
|
|||||||
And the data can be found at:
|
And the data can be found at:
|
||||||
https://www.kaggle.com/shahules/an-overview-of-encoding-techniques/data
|
https://www.kaggle.com/shahules/an-overview-of-encoding-techniques/data
|
||||||
|
|
||||||
Also, see the tutorial for using XGBoost with categorical data:
|
.. versionadded:: 1.6.0
|
||||||
:doc:`/tutorials/categorical`.
|
|
||||||
|
|
||||||
.. versionadded 1.6.0
|
See Also
|
||||||
|
--------
|
||||||
|
- :doc:`Tutorial </tutorials/categorical>`
|
||||||
|
- :ref:`sphx_glr_python_examples_categorical.py`
|
||||||
|
- :ref:`sphx_glr_python_examples_cat_pipeline.py`
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|||||||
145
demo/guide-python/cat_pipeline.py
Normal file
145
demo/guide-python/cat_pipeline.py
Normal file
@ -0,0 +1,145 @@
|
|||||||
|
"""
|
||||||
|
Feature engineering pipeline for categorical data
|
||||||
|
=================================================
|
||||||
|
|
||||||
|
The script showcases how to keep the categorical data encoding consistent across
|
||||||
|
training and inference. There are many ways to attain the same goal, this script can be
|
||||||
|
used as a starting point.
|
||||||
|
|
||||||
|
See Also
|
||||||
|
--------
|
||||||
|
- :doc:`Tutorial </tutorials/categorical>`
|
||||||
|
- :ref:`sphx_glr_python_examples_categorical.py`
|
||||||
|
- :ref:`sphx_glr_python_examples_cat_in_the_dat.py`
|
||||||
|
|
||||||
|
"""
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
from sklearn.compose import make_column_selector, make_column_transformer
|
||||||
|
from sklearn.model_selection import train_test_split
|
||||||
|
from sklearn.pipeline import make_pipeline
|
||||||
|
from sklearn.preprocessing import OrdinalEncoder
|
||||||
|
|
||||||
|
import xgboost as xgb
|
||||||
|
|
||||||
|
|
||||||
|
def make_example_data() -> Tuple[pd.DataFrame, pd.Series, List[str]]:
|
||||||
|
"""Generate data for demo."""
|
||||||
|
n_samples = 2048
|
||||||
|
rng = np.random.default_rng(1994)
|
||||||
|
|
||||||
|
# We have three categorical features, while the rest are numerical.
|
||||||
|
categorical_features = ["brand_id", "retailer_id", "category_id"]
|
||||||
|
|
||||||
|
df = pd.DataFrame(
|
||||||
|
np.random.randint(32, 96, size=(n_samples, 3)),
|
||||||
|
columns=categorical_features,
|
||||||
|
)
|
||||||
|
|
||||||
|
df["price"] = rng.integers(100, 200, size=(n_samples,))
|
||||||
|
df["stock_status"] = rng.choice([True, False], n_samples)
|
||||||
|
df["on_sale"] = rng.choice([True, False], n_samples)
|
||||||
|
df["label"] = rng.normal(loc=0.0, scale=1.0, size=n_samples)
|
||||||
|
|
||||||
|
X = df.drop(["label"], axis=1)
|
||||||
|
y = df["label"]
|
||||||
|
|
||||||
|
return X, y, categorical_features
|
||||||
|
|
||||||
|
|
||||||
|
def native() -> None:
|
||||||
|
"""Using the native XGBoost interface."""
|
||||||
|
X, y, cat_feats = make_example_data()
|
||||||
|
|
||||||
|
X_train, X_test, y_train, y_test = train_test_split(
|
||||||
|
X, y, random_state=1994, test_size=0.2
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create an encoder based on training data.
|
||||||
|
enc = OrdinalEncoder(handle_unknown="use_encoded_value", unknown_value=np.nan)
|
||||||
|
enc.set_output(transform="pandas")
|
||||||
|
enc = enc.fit(X_train[cat_feats])
|
||||||
|
|
||||||
|
def enc_transform(X: pd.DataFrame) -> pd.DataFrame:
|
||||||
|
# don't make change inplace so that we can have demonstrations for encoding
|
||||||
|
X = X.copy()
|
||||||
|
cat_cols = enc.transform(X[cat_feats])
|
||||||
|
for i, name in enumerate(cat_feats):
|
||||||
|
# create pd.Series based on the encoder
|
||||||
|
cat_cols[name] = pd.Categorical.from_codes(
|
||||||
|
codes=cat_cols[name].astype(np.int32), categories=enc.categories_[i]
|
||||||
|
)
|
||||||
|
X[cat_feats] = cat_cols
|
||||||
|
return X
|
||||||
|
|
||||||
|
# Encode the data based on fitted encoder.
|
||||||
|
X_train_enc = enc_transform(X_train)
|
||||||
|
X_test_enc = enc_transform(X_test)
|
||||||
|
# Train XGBoost model using the native interface.
|
||||||
|
Xy_train = xgb.QuantileDMatrix(X_train_enc, y_train, enable_categorical=True)
|
||||||
|
Xy_test = xgb.QuantileDMatrix(
|
||||||
|
X_test_enc, y_test, enable_categorical=True, ref=Xy_train
|
||||||
|
)
|
||||||
|
booster = xgb.train({}, Xy_train)
|
||||||
|
booster.predict(Xy_test)
|
||||||
|
|
||||||
|
# Following shows that data are encoded consistently.
|
||||||
|
|
||||||
|
# We first obtain result from newly encoded data
|
||||||
|
predt0 = booster.inplace_predict(enc_transform(X_train.head(16)))
|
||||||
|
# then we obtain result from already encoded data from training.
|
||||||
|
predt1 = booster.inplace_predict(X_train_enc.head(16))
|
||||||
|
|
||||||
|
np.testing.assert_allclose(predt0, predt1)
|
||||||
|
|
||||||
|
|
||||||
|
def pipeline() -> None:
|
||||||
|
"""Using the sklearn pipeline."""
|
||||||
|
X, y, cat_feats = make_example_data()
|
||||||
|
|
||||||
|
X_train, X_test, y_train, y_test = train_test_split(
|
||||||
|
X, y, random_state=3, test_size=0.2
|
||||||
|
)
|
||||||
|
|
||||||
|
enc = make_column_transformer(
|
||||||
|
(
|
||||||
|
OrdinalEncoder(handle_unknown="use_encoded_value", unknown_value=np.nan),
|
||||||
|
# all categorical feature names end with "_id"
|
||||||
|
make_column_selector(pattern=".*_id"),
|
||||||
|
),
|
||||||
|
remainder="passthrough",
|
||||||
|
verbose_feature_names_out=False,
|
||||||
|
)
|
||||||
|
# No need to set pandas output, we use `feature_types` to indicate the type of
|
||||||
|
# features.
|
||||||
|
|
||||||
|
# enc.set_output(transform="pandas")
|
||||||
|
|
||||||
|
feature_types = ["c" if fn in cat_feats else "q" for fn in X_train.columns]
|
||||||
|
reg = xgb.XGBRegressor(
|
||||||
|
feature_types=feature_types, enable_categorical=True, n_estimators=10
|
||||||
|
)
|
||||||
|
p = make_pipeline(enc, reg)
|
||||||
|
p.fit(X_train, y_train)
|
||||||
|
# check XGBoost is using the feature type correctly.
|
||||||
|
model_types = reg.get_booster().feature_types
|
||||||
|
assert model_types is not None
|
||||||
|
for a, b in zip(model_types, feature_types):
|
||||||
|
assert a == b
|
||||||
|
|
||||||
|
# Following shows that data are encoded consistently.
|
||||||
|
|
||||||
|
# We first create a slice of data that doesn't contain all the categories
|
||||||
|
predt0 = p.predict(X_train.iloc[:16, :])
|
||||||
|
# Then we use the dataframe that contains all the categories
|
||||||
|
predt1 = p.predict(X_train)[:16]
|
||||||
|
|
||||||
|
# The resulting encoding is the same
|
||||||
|
np.testing.assert_allclose(predt0, predt1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pipeline()
|
||||||
|
native()
|
||||||
@ -8,11 +8,14 @@ In before, users need to run an encoder themselves before passing the data into
|
|||||||
which creates a sparse matrix and potentially increase memory usage. This demo
|
which creates a sparse matrix and potentially increase memory usage. This demo
|
||||||
showcases the experimental categorical data support, more advanced features are planned.
|
showcases the experimental categorical data support, more advanced features are planned.
|
||||||
|
|
||||||
Also, see :doc:`the tutorial </tutorials/categorical>` for using XGBoost with
|
|
||||||
categorical data.
|
|
||||||
|
|
||||||
.. versionadded:: 1.5.0
|
.. versionadded:: 1.5.0
|
||||||
|
|
||||||
|
See Also
|
||||||
|
--------
|
||||||
|
- :doc:`Tutorial </tutorials/categorical>`
|
||||||
|
- :ref:`sphx_glr_python_examples_cat_in_the_dat.py`
|
||||||
|
- :ref:`sphx_glr_python_examples_cat_pipeline.py`
|
||||||
|
|
||||||
"""
|
"""
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
@ -52,11 +55,13 @@ def make_categorical(
|
|||||||
|
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
# Use builtin categorical data support
|
# Use builtin categorical data support
|
||||||
# For scikit-learn interface, the input data must be pandas DataFrame or cudf
|
|
||||||
# DataFrame with categorical features
|
# For scikit-learn interface, the input data should be pandas DataFrame or cudf
|
||||||
|
# DataFrame with categorical features. If an numpy/cupy array is used instead, the
|
||||||
|
# `feature_types` for `XGBRegressor` should be set accordingly.
|
||||||
X, y = make_categorical(100, 10, 4, False)
|
X, y = make_categorical(100, 10, 4, False)
|
||||||
# Specify `enable_categorical` to True, also we use onehot encoding based split
|
# Specify `enable_categorical` to True, also we use onehot-encoding-based split here
|
||||||
# here for demonstration. For details see the document of `max_cat_to_onehot`.
|
# for demonstration. For details see the document of `max_cat_to_onehot`.
|
||||||
reg = xgb.XGBRegressor(
|
reg = xgb.XGBRegressor(
|
||||||
tree_method="hist", enable_categorical=True, max_cat_to_onehot=5, device="cuda"
|
tree_method="hist", enable_categorical=True, max_cat_to_onehot=5, device="cuda"
|
||||||
)
|
)
|
||||||
|
|||||||
@ -137,7 +137,7 @@ To build and run C++ unit tests enable tests while running CMake:
|
|||||||
./testxgboost
|
./testxgboost
|
||||||
|
|
||||||
Flags like ``USE_CUDA``, ``USE_DMLC_GTEST`` are optional. For more info about how to build
|
Flags like ``USE_CUDA``, ``USE_DMLC_GTEST`` are optional. For more info about how to build
|
||||||
XGBoost from source, see :doc:`</build>`. One can also run all unit test using ctest tool
|
XGBoost from source, see :doc:`/build`. One can also run all unit tests using ctest tool
|
||||||
which provides higher flexibility. For example:
|
which provides higher flexibility. For example:
|
||||||
|
|
||||||
.. code-block:: bash
|
.. code-block:: bash
|
||||||
|
|||||||
@ -94,11 +94,11 @@ Using native interface
|
|||||||
**********************
|
**********************
|
||||||
|
|
||||||
The ``scikit-learn`` interface is user friendly, but lacks some features that are only
|
The ``scikit-learn`` interface is user friendly, but lacks some features that are only
|
||||||
available in native interface. For instance users cannot compute SHAP value directly or
|
available in native interface. For instance users cannot compute SHAP value directly.
|
||||||
use quantized :class:`DMatrix <xgboost.DMatrix>`. Also native interface supports data
|
Also native interface supports more data types. To use the native interface with
|
||||||
types other than dataframe, like ``numpy/cupy array``. To use the native interface with
|
categorical data, we need to pass the similar parameter to :class:`~xgboost.DMatrix` or
|
||||||
categorical data, we need to pass the similar parameter to :class:`DMatrix
|
:py:class:`~xgboost.QuantileDMatrix` and the :func:`train <xgboost.train>` function. For
|
||||||
<xgboost.DMatrix>` and the :func:`train <xgboost.train>` function. For dataframe input:
|
dataframe input:
|
||||||
|
|
||||||
.. code:: python
|
.. code:: python
|
||||||
|
|
||||||
@ -117,7 +117,6 @@ SHAP value computation:
|
|||||||
# categorical features are listed as "c"
|
# categorical features are listed as "c"
|
||||||
print(booster.feature_types)
|
print(booster.feature_types)
|
||||||
|
|
||||||
|
|
||||||
For other types of input, like ``numpy array``, we can tell XGBoost about the feature
|
For other types of input, like ``numpy array``, we can tell XGBoost about the feature
|
||||||
types by using the ``feature_types`` parameter in :class:`DMatrix <xgboost.DMatrix>`:
|
types by using the ``feature_types`` parameter in :class:`DMatrix <xgboost.DMatrix>`:
|
||||||
|
|
||||||
@ -131,7 +130,31 @@ types by using the ``feature_types`` parameter in :class:`DMatrix <xgboost.DMatr
|
|||||||
|
|
||||||
For numerical data, the feature type can be ``"q"`` or ``"float"``, while for categorical
|
For numerical data, the feature type can be ``"q"`` or ``"float"``, while for categorical
|
||||||
feature it's specified as ``"c"``. The Dask module in XGBoost has the same interface so
|
feature it's specified as ``"c"``. The Dask module in XGBoost has the same interface so
|
||||||
:class:`dask.Array <dask.Array>` can also be used for categorical data.
|
:class:`dask.Array <dask.Array>` can also be used for categorical data. Lastly, the
|
||||||
|
sklearn interface :py:class:`~xgboost.XGBRegressor` has the same parameter.
|
||||||
|
|
||||||
|
****************
|
||||||
|
Data Consistency
|
||||||
|
****************
|
||||||
|
|
||||||
|
XGBoost accepts parameters to indicate which feature is considered categorical, either through the ``dtypes`` of a dataframe or through the ``feature_types`` parameter. However, XGBoost by itself doesn't store information on how categories are encoded in the first place. For instance, given an encoding schema that maps music genres to integer codes:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
{"acoustic": 0, "indie": 1, "blues": 2, "country": 3}
|
||||||
|
|
||||||
|
XGBoost doesn't know this mapping from the input and hence cannot store it in the model. The mapping usually happens in the users' data engineering pipeline with column transformers like :py:class:`sklearn.preprocessing.OrdinalEncoder`. To make sure correct result from XGBoost, users need to keep the pipeline for transforming data consistent across training and testing data. One should watch out for errors like:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
X_train["genre"] = X_train["genre"].astype("category")
|
||||||
|
reg = xgb.XGBRegressor(enable_categorical=True).fit(X_train, y_train)
|
||||||
|
|
||||||
|
# invalid encoding
|
||||||
|
X_test["genre"] = X_test["genre"].astype("category")
|
||||||
|
reg.predict(X_test)
|
||||||
|
|
||||||
|
In the above snippet, training data and test data are encoded separately, resulting in two different encoding schemas and invalid prediction result. See :ref:`sphx_glr_python_examples_cat_pipeline.py` for a worked example using ordinal encoder.
|
||||||
|
|
||||||
*************
|
*************
|
||||||
Miscellaneous
|
Miscellaneous
|
||||||
|
|||||||
@ -821,7 +821,7 @@ class DirectoryExcursion:
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, path: os.PathLike, cleanup: bool = False):
|
def __init__(self, path: Union[os.PathLike, str], cleanup: bool = False):
|
||||||
self.path = path
|
self.path = path
|
||||||
self.curdir = os.path.normpath(os.path.abspath(os.path.curdir))
|
self.curdir = os.path.normpath(os.path.abspath(os.path.curdir))
|
||||||
self.cleanup = cleanup
|
self.cleanup = cleanup
|
||||||
|
|||||||
@ -21,6 +21,7 @@ class LintersPaths:
|
|||||||
"tests/python/test_data_iterator.py",
|
"tests/python/test_data_iterator.py",
|
||||||
"tests/python/test_dmatrix.py",
|
"tests/python/test_dmatrix.py",
|
||||||
"tests/python/test_dt.py",
|
"tests/python/test_dt.py",
|
||||||
|
"tests/python/test_demos.py",
|
||||||
"tests/python/test_predict.py",
|
"tests/python/test_predict.py",
|
||||||
"tests/python/test_quantile_dmatrix.py",
|
"tests/python/test_quantile_dmatrix.py",
|
||||||
"tests/python/test_tree_regularization.py",
|
"tests/python/test_tree_regularization.py",
|
||||||
@ -41,6 +42,7 @@ class LintersPaths:
|
|||||||
"demo/guide-python/cat_in_the_dat.py",
|
"demo/guide-python/cat_in_the_dat.py",
|
||||||
"demo/guide-python/callbacks.py",
|
"demo/guide-python/callbacks.py",
|
||||||
"demo/guide-python/categorical.py",
|
"demo/guide-python/categorical.py",
|
||||||
|
"demo/guide-python/cat_pipeline.py",
|
||||||
"demo/guide-python/feature_weights.py",
|
"demo/guide-python/feature_weights.py",
|
||||||
"demo/guide-python/sklearn_parallel.py",
|
"demo/guide-python/sklearn_parallel.py",
|
||||||
"demo/guide-python/spark_estimator_examples.py",
|
"demo/guide-python/spark_estimator_examples.py",
|
||||||
@ -79,6 +81,7 @@ class LintersPaths:
|
|||||||
"python-package/",
|
"python-package/",
|
||||||
# tests
|
# tests
|
||||||
"tests/python/test_dt.py",
|
"tests/python/test_dt.py",
|
||||||
|
"tests/python/test_demos.py",
|
||||||
"tests/python/test_data_iterator.py",
|
"tests/python/test_data_iterator.py",
|
||||||
"tests/python-gpu/test_gpu_data_iterator.py",
|
"tests/python-gpu/test_gpu_data_iterator.py",
|
||||||
"tests/python-gpu/load_pickle.py",
|
"tests/python-gpu/load_pickle.py",
|
||||||
@ -89,6 +92,8 @@ class LintersPaths:
|
|||||||
"demo/json-model/json_parser.py",
|
"demo/json-model/json_parser.py",
|
||||||
"demo/guide-python/external_memory.py",
|
"demo/guide-python/external_memory.py",
|
||||||
"demo/guide-python/cat_in_the_dat.py",
|
"demo/guide-python/cat_in_the_dat.py",
|
||||||
|
"demo/guide-python/categorical.py",
|
||||||
|
"demo/guide-python/cat_pipeline.py",
|
||||||
"demo/guide-python/feature_weights.py",
|
"demo/guide-python/feature_weights.py",
|
||||||
"demo/guide-python/individual_trees.py",
|
"demo/guide-python/individual_trees.py",
|
||||||
"demo/guide-python/quantile_regression.py",
|
"demo/guide-python/quantile_regression.py",
|
||||||
|
|||||||
@ -11,129 +11,143 @@ from xgboost import testing as tm
|
|||||||
pytestmark = tm.timeout(30)
|
pytestmark = tm.timeout(30)
|
||||||
|
|
||||||
DEMO_DIR = tm.demo_dir(__file__)
|
DEMO_DIR = tm.demo_dir(__file__)
|
||||||
PYTHON_DEMO_DIR = os.path.join(DEMO_DIR, 'guide-python')
|
PYTHON_DEMO_DIR = os.path.join(DEMO_DIR, "guide-python")
|
||||||
CLI_DEMO_DIR = os.path.join(DEMO_DIR, 'CLI')
|
CLI_DEMO_DIR = os.path.join(DEMO_DIR, "CLI")
|
||||||
|
|
||||||
|
|
||||||
def test_basic_walkthrough():
|
def test_basic_walkthrough() -> None:
|
||||||
script = os.path.join(PYTHON_DEMO_DIR, 'basic_walkthrough.py')
|
script = os.path.join(PYTHON_DEMO_DIR, "basic_walkthrough.py")
|
||||||
cmd = ['python', script]
|
cmd = ["python", script]
|
||||||
with tempfile.TemporaryDirectory() as tmpdir:
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
subprocess.check_call(cmd, cwd=tmpdir)
|
subprocess.check_call(cmd, cwd=tmpdir)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(**tm.no_matplotlib())
|
@pytest.mark.skipif(**tm.no_pandas())
|
||||||
def test_custom_multiclass_objective():
|
def test_categorical() -> None:
|
||||||
script = os.path.join(PYTHON_DEMO_DIR, 'custom_softmax.py')
|
script = os.path.join(PYTHON_DEMO_DIR, "categorical.py")
|
||||||
cmd = ['python', script, '--plot=0']
|
cmd = ["python", script]
|
||||||
|
subprocess.check_call(cmd)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(**tm.no_pandas())
|
||||||
|
def test_cat_pipeline() -> None:
|
||||||
|
script = os.path.join(PYTHON_DEMO_DIR, "cat_pipeline.py")
|
||||||
|
cmd = ["python", script]
|
||||||
subprocess.check_call(cmd)
|
subprocess.check_call(cmd)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(**tm.no_matplotlib())
|
@pytest.mark.skipif(**tm.no_matplotlib())
|
||||||
def test_custom_rmsle_objective():
|
def test_custom_multiclass_objective() -> None:
|
||||||
script = os.path.join(PYTHON_DEMO_DIR, 'custom_rmsle.py')
|
script = os.path.join(PYTHON_DEMO_DIR, "custom_softmax.py")
|
||||||
cmd = ['python', script, '--plot=0']
|
cmd = ["python", script, "--plot=0"]
|
||||||
subprocess.check_call(cmd)
|
subprocess.check_call(cmd)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(**tm.no_matplotlib())
|
@pytest.mark.skipif(**tm.no_matplotlib())
|
||||||
def test_feature_weights_demo():
|
def test_custom_rmsle_objective() -> None:
|
||||||
script = os.path.join(PYTHON_DEMO_DIR, 'feature_weights.py')
|
script = os.path.join(PYTHON_DEMO_DIR, "custom_rmsle.py")
|
||||||
cmd = ['python', script, '--plot=0']
|
cmd = ["python", script, "--plot=0"]
|
||||||
|
subprocess.check_call(cmd)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(**tm.no_matplotlib())
|
||||||
|
def test_feature_weights_demo() -> None:
|
||||||
|
script = os.path.join(PYTHON_DEMO_DIR, "feature_weights.py")
|
||||||
|
cmd = ["python", script, "--plot=0"]
|
||||||
subprocess.check_call(cmd)
|
subprocess.check_call(cmd)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(**tm.no_sklearn())
|
@pytest.mark.skipif(**tm.no_sklearn())
|
||||||
def test_sklearn_demo():
|
def test_sklearn_demo() -> None:
|
||||||
script = os.path.join(PYTHON_DEMO_DIR, 'sklearn_examples.py')
|
script = os.path.join(PYTHON_DEMO_DIR, "sklearn_examples.py")
|
||||||
cmd = ['python', script]
|
cmd = ["python", script]
|
||||||
subprocess.check_call(cmd)
|
subprocess.check_call(cmd)
|
||||||
assert os.path.exists('best_calif.pkl')
|
assert os.path.exists("best_calif.pkl")
|
||||||
os.remove('best_calif.pkl')
|
os.remove("best_calif.pkl")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(**tm.no_sklearn())
|
@pytest.mark.skipif(**tm.no_sklearn())
|
||||||
def test_sklearn_parallel_demo():
|
def test_sklearn_parallel_demo() -> None:
|
||||||
script = os.path.join(PYTHON_DEMO_DIR, 'sklearn_parallel.py')
|
script = os.path.join(PYTHON_DEMO_DIR, "sklearn_parallel.py")
|
||||||
cmd = ['python', script]
|
cmd = ["python", script]
|
||||||
subprocess.check_call(cmd)
|
subprocess.check_call(cmd)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(**tm.no_sklearn())
|
@pytest.mark.skipif(**tm.no_sklearn())
|
||||||
def test_sklearn_evals_result_demo():
|
def test_sklearn_evals_result_demo() -> None:
|
||||||
script = os.path.join(PYTHON_DEMO_DIR, 'sklearn_evals_result.py')
|
script = os.path.join(PYTHON_DEMO_DIR, "sklearn_evals_result.py")
|
||||||
cmd = ['python', script]
|
cmd = ["python", script]
|
||||||
subprocess.check_call(cmd)
|
subprocess.check_call(cmd)
|
||||||
|
|
||||||
|
|
||||||
def test_boost_from_prediction_demo():
|
def test_boost_from_prediction_demo() -> None:
|
||||||
script = os.path.join(PYTHON_DEMO_DIR, 'boost_from_prediction.py')
|
script = os.path.join(PYTHON_DEMO_DIR, "boost_from_prediction.py")
|
||||||
cmd = ['python', script]
|
cmd = ["python", script]
|
||||||
subprocess.check_call(cmd)
|
subprocess.check_call(cmd)
|
||||||
|
|
||||||
|
|
||||||
def test_predict_first_ntree_demo():
|
def test_predict_first_ntree_demo() -> None:
|
||||||
script = os.path.join(PYTHON_DEMO_DIR, 'predict_first_ntree.py')
|
script = os.path.join(PYTHON_DEMO_DIR, "predict_first_ntree.py")
|
||||||
cmd = ['python', script]
|
cmd = ["python", script]
|
||||||
subprocess.check_call(cmd)
|
subprocess.check_call(cmd)
|
||||||
|
|
||||||
|
|
||||||
def test_individual_trees():
|
def test_individual_trees() -> None:
|
||||||
script = os.path.join(PYTHON_DEMO_DIR, 'individual_trees.py')
|
script = os.path.join(PYTHON_DEMO_DIR, "individual_trees.py")
|
||||||
cmd = ['python', script]
|
cmd = ["python", script]
|
||||||
subprocess.check_call(cmd)
|
subprocess.check_call(cmd)
|
||||||
|
|
||||||
|
|
||||||
def test_predict_leaf_indices_demo():
|
def test_predict_leaf_indices_demo() -> None:
|
||||||
script = os.path.join(PYTHON_DEMO_DIR, 'predict_leaf_indices.py')
|
script = os.path.join(PYTHON_DEMO_DIR, "predict_leaf_indices.py")
|
||||||
cmd = ['python', script]
|
cmd = ["python", script]
|
||||||
subprocess.check_call(cmd)
|
subprocess.check_call(cmd)
|
||||||
|
|
||||||
|
|
||||||
def test_generalized_linear_model_demo():
|
def test_generalized_linear_model_demo() -> None:
|
||||||
script = os.path.join(PYTHON_DEMO_DIR, 'generalized_linear_model.py')
|
script = os.path.join(PYTHON_DEMO_DIR, "generalized_linear_model.py")
|
||||||
cmd = ['python', script]
|
cmd = ["python", script]
|
||||||
subprocess.check_call(cmd)
|
subprocess.check_call(cmd)
|
||||||
|
|
||||||
|
|
||||||
def test_cross_validation_demo():
|
def test_cross_validation_demo() -> None:
|
||||||
script = os.path.join(PYTHON_DEMO_DIR, 'cross_validation.py')
|
script = os.path.join(PYTHON_DEMO_DIR, "cross_validation.py")
|
||||||
cmd = ['python', script]
|
cmd = ["python", script]
|
||||||
subprocess.check_call(cmd)
|
subprocess.check_call(cmd)
|
||||||
|
|
||||||
|
|
||||||
def test_external_memory_demo():
|
def test_external_memory_demo() -> None:
|
||||||
script = os.path.join(PYTHON_DEMO_DIR, 'external_memory.py')
|
script = os.path.join(PYTHON_DEMO_DIR, "external_memory.py")
|
||||||
cmd = ['python', script]
|
cmd = ["python", script]
|
||||||
subprocess.check_call(cmd)
|
subprocess.check_call(cmd)
|
||||||
|
|
||||||
|
|
||||||
def test_evals_result_demo():
|
def test_evals_result_demo() -> None:
|
||||||
script = os.path.join(PYTHON_DEMO_DIR, 'evals_result.py')
|
script = os.path.join(PYTHON_DEMO_DIR, "evals_result.py")
|
||||||
cmd = ['python', script]
|
cmd = ["python", script]
|
||||||
subprocess.check_call(cmd)
|
subprocess.check_call(cmd)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(**tm.no_sklearn())
|
@pytest.mark.skipif(**tm.no_sklearn())
|
||||||
@pytest.mark.skipif(**tm.no_pandas())
|
@pytest.mark.skipif(**tm.no_pandas())
|
||||||
def test_aft_demo():
|
def test_aft_demo() -> None:
|
||||||
script = os.path.join(DEMO_DIR, 'aft_survival', 'aft_survival_demo.py')
|
script = os.path.join(DEMO_DIR, "aft_survival", "aft_survival_demo.py")
|
||||||
cmd = ['python', script]
|
cmd = ["python", script]
|
||||||
subprocess.check_call(cmd)
|
subprocess.check_call(cmd)
|
||||||
assert os.path.exists('aft_model.json')
|
assert os.path.exists("aft_model.json")
|
||||||
os.remove('aft_model.json')
|
os.remove("aft_model.json")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(**tm.no_matplotlib())
|
@pytest.mark.skipif(**tm.no_matplotlib())
|
||||||
def test_callbacks_demo():
|
def test_callbacks_demo() -> None:
|
||||||
script = os.path.join(PYTHON_DEMO_DIR, 'callbacks.py')
|
script = os.path.join(PYTHON_DEMO_DIR, "callbacks.py")
|
||||||
cmd = ['python', script, '--plot=0']
|
cmd = ["python", script, "--plot=0"]
|
||||||
subprocess.check_call(cmd)
|
subprocess.check_call(cmd)
|
||||||
|
|
||||||
|
|
||||||
def test_continuation_demo():
|
def test_continuation_demo() -> None:
|
||||||
script = os.path.join(PYTHON_DEMO_DIR, 'continuation.py')
|
script = os.path.join(PYTHON_DEMO_DIR, "continuation.py")
|
||||||
cmd = ['python', script]
|
cmd = ["python", script]
|
||||||
subprocess.check_call(cmd)
|
subprocess.check_call(cmd)
|
||||||
|
|
||||||
|
|
||||||
@ -141,14 +155,14 @@ def test_continuation_demo():
|
|||||||
@pytest.mark.skipif(**tm.no_matplotlib())
|
@pytest.mark.skipif(**tm.no_matplotlib())
|
||||||
def test_multioutput_reg() -> None:
|
def test_multioutput_reg() -> None:
|
||||||
script = os.path.join(PYTHON_DEMO_DIR, "multioutput_regression.py")
|
script = os.path.join(PYTHON_DEMO_DIR, "multioutput_regression.py")
|
||||||
cmd = ['python', script, "--plot=0"]
|
cmd = ["python", script, "--plot=0"]
|
||||||
subprocess.check_call(cmd)
|
subprocess.check_call(cmd)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(**tm.no_sklearn())
|
@pytest.mark.skipif(**tm.no_sklearn())
|
||||||
def test_quantile_reg() -> None:
|
def test_quantile_reg() -> None:
|
||||||
script = os.path.join(PYTHON_DEMO_DIR, "quantile_regression.py")
|
script = os.path.join(PYTHON_DEMO_DIR, "quantile_regression.py")
|
||||||
cmd = ['python', script]
|
cmd = ["python", script]
|
||||||
subprocess.check_call(cmd)
|
subprocess.check_call(cmd)
|
||||||
|
|
||||||
|
|
||||||
@ -197,28 +211,30 @@ def test_json_model() -> None:
|
|||||||
# - aft tunning is not tested due to extra dependency.
|
# - aft tunning is not tested due to extra dependency.
|
||||||
|
|
||||||
|
|
||||||
def test_cli_regression_demo():
|
def test_cli_regression_demo() -> None:
|
||||||
reg_dir = os.path.join(CLI_DEMO_DIR, 'regression')
|
reg_dir = os.path.join(CLI_DEMO_DIR, "regression")
|
||||||
script = os.path.join(reg_dir, 'mapfeat.py')
|
script = os.path.join(reg_dir, "mapfeat.py")
|
||||||
cmd = ['python', script]
|
cmd = ["python", script]
|
||||||
subprocess.check_call(cmd, cwd=reg_dir)
|
subprocess.check_call(cmd, cwd=reg_dir)
|
||||||
|
|
||||||
script = os.path.join(reg_dir, 'mknfold.py')
|
script = os.path.join(reg_dir, "mknfold.py")
|
||||||
cmd = ['python', script, 'machine.txt', '1']
|
cmd = ["python", script, "machine.txt", "1"]
|
||||||
subprocess.check_call(cmd, cwd=reg_dir)
|
subprocess.check_call(cmd, cwd=reg_dir)
|
||||||
|
|
||||||
exe = os.path.join(DEMO_DIR, os.path.pardir, 'xgboost')
|
exe = os.path.join(DEMO_DIR, os.path.pardir, "xgboost")
|
||||||
conf = os.path.join(reg_dir, 'machine.conf')
|
conf = os.path.join(reg_dir, "machine.conf")
|
||||||
subprocess.check_call([exe, conf], cwd=reg_dir)
|
subprocess.check_call([exe, conf], cwd=reg_dir)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(condition=sys.platform.startswith("win"),
|
@pytest.mark.skipif(
|
||||||
reason='Test requires sh execution.')
|
condition=sys.platform.startswith("win"), reason="Test requires sh execution."
|
||||||
def test_cli_binary_classification():
|
)
|
||||||
cls_dir = os.path.join(CLI_DEMO_DIR, 'binary_classification')
|
def test_cli_binary_classification() -> None:
|
||||||
|
cls_dir = os.path.join(CLI_DEMO_DIR, "binary_classification")
|
||||||
with tm.DirectoryExcursion(cls_dir, cleanup=True):
|
with tm.DirectoryExcursion(cls_dir, cleanup=True):
|
||||||
subprocess.check_call(['./runexp.sh'])
|
subprocess.check_call(["./runexp.sh"])
|
||||||
os.remove('0002.model')
|
os.remove("0002.model")
|
||||||
|
|
||||||
|
|
||||||
# year prediction is not tested due to data size being too large.
|
# year prediction is not tested due to data size being too large.
|
||||||
# rank is not tested as it requires unrar command.
|
# rank is not tested as it requires unrar command.
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user