Add Accelerated Failure Time loss for survival analysis task (#4763)

* [WIP] Add lower and upper bounds on the label for survival analysis

* Update test MetaInfo.SaveLoadBinary to account for extra two fields

* Don't clear qids_ for version 2 of MetaInfo

* Add SetInfo() and GetInfo() method for lower and upper bounds

* changes to aft

* Add parameter class for AFT; use enum's to represent distribution and event type

* Add AFT metric

* changes to neg grad to grad

* changes to binomial loss

* changes to overflow

* changes to eps

* changes to code refactoring

* changes to code refactoring

* changes to code refactoring

* Re-factor survival analysis

* Remove aft namespace

* Move function bodies out of AFTNormal and AFTLogistic, to reduce clutter

* Move function bodies out of AFTLoss, to reduce clutter

* Use smart pointer to store AFTDistribution and AFTLoss

* Rename AFTNoiseDistribution enum to AFTDistributionType for clarity

The enum class was not a distribution itself but a distribution type

* Add AFTDistribution::Create() method for convenience

* changes to extreme distribution

* changes to extreme distribution

* changes to extreme

* changes to extreme distribution

* changes to left censored

* deleted cout

* changes to x,mu and sd and code refactoring

* changes to print

* changes to hessian formula in censored and uncensored

* changes to variable names and pow

* changes to Logistic Pdf

* changes to parameter

* Expose lower and upper bound labels to R package

* Use example weights; normalize log likelihood metric

* changes to CHECK

* changes to logistic hessian to standard formula

* changes to logistic formula

* Comply with coding style guideline

* Revert back Rabit submodule

* Revert dmlc-core submodule

* Comply with coding style guideline (clang-tidy)

* Fix an error in AFTLoss::Gradient()

* Add missing files to amalgamation

* Address @RAMitchell's comment: minimize future change in MetaInfo interface

* Fix lint

* Fix compilation error on 32-bit target, when size_t == bst_uint

* Allocate sufficient memory to hold extra label info

* Use OpenMP to speed up

* Fix compilation on Windows

* Address reviewer's feedback

* Add unit tests for probability distributions

* Make Metric subclass of Configurable

* Address reviewer's feedback: Configure() AFT metric

* Add a dummy test for AFT metric configuration

* Complete AFT configuration test; remove debugging print

* Rename AFT parameters

* Clarify test comment

* Add a dummy test for AFT loss for uncensored case

* Fix a bug in AFT loss for uncensored labels

* Complete unit test for AFT loss metric

* Simplify unit tests for AFT metric

* Add unit test to verify aggregate output from AFT metric

* Use EXPECT_* instead of ASSERT_*, so that we run all unit tests

* Use aft_loss_param when serializing AFTObj

This is to be consistent with AFT metric

* Add unit tests for AFT Objective

* Fix OpenMP bug; clarify semantics for shared variables used in OpenMP loops

* Add comments

* Remove AFT prefix from probability distribution; put probability distribution in separate source file

* Add comments

* Define kPI and kEulerMascheroni in probability_distribution.h

* Add probability_distribution.cc to amalgamation

* Remove unnecessary diff

* Address reviewer's feedback: define variables where they're used

* Eliminate all INFs and NANs from AFT loss and gradient

* Add demo

* Add tutorial

* Fix lint

* Use 'survival:aft' to be consistent with 'survival:cox'

* Move sample data to demo/data

* Add visual demo with 1D toy data

* Add Python tests

Co-authored-by: Philip Cho <chohyu01@cs.washington.edu>
This commit is contained in:
Avinash Barnwal 2020-03-25 16:52:51 -04:00 committed by GitHub
parent 1de36cdf1e
commit dcf439932a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 1789 additions and 15 deletions

View File

@ -243,6 +243,18 @@ setinfo.xgb.DMatrix <- function(object, name, info, ...) {
.Call(XGDMatrixSetInfo_R, object, name, as.numeric(info))
return(TRUE)
}
if (name == "label_lower_bound") {
if (length(info) != nrow(object))
stop("The length of lower-bound labels must equal to the number of rows in the input data")
.Call(XGDMatrixSetInfo_R, object, name, as.numeric(info))
return(TRUE)
}
if (name == "label_upper_bound") {
if (length(info) != nrow(object))
stop("The length of upper-bound labels must equal to the number of rows in the input data")
.Call(XGDMatrixSetInfo_R, object, name, as.numeric(info))
return(TRUE)
}
if (name == "weight") {
if (length(info) != nrow(object))
stop("The length of weights must equal to the number of rows in the input data")

View File

@ -14,6 +14,7 @@
#include "../src/metric/elementwise_metric.cc"
#include "../src/metric/multiclass_metric.cc"
#include "../src/metric/rank_metric.cc"
#include "../src/metric/survival_metric.cc"
// objectives
#include "../src/objective/objective.cc"
@ -21,6 +22,7 @@
#include "../src/objective/multiclass_obj.cc"
#include "../src/objective/rank_obj.cc"
#include "../src/objective/hinge.cc"
#include "../src/objective/aft_obj.cc"
// gbms
#include "../src/gbm/gbm.cc"
@ -44,7 +46,7 @@
#include "../src/data/sparse_page_dmatrix.cc"
#endif
// tress
// trees
#include "../src/tree/param.cc"
#include "../src/tree/split_evaluator.cc"
#include "../src/tree/tree_model.cc"
@ -72,6 +74,8 @@
#include "../src/common/hist_util.cc"
#include "../src/common/json.cc"
#include "../src/common/io.cc"
#include "../src/common/survival_util.cc"
#include "../src/common/probability_distribution.cc"
#include "../src/common/version.cc"
// c_api

View File

@ -0,0 +1,54 @@
"""
Demo for survival analysis (regression) using Accelerated Failure Time (AFT) model
"""
from sklearn.model_selection import ShuffleSplit
import pandas as pd
import numpy as np
import xgboost as xgb
# The Veterans' Administration Lung Cancer Trial
# The Statistical Analysis of Failure Time Data by Kalbfleisch J. and Prentice R (1980)
df = pd.read_csv('../data/veterans_lung_cancer.csv')
print('Training data:')
print(df)
# Split features and labels
y_lower_bound = df['Survival_label_lower_bound']
y_upper_bound = df['Survival_label_upper_bound']
X = df.drop(['Survival_label_lower_bound', 'Survival_label_upper_bound'], axis=1)
# Split data into training and validation sets
rs = ShuffleSplit(n_splits=2, test_size=.7, random_state=0)
train_index, valid_index = next(rs.split(X))
dtrain = xgb.DMatrix(X.values[train_index, :])
dtrain.set_float_info('label_lower_bound', y_lower_bound[train_index])
dtrain.set_float_info('label_upper_bound', y_upper_bound[train_index])
dvalid = xgb.DMatrix(X.values[valid_index, :])
dvalid.set_float_info('label_lower_bound', y_lower_bound[valid_index])
dvalid.set_float_info('label_upper_bound', y_upper_bound[valid_index])
# Train gradient boosted trees using AFT loss and metric
params = {'verbosity': 0,
'objective': 'survival:aft',
'eval_metric': 'aft-nloglik',
'tree_method': 'hist',
'learning_rate': 0.05,
'aft_loss_distribution': 'normal',
'aft_loss_distribution_scale': 1.20,
'max_depth': 6,
'lambda': 0.01,
'alpha': 0.02}
bst = xgb.train(params, dtrain, num_boost_round=10000,
evals=[(dtrain, 'train'), (dvalid, 'valid')],
early_stopping_rounds=50)
# Run prediction on the validation set
df = pd.DataFrame({'Label (lower bound)': y_lower_bound[valid_index],
'Label (upper bound)': y_upper_bound[valid_index],
'Predicted label': bst.predict(dvalid)})
print(df)
# Show only data points with right-censored labels
print(df[np.isinf(df['Label (upper bound)'])])
# Save trained model
bst.save_model('aft_model.json')

View File

@ -0,0 +1,78 @@
"""
Demo for survival analysis (regression) using Accelerated Failure Time (AFT) model, using Optuna
to tune hyperparameters
"""
from sklearn.model_selection import ShuffleSplit
import pandas as pd
import numpy as np
import xgboost as xgb
import optuna
# The Veterans' Administration Lung Cancer Trial
# The Statistical Analysis of Failure Time Data by Kalbfleisch J. and Prentice R (1980)
df = pd.read_csv('../data/veterans_lung_cancer.csv')
print('Training data:')
print(df)
# Split features and labels
y_lower_bound = df['Survival_label_lower_bound']
y_upper_bound = df['Survival_label_upper_bound']
X = df.drop(['Survival_label_lower_bound', 'Survival_label_upper_bound'], axis=1)
# Split data into training and validation sets
rs = ShuffleSplit(n_splits=2, test_size=.7, random_state=0)
train_index, valid_index = next(rs.split(X))
dtrain = xgb.DMatrix(X.values[train_index, :])
dtrain.set_float_info('label_lower_bound', y_lower_bound[train_index])
dtrain.set_float_info('label_upper_bound', y_upper_bound[train_index])
dvalid = xgb.DMatrix(X.values[valid_index, :])
dvalid.set_float_info('label_lower_bound', y_lower_bound[valid_index])
dvalid.set_float_info('label_upper_bound', y_upper_bound[valid_index])
# Define hyperparameter search space
base_params = {'verbosity': 0,
'objective': 'survival:aft',
'eval_metric': 'aft-nloglik',
'tree_method': 'hist'} # Hyperparameters common to all trials
def objective(trial):
params = {'learning_rate': trial.suggest_loguniform('learning_rate', 0.01, 1.0),
'aft_loss_distribution': trial.suggest_categorical('aft_loss_distribution',
['normal', 'logistic', 'extreme']),
'aft_loss_distribution_scale': trial.suggest_loguniform('aft_loss_distribution_scale', 0.1, 10.0),
'max_depth': trial.suggest_int('max_depth', 3, 8),
'lambda': trial.suggest_loguniform('lambda', 1e-8, 1.0),
'alpha': trial.suggest_loguniform('alpha', 1e-8, 1.0)} # Search space
params.update(base_params)
pruning_callback = optuna.integration.XGBoostPruningCallback(trial, 'valid-aft-nloglik')
bst = xgb.train(params, dtrain, num_boost_round=10000,
evals=[(dtrain, 'train'), (dvalid, 'valid')],
early_stopping_rounds=50, verbose_eval=False, callbacks=[pruning_callback])
if bst.best_iteration >= 25:
return bst.best_score
else:
return np.inf # Reject models with < 25 trees
# Run hyperparameter search
study = optuna.create_study(direction='minimize')
study.optimize(objective, n_trials=200)
print('Completed hyperparameter tuning with best aft-nloglik = {}.'.format(study.best_trial.value))
params = {}
params.update(base_params)
params.update(study.best_trial.params)
# Re-run training with the best hyperparameter combination
print('Re-running the best trial... params = {}'.format(params))
bst = xgb.train(params, dtrain, num_boost_round=10000,
evals=[(dtrain, 'train'), (dvalid, 'valid')],
early_stopping_rounds=50)
# Run prediction on the validation set
df = pd.DataFrame({'Label (lower bound)': y_lower_bound[valid_index],
'Label (upper bound)': y_upper_bound[valid_index],
'Predicted label': bst.predict(dvalid)})
print(df)
# Show only data points with right-censored labels
print(df[np.isinf(df['Label (upper bound)'])])
# Save trained model
bst.save_model('aft_best_model.json')

View File

@ -0,0 +1,97 @@
"""
Visual demo for survival analysis (regression) with Accelerated Failure Time (AFT) model.
This demo uses 1D toy data and visualizes how XGBoost fits a tree ensemble. The ensemble model
starts out as a flat line and evolves into a step function in order to account for all ranged
labels.
"""
import numpy as np
import xgboost as xgb
import matplotlib.pyplot as plt
plt.rcParams.update({'font.size': 13})
# Function to visualize censored labels
def plot_censored_labels(X, y_lower, y_upper):
def replace_inf(x, target_value):
x[np.isinf(x)] = target_value
return x
plt.plot(X, y_lower, 'o', label='y_lower', color='blue')
plt.plot(X, y_upper, 'o', label='y_upper', color='fuchsia')
plt.vlines(X, ymin=replace_inf(y_lower, 0.01), ymax=replace_inf(y_upper, 1000),
label='Range for y', color='gray')
# Toy data
X = np.array([1, 2, 3, 4, 5]).reshape((-1, 1))
INF = np.inf
y_lower = np.array([ 10, 15, -INF, 30, 100])
y_upper = np.array([INF, INF, 20, 50, INF])
# Visualize toy data
plt.figure(figsize=(5, 4))
plot_censored_labels(X, y_lower, y_upper)
plt.ylim((6, 200))
plt.legend(loc='lower right')
plt.title('Toy data')
plt.xlabel('Input feature')
plt.ylabel('Label')
plt.yscale('log')
plt.tight_layout()
plt.show(block=True)
# Will be used to visualize XGBoost model
grid_pts = np.linspace(0.8, 5.2, 1000).reshape((-1, 1))
# Train AFT model using XGBoost
dmat = xgb.DMatrix(X)
dmat.set_float_info('label_lower_bound', y_lower)
dmat.set_float_info('label_upper_bound', y_upper)
params = {'max_depth': 3, 'objective':'survival:aft', 'min_child_weight': 0}
accuracy_history = []
def plot_intermediate_model_callback(env):
"""Custom callback to plot intermediate models"""
# Compute y_pred = prediction using the intermediate model, at current boosting iteration
y_pred = env.model.predict(dmat)
# "Accuracy" = the number of data points whose ranged label (y_lower, y_upper) includes
# the corresponding predicted label (y_pred)
acc = np.sum(np.logical_and(y_pred >= y_lower, y_pred <= y_upper)/len(X) * 100)
accuracy_history.append(acc)
# Plot ranged labels as well as predictions by the model
plt.subplot(5, 3, env.iteration + 1)
plot_censored_labels(X, y_lower, y_upper)
y_pred_grid_pts = env.model.predict(xgb.DMatrix(grid_pts))
plt.plot(grid_pts, y_pred_grid_pts, 'r-', label='XGBoost AFT model', linewidth=4)
plt.title('Iteration {}'.format(env.iteration), x=0.5, y=0.8)
plt.xlim((0.8, 5.2))
plt.ylim((1 if np.min(y_pred) < 6 else 6, 200))
plt.yscale('log')
res = {}
plt.figure(figsize=(12,13))
bst = xgb.train(params, dmat, 15, [(dmat, 'train')], evals_result=res,
callbacks=[plot_intermediate_model_callback])
plt.tight_layout()
plt.legend(loc='lower center', ncol=4,
bbox_to_anchor=(0.5, 0),
bbox_transform=plt.gcf().transFigure)
plt.tight_layout()
# Plot negative log likelihood over boosting iterations
plt.figure(figsize=(8,3))
plt.subplot(1, 2, 1)
plt.plot(res['train']['aft-nloglik'], 'b-o', label='aft-nloglik')
plt.xlabel('# Boosting Iterations')
plt.legend(loc='best')
# Plot "accuracy" over boosting iterations
# "Accuracy" = the number of data points whose ranged label (y_lower, y_upper) includes
# the corresponding predicted label (y_pred)
plt.subplot(1, 2, 2)
plt.plot(accuracy_history, 'r-o', label='Accuracy (%)')
plt.xlabel('# Boosting Iterations')
plt.legend(loc='best')
plt.tight_layout()
plt.show()

View File

@ -0,0 +1,138 @@
Survival_label_lower_bound,Survival_label_upper_bound,Age_in_years,Karnofsky_score,Months_from_Diagnosis,Celltype=adeno,Celltype=large,Celltype=smallcell,Celltype=squamous,Prior_therapy=no,Prior_therapy=yes,Treatment=standard,Treatment=test
72.0,72.0,69.0,60.0,7.0,0,0,0,1,1,0,1,0
411.0,411.0,64.0,70.0,5.0,0,0,0,1,0,1,1,0
228.0,228.0,38.0,60.0,3.0,0,0,0,1,1,0,1,0
126.0,126.0,63.0,60.0,9.0,0,0,0,1,0,1,1,0
118.0,118.0,65.0,70.0,11.0,0,0,0,1,0,1,1,0
10.0,10.0,49.0,20.0,5.0,0,0,0,1,1,0,1,0
82.0,82.0,69.0,40.0,10.0,0,0,0,1,0,1,1,0
110.0,110.0,68.0,80.0,29.0,0,0,0,1,1,0,1,0
314.0,314.0,43.0,50.0,18.0,0,0,0,1,1,0,1,0
100.0,inf,70.0,70.0,6.0,0,0,0,1,1,0,1,0
42.0,42.0,81.0,60.0,4.0,0,0,0,1,1,0,1,0
8.0,8.0,63.0,40.0,58.0,0,0,0,1,0,1,1,0
144.0,144.0,63.0,30.0,4.0,0,0,0,1,1,0,1,0
25.0,inf,52.0,80.0,9.0,0,0,0,1,0,1,1,0
11.0,11.0,48.0,70.0,11.0,0,0,0,1,0,1,1,0
30.0,30.0,61.0,60.0,3.0,0,0,1,0,1,0,1,0
384.0,384.0,42.0,60.0,9.0,0,0,1,0,1,0,1,0
4.0,4.0,35.0,40.0,2.0,0,0,1,0,1,0,1,0
54.0,54.0,63.0,80.0,4.0,0,0,1,0,0,1,1,0
13.0,13.0,56.0,60.0,4.0,0,0,1,0,1,0,1,0
123.0,inf,55.0,40.0,3.0,0,0,1,0,1,0,1,0
97.0,inf,67.0,60.0,5.0,0,0,1,0,1,0,1,0
153.0,153.0,63.0,60.0,14.0,0,0,1,0,0,1,1,0
59.0,59.0,65.0,30.0,2.0,0,0,1,0,1,0,1,0
117.0,117.0,46.0,80.0,3.0,0,0,1,0,1,0,1,0
16.0,16.0,53.0,30.0,4.0,0,0,1,0,0,1,1,0
151.0,151.0,69.0,50.0,12.0,0,0,1,0,1,0,1,0
22.0,22.0,68.0,60.0,4.0,0,0,1,0,1,0,1,0
56.0,56.0,43.0,80.0,12.0,0,0,1,0,0,1,1,0
21.0,21.0,55.0,40.0,2.0,0,0,1,0,0,1,1,0
18.0,18.0,42.0,20.0,15.0,0,0,1,0,1,0,1,0
139.0,139.0,64.0,80.0,2.0,0,0,1,0,1,0,1,0
20.0,20.0,65.0,30.0,5.0,0,0,1,0,1,0,1,0
31.0,31.0,65.0,75.0,3.0,0,0,1,0,1,0,1,0
52.0,52.0,55.0,70.0,2.0,0,0,1,0,1,0,1,0
287.0,287.0,66.0,60.0,25.0,0,0,1,0,0,1,1,0
18.0,18.0,60.0,30.0,4.0,0,0,1,0,1,0,1,0
51.0,51.0,67.0,60.0,1.0,0,0,1,0,1,0,1,0
122.0,122.0,53.0,80.0,28.0,0,0,1,0,1,0,1,0
27.0,27.0,62.0,60.0,8.0,0,0,1,0,1,0,1,0
54.0,54.0,67.0,70.0,1.0,0,0,1,0,1,0,1,0
7.0,7.0,72.0,50.0,7.0,0,0,1,0,1,0,1,0
63.0,63.0,48.0,50.0,11.0,0,0,1,0,1,0,1,0
392.0,392.0,68.0,40.0,4.0,0,0,1,0,1,0,1,0
10.0,10.0,67.0,40.0,23.0,0,0,1,0,0,1,1,0
8.0,8.0,61.0,20.0,19.0,1,0,0,0,0,1,1,0
92.0,92.0,60.0,70.0,10.0,1,0,0,0,1,0,1,0
35.0,35.0,62.0,40.0,6.0,1,0,0,0,1,0,1,0
117.0,117.0,38.0,80.0,2.0,1,0,0,0,1,0,1,0
132.0,132.0,50.0,80.0,5.0,1,0,0,0,1,0,1,0
12.0,12.0,63.0,50.0,4.0,1,0,0,0,0,1,1,0
162.0,162.0,64.0,80.0,5.0,1,0,0,0,1,0,1,0
3.0,3.0,43.0,30.0,3.0,1,0,0,0,1,0,1,0
95.0,95.0,34.0,80.0,4.0,1,0,0,0,1,0,1,0
177.0,177.0,66.0,50.0,16.0,0,1,0,0,0,1,1,0
162.0,162.0,62.0,80.0,5.0,0,1,0,0,1,0,1,0
216.0,216.0,52.0,50.0,15.0,0,1,0,0,1,0,1,0
553.0,553.0,47.0,70.0,2.0,0,1,0,0,1,0,1,0
278.0,278.0,63.0,60.0,12.0,0,1,0,0,1,0,1,0
12.0,12.0,68.0,40.0,12.0,0,1,0,0,0,1,1,0
260.0,260.0,45.0,80.0,5.0,0,1,0,0,1,0,1,0
200.0,200.0,41.0,80.0,12.0,0,1,0,0,0,1,1,0
156.0,156.0,66.0,70.0,2.0,0,1,0,0,1,0,1,0
182.0,inf,62.0,90.0,2.0,0,1,0,0,1,0,1,0
143.0,143.0,60.0,90.0,8.0,0,1,0,0,1,0,1,0
105.0,105.0,66.0,80.0,11.0,0,1,0,0,1,0,1,0
103.0,103.0,38.0,80.0,5.0,0,1,0,0,1,0,1,0
250.0,250.0,53.0,70.0,8.0,0,1,0,0,0,1,1,0
100.0,100.0,37.0,60.0,13.0,0,1,0,0,0,1,1,0
999.0,999.0,54.0,90.0,12.0,0,0,0,1,0,1,0,1
112.0,112.0,60.0,80.0,6.0,0,0,0,1,1,0,0,1
87.0,inf,48.0,80.0,3.0,0,0,0,1,1,0,0,1
231.0,inf,52.0,50.0,8.0,0,0,0,1,0,1,0,1
242.0,242.0,70.0,50.0,1.0,0,0,0,1,1,0,0,1
991.0,991.0,50.0,70.0,7.0,0,0,0,1,0,1,0,1
111.0,111.0,62.0,70.0,3.0,0,0,0,1,1,0,0,1
1.0,1.0,65.0,20.0,21.0,0,0,0,1,0,1,0,1
587.0,587.0,58.0,60.0,3.0,0,0,0,1,1,0,0,1
389.0,389.0,62.0,90.0,2.0,0,0,0,1,1,0,0,1
33.0,33.0,64.0,30.0,6.0,0,0,0,1,1,0,0,1
25.0,25.0,63.0,20.0,36.0,0,0,0,1,1,0,0,1
357.0,357.0,58.0,70.0,13.0,0,0,0,1,1,0,0,1
467.0,467.0,64.0,90.0,2.0,0,0,0,1,1,0,0,1
201.0,201.0,52.0,80.0,28.0,0,0,0,1,0,1,0,1
1.0,1.0,35.0,50.0,7.0,0,0,0,1,1,0,0,1
30.0,30.0,63.0,70.0,11.0,0,0,0,1,1,0,0,1
44.0,44.0,70.0,60.0,13.0,0,0,0,1,0,1,0,1
283.0,283.0,51.0,90.0,2.0,0,0,0,1,1,0,0,1
15.0,15.0,40.0,50.0,13.0,0,0,0,1,0,1,0,1
25.0,25.0,69.0,30.0,2.0,0,0,1,0,1,0,0,1
103.0,inf,36.0,70.0,22.0,0,0,1,0,0,1,0,1
21.0,21.0,71.0,20.0,4.0,0,0,1,0,1,0,0,1
13.0,13.0,62.0,30.0,2.0,0,0,1,0,1,0,0,1
87.0,87.0,60.0,60.0,2.0,0,0,1,0,1,0,0,1
2.0,2.0,44.0,40.0,36.0,0,0,1,0,0,1,0,1
20.0,20.0,54.0,30.0,9.0,0,0,1,0,0,1,0,1
7.0,7.0,66.0,20.0,11.0,0,0,1,0,1,0,0,1
24.0,24.0,49.0,60.0,8.0,0,0,1,0,1,0,0,1
99.0,99.0,72.0,70.0,3.0,0,0,1,0,1,0,0,1
8.0,8.0,68.0,80.0,2.0,0,0,1,0,1,0,0,1
99.0,99.0,62.0,85.0,4.0,0,0,1,0,1,0,0,1
61.0,61.0,71.0,70.0,2.0,0,0,1,0,1,0,0,1
25.0,25.0,70.0,70.0,2.0,0,0,1,0,1,0,0,1
95.0,95.0,61.0,70.0,1.0,0,0,1,0,1,0,0,1
80.0,80.0,71.0,50.0,17.0,0,0,1,0,1,0,0,1
51.0,51.0,59.0,30.0,87.0,0,0,1,0,0,1,0,1
29.0,29.0,67.0,40.0,8.0,0,0,1,0,1,0,0,1
24.0,24.0,60.0,40.0,2.0,1,0,0,0,1,0,0,1
18.0,18.0,69.0,40.0,5.0,1,0,0,0,0,1,0,1
83.0,inf,57.0,99.0,3.0,1,0,0,0,1,0,0,1
31.0,31.0,39.0,80.0,3.0,1,0,0,0,1,0,0,1
51.0,51.0,62.0,60.0,5.0,1,0,0,0,1,0,0,1
90.0,90.0,50.0,60.0,22.0,1,0,0,0,0,1,0,1
52.0,52.0,43.0,60.0,3.0,1,0,0,0,1,0,0,1
73.0,73.0,70.0,60.0,3.0,1,0,0,0,1,0,0,1
8.0,8.0,66.0,50.0,5.0,1,0,0,0,1,0,0,1
36.0,36.0,61.0,70.0,8.0,1,0,0,0,1,0,0,1
48.0,48.0,81.0,10.0,4.0,1,0,0,0,1,0,0,1
7.0,7.0,58.0,40.0,4.0,1,0,0,0,1,0,0,1
140.0,140.0,63.0,70.0,3.0,1,0,0,0,1,0,0,1
186.0,186.0,60.0,90.0,3.0,1,0,0,0,1,0,0,1
84.0,84.0,62.0,80.0,4.0,1,0,0,0,0,1,0,1
19.0,19.0,42.0,50.0,10.0,1,0,0,0,1,0,0,1
45.0,45.0,69.0,40.0,3.0,1,0,0,0,1,0,0,1
80.0,80.0,63.0,40.0,4.0,1,0,0,0,1,0,0,1
52.0,52.0,45.0,60.0,4.0,0,1,0,0,1,0,0,1
164.0,164.0,68.0,70.0,15.0,0,1,0,0,0,1,0,1
19.0,19.0,39.0,30.0,4.0,0,1,0,0,0,1,0,1
53.0,53.0,66.0,60.0,12.0,0,1,0,0,1,0,0,1
15.0,15.0,63.0,30.0,5.0,0,1,0,0,1,0,0,1
43.0,43.0,49.0,60.0,11.0,0,1,0,0,0,1,0,1
340.0,340.0,64.0,80.0,10.0,0,1,0,0,0,1,0,1
133.0,133.0,65.0,75.0,1.0,0,1,0,0,1,0,0,1
111.0,111.0,64.0,60.0,5.0,0,1,0,0,1,0,0,1
231.0,231.0,67.0,70.0,18.0,0,1,0,0,0,1,0,1
378.0,378.0,65.0,80.0,4.0,0,1,0,0,1,0,0,1
49.0,49.0,37.0,30.0,3.0,0,1,0,0,1,0,0,1
1 Survival_label_lower_bound Survival_label_upper_bound Age_in_years Karnofsky_score Months_from_Diagnosis Celltype=adeno Celltype=large Celltype=smallcell Celltype=squamous Prior_therapy=no Prior_therapy=yes Treatment=standard Treatment=test
2 72.0 72.0 69.0 60.0 7.0 0 0 0 1 1 0 1 0
3 411.0 411.0 64.0 70.0 5.0 0 0 0 1 0 1 1 0
4 228.0 228.0 38.0 60.0 3.0 0 0 0 1 1 0 1 0
5 126.0 126.0 63.0 60.0 9.0 0 0 0 1 0 1 1 0
6 118.0 118.0 65.0 70.0 11.0 0 0 0 1 0 1 1 0
7 10.0 10.0 49.0 20.0 5.0 0 0 0 1 1 0 1 0
8 82.0 82.0 69.0 40.0 10.0 0 0 0 1 0 1 1 0
9 110.0 110.0 68.0 80.0 29.0 0 0 0 1 1 0 1 0
10 314.0 314.0 43.0 50.0 18.0 0 0 0 1 1 0 1 0
11 100.0 inf 70.0 70.0 6.0 0 0 0 1 1 0 1 0
12 42.0 42.0 81.0 60.0 4.0 0 0 0 1 1 0 1 0
13 8.0 8.0 63.0 40.0 58.0 0 0 0 1 0 1 1 0
14 144.0 144.0 63.0 30.0 4.0 0 0 0 1 1 0 1 0
15 25.0 inf 52.0 80.0 9.0 0 0 0 1 0 1 1 0
16 11.0 11.0 48.0 70.0 11.0 0 0 0 1 0 1 1 0
17 30.0 30.0 61.0 60.0 3.0 0 0 1 0 1 0 1 0
18 384.0 384.0 42.0 60.0 9.0 0 0 1 0 1 0 1 0
19 4.0 4.0 35.0 40.0 2.0 0 0 1 0 1 0 1 0
20 54.0 54.0 63.0 80.0 4.0 0 0 1 0 0 1 1 0
21 13.0 13.0 56.0 60.0 4.0 0 0 1 0 1 0 1 0
22 123.0 inf 55.0 40.0 3.0 0 0 1 0 1 0 1 0
23 97.0 inf 67.0 60.0 5.0 0 0 1 0 1 0 1 0
24 153.0 153.0 63.0 60.0 14.0 0 0 1 0 0 1 1 0
25 59.0 59.0 65.0 30.0 2.0 0 0 1 0 1 0 1 0
26 117.0 117.0 46.0 80.0 3.0 0 0 1 0 1 0 1 0
27 16.0 16.0 53.0 30.0 4.0 0 0 1 0 0 1 1 0
28 151.0 151.0 69.0 50.0 12.0 0 0 1 0 1 0 1 0
29 22.0 22.0 68.0 60.0 4.0 0 0 1 0 1 0 1 0
30 56.0 56.0 43.0 80.0 12.0 0 0 1 0 0 1 1 0
31 21.0 21.0 55.0 40.0 2.0 0 0 1 0 0 1 1 0
32 18.0 18.0 42.0 20.0 15.0 0 0 1 0 1 0 1 0
33 139.0 139.0 64.0 80.0 2.0 0 0 1 0 1 0 1 0
34 20.0 20.0 65.0 30.0 5.0 0 0 1 0 1 0 1 0
35 31.0 31.0 65.0 75.0 3.0 0 0 1 0 1 0 1 0
36 52.0 52.0 55.0 70.0 2.0 0 0 1 0 1 0 1 0
37 287.0 287.0 66.0 60.0 25.0 0 0 1 0 0 1 1 0
38 18.0 18.0 60.0 30.0 4.0 0 0 1 0 1 0 1 0
39 51.0 51.0 67.0 60.0 1.0 0 0 1 0 1 0 1 0
40 122.0 122.0 53.0 80.0 28.0 0 0 1 0 1 0 1 0
41 27.0 27.0 62.0 60.0 8.0 0 0 1 0 1 0 1 0
42 54.0 54.0 67.0 70.0 1.0 0 0 1 0 1 0 1 0
43 7.0 7.0 72.0 50.0 7.0 0 0 1 0 1 0 1 0
44 63.0 63.0 48.0 50.0 11.0 0 0 1 0 1 0 1 0
45 392.0 392.0 68.0 40.0 4.0 0 0 1 0 1 0 1 0
46 10.0 10.0 67.0 40.0 23.0 0 0 1 0 0 1 1 0
47 8.0 8.0 61.0 20.0 19.0 1 0 0 0 0 1 1 0
48 92.0 92.0 60.0 70.0 10.0 1 0 0 0 1 0 1 0
49 35.0 35.0 62.0 40.0 6.0 1 0 0 0 1 0 1 0
50 117.0 117.0 38.0 80.0 2.0 1 0 0 0 1 0 1 0
51 132.0 132.0 50.0 80.0 5.0 1 0 0 0 1 0 1 0
52 12.0 12.0 63.0 50.0 4.0 1 0 0 0 0 1 1 0
53 162.0 162.0 64.0 80.0 5.0 1 0 0 0 1 0 1 0
54 3.0 3.0 43.0 30.0 3.0 1 0 0 0 1 0 1 0
55 95.0 95.0 34.0 80.0 4.0 1 0 0 0 1 0 1 0
56 177.0 177.0 66.0 50.0 16.0 0 1 0 0 0 1 1 0
57 162.0 162.0 62.0 80.0 5.0 0 1 0 0 1 0 1 0
58 216.0 216.0 52.0 50.0 15.0 0 1 0 0 1 0 1 0
59 553.0 553.0 47.0 70.0 2.0 0 1 0 0 1 0 1 0
60 278.0 278.0 63.0 60.0 12.0 0 1 0 0 1 0 1 0
61 12.0 12.0 68.0 40.0 12.0 0 1 0 0 0 1 1 0
62 260.0 260.0 45.0 80.0 5.0 0 1 0 0 1 0 1 0
63 200.0 200.0 41.0 80.0 12.0 0 1 0 0 0 1 1 0
64 156.0 156.0 66.0 70.0 2.0 0 1 0 0 1 0 1 0
65 182.0 inf 62.0 90.0 2.0 0 1 0 0 1 0 1 0
66 143.0 143.0 60.0 90.0 8.0 0 1 0 0 1 0 1 0
67 105.0 105.0 66.0 80.0 11.0 0 1 0 0 1 0 1 0
68 103.0 103.0 38.0 80.0 5.0 0 1 0 0 1 0 1 0
69 250.0 250.0 53.0 70.0 8.0 0 1 0 0 0 1 1 0
70 100.0 100.0 37.0 60.0 13.0 0 1 0 0 0 1 1 0
71 999.0 999.0 54.0 90.0 12.0 0 0 0 1 0 1 0 1
72 112.0 112.0 60.0 80.0 6.0 0 0 0 1 1 0 0 1
73 87.0 inf 48.0 80.0 3.0 0 0 0 1 1 0 0 1
74 231.0 inf 52.0 50.0 8.0 0 0 0 1 0 1 0 1
75 242.0 242.0 70.0 50.0 1.0 0 0 0 1 1 0 0 1
76 991.0 991.0 50.0 70.0 7.0 0 0 0 1 0 1 0 1
77 111.0 111.0 62.0 70.0 3.0 0 0 0 1 1 0 0 1
78 1.0 1.0 65.0 20.0 21.0 0 0 0 1 0 1 0 1
79 587.0 587.0 58.0 60.0 3.0 0 0 0 1 1 0 0 1
80 389.0 389.0 62.0 90.0 2.0 0 0 0 1 1 0 0 1
81 33.0 33.0 64.0 30.0 6.0 0 0 0 1 1 0 0 1
82 25.0 25.0 63.0 20.0 36.0 0 0 0 1 1 0 0 1
83 357.0 357.0 58.0 70.0 13.0 0 0 0 1 1 0 0 1
84 467.0 467.0 64.0 90.0 2.0 0 0 0 1 1 0 0 1
85 201.0 201.0 52.0 80.0 28.0 0 0 0 1 0 1 0 1
86 1.0 1.0 35.0 50.0 7.0 0 0 0 1 1 0 0 1
87 30.0 30.0 63.0 70.0 11.0 0 0 0 1 1 0 0 1
88 44.0 44.0 70.0 60.0 13.0 0 0 0 1 0 1 0 1
89 283.0 283.0 51.0 90.0 2.0 0 0 0 1 1 0 0 1
90 15.0 15.0 40.0 50.0 13.0 0 0 0 1 0 1 0 1
91 25.0 25.0 69.0 30.0 2.0 0 0 1 0 1 0 0 1
92 103.0 inf 36.0 70.0 22.0 0 0 1 0 0 1 0 1
93 21.0 21.0 71.0 20.0 4.0 0 0 1 0 1 0 0 1
94 13.0 13.0 62.0 30.0 2.0 0 0 1 0 1 0 0 1
95 87.0 87.0 60.0 60.0 2.0 0 0 1 0 1 0 0 1
96 2.0 2.0 44.0 40.0 36.0 0 0 1 0 0 1 0 1
97 20.0 20.0 54.0 30.0 9.0 0 0 1 0 0 1 0 1
98 7.0 7.0 66.0 20.0 11.0 0 0 1 0 1 0 0 1
99 24.0 24.0 49.0 60.0 8.0 0 0 1 0 1 0 0 1
100 99.0 99.0 72.0 70.0 3.0 0 0 1 0 1 0 0 1
101 8.0 8.0 68.0 80.0 2.0 0 0 1 0 1 0 0 1
102 99.0 99.0 62.0 85.0 4.0 0 0 1 0 1 0 0 1
103 61.0 61.0 71.0 70.0 2.0 0 0 1 0 1 0 0 1
104 25.0 25.0 70.0 70.0 2.0 0 0 1 0 1 0 0 1
105 95.0 95.0 61.0 70.0 1.0 0 0 1 0 1 0 0 1
106 80.0 80.0 71.0 50.0 17.0 0 0 1 0 1 0 0 1
107 51.0 51.0 59.0 30.0 87.0 0 0 1 0 0 1 0 1
108 29.0 29.0 67.0 40.0 8.0 0 0 1 0 1 0 0 1
109 24.0 24.0 60.0 40.0 2.0 1 0 0 0 1 0 0 1
110 18.0 18.0 69.0 40.0 5.0 1 0 0 0 0 1 0 1
111 83.0 inf 57.0 99.0 3.0 1 0 0 0 1 0 0 1
112 31.0 31.0 39.0 80.0 3.0 1 0 0 0 1 0 0 1
113 51.0 51.0 62.0 60.0 5.0 1 0 0 0 1 0 0 1
114 90.0 90.0 50.0 60.0 22.0 1 0 0 0 0 1 0 1
115 52.0 52.0 43.0 60.0 3.0 1 0 0 0 1 0 0 1
116 73.0 73.0 70.0 60.0 3.0 1 0 0 0 1 0 0 1
117 8.0 8.0 66.0 50.0 5.0 1 0 0 0 1 0 0 1
118 36.0 36.0 61.0 70.0 8.0 1 0 0 0 1 0 0 1
119 48.0 48.0 81.0 10.0 4.0 1 0 0 0 1 0 0 1
120 7.0 7.0 58.0 40.0 4.0 1 0 0 0 1 0 0 1
121 140.0 140.0 63.0 70.0 3.0 1 0 0 0 1 0 0 1
122 186.0 186.0 60.0 90.0 3.0 1 0 0 0 1 0 0 1
123 84.0 84.0 62.0 80.0 4.0 1 0 0 0 0 1 0 1
124 19.0 19.0 42.0 50.0 10.0 1 0 0 0 1 0 0 1
125 45.0 45.0 69.0 40.0 3.0 1 0 0 0 1 0 0 1
126 80.0 80.0 63.0 40.0 4.0 1 0 0 0 1 0 0 1
127 52.0 52.0 45.0 60.0 4.0 0 1 0 0 1 0 0 1
128 164.0 164.0 68.0 70.0 15.0 0 1 0 0 0 1 0 1
129 19.0 19.0 39.0 30.0 4.0 0 1 0 0 0 1 0 1
130 53.0 53.0 66.0 60.0 12.0 0 1 0 0 1 0 0 1
131 15.0 15.0 63.0 30.0 5.0 0 1 0 0 1 0 0 1
132 43.0 43.0 49.0 60.0 11.0 0 1 0 0 0 1 0 1
133 340.0 340.0 64.0 80.0 10.0 0 1 0 0 0 1 0 1
134 133.0 133.0 65.0 75.0 1.0 0 1 0 0 1 0 0 1
135 111.0 111.0 64.0 60.0 5.0 0 1 0 0 1 0 0 1
136 231.0 231.0 67.0 70.0 18.0 0 1 0 0 0 1 0 1
137 378.0 378.0 65.0 80.0 4.0 0 1 0 0 1 0 0 1
138 49.0 49.0 37.0 30.0 3.0 0 1 0 0 1 0 0 1

View File

@ -0,0 +1,135 @@
###############################################
Survival Analysis with Accelerated Failure Time
###############################################
.. contents::
:local:
:backlinks: none
**************************
What is survival analysis?
**************************
**Survival analysis (regression)** models **time to an event of interest**. Survival analysis is a special kind of regression and differs from the conventional regression task as follows:
* The label is always positive, since you cannot wait a negative amount of time until the event occurs.
* The label may not be fully known, or **censored**, because "it takes time to measure time."
The second bullet point is crucial and we should dwell on it more. As you may have guessed from the name, one of the earliest applications of survival analysis is to model mortality of a given population. Let's take `NCCTG Lung Cancer Dataset <https://stat.ethz.ch/R-manual/R-devel/library/survival/html/lung.html>`_ as an example. The first 8 columns represent features and the last column, Time to death, represents the label.
==== === === ======= ======== ========= ======== ======= ========================
Inst Age Sex ph.ecog ph.karno pat.karno meal.cal wt.loss **Time to death (days)**
==== === === ======= ======== ========= ======== ======= ========================
3 74 1 1 90 100 1175 N/A 306
3 68 1 0 90 90 1225 15 455
3 56 1 0 90 90 N/A 15 :math:`[1010, +\infty)`
5 57 1 1 90 60 1150 11 210
1 60 1 0 100 90 N/A 0 883
12 74 1 1 50 80 513 0 :math:`[1022, +\infty)`
7 68 2 2 70 60 384 10 310
==== === === ======= ======== ========= ======== ======= ========================
Take a close look at the label for the third patient. **His label is a range, not a single number.** The third patient's label is said to be **censored**, because for some reason the experimenters could not get a complete measurement for that label. One possible scenario: the patient survived the first 1010 days and walked out of the clinic on the 1011th day, so his death was not directly observed. Another possibility: The experiment was cut short (since you cannot run it forever) before his death could be observed. In any case, his label is :math:`[1010, +\infty)`, meaning his time to death can be any number that's higher than 1010, e.g. 2000, 3000, or 10000.
There are four kinds of censoring:
* **Uncensored**: the label is not censored and given as a single number.
* **Right-censored**: the label is of form :math:`[a, +\infty)`, where :math:`a` is the lower bound.
* **Left-censored**: the label is of form :math:`(-\infty, b]`, where :math:`b` is the upper bound.
* **Interval-censored**: the label is of form :math:`[a, b]`, where :math:`a` and :math:`b` are the lower and upper bounds, respectively.
Right-censoring is the most commonly used.
******************************
Accelerated Failure Time model
******************************
**Accelerated Failure Time (AFT)** model is one of the most commonly used models in survival analysis. The model is of the following form:
.. math::
\ln{Y} = \langle \mathbf{w}, \mathbf{x} \rangle + \sigma Z
where
* :math:`\mathbf{x}` is a vector in :math:`\mathbb{R}^d` representing the features.
* :math:`\mathbf{w}` is a vector consisting of :math:`d` coefficients, each corresponding to a feature.
* :math:`\langle \cdot, \cdot \rangle` is the usual dot product in :math:`\mathbb{R}^d`.
* :math:`\ln{(\cdot)}` is the natural logarithm.
* :math:`Y` and :math:`Z` are random variables.
- :math:`Y` is the output label.
- :math:`Z` is a random variable of a known probability distribution. Common choices are the normal distribution, the logistic distribution, and the extreme distribution. Intuitively, :math:`Z` represents the "noise" that pulls the prediction :math:`\langle \mathbf{w}, \mathbf{x} \rangle` away from the true log label :math:`\ln{Y}`.
* :math:`\sigma` is a parameter that scales the size of :math:`Z`.
Note that this model is a generalized form of a linear regression model :math:`Y = \langle \mathbf{w}, \mathbf{x} \rangle`. In order to make AFT work with gradient boosting, we revise the model as follows:
.. math::
\ln{Y} = \mathcal{T}(\mathbf{x}) + \sigma Z
where :math:`\mathcal{T}(\mathbf{x})` represents the output from a decision tree ensemble, given input :math:`\mathbf{x}`. Since :math:`Z` is a random variable, we have a likelihood defined for the expression :math:`\ln{Y} = \mathcal{T}(\mathbf{x}) + \sigma Z`. So the goal for XGBoost is to maximize the (log) likelihood by fitting a good tree ensemble :math:`\mathbf{x}`.
**********
How to use
**********
The first step is to express the labels in the form of a range, so that **every data point has two numbers associated with it, namely the lower and upper bounds for the label.** For uncensored labels, use a degenerate interval of form :math:`[a, a]`.
.. |tick| unicode:: U+2714
.. |cross| unicode:: U+2718
================= ==================== =================== ===================
Censoring type Interval form Lower bound finite? Upper bound finite?
================= ==================== =================== ===================
Uncensored :math:`[a, a]` |tick| |tick|
Right-censored :math:`[a, +\infty)` |tick| |cross|
Left-censored :math:`(-\infty, b]` |cross| |tick|
Interval-censored :math:`[a, b]` |tick| |tick|
================= ==================== =================== ===================
Collect the lower bound numbers in one array (let's call it ``y_lower_bound``) and the upper bound number in another array (call it ``y_upper_bound``). The ranged labels are associated with a data matrix object via calls to :meth:`xgboost.DMatrix.set_float_info`:
.. code-block:: python
import numpy as np
import xgboost as xgb
# 4-by-2 Data matrix
X = np.array([[1, -1], [-1, 1], [0, 1], [1, 0]])
dtrain = xgb.DMatrix(X)
# Associate ranged labels with the data matrix.
# This example shows each kind of censored labels.
# uncensored right left interval
y_lower_bound = np.array([ 2.0, 3.0, -np.inf, 4.0])
y_upper_bound = np.array([ 2.0, +np.inf, 4.0, 5.0])
dtrain.set_float_info('label_lower_bound', y_lower_bound)
dtrain.set_float_info('label_upper_bound', y_upper_bound)
Now we are ready to invoke the training API:
.. code-block:: python
params = {'objective': 'survival:aft',
'eval_metric': 'aft-nloglik',
'aft_loss_distribution': 'normal',
'aft_loss_distribution_scale': 1.20,
'tree_method': 'hist', 'learning_rate': 0.05, 'max_depth': 2}
bst = xgb.train(params, dtrain, num_boost_round=5,
evals=[(dtrain, 'train'), (dvalid, 'valid')])
We set ``objective`` parameter to ``survival:aft`` and ``eval_metric`` to ``aft-nloglik``, so that the log likelihood for the AFT model would be maximized. (XGBoost will actually minimize the negative log likelihood, hence the name ``aft-nloglik``.)
The parameter ``aft_loss_distribution`` corresponds to the distribution of the :math:`Z` term in the AFT model, and ``aft_loss_distribution_scale`` corresponds to the scaling factor :math:`\sigma`.
Currently, you can choose from three probability distributions for ``aft_loss_distribution``:
========================= ===========================================
``aft_loss_distribution`` Probabilty Density Function (PDF)
========================= ===========================================
``normal`` :math:`\dfrac{\exp{(-z^2/2)}}{\sqrt{2\pi}}`
``logistic`` :math:`\dfrac{e^z}{(1+e^z)^2}`
``extreme`` :math:`e^z e^{-\exp{z}}`
========================= ===========================================
Note that it is not yet possible to set the ranged label using the scikit-learn interface (e.g. :class:`xgboost.XGBRegressor`). For now, you should use :class:`xgboost.train` with :class:`xgboost.DMatrix`.

View File

@ -39,7 +39,7 @@ enum class DataType : uint8_t {
class MetaInfo {
public:
/*! \brief number of data fields in MetaInfo */
static constexpr uint64_t kNumField = 7;
static constexpr uint64_t kNumField = 9;
/*! \brief number of rows in the data */
uint64_t num_row_{0};
@ -62,6 +62,14 @@ class MetaInfo {
* can be used to specify initial prediction to boost from.
*/
HostDeviceVector<bst_float> base_margin_;
/*!
* \brief lower bound of the label, to be used for survival analysis (censored regression)
*/
HostDeviceVector<bst_float> labels_lower_bound_;
/*!
* \brief upper bound of the label, to be used for survival analysis (censored regression)
*/
HostDeviceVector<bst_float> labels_upper_bound_;
/*! \brief default constructor */
MetaInfo() = default;

View File

@ -8,6 +8,7 @@
#define XGBOOST_METRIC_H_
#include <dmlc/registry.h>
#include <xgboost/model.h>
#include <xgboost/generic_parameters.h>
#include <xgboost/data.h>
#include <xgboost/base.h>
@ -23,7 +24,7 @@ namespace xgboost {
* \brief interface of evaluation metric used to evaluate model performance.
* This has nothing to do with training, but merely act as evaluation purpose.
*/
class Metric {
class Metric : public Configurable {
protected:
GenericParameter const* tparam_;
@ -34,6 +35,21 @@ class Metric {
*/
virtual void Configure(
const std::vector<std::pair<std::string, std::string> >& args) {}
/*!
* \brief Load configuration from JSON object
* By default, metric has no internal configuration;
* override this function to maintain internal configuration
* \param in JSON object containing the configuration
*/
virtual void LoadConfig(Json const& in) {}
/*!
* \brief Save configuration to JSON object
* By default, metric has no internal configuration;
* override this function to maintain internal configuration
* \param out pointer to output JSON object
*/
virtual void SaveConfig(Json* out) const {}
/*!
* \brief evaluate a specific metric
* \param preds prediction

View File

@ -265,6 +265,10 @@ XGB_DLL int XGDMatrixGetFloatInfo(const DMatrixHandle handle,
vec = &info.weights_.HostVector();
} else if (!std::strcmp(field, "base_margin")) {
vec = &info.base_margin_.HostVector();
} else if (!std::strcmp(field, "label_lower_bound")) {
vec = &info.labels_lower_bound_.HostVector();
} else if (!std::strcmp(field, "label_upper_bound")) {
vec = &info.labels_upper_bound_.HostVector();
} else {
LOG(FATAL) << "Unknown float field name " << field;
}
@ -284,8 +288,7 @@ XGB_DLL int XGDMatrixGetUIntInfo(const DMatrixHandle handle,
if (!std::strcmp(field, "group_ptr")) {
vec = &info.group_ptr_;
} else {
LOG(FATAL) << "Unknown comp uint field name " << field
<< " with comparison " << std::strcmp(field, "group_ptr");
LOG(FATAL) << "Unknown uint field name " << field;
}
*out_len = static_cast<xgboost::bst_ulong>(vec->size());
*out_dptr = dmlc::BeginPtr(*vec);

View File

@ -0,0 +1,107 @@
/*!
* Copyright 2020 by Contributors
* \file probability_distribution.cc
* \brief Implementation of a few useful probability distributions
* \author Avinash Barnwal and Hyunsu Cho
*/
#include <xgboost/logging.h>
#include <cmath>
#include "probability_distribution.h"
namespace xgboost {
namespace common {
ProbabilityDistribution* ProbabilityDistribution::Create(ProbabilityDistributionType dist) {
switch (dist) {
case ProbabilityDistributionType::kNormal:
return new NormalDist;
case ProbabilityDistributionType::kLogistic:
return new LogisticDist;
case ProbabilityDistributionType::kExtreme:
return new ExtremeDist;
default:
LOG(FATAL) << "Unknown distribution";
}
return nullptr;
}
double NormalDist::PDF(double z) {
const double pdf = std::exp(-z * z / 2) / std::sqrt(2 * probability_constant::kPI);
return pdf;
}
double NormalDist::CDF(double z) {
const double cdf = 0.5 * (1 + std::erf(z / std::sqrt(2)));
return cdf;
}
double NormalDist::GradPDF(double z) {
const double pdf = this->PDF(z);
const double grad = -1 * z * pdf;
return grad;
}
double NormalDist::HessPDF(double z) {
const double pdf = this->PDF(z);
const double hess = (z * z - 1) * pdf;
return hess;
}
double LogisticDist::PDF(double z) {
const double w = std::exp(z);
const double sqrt_denominator = 1 + w;
const double pdf
= (std::isinf(w) || std::isinf(w * w)) ? 0.0 : (w / (sqrt_denominator * sqrt_denominator));
return pdf;
}
double LogisticDist::CDF(double z) {
const double w = std::exp(z);
const double cdf = std::isinf(w) ? 1.0 : (w / (1 + w));
return cdf;
}
double LogisticDist::GradPDF(double z) {
const double pdf = this->PDF(z);
const double w = std::exp(z);
const double grad = std::isinf(w) ? 0.0 : pdf * (1 - w) / (1 + w);
return grad;
}
double LogisticDist::HessPDF(double z) {
const double pdf = this->PDF(z);
const double w = std::exp(z);
const double hess
= (std::isinf(w) || std::isinf(w * w)) ? 0.0 : pdf * (w * w - 4 * w + 1) / ((1 + w) * (1 + w));
return hess;
}
double ExtremeDist::PDF(double z) {
const double w = std::exp(z);
const double pdf = std::isinf(w) ? 0.0 : (w * std::exp(-w));
return pdf;
}
double ExtremeDist::CDF(double z) {
const double w = std::exp(z);
const double cdf = 1 - std::exp(-w);
return cdf;
}
double ExtremeDist::GradPDF(double z) {
const double pdf = this->PDF(z);
const double w = std::exp(z);
const double grad = std::isinf(w) ? 0.0 : ((1 - w) * pdf);
return grad;
}
double ExtremeDist::HessPDF(double z) {
const double pdf = this->PDF(z);
const double w = std::exp(z);
const double hess = (std::isinf(w) || std::isinf(w * w)) ? 0.0 : ((w * w - 3 * w + 1) * pdf);
return hess;
}
} // namespace common
} // namespace xgboost

View File

@ -0,0 +1,94 @@
/*!
* Copyright 2020 by Contributors
* \file probability_distribution.h
* \brief Implementation of a few useful probability distributions
* \author Avinash Barnwal and Hyunsu Cho
*/
#ifndef XGBOOST_COMMON_PROBABILITY_DISTRIBUTION_H_
#define XGBOOST_COMMON_PROBABILITY_DISTRIBUTION_H_
namespace xgboost {
namespace common {
namespace probability_constant {
/*! \brief Constant PI */
const double kPI = 3.14159265358979323846;
/*! \brief The Euler-Mascheroni_constant */
const double kEulerMascheroni = 0.57721566490153286060651209008240243104215933593992;
} // namespace probability_constant
/*! \brief Enum encoding possible choices of probability distribution */
enum class ProbabilityDistributionType : int {
kNormal = 0, kLogistic = 1, kExtreme = 2
};
/*! \brief Interface for a probability distribution */
class ProbabilityDistribution {
public:
/*!
* \brief Evaluate Probability Density Function (PDF) at a particular point
* \param z point at which to evaluate PDF
* \return Value of PDF evaluated
*/
virtual double PDF(double z) = 0;
/*!
* \brief Evaluate Cumulative Distribution Function (CDF) at a particular point
* \param z point at which to evaluate CDF
* \return Value of CDF evaluated
*/
virtual double CDF(double z) = 0;
/*!
* \brief Evaluate first derivative of PDF at a particular point
* \param z point at which to evaluate first derivative of PDF
* \return Value of first derivative of PDF evaluated
*/
virtual double GradPDF(double z) = 0;
/*!
* \brief Evaluate second derivative of PDF at a particular point
* \param z point at which to evaluate second derivative of PDF
* \return Value of second derivative of PDF evaluated
*/
virtual double HessPDF(double z) = 0;
/*!
* \brief Factory function to instantiate a new probability distribution object
* \param dist kind of probability distribution
* \return Reference to the newly created probability distribution object
*/
static ProbabilityDistribution* Create(ProbabilityDistributionType dist);
};
/*! \brief The (standard) normal distribution */
class NormalDist : public ProbabilityDistribution {
public:
double PDF(double z) override;
double CDF(double z) override;
double GradPDF(double z) override;
double HessPDF(double z) override;
};
/*! \brief The (standard) logistic distribution */
class LogisticDist : public ProbabilityDistribution {
public:
double PDF(double z) override;
double CDF(double z) override;
double GradPDF(double z) override;
double HessPDF(double z) override;
};
/*! \brief The extreme distribution, also known as the Gumbel (minimum) distribution */
class ExtremeDist : public ProbabilityDistribution {
public:
double PDF(double z) override;
double CDF(double z) override;
double GradPDF(double z) override;
double HessPDF(double z) override;
};
} // namespace common
} // namespace xgboost
#endif // XGBOOST_COMMON_PROBABILITY_DISTRIBUTION_H_

146
src/common/survival_util.cc Normal file
View File

@ -0,0 +1,146 @@
/*!
* Copyright 2019 by Contributors
* \file survival_util.cc
* \brief Utility functions, useful for implementing objective and metric functions for survival
* analysis
* \author Avinash Barnwal, Hyunsu Cho and Toby Hocking
*/
#include <dmlc/registry.h>
#include <algorithm>
#include <cmath>
#include "survival_util.h"
/*
- Formulas are motivated from document -
http://members.cbio.mines-paristech.fr/~thocking/survival.pdf
- Detailed Derivation of Loss/Gradient/Hessian -
https://github.com/avinashbarnwal/GSOC-2019/blob/master/doc/Accelerated_Failure_Time.pdf
*/
namespace xgboost {
namespace common {
DMLC_REGISTER_PARAMETER(AFTParam);
double AFTLoss::Loss(double y_lower, double y_upper, double y_pred, double sigma) {
const double log_y_lower = std::log(y_lower);
const double log_y_upper = std::log(y_upper);
const double eps = 1e-12;
double cost;
if (y_lower == y_upper) { // uncensored
const double z = (log_y_lower - y_pred) / sigma;
const double pdf = dist_->PDF(z);
// Regularize the denominator with eps, to avoid INF or NAN
cost = -std::log(std::max(pdf / (sigma * y_lower), eps));
} else { // censored; now check what type of censorship we have
double z_u, z_l, cdf_u, cdf_l;
if (std::isinf(y_upper)) { // right-censored
cdf_u = 1;
} else { // left-censored or interval-censored
z_u = (log_y_upper - y_pred) / sigma;
cdf_u = dist_->CDF(z_u);
}
if (std::isinf(y_lower)) { // left-censored
cdf_l = 0;
} else { // right-censored or interval-censored
z_l = (log_y_lower - y_pred) / sigma;
cdf_l = dist_->CDF(z_l);
}
// Regularize the denominator with eps, to avoid INF or NAN
cost = -std::log(std::max(cdf_u - cdf_l, eps));
}
return cost;
}
double AFTLoss::Gradient(double y_lower, double y_upper, double y_pred, double sigma) {
const double log_y_lower = std::log(y_lower);
const double log_y_upper = std::log(y_upper);
double gradient;
const double eps = 1e-12;
if (y_lower == y_upper) { // uncensored
const double z = (log_y_lower - y_pred) / sigma;
const double pdf = dist_->PDF(z);
const double grad_pdf = dist_->GradPDF(z);
// Regularize the denominator with eps, so that gradient doesn't get too big
gradient = grad_pdf / (sigma * std::max(pdf, eps));
} else { // censored; now check what type of censorship we have
double z_u, z_l, pdf_u, pdf_l, cdf_u, cdf_l;
if (std::isinf(y_upper)) { // right-censored
pdf_u = 0;
cdf_u = 1;
} else { // interval-censored or left-censored
z_u = (log_y_upper - y_pred) / sigma;
pdf_u = dist_->PDF(z_u);
cdf_u = dist_->CDF(z_u);
}
if (std::isinf(y_lower)) { // left-censored
pdf_l = 0;
cdf_l = 0;
} else { // interval-censored or right-censored
z_l = (log_y_lower - y_pred) / sigma;
pdf_l = dist_->PDF(z_l);
cdf_l = dist_->CDF(z_l);
}
// Regularize the denominator with eps, so that gradient doesn't get too big
gradient = (pdf_u - pdf_l) / (sigma * std::max(cdf_u - cdf_l, eps));
}
return gradient;
}
double AFTLoss::Hessian(double y_lower, double y_upper, double y_pred, double sigma) {
const double log_y_lower = std::log(y_lower);
const double log_y_upper = std::log(y_upper);
const double eps = 1e-12;
double hessian;
if (y_lower == y_upper) { // uncensored
const double z = (log_y_lower - y_pred) / sigma;
const double pdf = dist_->PDF(z);
const double grad_pdf = dist_->GradPDF(z);
const double hess_pdf = dist_->HessPDF(z);
// Regularize the denominator with eps, so that gradient doesn't get too big
hessian = -(pdf * hess_pdf - std::pow(grad_pdf, 2))
/ (std::pow(sigma, 2) * std::pow(std::max(pdf, eps), 2));
} else { // censored; now check what type of censorship we have
double z_u, z_l, grad_pdf_u, grad_pdf_l, pdf_u, pdf_l, cdf_u, cdf_l;
if (std::isinf(y_upper)) { // right-censored
pdf_u = 0;
cdf_u = 1;
grad_pdf_u = 0;
} else { // interval-censored or left-censored
z_u = (log_y_upper - y_pred) / sigma;
pdf_u = dist_->PDF(z_u);
cdf_u = dist_->CDF(z_u);
grad_pdf_u = dist_->GradPDF(z_u);
}
if (std::isinf(y_lower)) { // left-censored
pdf_l = 0;
cdf_l = 0;
grad_pdf_l = 0;
} else { // interval-censored or right-censored
z_l = (log_y_lower - y_pred) / sigma;
pdf_l = dist_->PDF(z_l);
cdf_l = dist_->CDF(z_l);
grad_pdf_l = dist_->GradPDF(z_l);
}
const double cdf_diff = cdf_u - cdf_l;
const double pdf_diff = pdf_u - pdf_l;
const double grad_diff = grad_pdf_u - grad_pdf_l;
// Regularize the denominator with eps, so that gradient doesn't get too big
const double cdf_diff_thresh = std::max(cdf_diff, eps);
const double numerator = -(cdf_diff * grad_diff - pdf_diff * pdf_diff);
const double sqrt_denominator = sigma * cdf_diff_thresh;
const double denominator = sqrt_denominator * sqrt_denominator;
hessian = numerator / denominator;
}
return hessian;
}
} // namespace common
} // namespace xgboost

View File

@ -0,0 +1,85 @@
/*!
* Copyright 2019 by Contributors
* \file survival_util.h
* \brief Utility functions, useful for implementing objective and metric functions for survival
* analysis
* \author Avinash Barnwal, Hyunsu Cho and Toby Hocking
*/
#ifndef XGBOOST_COMMON_SURVIVAL_UTIL_H_
#define XGBOOST_COMMON_SURVIVAL_UTIL_H_
#include <xgboost/parameter.h>
#include <memory>
#include "probability_distribution.h"
DECLARE_FIELD_ENUM_CLASS(xgboost::common::ProbabilityDistributionType);
namespace xgboost {
namespace common {
/*! \brief Parameter structure for AFT loss and metric */
struct AFTParam : public XGBoostParameter<AFTParam> {
/*! \brief Choice of probability distribution for the noise term in AFT */
ProbabilityDistributionType aft_loss_distribution;
/*! \brief Scaling factor to be applied to the distribution */
float aft_loss_distribution_scale;
DMLC_DECLARE_PARAMETER(AFTParam) {
DMLC_DECLARE_FIELD(aft_loss_distribution)
.set_default(ProbabilityDistributionType::kNormal)
.add_enum("normal", ProbabilityDistributionType::kNormal)
.add_enum("logistic", ProbabilityDistributionType::kLogistic)
.add_enum("extreme", ProbabilityDistributionType::kExtreme)
.describe("Choice of distribution for the noise term in "
"Accelerated Failure Time model");
DMLC_DECLARE_FIELD(aft_loss_distribution_scale)
.set_default(1.0f)
.describe("Scaling factor used to scale the distribution in "
"Accelerated Failure Time model");
}
};
/*! \brief The AFT loss function */
class AFTLoss {
private:
std::unique_ptr<ProbabilityDistribution> dist_;
public:
/*!
* \brief Constructor for AFT loss function
* \param dist Choice of probability distribution for the noise term in AFT
*/
explicit AFTLoss(ProbabilityDistributionType dist) {
dist_.reset(ProbabilityDistribution::Create(dist));
}
public:
/*!
* \brief Compute the AFT loss
* \param y_lower Lower bound for the true label
* \param y_upper Upper bound for the true label
* \param y_pred Predicted label
* \param sigma Scaling factor to be applied to the distribution of the noise term
*/
double Loss(double y_lower, double y_upper, double y_pred, double sigma);
/*!
* \brief Compute the gradient of the AFT loss
* \param y_lower Lower bound for the true label
* \param y_upper Upper bound for the true label
* \param y_pred Predicted label
* \param sigma Scaling factor to be applied to the distribution of the noise term
*/
double Gradient(double y_lower, double y_upper, double y_pred, double sigma);
/*!
* \brief Compute the hessian of the AFT loss
* \param y_lower Lower bound for the true label
* \param y_upper Upper bound for the true label
* \param y_pred Predicted label
* \param sigma Scaling factor to be applied to the distribution of the noise term
*/
double Hessian(double y_lower, double y_upper, double y_pred, double sigma);
};
} // namespace common
} // namespace xgboost
#endif // XGBOOST_COMMON_SURVIVAL_UTIL_H_

View File

@ -134,7 +134,7 @@ void MetaInfo::Clear() {
* Binary serialization format for MetaInfo:
*
* | name | type | is_scalar | num_row | num_col | value |
* |-------------+----------+-----------+---------+---------+-----------------|
* |--------------------+----------+-----------+---------+---------+-------------------------|
* | num_row | kUInt64 | True | NA | NA | ${num_row_} |
* | num_col | kUInt64 | True | NA | NA | ${num_col_} |
* | num_nonzero | kUInt64 | True | NA | NA | ${num_nonzero_} |
@ -142,6 +142,8 @@ void MetaInfo::Clear() {
* | group_ptr | kUInt32 | False | ${size} | 1 | ${group_ptr_} |
* | weights | kFloat32 | False | ${size} | 1 | ${weights_} |
* | base_margin | kFloat32 | False | ${size} | 1 | ${base_margin_} |
* | labels_lower_bound | kFloat32 | False | ${size} | 1 | ${labels_lower_bound__} |
* | labels_upper_bound | kFloat32 | False | ${size} | 1 | ${labels_upper_bound__} |
*
* Note that the scalar fields (is_scalar=True) will have num_row and num_col missing.
* Also notice the difference between the saved name and the name used in `SetInfo':
@ -164,6 +166,10 @@ void MetaInfo::SaveBinary(dmlc::Stream *fo) const {
{weights_.Size(), 1}, weights_); ++field_cnt;
SaveVectorField(fo, u8"base_margin", DataType::kFloat32,
{base_margin_.Size(), 1}, base_margin_); ++field_cnt;
SaveVectorField(fo, u8"labels_lower_bound", DataType::kFloat32,
{labels_lower_bound_.Size(), 1}, labels_lower_bound_); ++field_cnt;
SaveVectorField(fo, u8"labels_upper_bound", DataType::kFloat32,
{labels_upper_bound_.Size(), 1}, labels_upper_bound_); ++field_cnt;
CHECK_EQ(field_cnt, kNumField) << "Wrong number of fields";
}
@ -195,6 +201,8 @@ void MetaInfo::LoadBinary(dmlc::Stream *fi) {
LoadVectorField(fi, u8"group_ptr", DataType::kUInt32, &group_ptr_);
LoadVectorField(fi, u8"weights", DataType::kFloat32, &weights_);
LoadVectorField(fi, u8"base_margin", DataType::kFloat32, &base_margin_);
LoadVectorField(fi, u8"labels_lower_bound", DataType::kFloat32, &labels_lower_bound_);
LoadVectorField(fi, u8"labels_upper_bound", DataType::kFloat32, &labels_upper_bound_);
}
// try to load group information from file, if exists
@ -268,8 +276,18 @@ void MetaInfo::SetInfo(const char* key, const void* dptr, DataType dtype, size_t
for (size_t i = 1; i < group_ptr_.size(); ++i) {
group_ptr_[i] = group_ptr_[i - 1] + group_ptr_[i];
}
} else if (!std::strcmp(key, "label_lower_bound")) {
auto& labels = labels_lower_bound_.HostVector();
labels.resize(num);
DISPATCH_CONST_PTR(dtype, dptr, cast_dptr,
std::copy(cast_dptr, cast_dptr + num, labels.begin()));
} else if (!std::strcmp(key, "label_upper_bound")) {
auto& labels = labels_upper_bound_.HostVector();
labels.resize(num);
DISPATCH_CONST_PTR(dtype, dptr, cast_dptr,
std::copy(cast_dptr, cast_dptr + num, labels.begin()));
} else {
LOG(FATAL) << "Unknown metainfo: " << key;
LOG(FATAL) << "Unknown key for MetaInfo: " << key;
}
}

View File

@ -0,0 +1,106 @@
/*!
* Copyright 2019 by Contributors
* \file survival_metric.cc
* \brief Metrics for survival analysis
* \author Avinash Barnwal, Hyunsu Cho and Toby Hocking
*/
#include <rabit/rabit.h>
#include <xgboost/metric.h>
#include <xgboost/host_device_vector.h>
#include <dmlc/registry.h>
#include <cmath>
#include <memory>
#include <vector>
#include <limits>
#include "xgboost/json.h"
#include "../common/math.h"
#include "../common/survival_util.h"
using AFTParam = xgboost::common::AFTParam;
using AFTLoss = xgboost::common::AFTLoss;
namespace xgboost {
namespace metric {
// tag the this file, used by force static link later.
DMLC_REGISTRY_FILE_TAG(survival_metric);
/*! \brief Negative log likelihood of Accelerated Failure Time model */
struct EvalAFT : public Metric {
public:
explicit EvalAFT(const char* param) {}
void Configure(const Args& args) override {
param_.UpdateAllowUnknown(args);
loss_.reset(new AFTLoss(param_.aft_loss_distribution));
}
void SaveConfig(Json* p_out) const override {
auto& out = *p_out;
out["name"] = String(this->Name());
out["aft_loss_param"] = toJson(param_);
}
void LoadConfig(Json const& in) override {
fromJson(in["aft_loss_param"], &param_);
}
bst_float Eval(const HostDeviceVector<bst_float> &preds,
const MetaInfo &info,
bool distributed) override {
CHECK_NE(info.labels_lower_bound_.Size(), 0U)
<< "y_lower cannot be empty";
CHECK_NE(info.labels_upper_bound_.Size(), 0U)
<< "y_higher cannot be empty";
CHECK_EQ(preds.Size(), info.labels_lower_bound_.Size());
CHECK_EQ(preds.Size(), info.labels_upper_bound_.Size());
/* Compute negative log likelihood for each data point and compute weighted average */
const auto& yhat = preds.HostVector();
const auto& y_lower = info.labels_lower_bound_.HostVector();
const auto& y_upper = info.labels_upper_bound_.HostVector();
const auto& weights = info.weights_.HostVector();
const bool is_null_weight = weights.empty();
const float aft_loss_distribution_scale = param_.aft_loss_distribution_scale;
CHECK_LE(yhat.size(), static_cast<size_t>(std::numeric_limits<omp_ulong>::max()))
<< "yhat is too big";
const omp_ulong nsize = static_cast<omp_ulong>(yhat.size());
double nloglik_sum = 0.0;
double weight_sum = 0.0;
#pragma omp parallel for default(none) \
firstprivate(nsize, is_null_weight, aft_loss_distribution_scale) \
shared(weights, y_lower, y_upper, yhat) reduction(+:nloglik_sum, weight_sum)
for (omp_ulong i = 0; i < nsize; ++i) {
// If weights are empty, data is unweighted so we use 1.0 everywhere
const double w = is_null_weight ? 1.0 : weights[i];
const double loss
= loss_->Loss(y_lower[i], y_upper[i], yhat[i], aft_loss_distribution_scale);
nloglik_sum += loss;
weight_sum += w;
}
double dat[2]{nloglik_sum, weight_sum};
if (distributed) {
rabit::Allreduce<rabit::op::Sum>(dat, 2);
}
return static_cast<bst_float>(dat[0] / dat[1]);
}
const char* Name() const override {
return "aft-nloglik";
}
private:
AFTParam param_;
std::unique_ptr<AFTLoss> loss_;
};
XGBOOST_REGISTER_METRIC(AFT, "aft-nloglik")
.describe("Negative log likelihood of Accelerated Failure Time model.")
.set_body([](const char* param) { return new EvalAFT(param); });
} // namespace metric
} // namespace xgboost

119
src/objective/aft_obj.cc Normal file
View File

@ -0,0 +1,119 @@
/*!
* Copyright 2015 by Contributors
* \file rank.cc
* \brief Definition of aft loss.
*/
#include <dmlc/omp.h>
#include <xgboost/logging.h>
#include <xgboost/objective.h>
#include <vector>
#include <limits>
#include <algorithm>
#include <memory>
#include <utility>
#include <cmath>
#include "xgboost/json.h"
#include "../common/math.h"
#include "../common/random.h"
#include "../common/survival_util.h"
using AFTParam = xgboost::common::AFTParam;
using AFTLoss = xgboost::common::AFTLoss;
namespace xgboost {
namespace obj {
DMLC_REGISTRY_FILE_TAG(aft_obj);
class AFTObj : public ObjFunction {
public:
void Configure(const std::vector<std::pair<std::string, std::string> >& args) override {
param_.UpdateAllowUnknown(args);
loss_.reset(new AFTLoss(param_.aft_loss_distribution));
}
void GetGradient(const HostDeviceVector<bst_float>& preds,
const MetaInfo& info,
int iter,
HostDeviceVector<GradientPair>* out_gpair) override {
/* Boilerplate */
CHECK_EQ(preds.Size(), info.labels_lower_bound_.Size());
CHECK_EQ(preds.Size(), info.labels_upper_bound_.Size());
const auto& yhat = preds.HostVector();
const auto& y_lower = info.labels_lower_bound_.HostVector();
const auto& y_upper = info.labels_upper_bound_.HostVector();
const auto& weights = info.weights_.HostVector();
const bool is_null_weight = weights.empty();
out_gpair->Resize(yhat.size());
std::vector<GradientPair>& gpair = out_gpair->HostVector();
CHECK_LE(yhat.size(), static_cast<size_t>(std::numeric_limits<omp_ulong>::max()))
<< "yhat is too big";
const omp_ulong nsize = static_cast<omp_ulong>(yhat.size());
const float aft_loss_distribution_scale = param_.aft_loss_distribution_scale;
#pragma omp parallel for default(none) \
firstprivate(nsize, is_null_weight, aft_loss_distribution_scale) \
shared(weights, y_lower, y_upper, yhat, gpair)
for (omp_ulong i = 0; i < nsize; ++i) {
// If weights are empty, data is unweighted so we use 1.0 everywhere
const double w = is_null_weight ? 1.0 : weights[i];
const double grad = loss_->Gradient(y_lower[i], y_upper[i],
yhat[i], aft_loss_distribution_scale);
const double hess = loss_->Hessian(y_lower[i], y_upper[i],
yhat[i], aft_loss_distribution_scale);
gpair[i] = GradientPair(grad * w, hess * w);
}
}
void PredTransform(HostDeviceVector<bst_float> *io_preds) override {
// Trees give us a prediction in log scale, so exponentiate
std::vector<bst_float> &preds = io_preds->HostVector();
const long ndata = static_cast<long>(preds.size()); // NOLINT(*)
#pragma omp parallel for default(none) firstprivate(ndata) shared(preds)
for (long j = 0; j < ndata; ++j) { // NOLINT(*)
preds[j] = std::exp(preds[j]);
}
}
void EvalTransform(HostDeviceVector<bst_float> *io_preds) override {
// do nothing here, since the AFT metric expects untransformed prediction score
}
bst_float ProbToMargin(bst_float base_score) const override {
return std::log(base_score);
}
const char* DefaultEvalMetric() const override {
return "aft-nloglik";
}
void SaveConfig(Json* p_out) const override {
auto& out = *p_out;
out["name"] = String("survival:aft");
out["aft_loss_param"] = toJson(param_);
}
void LoadConfig(Json const& in) override {
fromJson(in["aft_loss_param"], &param_);
loss_.reset(new AFTLoss(param_.aft_loss_distribution));
}
private:
AFTParam param_;
std::unique_ptr<AFTLoss> loss_;
};
// register the objective functions
XGBOOST_REGISTER_OBJECTIVE(AFTObj, "survival:aft")
.describe("AFT loss function")
.set_body([]() { return new AFTObj(); });
} // namespace obj
} // namespace xgboost

View File

@ -0,0 +1,121 @@
/*!
* Copyright (c) by Contributors 2020
*/
#include <gtest/gtest.h>
#include <memory>
#include <cmath>
#include "xgboost/logging.h"
#include "../../../src/common/probability_distribution.h"
namespace xgboost {
namespace common {
TEST(ProbabilityDistribution, DistributionGeneric) {
// Assert d/dx CDF = PDF, d/dx PDF = GradPDF, d/dx GradPDF = HessPDF
// Do this for every distribution type
for (auto type : {ProbabilityDistributionType::kNormal, ProbabilityDistributionType::kLogistic,
ProbabilityDistributionType::kExtreme}) {
std::unique_ptr<ProbabilityDistribution> dist{ ProbabilityDistribution::Create(type) };
double integral_of_pdf = dist->CDF(-2.0);
double integral_of_grad_pdf = dist->PDF(-2.0);
double integral_of_hess_pdf = dist->GradPDF(-2.0);
// Perform numerical differentiation and integration
// Enumerate 4000 grid points in range [-2, 2]
for (int i = 0; i <= 4000; ++i) {
const double x = static_cast<double>(i) / 1000.0 - 2.0;
// Numerical differentiation (p. 246, Numerical Analysis 2nd ed. by Timothy Sauer)
EXPECT_NEAR((dist->CDF(x + 1e-5) - dist->CDF(x - 1e-5)) / 2e-5, dist->PDF(x), 6e-11);
EXPECT_NEAR((dist->PDF(x + 1e-5) - dist->PDF(x - 1e-5)) / 2e-5, dist->GradPDF(x), 6e-11);
EXPECT_NEAR((dist->GradPDF(x + 1e-5) - dist->GradPDF(x - 1e-5)) / 2e-5,
dist->HessPDF(x), 6e-11);
// Numerical integration using Trapezoid Rule (p. 257, Sauer)
integral_of_pdf += 5e-4 * (dist->PDF(x - 1e-3) + dist->PDF(x));
integral_of_grad_pdf += 5e-4 * (dist->GradPDF(x - 1e-3) + dist->GradPDF(x));
integral_of_hess_pdf += 5e-4 * (dist->HessPDF(x - 1e-3) + dist->HessPDF(x));
EXPECT_NEAR(integral_of_pdf, dist->CDF(x), 2e-4);
EXPECT_NEAR(integral_of_grad_pdf, dist->PDF(x), 2e-4);
EXPECT_NEAR(integral_of_hess_pdf, dist->GradPDF(x), 2e-4);
}
}
}
TEST(ProbabilityDistribution, NormalDist) {
std::unique_ptr<ProbabilityDistribution> dist{
ProbabilityDistribution::Create(ProbabilityDistributionType::kNormal)
};
// "Three-sigma rule" (https://en.wikipedia.org/wiki/689599.7_rule)
// 68% of values are within 1 standard deviation away from the mean
// 95% of values are within 2 standard deviation away from the mean
// 99.7% of values are within 3 standard deviation away from the mean
EXPECT_NEAR(dist->CDF(0.5) - dist->CDF(-0.5), 0.3829, 0.00005);
EXPECT_NEAR(dist->CDF(1.0) - dist->CDF(-1.0), 0.6827, 0.00005);
EXPECT_NEAR(dist->CDF(1.5) - dist->CDF(-1.5), 0.8664, 0.00005);
EXPECT_NEAR(dist->CDF(2.0) - dist->CDF(-2.0), 0.9545, 0.00005);
EXPECT_NEAR(dist->CDF(2.5) - dist->CDF(-2.5), 0.9876, 0.00005);
EXPECT_NEAR(dist->CDF(3.0) - dist->CDF(-3.0), 0.9973, 0.00005);
EXPECT_NEAR(dist->CDF(3.5) - dist->CDF(-3.5), 0.9995, 0.00005);
EXPECT_NEAR(dist->CDF(4.0) - dist->CDF(-4.0), 0.9999, 0.00005);
}
TEST(ProbabilityDistribution, LogisticDist) {
std::unique_ptr<ProbabilityDistribution> dist{
ProbabilityDistribution::Create(ProbabilityDistributionType::kLogistic)
};
/**
* Enforce known properties of the logistic distribution.
* (https://en.wikipedia.org/wiki/Logistic_distribution)
**/
// Enumerate 4000 grid points in range [-2, 2]
for (int i = 0; i <= 4000; ++i) {
const double x = static_cast<double>(i) / 1000.0 - 2.0;
// PDF = 1/4 * sech(x/2)**2
const double sech_x = 1.0 / std::cosh(x * 0.5); // hyperbolic secant at x/2
EXPECT_NEAR(0.25 * sech_x * sech_x, dist->PDF(x), 1e-15);
// CDF = 1/2 + 1/2 * tanh(x/2)
EXPECT_NEAR(0.5 + 0.5 * std::tanh(x * 0.5), dist->CDF(x), 1e-15);
}
}
TEST(ProbabilityDistribution, ExtremeDist) {
std::unique_ptr<ProbabilityDistribution> dist{
ProbabilityDistribution::Create(ProbabilityDistributionType::kExtreme)
};
/**
* Enforce known properties of the extreme distribution (also known as Gumbel distribution).
* The mean is the negative of the Euler-Mascheroni constant.
* The variance is 1/6 * pi**2. (https://mathworld.wolfram.com/GumbelDistribution.html)
**/
// Enumerate 25000 grid points in range [-20, 5].
// Compute the mean (expected value) of the distribution using numerical integration.
// Nearly all mass of the extreme distribution is concentrated between -20 and 5,
// so numerically integrating x*PDF(x) over [-20, 5] gives good estimate of the mean.
double mean = 0.0;
for (int i = 0; i <= 25000; ++i) {
const double x = static_cast<double>(i) / 1000.0 - 20.0;
// Numerical integration using Trapezoid Rule (p. 257, Sauer)
mean += 5e-4 * ((x - 1e-3) * dist->PDF(x - 1e-3) + x * dist->PDF(x));
}
EXPECT_NEAR(mean, -probability_constant::kEulerMascheroni, 1e-7);
// Enumerate 25000 grid points in range [-20, 5].
// Compute the variance of the distribution using numerical integration.
// Nearly all mass of the extreme distribution is concentrated between -20 and 5,
// so numerically integrating (x-mean)*PDF(x) over [-20, 5] gives good estimate of the variance.
double variance = 0.0;
for (int i = 0; i <= 25000; ++i) {
const double x = static_cast<double>(i) / 1000.0 - 20.0;
// Numerical integration using Trapezoid Rule (p. 257, Sauer)
variance += 5e-4 * ((x - 1e-3 - mean) * (x - 1e-3 - mean) * dist->PDF(x - 1e-3)
+ (x - mean) * (x - mean) * dist->PDF(x));
}
EXPECT_NEAR(variance, probability_constant::kPI * probability_constant::kPI / 6.0, 1e-6);
}
} // namespace common
} // namespace xgboost

View File

@ -0,0 +1,169 @@
/*!
* Copyright (c) by Contributors 2020
*/
#include <gtest/gtest.h>
#include <memory>
#include <vector>
#include <string>
#include <limits>
#include <cmath>
#include "xgboost/metric.h"
#include "xgboost/logging.h"
#include "../helpers.h"
#include "../../../src/common/survival_util.h"
namespace xgboost {
namespace common {
/**
* Reference values obtained from
* https://github.com/avinashbarnwal/GSOC-2019/blob/master/AFT/R/combined_assignment.R
**/
TEST(Metric, AFTNegLogLik) {
auto lparam = CreateEmptyGenericParam(-1); // currently AFT metric is CPU only
/**
* Test aggregate output from the AFT metric over a small test data set.
* This is unlike AFTLoss.* tests, which verify metric values over individual data points.
**/
MetaInfo info;
info.num_row_ = 4;
info.labels_lower_bound_.HostVector()
= { 100.0f, -std::numeric_limits<bst_float>::infinity(), 60.0f, 16.0f };
info.labels_upper_bound_.HostVector()
= { 100.0f, 20.0f, std::numeric_limits<bst_float>::infinity(), 200.0f };
info.weights_.HostVector() = std::vector<bst_float>();
HostDeviceVector<bst_float> preds(4, std::log(64));
struct TestCase {
std::string dist_type;
bst_float reference_value;
};
for (const auto& test_case : std::vector<TestCase>{ {"normal", 2.1508f}, {"logistic", 2.1804f},
{"extreme", 2.0706f} }) {
std::unique_ptr<Metric> metric(Metric::Create("aft-nloglik", &lparam));
metric->Configure({ {"aft_loss_distribution", test_case.dist_type},
{"aft_loss_distribution_scale", "1.0"} });
EXPECT_NEAR(metric->Eval(preds, info, false), test_case.reference_value, 1e-4);
}
}
// Test configuration of AFT metric
TEST(AFTNegLogLikMetric, Configuration) {
auto lparam = CreateEmptyGenericParam(-1); // currently AFT metric is CPU only
std::unique_ptr<Metric> metric(Metric::Create("aft-nloglik", &lparam));
metric->Configure({{"aft_loss_distribution", "normal"}, {"aft_loss_distribution_scale", "10"}});
// Configuration round-trip test
Json j_obj{ Object() };
metric->SaveConfig(&j_obj);
auto aft_param_json = j_obj["aft_loss_param"];
EXPECT_EQ(get<String>(aft_param_json["aft_loss_distribution"]), "normal");
EXPECT_EQ(get<String>(aft_param_json["aft_loss_distribution_scale"]), "10");
}
/**
* AFTLoss.* tests verify metric values over individual data points.
**/
// Generate prediction value ranging from 2**1 to 2**15, using grid points in log scale
// Then check prediction against the reference values
static inline void CheckLossOverGridPoints(
double true_label_lower_bound,
double true_label_upper_bound,
ProbabilityDistributionType dist_type,
const std::vector<double>& reference_values) {
const int num_point = 20;
const double log_y_low = 1.0;
const double log_y_high = 15.0;
std::unique_ptr<AFTLoss> loss(new AFTLoss(dist_type));
CHECK_EQ(num_point, reference_values.size());
for (int i = 0; i < num_point; ++i) {
const double y_pred
= std::pow(2.0, i * (log_y_high - log_y_low) / (num_point - 1) + log_y_low);
const double loss_val
= loss->Loss(true_label_lower_bound, true_label_upper_bound, std::log(y_pred), 1.0);
EXPECT_NEAR(loss_val, reference_values[i], 1e-4);
}
}
TEST(AFTLoss, Uncensored) {
// Given label 100, compute the AFT loss for various prediction values
const double true_label_lower_bound = 100.0;
const double true_label_upper_bound = true_label_lower_bound;
CheckLossOverGridPoints(true_label_lower_bound, true_label_upper_bound,
ProbabilityDistributionType::kNormal,
{ 13.1761, 11.3085, 9.7017, 8.3558, 7.2708, 6.4466, 5.8833, 5.5808, 5.5392, 5.7585, 6.2386,
6.9795, 7.9813, 9.2440, 10.7675, 12.5519, 14.5971, 16.9032, 19.4702, 22.2980 });
CheckLossOverGridPoints(true_label_lower_bound, true_label_upper_bound,
ProbabilityDistributionType::kLogistic,
{ 8.5568, 8.0720, 7.6038, 7.1620, 6.7612, 6.4211, 6.1659, 6.0197, 5.9990, 6.1064, 6.3293,
6.6450, 7.0289, 7.4594, 7.9205, 8.4008, 8.8930, 9.3926, 9.8966, 10.4033 });
CheckLossOverGridPoints(true_label_lower_bound, true_label_upper_bound,
ProbabilityDistributionType::kExtreme,
{ 27.6310, 27.6310, 19.7177, 13.0281, 9.2183, 7.1365, 6.0916, 5.6688, 5.6195, 5.7941, 6.1031,
6.4929, 6.9310, 7.3981, 7.8827, 8.3778, 8.8791, 9.3842, 9.8916, 10.40033 });
}
TEST(AFTLoss, LeftCensored) {
// Given label (-inf, 20], compute the AFT loss for various prediction values
const double true_label_lower_bound = -std::numeric_limits<double>::infinity();
const double true_label_upper_bound = 20.0;
CheckLossOverGridPoints(true_label_lower_bound, true_label_upper_bound,
ProbabilityDistributionType::kNormal,
{ 0.0107, 0.0373, 0.1054, 0.2492, 0.5068, 0.9141, 1.5003, 2.2869, 3.2897, 4.5196, 5.9846,
7.6902, 9.6405, 11.8385, 14.2867, 16.9867, 19.9399, 23.1475, 26.6103, 27.6310 });
CheckLossOverGridPoints(true_label_lower_bound, true_label_upper_bound,
ProbabilityDistributionType::kLogistic,
{ 0.0953, 0.1541, 0.2451, 0.3804, 0.5717, 0.8266, 1.1449, 1.5195, 1.9387, 2.3902, 2.8636,
3.3512, 3.8479, 4.3500, 4.8556, 5.3632, 5.8721, 6.3817, 6.8918, 7.4021 });
CheckLossOverGridPoints(true_label_lower_bound, true_label_upper_bound,
ProbabilityDistributionType::kExtreme,
{ 0.0000, 0.0025, 0.0277, 0.1225, 0.3195, 0.6150, 0.9862, 1.4094, 1.8662, 2.3441, 2.8349,
3.3337, 3.8372, 4.3436, 4.8517, 5.3609, 5.8707, 6.3808, 6.8912, 7.4018 });
}
TEST(AFTLoss, RightCensored) {
// Given label [60, +inf), compute the AFT loss for various prediction values
const double true_label_lower_bound = 60.0;
const double true_label_upper_bound = std::numeric_limits<double>::infinity();
CheckLossOverGridPoints(true_label_lower_bound, true_label_upper_bound,
ProbabilityDistributionType::kNormal,
{ 8.0000, 6.2537, 4.7487, 3.4798, 2.4396, 1.6177, 0.9993, 0.5638, 0.2834, 0.1232, 0.0450,
0.0134, 0.0032, 0.0006, 0.0001, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000 });
CheckLossOverGridPoints(true_label_lower_bound, true_label_upper_bound,
ProbabilityDistributionType::kLogistic,
{ 3.4340, 2.9445, 2.4683, 2.0125, 1.5871, 1.2041, 0.8756, 0.6099, 0.4083, 0.2643, 0.1668,
0.1034, 0.0633, 0.0385, 0.0233, 0.0140, 0.0084, 0.0051, 0.0030, 0.0018 });
CheckLossOverGridPoints(true_label_lower_bound, true_label_upper_bound,
ProbabilityDistributionType::kExtreme,
{ 27.6310, 18.0015, 10.8018, 6.4817, 3.8893, 2.3338, 1.4004, 0.8403, 0.5042, 0.3026, 0.1816,
0.1089, 0.0654, 0.0392, 0.0235, 0.0141, 0.0085, 0.0051, 0.0031, 0.0018 });
}
TEST(AFTLoss, IntervalCensored) {
// Given label [16, 200], compute the AFT loss for various prediction values
const double true_label_lower_bound = 16.0;
const double true_label_upper_bound = 200.0;
CheckLossOverGridPoints(true_label_lower_bound, true_label_upper_bound,
ProbabilityDistributionType::kNormal,
{ 3.9746, 2.8415, 1.9319, 1.2342, 0.7335, 0.4121, 0.2536, 0.2470, 0.3919, 0.6982, 1.1825,
1.8622, 2.7526, 3.8656, 5.2102, 6.7928, 8.6183, 10.6901, 13.0108, 15.5826 });
CheckLossOverGridPoints(true_label_lower_bound, true_label_upper_bound,
ProbabilityDistributionType::kLogistic,
{ 2.2906, 1.8578, 1.4667, 1.1324, 0.8692, 0.6882, 0.5948, 0.5909, 0.6764, 0.8499, 1.1061,
1.4348, 1.8215, 2.2511, 2.7104, 3.1891, 3.6802, 4.1790, 4.6825, 5.1888 });
CheckLossOverGridPoints(true_label_lower_bound, true_label_upper_bound,
ProbabilityDistributionType::kExtreme,
{ 8.0000, 4.8004, 2.8805, 1.7284, 1.0372, 0.6231, 0.3872, 0.3031, 0.3740, 0.5839, 0.8995,
1.2878, 1.7231, 2.1878, 2.6707, 3.1647, 3.6653, 4.1699, 4.6770, 5.1856 });
}
} // namespace common
} // namespace xgboost

View File

@ -0,0 +1,174 @@
/*!
* Copyright (c) by Contributors 2020
*/
#include <gtest/gtest.h>
#include <memory>
#include <vector>
#include <limits>
#include <cmath>
#include "xgboost/objective.h"
#include "xgboost/logging.h"
#include "../helpers.h"
#include "../../../src/common/survival_util.h"
namespace xgboost {
namespace common {
TEST(Objective, AFTObjConfiguration) {
auto lparam = CreateEmptyGenericParam(-1); // currently AFT objective is CPU only
std::unique_ptr<ObjFunction> objective(ObjFunction::Create("survival:aft", &lparam));
objective->Configure({ {"aft_loss_distribution", "logistic"},
{"aft_loss_distribution_scale", "5"} });
// Configuration round-trip test
Json j_obj{ Object() };
objective->SaveConfig(&j_obj);
EXPECT_EQ(get<String>(j_obj["name"]), "survival:aft");
auto aft_param_json = j_obj["aft_loss_param"];
EXPECT_EQ(get<String>(aft_param_json["aft_loss_distribution"]), "logistic");
EXPECT_EQ(get<String>(aft_param_json["aft_loss_distribution_scale"]), "5");
}
/**
* Verify that gradient pair (gpair) is computed correctly for various prediction values.
* Reference values obtained from
* https://github.com/avinashbarnwal/GSOC-2019/blob/master/AFT/R/combined_assignment.R
**/
// Generate prediction value ranging from 2**1 to 2**15, using grid points in log scale
// Then check prediction against the reference values
static inline void CheckGPairOverGridPoints(
ObjFunction* obj,
bst_float true_label_lower_bound,
bst_float true_label_upper_bound,
const std::string& dist_type,
const std::vector<bst_float>& expected_grad,
const std::vector<bst_float>& expected_hess,
float ftol = 1e-4f) {
const int num_point = 20;
const double log_y_low = 1.0;
const double log_y_high = 15.0;
obj->Configure({ {"aft_loss_distribution", dist_type},
{"aft_loss_distribution_scale", "1"} });
MetaInfo info;
info.num_row_ = num_point;
info.labels_lower_bound_.HostVector()
= std::vector<bst_float>(num_point, true_label_lower_bound);
info.labels_upper_bound_.HostVector()
= std::vector<bst_float>(num_point, true_label_upper_bound);
info.weights_.HostVector() = std::vector<bst_float>();
std::vector<bst_float> preds(num_point);
for (int i = 0; i < num_point; ++i) {
preds[i] = std::log(std::pow(2.0, i * (log_y_high - log_y_low) / (num_point - 1) + log_y_low));
}
HostDeviceVector<GradientPair> out_gpair;
obj->GetGradient(HostDeviceVector<bst_float>(preds), info, 1, &out_gpair);
const auto& gpair = out_gpair.HostVector();
CHECK_EQ(num_point, expected_grad.size());
CHECK_EQ(num_point, expected_hess.size());
for (int i = 0; i < num_point; ++i) {
EXPECT_NEAR(gpair[i].GetGrad(), expected_grad[i], ftol);
EXPECT_NEAR(gpair[i].GetHess(), expected_hess[i], ftol);
}
}
TEST(Objective, AFTObjGPairUncensoredLabels) {
auto lparam = CreateEmptyGenericParam(-1); // currently AFT objective is CPU only
std::unique_ptr<ObjFunction> obj(ObjFunction::Create("survival:aft", &lparam));
CheckGPairOverGridPoints(obj.get(), 100.0f, 100.0f, "normal",
{ -3.9120f, -3.4013f, -2.8905f, -2.3798f, -1.8691f, -1.3583f, -0.8476f, -0.3368f, 0.1739f,
0.6846f, 1.1954f, 1.7061f, 2.2169f, 2.7276f, 3.2383f, 3.7491f, 4.2598f, 4.7706f, 5.2813f,
5.7920f },
{ 1.0000f, 1.0000f, 1.0000f, 1.0000f, 1.0000f, 1.0000f, 1.0000f, 1.0000f, 1.0000f, 1.0000f,
1.0000f, 1.0000f, 1.0000f, 1.0000f, 1.0000f, 1.0000f, 1.0000f, 1.0000f, 1.0000f, 1.0000f });
CheckGPairOverGridPoints(obj.get(), 100.0f, 100.0f, "logistic",
{ -0.9608f, -0.9355f, -0.8948f, -0.8305f, -0.7327f, -0.5910f, -0.4001f, -0.1668f, 0.0867f,
0.3295f, 0.5354f, 0.6927f, 0.8035f, 0.8773f, 0.9245f, 0.9540f, 0.9721f, 0.9832f, 0.9899f,
0.9939f },
{ 0.0384f, 0.0624f, 0.0997f, 0.1551f, 0.2316f, 0.3254f, 0.4200f, 0.4861f, 0.4962f, 0.4457f,
0.3567f, 0.2601f, 0.1772f, 0.1152f, 0.0726f, 0.0449f, 0.0275f, 0.0167f, 0.0101f, 0.0061f });
CheckGPairOverGridPoints(obj.get(), 100.0f, 100.0f, "extreme",
{ -0.0000f, -29.0026f, -17.0031f, -9.8028f, -5.4822f, -2.8897f, -1.3340f, -0.4005f, 0.1596f,
0.4957f, 0.6974f, 0.8184f, 0.8910f, 0.9346f, 0.9608f, 0.9765f, 0.9859f, 0.9915f, 0.9949f,
0.9969f },
{ 0.0000f, 30.0026f, 18.0031f, 10.8028f, 6.4822f, 3.8897f, 2.3340f, 1.4005f, 0.8404f, 0.5043f,
0.3026f, 0.1816f, 0.1090f, 0.0654f, 0.0392f, 0.0235f, 0.0141f, 0.0085f, 0.0051f, 0.0031f });
}
TEST(Objective, AFTObjGPairLeftCensoredLabels) {
auto lparam = CreateEmptyGenericParam(-1); // currently AFT objective is CPU only
std::unique_ptr<ObjFunction> obj(ObjFunction::Create("survival:aft", &lparam));
CheckGPairOverGridPoints(obj.get(), -std::numeric_limits<float>::infinity(), 20.0f, "normal",
{ 0.0285f, 0.0832f, 0.1951f, 0.3804f, 0.6403f, 0.9643f, 1.3379f, 1.7475f, 2.1828f, 2.6361f,
3.1023f, 3.5779f, 4.0603f, 4.5479f, 5.0394f, 5.5340f, 6.0309f, 6.5298f, 7.0303f, 0.5072f },
{ 0.0663f, 0.1559f, 0.2881f, 0.4378f, 0.5762f, 0.6878f, 0.7707f, 0.8300f, 0.8719f, 0.9016f,
0.9229f, 0.9385f, 0.9501f, 0.9588f, 0.9656f, 0.9709f, 0.9751f, 0.9785f, 0.9812f, 0.0045f },
2e-4);
CheckGPairOverGridPoints(obj.get(), -std::numeric_limits<float>::infinity(), 20.0f, "logistic",
{ 0.0909f, 0.1428f, 0.2174f, 0.3164f, 0.4355f, 0.5625f, 0.6818f, 0.7812f, 0.8561f, 0.9084f,
0.9429f, 0.9650f, 0.9787f, 0.9871f, 0.9922f, 0.9953f, 0.9972f, 0.9983f, 0.9990f, 0.9994f },
{ 0.0826f, 0.1224f, 0.1701f, 0.2163f, 0.2458f, 0.2461f, 0.2170f, 0.1709f, 0.1232f, 0.0832f,
0.0538f, 0.0338f, 0.0209f, 0.0127f, 0.0077f, 0.0047f, 0.0028f, 0.0017f, 0.0010f, 0.0006f });
CheckGPairOverGridPoints(obj.get(), -std::numeric_limits<float>::infinity(), 20.0f, "extreme",
{ 0.0005f, 0.0149f, 0.1011f, 0.2815f, 0.4881f, 0.6610f, 0.7847f, 0.8665f, 0.9183f, 0.9504f,
0.9700f, 0.9820f, 0.9891f, 0.9935f, 0.9961f, 0.9976f, 0.9986f, 0.9992f, 0.9995f, 0.9997f },
{ 0.0041f, 0.0747f, 0.2731f, 0.4059f, 0.3829f, 0.2901f, 0.1973f, 0.1270f, 0.0793f, 0.0487f,
0.0296f, 0.0179f, 0.0108f, 0.0065f, 0.0039f, 0.0024f, 0.0014f, 0.0008f, 0.0005f, 0.0003f });
}
TEST(Objective, AFTObjGPairRightCensoredLabels) {
auto lparam = CreateEmptyGenericParam(-1); // currently AFT objective is CPU only
std::unique_ptr<ObjFunction> obj(ObjFunction::Create("survival:aft", &lparam));
CheckGPairOverGridPoints(obj.get(), 60.0f, std::numeric_limits<float>::infinity(), "normal",
{ -3.6583f, -3.1815f, -2.7135f, -2.2577f, -1.8190f, -1.4044f, -1.0239f, -0.6905f, -0.4190f,
-0.2209f, -0.0973f, -0.0346f, -0.0097f, -0.0021f, -0.0004f, -0.0000f, -0.0000f, -0.0000f,
-0.0000f, -0.0000f },
{ 0.9407f, 0.9259f, 0.9057f, 0.8776f, 0.8381f, 0.7821f, 0.7036f, 0.5970f, 0.4624f, 0.3128f,
0.1756f, 0.0780f, 0.0265f, 0.0068f, 0.0013f, 0.0002f, 0.0000f, 0.0000f, 0.0000f, 0.0000f });
CheckGPairOverGridPoints(obj.get(), 60.0f, std::numeric_limits<float>::infinity(), "logistic",
{ -0.9677f, -0.9474f, -0.9153f, -0.8663f, -0.7955f, -0.7000f, -0.5834f, -0.4566f, -0.3352f,
-0.2323f, -0.1537f, -0.0982f, -0.0614f, -0.0377f, -0.0230f, -0.0139f, -0.0084f, -0.0051f,
-0.0030f, -0.0018f },
{ 0.0312f, 0.0499f, 0.0776f, 0.1158f, 0.1627f, 0.2100f, 0.2430f, 0.2481f, 0.2228f, 0.1783f,
0.1300f, 0.0886f, 0.0576f, 0.0363f, 0.0225f, 0.0137f, 0.0083f, 0.0050f, 0.0030f, 0.0018f });
CheckGPairOverGridPoints(obj.get(), 60.0f, std::numeric_limits<float>::infinity(), "extreme",
{ -2.8073f, -18.0015f, -10.8018f, -6.4817f, -3.8893f, -2.3338f, -1.4004f, -0.8403f, -0.5042f,
-0.3026f, -0.1816f, -0.1089f, -0.0654f, -0.0392f, -0.0235f, -0.0141f, -0.0085f, -0.0051f,
-0.0031f, -0.0018f },
{ 0.2614f, 18.0015f, 10.8018f, 6.4817f, 3.8893f, 2.3338f, 1.4004f, 0.8403f, 0.5042f, 0.3026f,
0.1816f, 0.1089f, 0.0654f, 0.0392f, 0.0235f, 0.0141f, 0.0085f, 0.0051f, 0.0031f, 0.0018f });
}
TEST(Objective, AFTObjGPairIntervalCensoredLabels) {
auto lparam = CreateEmptyGenericParam(-1); // currently AFT objective is CPU only
std::unique_ptr<ObjFunction> obj(ObjFunction::Create("survival:aft", &lparam));
CheckGPairOverGridPoints(obj.get(), 16.0f, 200.0f, "normal",
{ -2.4435f, -1.9965f, -1.5691f, -1.1679f, -0.7990f, -0.4649f, -0.1596f, 0.1336f, 0.4370f,
0.7682f, 1.1340f, 1.5326f, 1.9579f, 2.4035f, 2.8639f, 3.3351f, 3.8143f, 4.2995f, 4.7891f,
5.2822f },
{ 0.8909f, 0.8579f, 0.8134f, 0.7557f, 0.6880f, 0.6221f, 0.5789f, 0.5769f, 0.6171f, 0.6818f,
0.7500f, 0.8088f, 0.8545f, 0.8884f, 0.9131f, 0.9312f, 0.9446f, 0.9547f, 0.9624f, 0.9684f });
CheckGPairOverGridPoints(obj.get(), 16.0f, 200.0f, "logistic",
{ -0.8790f, -0.8112f, -0.7153f, -0.5893f, -0.4375f, -0.2697f, -0.0955f, 0.0800f, 0.2545f,
0.4232f, 0.5768f, 0.7054f, 0.8040f, 0.8740f, 0.9210f, 0.9513f, 0.9703f, 0.9820f, 0.9891f,
0.9934f },
{ 0.1086f, 0.1588f, 0.2176f, 0.2745f, 0.3164f, 0.3374f, 0.3433f, 0.3434f, 0.3384f, 0.3191f,
0.2789f, 0.2229f, 0.1637f, 0.1125f, 0.0737f, 0.0467f, 0.0290f, 0.0177f, 0.0108f, 0.0065f });
CheckGPairOverGridPoints(obj.get(), 16.0f, 200.0f, "extreme",
{ -8.0000f, -4.8004f, -2.8805f, -1.7284f, -1.0371f, -0.6168f, -0.3140f, -0.0121f, 0.2841f,
0.5261f, 0.6989f, 0.8132f, 0.8857f, 0.9306f, 0.9581f, 0.9747f, 0.9848f, 0.9909f, 0.9945f,
0.9967f },
{ 8.0000f, 4.8004f, 2.8805f, 1.7284f, 1.0380f, 0.6567f, 0.5727f, 0.6033f, 0.5384f, 0.4051f,
0.2757f, 0.1776f, 0.1110f, 0.0682f, 0.0415f, 0.0251f, 0.0151f, 0.0091f, 0.0055f, 0.0033f });
}
} // namespace common
} // namespace xgboost

View File

@ -0,0 +1,90 @@
import testing as tm
import pytest
import numpy as np
import xgboost as xgb
import json
from pathlib import Path
dpath = Path('demo/data')
def test_aft_survival_toy_data():
# See demo/aft_survival/aft_survival_viz_demo.py
X = np.array([1, 2, 3, 4, 5]).reshape((-1, 1))
INF = np.inf
y_lower = np.array([ 10, 15, -INF, 30, 100])
y_upper = np.array([INF, INF, 20, 50, INF])
dmat = xgb.DMatrix(X)
dmat.set_float_info('label_lower_bound', y_lower)
dmat.set_float_info('label_upper_bound', y_upper)
# "Accuracy" = the number of data points whose ranged label (y_lower, y_upper) includes
# the corresponding predicted label (y_pred)
acc_rec = []
def my_callback(env):
y_pred = env.model.predict(dmat)
acc = np.sum(np.logical_and(y_pred >= y_lower, y_pred <= y_upper)/len(X))
acc_rec.append(acc)
evals_result = {}
params = {'max_depth': 3, 'objective':'survival:aft', 'min_child_weight': 0}
bst = xgb.train(params, dmat, 15, [(dmat, 'train')], evals_result=evals_result,
callbacks=[my_callback])
nloglik_rec = evals_result['train']['aft-nloglik']
# AFT metric (negative log likelihood) improve monotonically
assert all(p >= q for p, q in zip(nloglik_rec, nloglik_rec[:1]))
# "Accuracy" improve monotonically.
# Over time, XGBoost model makes predictions that fall within given label ranges.
assert all(p <= q for p, q in zip(acc_rec, acc_rec[1:]))
assert acc_rec[-1] == 1.0
def gather_split_thresholds(tree):
if 'split_condition' in tree:
return (gather_split_thresholds(tree['children'][0])
| gather_split_thresholds(tree['children'][1])
| {tree['split_condition']})
return set()
# Only 2.5, 3.5, and 4.5 are used as split thresholds.
model_json = [json.loads(e) for e in bst.get_dump(dump_format='json')]
for tree in model_json:
assert gather_split_thresholds(tree).issubset({2.5, 3.5, 4.5})
@pytest.mark.skipif(**tm.no_pandas())
def test_aft_survival_demo_data():
import pandas as pd
df = pd.read_csv(dpath / 'veterans_lung_cancer.csv')
y_lower_bound = df['Survival_label_lower_bound']
y_upper_bound = df['Survival_label_upper_bound']
X = df.drop(['Survival_label_lower_bound', 'Survival_label_upper_bound'], axis=1)
dtrain = xgb.DMatrix(X)
dtrain.set_float_info('label_lower_bound', y_lower_bound)
dtrain.set_float_info('label_upper_bound', y_upper_bound)
base_params = {'verbosity': 0,
'objective': 'survival:aft',
'eval_metric': 'aft-nloglik',
'tree_method': 'hist',
'learning_rate': 0.05,
'aft_loss_distribution_scale': 1.20,
'max_depth': 6,
'lambda': 0.01,
'alpha': 0.02}
nloglik_rec = {}
dists = ['normal', 'logistic', 'extreme']
for dist in dists:
params = base_params
params.update({'aft_loss_distribution': dist})
evals_result = {}
bst = xgb.train(params, dtrain, num_boost_round=500, evals=[(dtrain, 'train')],
evals_result=evals_result)
nloglik_rec[dist] = evals_result['train']['aft-nloglik']
# AFT metric (negative log likelihood) improve monotonically
assert all(p >= q for p, q in zip(nloglik_rec[dist], nloglik_rec[dist][:1]))
# For this data, normal distribution works the best
assert nloglik_rec['normal'][-1] < 5.0
assert nloglik_rec['logistic'][-1] > 5.0
assert nloglik_rec['extreme'][-1] > 5.0