Add R code to AFT tutorial [skip ci] (#5486)

This commit is contained in:
Philip Hyunsu Cho 2020-04-04 13:06:12 -07:00 committed by GitHub
parent 15800107ad
commit 30e94ddd04
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -90,6 +90,7 @@ Interval-censored :math:`[a, b]` |tick| |tick|
Collect the lower bound numbers in one array (let's call it ``y_lower_bound``) and the upper bound number in another array (call it ``y_upper_bound``). The ranged labels are associated with a data matrix object via calls to :meth:`xgboost.DMatrix.set_float_info`:
.. code-block:: python
:caption: Python
import numpy as np
import xgboost as xgb
@ -105,10 +106,29 @@ Collect the lower bound numbers in one array (let's call it ``y_lower_bound``) a
y_upper_bound = np.array([ 2.0, +np.inf, 4.0, 5.0])
dtrain.set_float_info('label_lower_bound', y_lower_bound)
dtrain.set_float_info('label_upper_bound', y_upper_bound)
.. code-block:: r
:caption: R
library(xgboost)
# 4-by-2 Data matrix
X <- matrix(c(1., -1., -1., 1., 0., 1., 1., 0.),
nrow=4, ncol=2, byrow=TRUE)
dtrain <- xgb.DMatrix(X)
# Associate ranged labels with the data matrix.
# This example shows each kind of censored labels.
# uncensored right left interval
y_lower_bound <- c( 2., 3., -Inf, 4.)
y_upper_bound <- c( 2., +Inf, 4., 5.)
setinfo(dtrain, 'label_lower_bound', y_lower_bound)
setinfo(dtrain, 'label_upper_bound', y_upper_bound)
Now we are ready to invoke the training API:
.. code-block:: python
:caption: Python
params = {'objective': 'survival:aft',
'eval_metric': 'aft-nloglik',
@ -118,6 +138,19 @@ Now we are ready to invoke the training API:
bst = xgb.train(params, dtrain, num_boost_round=5,
evals=[(dtrain, 'train'), (dvalid, 'valid')])
.. code-block:: r
:caption: R
params <- list(objective='survival:aft',
eval_metric='aft-nloglik',
aft_loss_distribution='normal',
aft_loss_distribution_scale=1.20,
tree_method='hist',
learning_rate=0.05,
max_depth=2)
watchlist <- list(train = dtrain)
bst <- xgb.train(params, dtrain, nrounds=5, watchlist)
We set ``objective`` parameter to ``survival:aft`` and ``eval_metric`` to ``aft-nloglik``, so that the log likelihood for the AFT model would be maximized. (XGBoost will actually minimize the negative log likelihood, hence the name ``aft-nloglik``.)
The parameter ``aft_loss_distribution`` corresponds to the distribution of the :math:`Z` term in the AFT model, and ``aft_loss_distribution_scale`` corresponds to the scaling factor :math:`\sigma`.