Initial support for quantile loss. (#8750)
- Add support for Python. - Add objective.
This commit is contained in:
124
demo/guide-python/quantile_regression.py
Normal file
124
demo/guide-python/quantile_regression.py
Normal file
@@ -0,0 +1,124 @@
|
||||
"""
|
||||
Quantile Regression
|
||||
===================
|
||||
|
||||
The script is inspired by this awesome example in sklearn:
|
||||
https://scikit-learn.org/stable/auto_examples/ensemble/plot_gradient_boosting_quantile.html
|
||||
|
||||
"""
|
||||
import argparse
|
||||
from typing import Dict
|
||||
|
||||
import numpy as np
|
||||
from sklearn.model_selection import train_test_split
|
||||
|
||||
import xgboost as xgb
|
||||
|
||||
|
||||
def f(x: np.ndarray) -> np.ndarray:
|
||||
"""The function to predict."""
|
||||
return x * np.sin(x)
|
||||
|
||||
|
||||
def quantile_loss(args: argparse.Namespace) -> None:
|
||||
"""Train a quantile regression model."""
|
||||
rng = np.random.RandomState(1994)
|
||||
# Generate a synthetic dataset for demo, the generate process is from the sklearn
|
||||
# example.
|
||||
X = np.atleast_2d(rng.uniform(0, 10.0, size=1000)).T
|
||||
expected_y = f(X).ravel()
|
||||
|
||||
sigma = 0.5 + X.ravel() / 10.0
|
||||
noise = rng.lognormal(sigma=sigma) - np.exp(sigma**2.0 / 2.0)
|
||||
y = expected_y + noise
|
||||
|
||||
# Train on 0.05 and 0.95 quantiles. The model is similar to multi-class and
|
||||
# multi-target models.
|
||||
alpha = np.array([0.05, 0.5, 0.95])
|
||||
evals_result: Dict[str, Dict] = {}
|
||||
|
||||
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=rng)
|
||||
# We will be using the `hist` tree method, quantile DMatrix can be used to preserve
|
||||
# memory.
|
||||
# Do not use the `exact` tree method for quantile regression, otherwise the
|
||||
# performance might drop.
|
||||
Xy = xgb.QuantileDMatrix(X, y)
|
||||
# use Xy as a reference
|
||||
Xy_test = xgb.QuantileDMatrix(X_test, y_test, ref=Xy)
|
||||
|
||||
booster = xgb.train(
|
||||
{
|
||||
# Use the quantile objective function.
|
||||
"objective": "reg:quantileerror",
|
||||
"tree_method": "hist",
|
||||
"quantile_alpha": alpha,
|
||||
# Let's try not to overfit.
|
||||
"learning_rate": 0.01,
|
||||
"max_depth": 3,
|
||||
"min_child_weight": 16.0,
|
||||
},
|
||||
Xy,
|
||||
num_boost_round=32,
|
||||
early_stopping_rounds=2,
|
||||
# The evaluation result is a weighted average across multiple quantiles.
|
||||
evals=[(Xy, "Train"), (Xy_test, "Test")],
|
||||
evals_result=evals_result,
|
||||
)
|
||||
xx = np.atleast_2d(np.linspace(0, 10, 1000)).T
|
||||
scores = booster.inplace_predict(xx)
|
||||
# dim 1 is the quantiles
|
||||
assert scores.shape[0] == xx.shape[0]
|
||||
assert scores.shape[1] == alpha.shape[0]
|
||||
|
||||
y_lower = scores[:, 0] # alpha=0.05
|
||||
y_med = scores[:, 1] # alpha=0.5, median
|
||||
y_upper = scores[:, 2] # alpha=0.95
|
||||
|
||||
# Train a mse model for comparison
|
||||
booster = xgb.train(
|
||||
{
|
||||
"objective": "reg:squarederror",
|
||||
"tree_method": "hist",
|
||||
# Let's try not to overfit.
|
||||
"learning_rate": 0.01,
|
||||
"max_depth": 3,
|
||||
"min_child_weight": 16.0,
|
||||
},
|
||||
Xy,
|
||||
num_boost_round=32,
|
||||
early_stopping_rounds=2,
|
||||
evals=[(Xy, "Train"), (Xy_test, "Test")],
|
||||
evals_result=evals_result,
|
||||
)
|
||||
xx = np.atleast_2d(np.linspace(0, 10, 1000)).T
|
||||
y_pred = booster.inplace_predict(xx)
|
||||
|
||||
if args.plot:
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
fig = plt.figure(figsize=(10, 10))
|
||||
plt.plot(xx, f(xx), "g:", linewidth=3, label=r"$f(x) = x\,\sin(x)$")
|
||||
plt.plot(X_test, y_test, "b.", markersize=10, label="Test observations")
|
||||
plt.plot(xx, y_med, "r-", label="Predicted median")
|
||||
plt.plot(xx, y_pred, "m-", label="Predicted mean")
|
||||
plt.plot(xx, y_upper, "k-")
|
||||
plt.plot(xx, y_lower, "k-")
|
||||
plt.fill_between(
|
||||
xx.ravel(), y_lower, y_upper, alpha=0.4, label="Predicted 90% interval"
|
||||
)
|
||||
plt.xlabel("$x$")
|
||||
plt.ylabel("$f(x)$")
|
||||
plt.ylim(-10, 25)
|
||||
plt.legend(loc="upper left")
|
||||
plt.show()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--plot",
|
||||
action="store_true",
|
||||
help="Specify it to enable plotting the outputs.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
quantile_loss(args)
|
||||
Reference in New Issue
Block a user