* [WIP] Add lower and upper bounds on the label for survival analysis * Update test MetaInfo.SaveLoadBinary to account for extra two fields * Don't clear qids_ for version 2 of MetaInfo * Add SetInfo() and GetInfo() method for lower and upper bounds * changes to aft * Add parameter class for AFT; use enum's to represent distribution and event type * Add AFT metric * changes to neg grad to grad * changes to binomial loss * changes to overflow * changes to eps * changes to code refactoring * changes to code refactoring * changes to code refactoring * Re-factor survival analysis * Remove aft namespace * Move function bodies out of AFTNormal and AFTLogistic, to reduce clutter * Move function bodies out of AFTLoss, to reduce clutter * Use smart pointer to store AFTDistribution and AFTLoss * Rename AFTNoiseDistribution enum to AFTDistributionType for clarity The enum class was not a distribution itself but a distribution type * Add AFTDistribution::Create() method for convenience * changes to extreme distribution * changes to extreme distribution * changes to extreme * changes to extreme distribution * changes to left censored * deleted cout * changes to x,mu and sd and code refactoring * changes to print * changes to hessian formula in censored and uncensored * changes to variable names and pow * changes to Logistic Pdf * changes to parameter * Expose lower and upper bound labels to R package * Use example weights; normalize log likelihood metric * changes to CHECK * changes to logistic hessian to standard formula * changes to logistic formula * Comply with coding style guideline * Revert back Rabit submodule * Revert dmlc-core submodule * Comply with coding style guideline (clang-tidy) * Fix an error in AFTLoss::Gradient() * Add missing files to amalgamation * Address @RAMitchell's comment: minimize future change in MetaInfo interface * Fix lint * Fix compilation error on 32-bit target, when size_t == bst_uint * Allocate sufficient memory to hold extra label info * Use OpenMP to speed up * Fix compilation on Windows * Address reviewer's feedback * Add unit tests for probability distributions * Make Metric subclass of Configurable * Address reviewer's feedback: Configure() AFT metric * Add a dummy test for AFT metric configuration * Complete AFT configuration test; remove debugging print * Rename AFT parameters * Clarify test comment * Add a dummy test for AFT loss for uncensored case * Fix a bug in AFT loss for uncensored labels * Complete unit test for AFT loss metric * Simplify unit tests for AFT metric * Add unit test to verify aggregate output from AFT metric * Use EXPECT_* instead of ASSERT_*, so that we run all unit tests * Use aft_loss_param when serializing AFTObj This is to be consistent with AFT metric * Add unit tests for AFT Objective * Fix OpenMP bug; clarify semantics for shared variables used in OpenMP loops * Add comments * Remove AFT prefix from probability distribution; put probability distribution in separate source file * Add comments * Define kPI and kEulerMascheroni in probability_distribution.h * Add probability_distribution.cc to amalgamation * Remove unnecessary diff * Address reviewer's feedback: define variables where they're used * Eliminate all INFs and NANs from AFT loss and gradient * Add demo * Add tutorial * Fix lint * Use 'survival:aft' to be consistent with 'survival:cox' * Move sample data to demo/data * Add visual demo with 1D toy data * Add Python tests Co-authored-by: Philip Cho <chohyu01@cs.washington.edu>
91 lines
3.6 KiB
Python
91 lines
3.6 KiB
Python
import testing as tm
|
|
import pytest
|
|
import numpy as np
|
|
import xgboost as xgb
|
|
import json
|
|
from pathlib import Path
|
|
|
|
dpath = Path('demo/data')
|
|
|
|
def test_aft_survival_toy_data():
|
|
# See demo/aft_survival/aft_survival_viz_demo.py
|
|
X = np.array([1, 2, 3, 4, 5]).reshape((-1, 1))
|
|
INF = np.inf
|
|
y_lower = np.array([ 10, 15, -INF, 30, 100])
|
|
y_upper = np.array([INF, INF, 20, 50, INF])
|
|
|
|
dmat = xgb.DMatrix(X)
|
|
dmat.set_float_info('label_lower_bound', y_lower)
|
|
dmat.set_float_info('label_upper_bound', y_upper)
|
|
|
|
# "Accuracy" = the number of data points whose ranged label (y_lower, y_upper) includes
|
|
# the corresponding predicted label (y_pred)
|
|
acc_rec = []
|
|
def my_callback(env):
|
|
y_pred = env.model.predict(dmat)
|
|
acc = np.sum(np.logical_and(y_pred >= y_lower, y_pred <= y_upper)/len(X))
|
|
acc_rec.append(acc)
|
|
|
|
evals_result = {}
|
|
params = {'max_depth': 3, 'objective':'survival:aft', 'min_child_weight': 0}
|
|
bst = xgb.train(params, dmat, 15, [(dmat, 'train')], evals_result=evals_result,
|
|
callbacks=[my_callback])
|
|
|
|
nloglik_rec = evals_result['train']['aft-nloglik']
|
|
# AFT metric (negative log likelihood) improve monotonically
|
|
assert all(p >= q for p, q in zip(nloglik_rec, nloglik_rec[:1]))
|
|
# "Accuracy" improve monotonically.
|
|
# Over time, XGBoost model makes predictions that fall within given label ranges.
|
|
assert all(p <= q for p, q in zip(acc_rec, acc_rec[1:]))
|
|
assert acc_rec[-1] == 1.0
|
|
|
|
def gather_split_thresholds(tree):
|
|
if 'split_condition' in tree:
|
|
return (gather_split_thresholds(tree['children'][0])
|
|
| gather_split_thresholds(tree['children'][1])
|
|
| {tree['split_condition']})
|
|
return set()
|
|
|
|
# Only 2.5, 3.5, and 4.5 are used as split thresholds.
|
|
model_json = [json.loads(e) for e in bst.get_dump(dump_format='json')]
|
|
for tree in model_json:
|
|
assert gather_split_thresholds(tree).issubset({2.5, 3.5, 4.5})
|
|
|
|
@pytest.mark.skipif(**tm.no_pandas())
|
|
def test_aft_survival_demo_data():
|
|
import pandas as pd
|
|
df = pd.read_csv(dpath / 'veterans_lung_cancer.csv')
|
|
|
|
y_lower_bound = df['Survival_label_lower_bound']
|
|
y_upper_bound = df['Survival_label_upper_bound']
|
|
X = df.drop(['Survival_label_lower_bound', 'Survival_label_upper_bound'], axis=1)
|
|
|
|
dtrain = xgb.DMatrix(X)
|
|
dtrain.set_float_info('label_lower_bound', y_lower_bound)
|
|
dtrain.set_float_info('label_upper_bound', y_upper_bound)
|
|
|
|
base_params = {'verbosity': 0,
|
|
'objective': 'survival:aft',
|
|
'eval_metric': 'aft-nloglik',
|
|
'tree_method': 'hist',
|
|
'learning_rate': 0.05,
|
|
'aft_loss_distribution_scale': 1.20,
|
|
'max_depth': 6,
|
|
'lambda': 0.01,
|
|
'alpha': 0.02}
|
|
nloglik_rec = {}
|
|
dists = ['normal', 'logistic', 'extreme']
|
|
for dist in dists:
|
|
params = base_params
|
|
params.update({'aft_loss_distribution': dist})
|
|
evals_result = {}
|
|
bst = xgb.train(params, dtrain, num_boost_round=500, evals=[(dtrain, 'train')],
|
|
evals_result=evals_result)
|
|
nloglik_rec[dist] = evals_result['train']['aft-nloglik']
|
|
# AFT metric (negative log likelihood) improve monotonically
|
|
assert all(p >= q for p, q in zip(nloglik_rec[dist], nloglik_rec[dist][:1]))
|
|
# For this data, normal distribution works the best
|
|
assert nloglik_rec['normal'][-1] < 5.0
|
|
assert nloglik_rec['logistic'][-1] > 5.0
|
|
assert nloglik_rec['extreme'][-1] > 5.0
|