Demo for using custom objective with multi-target regression. (#7736)

This commit is contained in:
Jiaming Yuan 2022-03-20 17:44:25 +08:00 committed by GitHub
parent 996cc705af
commit cd55823112
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -8,43 +8,104 @@ https://scikit-learn.org/stable/auto_examples/ensemble/plot_random_forest_regres
See :doc:`/tutorials/multioutput` for more information. See :doc:`/tutorials/multioutput` for more information.
""" """
import numpy as np
import xgboost as xgb
import argparse import argparse
from typing import Dict, Tuple, List
import numpy as np
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
import xgboost as xgb
def plot_predt(y, y_predt, name): def plot_predt(y: np.ndarray, y_predt: np.ndarray, name: str) -> None:
s = 25 s = 25
plt.scatter(y[:, 0], y[:, 1], c="navy", s=s, plt.scatter(y[:, 0], y[:, 1], c="navy", s=s, edgecolor="black", label="data")
edgecolor="black", label="data") plt.scatter(
plt.scatter(y_predt[:, 0], y_predt[:, 1], c="cornflowerblue", s=s, y_predt[:, 0], y_predt[:, 1], c="cornflowerblue", s=s, edgecolor="black"
edgecolor="black") )
plt.xlim([-1, 2]) plt.xlim([-1, 2])
plt.ylim([-1, 2]) plt.ylim([-1, 2])
plt.show() plt.show()
def main(plot_result: bool): def gen_circle() -> Tuple[np.ndarray, np.ndarray]:
"""Draw a circle with 2-dim coordinate as target variables.""" "Generate a sample dataset that y is a 2 dim circle."
rng = np.random.RandomState(1994) rng = np.random.RandomState(1994)
X = np.sort(200 * rng.rand(100, 1) - 100, axis=0) X = np.sort(200 * rng.rand(100, 1) - 100, axis=0)
y = np.array([np.pi * np.sin(X).ravel(), np.pi * np.cos(X).ravel()]).T y = np.array([np.pi * np.sin(X).ravel(), np.pi * np.cos(X).ravel()]).T
y[::5, :] += (0.5 - rng.rand(20, 2)) y[::5, :] += 0.5 - rng.rand(20, 2)
y = y - y.min() y = y - y.min()
y = y / y.max() y = y / y.max()
return X, y
def rmse_model(plot_result: bool):
"""Draw a circle with 2-dim coordinate as target variables."""
X, y = gen_circle()
# Train a regressor on it # Train a regressor on it
reg = xgb.XGBRegressor(tree_method="hist", n_estimators=64) reg = xgb.XGBRegressor(tree_method="hist", n_estimators=64)
reg.fit(X, y, eval_set=[(X, y)]) reg.fit(X, y, eval_set=[(X, y)])
y_predt = reg.predict(X) y_predt = reg.predict(X)
if plot_result: if plot_result:
plot_predt(y, y_predt, 'multi') plot_predt(y, y_predt, "multi")
def custom_rmse_model(plot_result: bool) -> None:
"""Train using Python implementation of Squared Error."""
# As the experimental support status, custom objective doesn't support matrix as
# gradient and hessian, which will be changed in future release.
def gradient(predt: np.ndarray, dtrain: xgb.DMatrix) -> np.ndarray:
"""Compute the gradient squared error."""
y = dtrain.get_label().reshape(predt.shape)
return (predt - y).reshape(y.size)
def hessian(predt: np.ndarray, dtrain: xgb.DMatrix) -> np.ndarray:
"""Compute the hessian for squared error."""
return np.ones(predt.shape).reshape(predt.size)
def squared_log(
predt: np.ndarray, dtrain: xgb.DMatrix
) -> Tuple[np.ndarray, np.ndarray]:
grad = gradient(predt, dtrain)
hess = hessian(predt, dtrain)
return grad, hess
def rmse(predt: np.ndarray, dtrain: xgb.DMatrix) -> Tuple[str, float]:
y = dtrain.get_label().reshape(predt.shape)
v = np.sqrt(np.sum(np.power(y - predt, 2)))
return "PyRMSE", v
X, y = gen_circle()
Xy = xgb.DMatrix(X, y)
results: Dict[str, Dict[str, List[float]]] = {}
# Make sure the `num_target` is passed to XGBoost when custom objective is used.
# When builtin objective is used, XGBoost can figure out the number of targets
# automatically.
booster = xgb.train(
{
"tree_method": "hist",
"num_target": y.shape[1],
},
dtrain=Xy,
num_boost_round=100,
obj=squared_log,
evals=[(Xy, "Train")],
evals_result=results,
custom_metric=rmse,
)
y_predt = booster.inplace_predict(X)
if plot_result:
plot_predt(y, y_predt, "multi")
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--plot", choices=[0, 1], type=int, default=1) parser.add_argument("--plot", choices=[0, 1], type=int, default=1)
args = parser.parse_args() args = parser.parse_args()
main(args.plot == 1) # Train with builtin RMSE objective
rmse_model(args.plot == 1)
# Train with custom objective.
custom_rmse_model(args.plot == 1)