Fix callback in AFT viz demo. (#9333)

* Fix callback in AFT viz demo.

- Update the callback function.
- Add lint check.
This commit is contained in:
Jiaming Yuan 2023-06-26 22:35:02 +08:00 committed by GitHub
parent 6efe7c129f
commit cfa9c42eb4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 96 additions and 49 deletions

View File

@ -11,33 +11,43 @@ import numpy as np
import xgboost as xgb import xgboost as xgb
plt.rcParams.update({'font.size': 13}) plt.rcParams.update({"font.size": 13})
# Function to visualize censored labels # Function to visualize censored labels
def plot_censored_labels(X, y_lower, y_upper): def plot_censored_labels(
def replace_inf(x, target_value): X: np.ndarray, y_lower: np.ndarray, y_upper: np.ndarray
) -> None:
def replace_inf(x: np.ndarray, target_value: float) -> np.ndarray:
x[np.isinf(x)] = target_value x[np.isinf(x)] = target_value
return x return x
plt.plot(X, y_lower, 'o', label='y_lower', color='blue')
plt.plot(X, y_upper, 'o', label='y_upper', color='fuchsia') plt.plot(X, y_lower, "o", label="y_lower", color="blue")
plt.vlines(X, ymin=replace_inf(y_lower, 0.01), ymax=replace_inf(y_upper, 1000), plt.plot(X, y_upper, "o", label="y_upper", color="fuchsia")
label='Range for y', color='gray') plt.vlines(
X,
ymin=replace_inf(y_lower, 0.01),
ymax=replace_inf(y_upper, 1000.0),
label="Range for y",
color="gray",
)
# Toy data # Toy data
X = np.array([1, 2, 3, 4, 5]).reshape((-1, 1)) X = np.array([1, 2, 3, 4, 5]).reshape((-1, 1))
INF = np.inf INF = np.inf
y_lower = np.array([ 10, 15, -INF, 30, 100]) y_lower = np.array([10, 15, -INF, 30, 100])
y_upper = np.array([INF, INF, 20, 50, INF]) y_upper = np.array([INF, INF, 20, 50, INF])
# Visualize toy data # Visualize toy data
plt.figure(figsize=(5, 4)) plt.figure(figsize=(5, 4))
plot_censored_labels(X, y_lower, y_upper) plot_censored_labels(X, y_lower, y_upper)
plt.ylim((6, 200)) plt.ylim((6, 200))
plt.legend(loc='lower right') plt.legend(loc="lower right")
plt.title('Toy data') plt.title("Toy data")
plt.xlabel('Input feature') plt.xlabel("Input feature")
plt.ylabel('Label') plt.ylabel("Label")
plt.yscale('log') plt.yscale("log")
plt.tight_layout() plt.tight_layout()
plt.show(block=True) plt.show(block=True)
@ -46,54 +56,83 @@ grid_pts = np.linspace(0.8, 5.2, 1000).reshape((-1, 1))
# Train AFT model using XGBoost # Train AFT model using XGBoost
dmat = xgb.DMatrix(X) dmat = xgb.DMatrix(X)
dmat.set_float_info('label_lower_bound', y_lower) dmat.set_float_info("label_lower_bound", y_lower)
dmat.set_float_info('label_upper_bound', y_upper) dmat.set_float_info("label_upper_bound", y_upper)
params = {'max_depth': 3, 'objective':'survival:aft', 'min_child_weight': 0} params = {"max_depth": 3, "objective": "survival:aft", "min_child_weight": 0}
accuracy_history = [] accuracy_history = []
def plot_intermediate_model_callback(env):
"""Custom callback to plot intermediate models"""
# Compute y_pred = prediction using the intermediate model, at current boosting iteration class PlotIntermediateModel(xgb.callback.TrainingCallback):
y_pred = env.model.predict(dmat) """Custom callback to plot intermediate models."""
# "Accuracy" = the number of data points whose ranged label (y_lower, y_upper) includes
# the corresponding predicted label (y_pred) def __init__(self) -> None:
acc = np.sum(np.logical_and(y_pred >= y_lower, y_pred <= y_upper)/len(X) * 100) super().__init__()
def after_iteration(
self,
model: xgb.Booster,
epoch: int,
evals_log: xgb.callback.TrainingCallback.EvalsLog,
) -> bool:
"""Run after training is finished."""
# Compute y_pred = prediction using the intermediate model, at current boosting
# iteration
y_pred = model.predict(dmat)
# "Accuracy" = the number of data points whose ranged label (y_lower, y_upper)
# includes the corresponding predicted label (y_pred)
acc = np.sum(
np.logical_and(y_pred >= y_lower, y_pred <= y_upper) / len(X) * 100
)
accuracy_history.append(acc) accuracy_history.append(acc)
# Plot ranged labels as well as predictions by the model # Plot ranged labels as well as predictions by the model
plt.subplot(5, 3, env.iteration + 1) plt.subplot(5, 3, epoch + 1)
plot_censored_labels(X, y_lower, y_upper) plot_censored_labels(X, y_lower, y_upper)
y_pred_grid_pts = env.model.predict(xgb.DMatrix(grid_pts)) y_pred_grid_pts = model.predict(xgb.DMatrix(grid_pts))
plt.plot(grid_pts, y_pred_grid_pts, 'r-', label='XGBoost AFT model', linewidth=4) plt.plot(
plt.title('Iteration {}'.format(env.iteration), x=0.5, y=0.8) grid_pts, y_pred_grid_pts, "r-", label="XGBoost AFT model", linewidth=4
)
plt.title("Iteration {}".format(epoch), x=0.5, y=0.8)
plt.xlim((0.8, 5.2)) plt.xlim((0.8, 5.2))
plt.ylim((1 if np.min(y_pred) < 6 else 6, 200)) plt.ylim((1 if np.min(y_pred) < 6 else 6, 200))
plt.yscale('log') plt.yscale("log")
return False
res = {}
plt.figure(figsize=(12,13)) res: xgb.callback.TrainingCallback.EvalsLog = {}
bst = xgb.train(params, dmat, 15, [(dmat, 'train')], evals_result=res, plt.figure(figsize=(12, 13))
callbacks=[plot_intermediate_model_callback]) bst = xgb.train(
params,
dmat,
15,
[(dmat, "train")],
evals_result=res,
callbacks=[PlotIntermediateModel()],
)
plt.tight_layout() plt.tight_layout()
plt.legend(loc='lower center', ncol=4, plt.legend(
loc="lower center",
ncol=4,
bbox_to_anchor=(0.5, 0), bbox_to_anchor=(0.5, 0),
bbox_transform=plt.gcf().transFigure) bbox_transform=plt.gcf().transFigure,
)
plt.tight_layout() plt.tight_layout()
# Plot negative log likelihood over boosting iterations # Plot negative log likelihood over boosting iterations
plt.figure(figsize=(8,3)) plt.figure(figsize=(8, 3))
plt.subplot(1, 2, 1) plt.subplot(1, 2, 1)
plt.plot(res['train']['aft-nloglik'], 'b-o', label='aft-nloglik') plt.plot(res["train"]["aft-nloglik"], "b-o", label="aft-nloglik")
plt.xlabel('# Boosting Iterations') plt.xlabel("# Boosting Iterations")
plt.legend(loc='best') plt.legend(loc="best")
# Plot "accuracy" over boosting iterations # Plot "accuracy" over boosting iterations
# "Accuracy" = the number of data points whose ranged label (y_lower, y_upper) includes # "Accuracy" = the number of data points whose ranged label (y_lower, y_upper) includes
# the corresponding predicted label (y_pred) # the corresponding predicted label (y_pred)
plt.subplot(1, 2, 2) plt.subplot(1, 2, 2)
plt.plot(accuracy_history, 'r-o', label='Accuracy (%)') plt.plot(accuracy_history, "r-o", label="Accuracy (%)")
plt.xlabel('# Boosting Iterations') plt.xlabel("# Boosting Iterations")
plt.legend(loc='best') plt.legend(loc="best")
plt.tight_layout() plt.tight_layout()
plt.show() plt.show()

View File

@ -37,6 +37,7 @@ class LintersPaths:
"demo/guide-python/quantile_regression.py", "demo/guide-python/quantile_regression.py",
"demo/guide-python/multioutput_regression.py", "demo/guide-python/multioutput_regression.py",
"demo/guide-python/learning_to_rank.py", "demo/guide-python/learning_to_rank.py",
"demo/aft_survival/aft_survival_viz_demo.py",
# CI # CI
"tests/ci_build/lint_python.py", "tests/ci_build/lint_python.py",
"tests/ci_build/test_r_package.py", "tests/ci_build/test_r_package.py",
@ -78,6 +79,7 @@ class LintersPaths:
"demo/guide-python/quantile_regression.py", "demo/guide-python/quantile_regression.py",
"demo/guide-python/multioutput_regression.py", "demo/guide-python/multioutput_regression.py",
"demo/guide-python/learning_to_rank.py", "demo/guide-python/learning_to_rank.py",
"demo/aft_survival/aft_survival_viz_demo.py",
# CI # CI
"tests/ci_build/lint_python.py", "tests/ci_build/lint_python.py",
"tests/ci_build/test_r_package.py", "tests/ci_build/test_r_package.py",
@ -114,7 +116,13 @@ def run_black(rel_path: str, fix: bool) -> bool:
@cd(PY_PACKAGE) @cd(PY_PACKAGE)
def run_isort(rel_path: str, fix: bool) -> bool: def run_isort(rel_path: str, fix: bool) -> bool:
# Isort gets confused when trying to find the config file, so specified explicitly. # Isort gets confused when trying to find the config file, so specified explicitly.
cmd = ["isort", "--settings-path", PY_PACKAGE, os.path.join(ROOT, rel_path)] cmd = [
"isort",
"--settings-path",
PY_PACKAGE,
f"--src={PY_PACKAGE}",
os.path.join(ROOT, rel_path),
]
if not fix: if not fix:
cmd += ["--check"] cmd += ["--check"]