diff --git a/demo/aft_survival/aft_survival_viz_demo.py b/demo/aft_survival/aft_survival_viz_demo.py index a17c55edf..b925ca547 100644 --- a/demo/aft_survival/aft_survival_viz_demo.py +++ b/demo/aft_survival/aft_survival_viz_demo.py @@ -11,33 +11,43 @@ import numpy as np import xgboost as xgb -plt.rcParams.update({'font.size': 13}) +plt.rcParams.update({"font.size": 13}) + # Function to visualize censored labels -def plot_censored_labels(X, y_lower, y_upper): - def replace_inf(x, target_value): +def plot_censored_labels( + X: np.ndarray, y_lower: np.ndarray, y_upper: np.ndarray +) -> None: + def replace_inf(x: np.ndarray, target_value: float) -> np.ndarray: x[np.isinf(x)] = target_value return x - plt.plot(X, y_lower, 'o', label='y_lower', color='blue') - plt.plot(X, y_upper, 'o', label='y_upper', color='fuchsia') - plt.vlines(X, ymin=replace_inf(y_lower, 0.01), ymax=replace_inf(y_upper, 1000), - label='Range for y', color='gray') + + plt.plot(X, y_lower, "o", label="y_lower", color="blue") + plt.plot(X, y_upper, "o", label="y_upper", color="fuchsia") + plt.vlines( + X, + ymin=replace_inf(y_lower, 0.01), + ymax=replace_inf(y_upper, 1000.0), + label="Range for y", + color="gray", + ) + # Toy data X = np.array([1, 2, 3, 4, 5]).reshape((-1, 1)) INF = np.inf -y_lower = np.array([ 10, 15, -INF, 30, 100]) -y_upper = np.array([INF, INF, 20, 50, INF]) +y_lower = np.array([10, 15, -INF, 30, 100]) +y_upper = np.array([INF, INF, 20, 50, INF]) # Visualize toy data plt.figure(figsize=(5, 4)) plot_censored_labels(X, y_lower, y_upper) plt.ylim((6, 200)) -plt.legend(loc='lower right') -plt.title('Toy data') -plt.xlabel('Input feature') -plt.ylabel('Label') -plt.yscale('log') +plt.legend(loc="lower right") +plt.title("Toy data") +plt.xlabel("Input feature") +plt.ylabel("Label") +plt.yscale("log") plt.tight_layout() plt.show(block=True) @@ -46,54 +56,83 @@ grid_pts = np.linspace(0.8, 5.2, 1000).reshape((-1, 1)) # Train AFT model using XGBoost dmat = xgb.DMatrix(X) -dmat.set_float_info('label_lower_bound', y_lower) -dmat.set_float_info('label_upper_bound', y_upper) -params = {'max_depth': 3, 'objective':'survival:aft', 'min_child_weight': 0} +dmat.set_float_info("label_lower_bound", y_lower) +dmat.set_float_info("label_upper_bound", y_upper) +params = {"max_depth": 3, "objective": "survival:aft", "min_child_weight": 0} accuracy_history = [] -def plot_intermediate_model_callback(env): - """Custom callback to plot intermediate models""" - # Compute y_pred = prediction using the intermediate model, at current boosting iteration - y_pred = env.model.predict(dmat) - # "Accuracy" = the number of data points whose ranged label (y_lower, y_upper) includes - # the corresponding predicted label (y_pred) - acc = np.sum(np.logical_and(y_pred >= y_lower, y_pred <= y_upper)/len(X) * 100) - accuracy_history.append(acc) - # Plot ranged labels as well as predictions by the model - plt.subplot(5, 3, env.iteration + 1) - plot_censored_labels(X, y_lower, y_upper) - y_pred_grid_pts = env.model.predict(xgb.DMatrix(grid_pts)) - plt.plot(grid_pts, y_pred_grid_pts, 'r-', label='XGBoost AFT model', linewidth=4) - plt.title('Iteration {}'.format(env.iteration), x=0.5, y=0.8) - plt.xlim((0.8, 5.2)) - plt.ylim((1 if np.min(y_pred) < 6 else 6, 200)) - plt.yscale('log') -res = {} -plt.figure(figsize=(12,13)) -bst = xgb.train(params, dmat, 15, [(dmat, 'train')], evals_result=res, - callbacks=[plot_intermediate_model_callback]) +class PlotIntermediateModel(xgb.callback.TrainingCallback): + """Custom callback to plot intermediate models.""" + + def __init__(self) -> None: + super().__init__() + + def after_iteration( + self, + model: xgb.Booster, + epoch: int, + evals_log: xgb.callback.TrainingCallback.EvalsLog, + ) -> bool: + """Run after training is finished.""" + # Compute y_pred = prediction using the intermediate model, at current boosting + # iteration + y_pred = model.predict(dmat) + # "Accuracy" = the number of data points whose ranged label (y_lower, y_upper) + # includes the corresponding predicted label (y_pred) + acc = np.sum( + np.logical_and(y_pred >= y_lower, y_pred <= y_upper) / len(X) * 100 + ) + accuracy_history.append(acc) + + # Plot ranged labels as well as predictions by the model + plt.subplot(5, 3, epoch + 1) + plot_censored_labels(X, y_lower, y_upper) + y_pred_grid_pts = model.predict(xgb.DMatrix(grid_pts)) + plt.plot( + grid_pts, y_pred_grid_pts, "r-", label="XGBoost AFT model", linewidth=4 + ) + plt.title("Iteration {}".format(epoch), x=0.5, y=0.8) + plt.xlim((0.8, 5.2)) + plt.ylim((1 if np.min(y_pred) < 6 else 6, 200)) + plt.yscale("log") + return False + + +res: xgb.callback.TrainingCallback.EvalsLog = {} +plt.figure(figsize=(12, 13)) +bst = xgb.train( + params, + dmat, + 15, + [(dmat, "train")], + evals_result=res, + callbacks=[PlotIntermediateModel()], +) plt.tight_layout() -plt.legend(loc='lower center', ncol=4, - bbox_to_anchor=(0.5, 0), - bbox_transform=plt.gcf().transFigure) +plt.legend( + loc="lower center", + ncol=4, + bbox_to_anchor=(0.5, 0), + bbox_transform=plt.gcf().transFigure, +) plt.tight_layout() # Plot negative log likelihood over boosting iterations -plt.figure(figsize=(8,3)) +plt.figure(figsize=(8, 3)) plt.subplot(1, 2, 1) -plt.plot(res['train']['aft-nloglik'], 'b-o', label='aft-nloglik') -plt.xlabel('# Boosting Iterations') -plt.legend(loc='best') +plt.plot(res["train"]["aft-nloglik"], "b-o", label="aft-nloglik") +plt.xlabel("# Boosting Iterations") +plt.legend(loc="best") # Plot "accuracy" over boosting iterations # "Accuracy" = the number of data points whose ranged label (y_lower, y_upper) includes # the corresponding predicted label (y_pred) plt.subplot(1, 2, 2) -plt.plot(accuracy_history, 'r-o', label='Accuracy (%)') -plt.xlabel('# Boosting Iterations') -plt.legend(loc='best') +plt.plot(accuracy_history, "r-o", label="Accuracy (%)") +plt.xlabel("# Boosting Iterations") +plt.legend(loc="best") plt.tight_layout() plt.show() diff --git a/tests/ci_build/lint_python.py b/tests/ci_build/lint_python.py index 90c52aad4..85ece676e 100644 --- a/tests/ci_build/lint_python.py +++ b/tests/ci_build/lint_python.py @@ -37,6 +37,7 @@ class LintersPaths: "demo/guide-python/quantile_regression.py", "demo/guide-python/multioutput_regression.py", "demo/guide-python/learning_to_rank.py", + "demo/aft_survival/aft_survival_viz_demo.py", # CI "tests/ci_build/lint_python.py", "tests/ci_build/test_r_package.py", @@ -78,6 +79,7 @@ class LintersPaths: "demo/guide-python/quantile_regression.py", "demo/guide-python/multioutput_regression.py", "demo/guide-python/learning_to_rank.py", + "demo/aft_survival/aft_survival_viz_demo.py", # CI "tests/ci_build/lint_python.py", "tests/ci_build/test_r_package.py", @@ -114,7 +116,13 @@ def run_black(rel_path: str, fix: bool) -> bool: @cd(PY_PACKAGE) def run_isort(rel_path: str, fix: bool) -> bool: # Isort gets confused when trying to find the config file, so specified explicitly. - cmd = ["isort", "--settings-path", PY_PACKAGE, os.path.join(ROOT, rel_path)] + cmd = [ + "isort", + "--settings-path", + PY_PACKAGE, + f"--src={PY_PACKAGE}", + os.path.join(ROOT, rel_path), + ] if not fix: cmd += ["--check"]