Add R code to AFT tutorial [skip ci] (#5486)
This commit is contained in:
parent
15800107ad
commit
30e94ddd04
@ -90,6 +90,7 @@ Interval-censored :math:`[a, b]` |tick| |tick|
|
||||
Collect the lower bound numbers in one array (let's call it ``y_lower_bound``) and the upper bound number in another array (call it ``y_upper_bound``). The ranged labels are associated with a data matrix object via calls to :meth:`xgboost.DMatrix.set_float_info`:
|
||||
|
||||
.. code-block:: python
|
||||
:caption: Python
|
||||
|
||||
import numpy as np
|
||||
import xgboost as xgb
|
||||
@ -105,10 +106,29 @@ Collect the lower bound numbers in one array (let's call it ``y_lower_bound``) a
|
||||
y_upper_bound = np.array([ 2.0, +np.inf, 4.0, 5.0])
|
||||
dtrain.set_float_info('label_lower_bound', y_lower_bound)
|
||||
dtrain.set_float_info('label_upper_bound', y_upper_bound)
|
||||
|
||||
.. code-block:: r
|
||||
:caption: R
|
||||
|
||||
library(xgboost)
|
||||
|
||||
# 4-by-2 Data matrix
|
||||
X <- matrix(c(1., -1., -1., 1., 0., 1., 1., 0.),
|
||||
nrow=4, ncol=2, byrow=TRUE)
|
||||
dtrain <- xgb.DMatrix(X)
|
||||
|
||||
# Associate ranged labels with the data matrix.
|
||||
# This example shows each kind of censored labels.
|
||||
# uncensored right left interval
|
||||
y_lower_bound <- c( 2., 3., -Inf, 4.)
|
||||
y_upper_bound <- c( 2., +Inf, 4., 5.)
|
||||
setinfo(dtrain, 'label_lower_bound', y_lower_bound)
|
||||
setinfo(dtrain, 'label_upper_bound', y_upper_bound)
|
||||
|
||||
Now we are ready to invoke the training API:
|
||||
|
||||
.. code-block:: python
|
||||
:caption: Python
|
||||
|
||||
params = {'objective': 'survival:aft',
|
||||
'eval_metric': 'aft-nloglik',
|
||||
@ -118,6 +138,19 @@ Now we are ready to invoke the training API:
|
||||
bst = xgb.train(params, dtrain, num_boost_round=5,
|
||||
evals=[(dtrain, 'train'), (dvalid, 'valid')])
|
||||
|
||||
.. code-block:: r
|
||||
:caption: R
|
||||
|
||||
params <- list(objective='survival:aft',
|
||||
eval_metric='aft-nloglik',
|
||||
aft_loss_distribution='normal',
|
||||
aft_loss_distribution_scale=1.20,
|
||||
tree_method='hist',
|
||||
learning_rate=0.05,
|
||||
max_depth=2)
|
||||
watchlist <- list(train = dtrain)
|
||||
bst <- xgb.train(params, dtrain, nrounds=5, watchlist)
|
||||
|
||||
We set ``objective`` parameter to ``survival:aft`` and ``eval_metric`` to ``aft-nloglik``, so that the log likelihood for the AFT model would be maximized. (XGBoost will actually minimize the negative log likelihood, hence the name ``aft-nloglik``.)
|
||||
|
||||
The parameter ``aft_loss_distribution`` corresponds to the distribution of the :math:`Z` term in the AFT model, and ``aft_loss_distribution_scale`` corresponds to the scaling factor :math:`\sigma`.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user