xgboost/doc/tutorials/aft_survival_analysis.rst
Avinash Barnwal dcf439932a
Add Accelerated Failure Time loss for survival analysis task (#4763)
* [WIP] Add lower and upper bounds on the label for survival analysis

* Update test MetaInfo.SaveLoadBinary to account for extra two fields

* Don't clear qids_ for version 2 of MetaInfo

* Add SetInfo() and GetInfo() method for lower and upper bounds

* changes to aft

* Add parameter class for AFT; use enum's to represent distribution and event type

* Add AFT metric

* changes to neg grad to grad

* changes to binomial loss

* changes to overflow

* changes to eps

* changes to code refactoring

* changes to code refactoring

* changes to code refactoring

* Re-factor survival analysis

* Remove aft namespace

* Move function bodies out of AFTNormal and AFTLogistic, to reduce clutter

* Move function bodies out of AFTLoss, to reduce clutter

* Use smart pointer to store AFTDistribution and AFTLoss

* Rename AFTNoiseDistribution enum to AFTDistributionType for clarity

The enum class was not a distribution itself but a distribution type

* Add AFTDistribution::Create() method for convenience

* changes to extreme distribution

* changes to extreme distribution

* changes to extreme

* changes to extreme distribution

* changes to left censored

* deleted cout

* changes to x,mu and sd and code refactoring

* changes to print

* changes to hessian formula in censored and uncensored

* changes to variable names and pow

* changes to Logistic Pdf

* changes to parameter

* Expose lower and upper bound labels to R package

* Use example weights; normalize log likelihood metric

* changes to CHECK

* changes to logistic hessian to standard formula

* changes to logistic formula

* Comply with coding style guideline

* Revert back Rabit submodule

* Revert dmlc-core submodule

* Comply with coding style guideline (clang-tidy)

* Fix an error in AFTLoss::Gradient()

* Add missing files to amalgamation

* Address @RAMitchell's comment: minimize future change in MetaInfo interface

* Fix lint

* Fix compilation error on 32-bit target, when size_t == bst_uint

* Allocate sufficient memory to hold extra label info

* Use OpenMP to speed up

* Fix compilation on Windows

* Address reviewer's feedback

* Add unit tests for probability distributions

* Make Metric subclass of Configurable

* Address reviewer's feedback: Configure() AFT metric

* Add a dummy test for AFT metric configuration

* Complete AFT configuration test; remove debugging print

* Rename AFT parameters

* Clarify test comment

* Add a dummy test for AFT loss for uncensored case

* Fix a bug in AFT loss for uncensored labels

* Complete unit test for AFT loss metric

* Simplify unit tests for AFT metric

* Add unit test to verify aggregate output from AFT metric

* Use EXPECT_* instead of ASSERT_*, so that we run all unit tests

* Use aft_loss_param when serializing AFTObj

This is to be consistent with AFT metric

* Add unit tests for AFT Objective

* Fix OpenMP bug; clarify semantics for shared variables used in OpenMP loops

* Add comments

* Remove AFT prefix from probability distribution; put probability distribution in separate source file

* Add comments

* Define kPI and kEulerMascheroni in probability_distribution.h

* Add probability_distribution.cc to amalgamation

* Remove unnecessary diff

* Address reviewer's feedback: define variables where they're used

* Eliminate all INFs and NANs from AFT loss and gradient

* Add demo

* Add tutorial

* Fix lint

* Use 'survival:aft' to be consistent with 'survival:cox'

* Move sample data to demo/data

* Add visual demo with 1D toy data

* Add Python tests

Co-authored-by: Philip Cho <chohyu01@cs.washington.edu>
2020-03-25 13:52:51 -07:00

135 lines
8.1 KiB
ReStructuredText

###############################################
Survival Analysis with Accelerated Failure Time
###############################################
.. contents::
:local:
:backlinks: none
**************************
What is survival analysis?
**************************
**Survival analysis (regression)** models **time to an event of interest**. Survival analysis is a special kind of regression and differs from the conventional regression task as follows:
* The label is always positive, since you cannot wait a negative amount of time until the event occurs.
* The label may not be fully known, or **censored**, because "it takes time to measure time."
The second bullet point is crucial and we should dwell on it more. As you may have guessed from the name, one of the earliest applications of survival analysis is to model mortality of a given population. Let's take `NCCTG Lung Cancer Dataset <https://stat.ethz.ch/R-manual/R-devel/library/survival/html/lung.html>`_ as an example. The first 8 columns represent features and the last column, Time to death, represents the label.
==== === === ======= ======== ========= ======== ======= ========================
Inst Age Sex ph.ecog ph.karno pat.karno meal.cal wt.loss **Time to death (days)**
==== === === ======= ======== ========= ======== ======= ========================
3 74 1 1 90 100 1175 N/A 306
3 68 1 0 90 90 1225 15 455
3 56 1 0 90 90 N/A 15 :math:`[1010, +\infty)`
5 57 1 1 90 60 1150 11 210
1 60 1 0 100 90 N/A 0 883
12 74 1 1 50 80 513 0 :math:`[1022, +\infty)`
7 68 2 2 70 60 384 10 310
==== === === ======= ======== ========= ======== ======= ========================
Take a close look at the label for the third patient. **His label is a range, not a single number.** The third patient's label is said to be **censored**, because for some reason the experimenters could not get a complete measurement for that label. One possible scenario: the patient survived the first 1010 days and walked out of the clinic on the 1011th day, so his death was not directly observed. Another possibility: The experiment was cut short (since you cannot run it forever) before his death could be observed. In any case, his label is :math:`[1010, +\infty)`, meaning his time to death can be any number that's higher than 1010, e.g. 2000, 3000, or 10000.
There are four kinds of censoring:
* **Uncensored**: the label is not censored and given as a single number.
* **Right-censored**: the label is of form :math:`[a, +\infty)`, where :math:`a` is the lower bound.
* **Left-censored**: the label is of form :math:`(-\infty, b]`, where :math:`b` is the upper bound.
* **Interval-censored**: the label is of form :math:`[a, b]`, where :math:`a` and :math:`b` are the lower and upper bounds, respectively.
Right-censoring is the most commonly used.
******************************
Accelerated Failure Time model
******************************
**Accelerated Failure Time (AFT)** model is one of the most commonly used models in survival analysis. The model is of the following form:
.. math::
\ln{Y} = \langle \mathbf{w}, \mathbf{x} \rangle + \sigma Z
where
* :math:`\mathbf{x}` is a vector in :math:`\mathbb{R}^d` representing the features.
* :math:`\mathbf{w}` is a vector consisting of :math:`d` coefficients, each corresponding to a feature.
* :math:`\langle \cdot, \cdot \rangle` is the usual dot product in :math:`\mathbb{R}^d`.
* :math:`\ln{(\cdot)}` is the natural logarithm.
* :math:`Y` and :math:`Z` are random variables.
- :math:`Y` is the output label.
- :math:`Z` is a random variable of a known probability distribution. Common choices are the normal distribution, the logistic distribution, and the extreme distribution. Intuitively, :math:`Z` represents the "noise" that pulls the prediction :math:`\langle \mathbf{w}, \mathbf{x} \rangle` away from the true log label :math:`\ln{Y}`.
* :math:`\sigma` is a parameter that scales the size of :math:`Z`.
Note that this model is a generalized form of a linear regression model :math:`Y = \langle \mathbf{w}, \mathbf{x} \rangle`. In order to make AFT work with gradient boosting, we revise the model as follows:
.. math::
\ln{Y} = \mathcal{T}(\mathbf{x}) + \sigma Z
where :math:`\mathcal{T}(\mathbf{x})` represents the output from a decision tree ensemble, given input :math:`\mathbf{x}`. Since :math:`Z` is a random variable, we have a likelihood defined for the expression :math:`\ln{Y} = \mathcal{T}(\mathbf{x}) + \sigma Z`. So the goal for XGBoost is to maximize the (log) likelihood by fitting a good tree ensemble :math:`\mathbf{x}`.
**********
How to use
**********
The first step is to express the labels in the form of a range, so that **every data point has two numbers associated with it, namely the lower and upper bounds for the label.** For uncensored labels, use a degenerate interval of form :math:`[a, a]`.
.. |tick| unicode:: U+2714
.. |cross| unicode:: U+2718
================= ==================== =================== ===================
Censoring type Interval form Lower bound finite? Upper bound finite?
================= ==================== =================== ===================
Uncensored :math:`[a, a]` |tick| |tick|
Right-censored :math:`[a, +\infty)` |tick| |cross|
Left-censored :math:`(-\infty, b]` |cross| |tick|
Interval-censored :math:`[a, b]` |tick| |tick|
================= ==================== =================== ===================
Collect the lower bound numbers in one array (let's call it ``y_lower_bound``) and the upper bound number in another array (call it ``y_upper_bound``). The ranged labels are associated with a data matrix object via calls to :meth:`xgboost.DMatrix.set_float_info`:
.. code-block:: python
import numpy as np
import xgboost as xgb
# 4-by-2 Data matrix
X = np.array([[1, -1], [-1, 1], [0, 1], [1, 0]])
dtrain = xgb.DMatrix(X)
# Associate ranged labels with the data matrix.
# This example shows each kind of censored labels.
# uncensored right left interval
y_lower_bound = np.array([ 2.0, 3.0, -np.inf, 4.0])
y_upper_bound = np.array([ 2.0, +np.inf, 4.0, 5.0])
dtrain.set_float_info('label_lower_bound', y_lower_bound)
dtrain.set_float_info('label_upper_bound', y_upper_bound)
Now we are ready to invoke the training API:
.. code-block:: python
params = {'objective': 'survival:aft',
'eval_metric': 'aft-nloglik',
'aft_loss_distribution': 'normal',
'aft_loss_distribution_scale': 1.20,
'tree_method': 'hist', 'learning_rate': 0.05, 'max_depth': 2}
bst = xgb.train(params, dtrain, num_boost_round=5,
evals=[(dtrain, 'train'), (dvalid, 'valid')])
We set ``objective`` parameter to ``survival:aft`` and ``eval_metric`` to ``aft-nloglik``, so that the log likelihood for the AFT model would be maximized. (XGBoost will actually minimize the negative log likelihood, hence the name ``aft-nloglik``.)
The parameter ``aft_loss_distribution`` corresponds to the distribution of the :math:`Z` term in the AFT model, and ``aft_loss_distribution_scale`` corresponds to the scaling factor :math:`\sigma`.
Currently, you can choose from three probability distributions for ``aft_loss_distribution``:
========================= ===========================================
``aft_loss_distribution`` Probabilty Density Function (PDF)
========================= ===========================================
``normal`` :math:`\dfrac{\exp{(-z^2/2)}}{\sqrt{2\pi}}`
``logistic`` :math:`\dfrac{e^z}{(1+e^z)^2}`
``extreme`` :math:`e^z e^{-\exp{z}}`
========================= ===========================================
Note that it is not yet possible to set the ranged label using the scikit-learn interface (e.g. :class:`xgboost.XGBRegressor`). For now, you should use :class:`xgboost.train` with :class:`xgboost.DMatrix`.