xgboost/tests/python/test_survival.py
Avinash Barnwal dcf439932a
Add Accelerated Failure Time loss for survival analysis task (#4763)
* [WIP] Add lower and upper bounds on the label for survival analysis

* Update test MetaInfo.SaveLoadBinary to account for extra two fields

* Don't clear qids_ for version 2 of MetaInfo

* Add SetInfo() and GetInfo() method for lower and upper bounds

* changes to aft

* Add parameter class for AFT; use enum's to represent distribution and event type

* Add AFT metric

* changes to neg grad to grad

* changes to binomial loss

* changes to overflow

* changes to eps

* changes to code refactoring

* changes to code refactoring

* changes to code refactoring

* Re-factor survival analysis

* Remove aft namespace

* Move function bodies out of AFTNormal and AFTLogistic, to reduce clutter

* Move function bodies out of AFTLoss, to reduce clutter

* Use smart pointer to store AFTDistribution and AFTLoss

* Rename AFTNoiseDistribution enum to AFTDistributionType for clarity

The enum class was not a distribution itself but a distribution type

* Add AFTDistribution::Create() method for convenience

* changes to extreme distribution

* changes to extreme distribution

* changes to extreme

* changes to extreme distribution

* changes to left censored

* deleted cout

* changes to x,mu and sd and code refactoring

* changes to print

* changes to hessian formula in censored and uncensored

* changes to variable names and pow

* changes to Logistic Pdf

* changes to parameter

* Expose lower and upper bound labels to R package

* Use example weights; normalize log likelihood metric

* changes to CHECK

* changes to logistic hessian to standard formula

* changes to logistic formula

* Comply with coding style guideline

* Revert back Rabit submodule

* Revert dmlc-core submodule

* Comply with coding style guideline (clang-tidy)

* Fix an error in AFTLoss::Gradient()

* Add missing files to amalgamation

* Address @RAMitchell's comment: minimize future change in MetaInfo interface

* Fix lint

* Fix compilation error on 32-bit target, when size_t == bst_uint

* Allocate sufficient memory to hold extra label info

* Use OpenMP to speed up

* Fix compilation on Windows

* Address reviewer's feedback

* Add unit tests for probability distributions

* Make Metric subclass of Configurable

* Address reviewer's feedback: Configure() AFT metric

* Add a dummy test for AFT metric configuration

* Complete AFT configuration test; remove debugging print

* Rename AFT parameters

* Clarify test comment

* Add a dummy test for AFT loss for uncensored case

* Fix a bug in AFT loss for uncensored labels

* Complete unit test for AFT loss metric

* Simplify unit tests for AFT metric

* Add unit test to verify aggregate output from AFT metric

* Use EXPECT_* instead of ASSERT_*, so that we run all unit tests

* Use aft_loss_param when serializing AFTObj

This is to be consistent with AFT metric

* Add unit tests for AFT Objective

* Fix OpenMP bug; clarify semantics for shared variables used in OpenMP loops

* Add comments

* Remove AFT prefix from probability distribution; put probability distribution in separate source file

* Add comments

* Define kPI and kEulerMascheroni in probability_distribution.h

* Add probability_distribution.cc to amalgamation

* Remove unnecessary diff

* Address reviewer's feedback: define variables where they're used

* Eliminate all INFs and NANs from AFT loss and gradient

* Add demo

* Add tutorial

* Fix lint

* Use 'survival:aft' to be consistent with 'survival:cox'

* Move sample data to demo/data

* Add visual demo with 1D toy data

* Add Python tests

Co-authored-by: Philip Cho <chohyu01@cs.washington.edu>
2020-03-25 13:52:51 -07:00

91 lines
3.6 KiB
Python

import testing as tm
import pytest
import numpy as np
import xgboost as xgb
import json
from pathlib import Path
dpath = Path('demo/data')
def test_aft_survival_toy_data():
# See demo/aft_survival/aft_survival_viz_demo.py
X = np.array([1, 2, 3, 4, 5]).reshape((-1, 1))
INF = np.inf
y_lower = np.array([ 10, 15, -INF, 30, 100])
y_upper = np.array([INF, INF, 20, 50, INF])
dmat = xgb.DMatrix(X)
dmat.set_float_info('label_lower_bound', y_lower)
dmat.set_float_info('label_upper_bound', y_upper)
# "Accuracy" = the number of data points whose ranged label (y_lower, y_upper) includes
# the corresponding predicted label (y_pred)
acc_rec = []
def my_callback(env):
y_pred = env.model.predict(dmat)
acc = np.sum(np.logical_and(y_pred >= y_lower, y_pred <= y_upper)/len(X))
acc_rec.append(acc)
evals_result = {}
params = {'max_depth': 3, 'objective':'survival:aft', 'min_child_weight': 0}
bst = xgb.train(params, dmat, 15, [(dmat, 'train')], evals_result=evals_result,
callbacks=[my_callback])
nloglik_rec = evals_result['train']['aft-nloglik']
# AFT metric (negative log likelihood) improve monotonically
assert all(p >= q for p, q in zip(nloglik_rec, nloglik_rec[:1]))
# "Accuracy" improve monotonically.
# Over time, XGBoost model makes predictions that fall within given label ranges.
assert all(p <= q for p, q in zip(acc_rec, acc_rec[1:]))
assert acc_rec[-1] == 1.0
def gather_split_thresholds(tree):
if 'split_condition' in tree:
return (gather_split_thresholds(tree['children'][0])
| gather_split_thresholds(tree['children'][1])
| {tree['split_condition']})
return set()
# Only 2.5, 3.5, and 4.5 are used as split thresholds.
model_json = [json.loads(e) for e in bst.get_dump(dump_format='json')]
for tree in model_json:
assert gather_split_thresholds(tree).issubset({2.5, 3.5, 4.5})
@pytest.mark.skipif(**tm.no_pandas())
def test_aft_survival_demo_data():
import pandas as pd
df = pd.read_csv(dpath / 'veterans_lung_cancer.csv')
y_lower_bound = df['Survival_label_lower_bound']
y_upper_bound = df['Survival_label_upper_bound']
X = df.drop(['Survival_label_lower_bound', 'Survival_label_upper_bound'], axis=1)
dtrain = xgb.DMatrix(X)
dtrain.set_float_info('label_lower_bound', y_lower_bound)
dtrain.set_float_info('label_upper_bound', y_upper_bound)
base_params = {'verbosity': 0,
'objective': 'survival:aft',
'eval_metric': 'aft-nloglik',
'tree_method': 'hist',
'learning_rate': 0.05,
'aft_loss_distribution_scale': 1.20,
'max_depth': 6,
'lambda': 0.01,
'alpha': 0.02}
nloglik_rec = {}
dists = ['normal', 'logistic', 'extreme']
for dist in dists:
params = base_params
params.update({'aft_loss_distribution': dist})
evals_result = {}
bst = xgb.train(params, dtrain, num_boost_round=500, evals=[(dtrain, 'train')],
evals_result=evals_result)
nloglik_rec[dist] = evals_result['train']['aft-nloglik']
# AFT metric (negative log likelihood) improve monotonically
assert all(p >= q for p, q in zip(nloglik_rec[dist], nloglik_rec[dist][:1]))
# For this data, normal distribution works the best
assert nloglik_rec['normal'][-1] < 5.0
assert nloglik_rec['logistic'][-1] > 5.0
assert nloglik_rec['extreme'][-1] > 5.0