* [WIP] Add lower and upper bounds on the label for survival analysis * Update test MetaInfo.SaveLoadBinary to account for extra two fields * Don't clear qids_ for version 2 of MetaInfo * Add SetInfo() and GetInfo() method for lower and upper bounds * changes to aft * Add parameter class for AFT; use enum's to represent distribution and event type * Add AFT metric * changes to neg grad to grad * changes to binomial loss * changes to overflow * changes to eps * changes to code refactoring * changes to code refactoring * changes to code refactoring * Re-factor survival analysis * Remove aft namespace * Move function bodies out of AFTNormal and AFTLogistic, to reduce clutter * Move function bodies out of AFTLoss, to reduce clutter * Use smart pointer to store AFTDistribution and AFTLoss * Rename AFTNoiseDistribution enum to AFTDistributionType for clarity The enum class was not a distribution itself but a distribution type * Add AFTDistribution::Create() method for convenience * changes to extreme distribution * changes to extreme distribution * changes to extreme * changes to extreme distribution * changes to left censored * deleted cout * changes to x,mu and sd and code refactoring * changes to print * changes to hessian formula in censored and uncensored * changes to variable names and pow * changes to Logistic Pdf * changes to parameter * Expose lower and upper bound labels to R package * Use example weights; normalize log likelihood metric * changes to CHECK * changes to logistic hessian to standard formula * changes to logistic formula * Comply with coding style guideline * Revert back Rabit submodule * Revert dmlc-core submodule * Comply with coding style guideline (clang-tidy) * Fix an error in AFTLoss::Gradient() * Add missing files to amalgamation * Address @RAMitchell's comment: minimize future change in MetaInfo interface * Fix lint * Fix compilation error on 32-bit target, when size_t == bst_uint * Allocate sufficient memory to hold extra label info * Use OpenMP to speed up * Fix compilation on Windows * Address reviewer's feedback * Add unit tests for probability distributions * Make Metric subclass of Configurable * Address reviewer's feedback: Configure() AFT metric * Add a dummy test for AFT metric configuration * Complete AFT configuration test; remove debugging print * Rename AFT parameters * Clarify test comment * Add a dummy test for AFT loss for uncensored case * Fix a bug in AFT loss for uncensored labels * Complete unit test for AFT loss metric * Simplify unit tests for AFT metric * Add unit test to verify aggregate output from AFT metric * Use EXPECT_* instead of ASSERT_*, so that we run all unit tests * Use aft_loss_param when serializing AFTObj This is to be consistent with AFT metric * Add unit tests for AFT Objective * Fix OpenMP bug; clarify semantics for shared variables used in OpenMP loops * Add comments * Remove AFT prefix from probability distribution; put probability distribution in separate source file * Add comments * Define kPI and kEulerMascheroni in probability_distribution.h * Add probability_distribution.cc to amalgamation * Remove unnecessary diff * Address reviewer's feedback: define variables where they're used * Eliminate all INFs and NANs from AFT loss and gradient * Add demo * Add tutorial * Fix lint * Use 'survival:aft' to be consistent with 'survival:cox' * Move sample data to demo/data * Add visual demo with 1D toy data * Add Python tests Co-authored-by: Philip Cho <chohyu01@cs.washington.edu>
98 lines
3.4 KiB
Python
98 lines
3.4 KiB
Python
"""
|
|
Visual demo for survival analysis (regression) with Accelerated Failure Time (AFT) model.
|
|
|
|
This demo uses 1D toy data and visualizes how XGBoost fits a tree ensemble. The ensemble model
|
|
starts out as a flat line and evolves into a step function in order to account for all ranged
|
|
labels.
|
|
"""
|
|
import numpy as np
|
|
import xgboost as xgb
|
|
import matplotlib.pyplot as plt
|
|
|
|
plt.rcParams.update({'font.size': 13})
|
|
|
|
# Function to visualize censored labels
|
|
def plot_censored_labels(X, y_lower, y_upper):
|
|
def replace_inf(x, target_value):
|
|
x[np.isinf(x)] = target_value
|
|
return x
|
|
plt.plot(X, y_lower, 'o', label='y_lower', color='blue')
|
|
plt.plot(X, y_upper, 'o', label='y_upper', color='fuchsia')
|
|
plt.vlines(X, ymin=replace_inf(y_lower, 0.01), ymax=replace_inf(y_upper, 1000),
|
|
label='Range for y', color='gray')
|
|
|
|
# Toy data
|
|
X = np.array([1, 2, 3, 4, 5]).reshape((-1, 1))
|
|
INF = np.inf
|
|
y_lower = np.array([ 10, 15, -INF, 30, 100])
|
|
y_upper = np.array([INF, INF, 20, 50, INF])
|
|
|
|
# Visualize toy data
|
|
plt.figure(figsize=(5, 4))
|
|
plot_censored_labels(X, y_lower, y_upper)
|
|
plt.ylim((6, 200))
|
|
plt.legend(loc='lower right')
|
|
plt.title('Toy data')
|
|
plt.xlabel('Input feature')
|
|
plt.ylabel('Label')
|
|
plt.yscale('log')
|
|
plt.tight_layout()
|
|
plt.show(block=True)
|
|
|
|
# Will be used to visualize XGBoost model
|
|
grid_pts = np.linspace(0.8, 5.2, 1000).reshape((-1, 1))
|
|
|
|
# Train AFT model using XGBoost
|
|
dmat = xgb.DMatrix(X)
|
|
dmat.set_float_info('label_lower_bound', y_lower)
|
|
dmat.set_float_info('label_upper_bound', y_upper)
|
|
params = {'max_depth': 3, 'objective':'survival:aft', 'min_child_weight': 0}
|
|
|
|
accuracy_history = []
|
|
def plot_intermediate_model_callback(env):
|
|
"""Custom callback to plot intermediate models"""
|
|
# Compute y_pred = prediction using the intermediate model, at current boosting iteration
|
|
y_pred = env.model.predict(dmat)
|
|
# "Accuracy" = the number of data points whose ranged label (y_lower, y_upper) includes
|
|
# the corresponding predicted label (y_pred)
|
|
acc = np.sum(np.logical_and(y_pred >= y_lower, y_pred <= y_upper)/len(X) * 100)
|
|
accuracy_history.append(acc)
|
|
|
|
# Plot ranged labels as well as predictions by the model
|
|
plt.subplot(5, 3, env.iteration + 1)
|
|
plot_censored_labels(X, y_lower, y_upper)
|
|
y_pred_grid_pts = env.model.predict(xgb.DMatrix(grid_pts))
|
|
plt.plot(grid_pts, y_pred_grid_pts, 'r-', label='XGBoost AFT model', linewidth=4)
|
|
plt.title('Iteration {}'.format(env.iteration), x=0.5, y=0.8)
|
|
plt.xlim((0.8, 5.2))
|
|
plt.ylim((1 if np.min(y_pred) < 6 else 6, 200))
|
|
plt.yscale('log')
|
|
|
|
res = {}
|
|
plt.figure(figsize=(12,13))
|
|
bst = xgb.train(params, dmat, 15, [(dmat, 'train')], evals_result=res,
|
|
callbacks=[plot_intermediate_model_callback])
|
|
plt.tight_layout()
|
|
plt.legend(loc='lower center', ncol=4,
|
|
bbox_to_anchor=(0.5, 0),
|
|
bbox_transform=plt.gcf().transFigure)
|
|
plt.tight_layout()
|
|
|
|
# Plot negative log likelihood over boosting iterations
|
|
plt.figure(figsize=(8,3))
|
|
plt.subplot(1, 2, 1)
|
|
plt.plot(res['train']['aft-nloglik'], 'b-o', label='aft-nloglik')
|
|
plt.xlabel('# Boosting Iterations')
|
|
plt.legend(loc='best')
|
|
|
|
# Plot "accuracy" over boosting iterations
|
|
# "Accuracy" = the number of data points whose ranged label (y_lower, y_upper) includes
|
|
# the corresponding predicted label (y_pred)
|
|
plt.subplot(1, 2, 2)
|
|
plt.plot(accuracy_history, 'r-o', label='Accuracy (%)')
|
|
plt.xlabel('# Boosting Iterations')
|
|
plt.legend(loc='best')
|
|
plt.tight_layout()
|
|
|
|
plt.show()
|