[doc] Display survival demos in sphinx doc. [skip ci] (#8328)

This commit is contained in:
Jiaming Yuan 2022-10-13 20:51:23 +08:00 committed by GitHub
parent 3ef1703553
commit 4633b476e9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 31 additions and 15 deletions

View File

@ -0,0 +1,5 @@
Survival Analysis Walkthrough
=============================
This is a collection of examples for using the XGBoost Python package for training
survival models. For an introduction, see :doc:`/tutorials/aft_survival_analysis`

View File

@ -1,6 +1,10 @@
""" """
Demo for survival analysis (regression) using Accelerated Failure Time (AFT) model Demo for survival analysis (regression).
========================================
Demo for survival analysis (regression). using Accelerated Failure Time (AFT) model.
""" """
import os import os
from sklearn.model_selection import ShuffleSplit from sklearn.model_selection import ShuffleSplit
import pandas as pd import pandas as pd

View File

@ -1,6 +1,10 @@
""" """
Demo for survival analysis (regression) using Accelerated Failure Time (AFT) model, using Optuna Demo for survival analysis (regression) with Optuna.
to tune hyperparameters ====================================================
Demo for survival analysis (regression) using Accelerated Failure Time (AFT) model,
using Optuna to tune hyperparameters
""" """
from sklearn.model_selection import ShuffleSplit from sklearn.model_selection import ShuffleSplit
import pandas as pd import pandas as pd
@ -45,7 +49,7 @@ def objective(trial):
params.update(base_params) params.update(base_params)
pruning_callback = optuna.integration.XGBoostPruningCallback(trial, 'valid-aft-nloglik') pruning_callback = optuna.integration.XGBoostPruningCallback(trial, 'valid-aft-nloglik')
bst = xgb.train(params, dtrain, num_boost_round=10000, bst = xgb.train(params, dtrain, num_boost_round=10000,
evals=[(dtrain, 'train'), (dvalid, 'valid')], evals=[(dtrain, 'train'), (dvalid, 'valid')],
early_stopping_rounds=50, verbose_eval=False, callbacks=[pruning_callback]) early_stopping_rounds=50, verbose_eval=False, callbacks=[pruning_callback])
if bst.best_iteration >= 25: if bst.best_iteration >= 25:
return bst.best_score return bst.best_score
@ -63,7 +67,7 @@ params.update(study.best_trial.params)
# Re-run training with the best hyperparameter combination # Re-run training with the best hyperparameter combination
print('Re-running the best trial... params = {}'.format(params)) print('Re-running the best trial... params = {}'.format(params))
bst = xgb.train(params, dtrain, num_boost_round=10000, bst = xgb.train(params, dtrain, num_boost_round=10000,
evals=[(dtrain, 'train'), (dvalid, 'valid')], evals=[(dtrain, 'train'), (dvalid, 'valid')],
early_stopping_rounds=50) early_stopping_rounds=50)
# Run prediction on the validation set # Run prediction on the validation set

View File

@ -1,9 +1,10 @@
""" """
Visual demo for survival analysis (regression) with Accelerated Failure Time (AFT) model. Visual demo for survival analysis (regression) with Accelerated Failure Time (AFT) model.
=========================================================================================
This demo uses 1D toy data and visualizes how XGBoost fits a tree ensemble. The ensemble model This demo uses 1D toy data and visualizes how XGBoost fits a tree ensemble. The ensemble
starts out as a flat line and evolves into a step function in order to account for all ranged model starts out as a flat line and evolves into a step function in order to account for
labels. all ranged labels.
""" """
import numpy as np import numpy as np
import xgboost as xgb import xgboost as xgb
@ -57,7 +58,7 @@ def plot_intermediate_model_callback(env):
# the corresponding predicted label (y_pred) # the corresponding predicted label (y_pred)
acc = np.sum(np.logical_and(y_pred >= y_lower, y_pred <= y_upper)/len(X) * 100) acc = np.sum(np.logical_and(y_pred >= y_lower, y_pred <= y_upper)/len(X) * 100)
accuracy_history.append(acc) accuracy_history.append(acc)
# Plot ranged labels as well as predictions by the model # Plot ranged labels as well as predictions by the model
plt.subplot(5, 3, env.iteration + 1) plt.subplot(5, 3, env.iteration + 1)
plot_censored_labels(X, y_lower, y_upper) plot_censored_labels(X, y_lower, y_upper)

View File

@ -93,9 +93,9 @@ extensions = [
sphinx_gallery_conf = { sphinx_gallery_conf = {
# path to your example scripts # path to your example scripts
"examples_dirs": ["../demo/guide-python", "../demo/dask"], "examples_dirs": ["../demo/guide-python", "../demo/dask", "../demo/aft_survival"],
# path to where to save gallery generated output # path to where to save gallery generated output
"gallery_dirs": ["python/examples", "python/dask-examples"], "gallery_dirs": ["python/examples", "python/dask-examples", "python/survival-examples"],
"matplotlib_animations": True, "matplotlib_animations": True,
} }

View File

@ -1,2 +1,3 @@
examples examples
dask-examples dask-examples
survival-examples

View File

@ -15,3 +15,4 @@ Contents
model model
examples/index examples/index
dask-examples/index dask-examples/index
survival-examples/index

View File

@ -98,7 +98,7 @@ Collect the lower bound numbers in one array (let's call it ``y_lower_bound``) a
# 4-by-2 Data matrix # 4-by-2 Data matrix
X = np.array([[1, -1], [-1, 1], [0, 1], [1, 0]]) X = np.array([[1, -1], [-1, 1], [0, 1], [1, 0]])
dtrain = xgb.DMatrix(X) dtrain = xgb.DMatrix(X)
# Associate ranged labels with the data matrix. # Associate ranged labels with the data matrix.
# This example shows each kind of censored labels. # This example shows each kind of censored labels.
# uncensored right left interval # uncensored right left interval
@ -109,7 +109,7 @@ Collect the lower bound numbers in one array (let's call it ``y_lower_bound``) a
.. code-block:: r .. code-block:: r
:caption: R :caption: R
library(xgboost) library(xgboost)
# 4-by-2 Data matrix # 4-by-2 Data matrix
@ -165,4 +165,4 @@ Currently, you can choose from three probability distributions for ``aft_loss_di
``extreme`` :math:`e^z e^{-\exp{z}}` ``extreme`` :math:`e^z e^{-\exp{z}}`
========================= =========================================== ========================= ===========================================
Note that it is not yet possible to set the ranged label using the scikit-learn interface (e.g. :class:`xgboost.XGBRegressor`). For now, you should use :class:`xgboost.train` with :class:`xgboost.DMatrix`. Note that it is not yet possible to set the ranged label using the scikit-learn interface (e.g. :class:`xgboost.XGBRegressor`). For now, you should use :class:`xgboost.train` with :class:`xgboost.DMatrix`. For a collection of Python examples, see :doc:`/python/survival-examples/index`