From 4633b476e9eed94f691985287a6dea7ea76e2e47 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Thu, 13 Oct 2022 20:51:23 +0800 Subject: [PATCH] [doc] Display survival demos in sphinx doc. [skip ci] (#8328) --- demo/aft_survival/README.rst | 5 +++++ demo/aft_survival/aft_survival_demo.py | 6 +++++- demo/aft_survival/aft_survival_demo_with_optuna.py | 12 ++++++++---- demo/aft_survival/aft_survival_viz_demo.py | 9 +++++---- doc/conf.py | 4 ++-- doc/python/.gitignore | 3 ++- doc/python/index.rst | 1 + doc/tutorials/aft_survival_analysis.rst | 6 +++--- 8 files changed, 31 insertions(+), 15 deletions(-) create mode 100644 demo/aft_survival/README.rst diff --git a/demo/aft_survival/README.rst b/demo/aft_survival/README.rst new file mode 100644 index 000000000..7b75ae072 --- /dev/null +++ b/demo/aft_survival/README.rst @@ -0,0 +1,5 @@ +Survival Analysis Walkthrough +============================= + +This is a collection of examples for using the XGBoost Python package for training +survival models. For an introduction, see :doc:`/tutorials/aft_survival_analysis` diff --git a/demo/aft_survival/aft_survival_demo.py b/demo/aft_survival/aft_survival_demo.py index 0a659e79e..7046548b3 100644 --- a/demo/aft_survival/aft_survival_demo.py +++ b/demo/aft_survival/aft_survival_demo.py @@ -1,6 +1,10 @@ """ -Demo for survival analysis (regression) using Accelerated Failure Time (AFT) model +Demo for survival analysis (regression). +======================================== + +Demo for survival analysis (regression). using Accelerated Failure Time (AFT) model. """ + import os from sklearn.model_selection import ShuffleSplit import pandas as pd diff --git a/demo/aft_survival/aft_survival_demo_with_optuna.py b/demo/aft_survival/aft_survival_demo_with_optuna.py index 117be8ba1..a6cf2aaf6 100644 --- a/demo/aft_survival/aft_survival_demo_with_optuna.py +++ b/demo/aft_survival/aft_survival_demo_with_optuna.py @@ -1,6 +1,10 @@ """ -Demo for survival analysis (regression) using Accelerated Failure Time (AFT) model, using Optuna -to tune hyperparameters +Demo for survival analysis (regression) with Optuna. +==================================================== + +Demo for survival analysis (regression) using Accelerated Failure Time (AFT) model, +using Optuna to tune hyperparameters + """ from sklearn.model_selection import ShuffleSplit import pandas as pd @@ -45,7 +49,7 @@ def objective(trial): params.update(base_params) pruning_callback = optuna.integration.XGBoostPruningCallback(trial, 'valid-aft-nloglik') bst = xgb.train(params, dtrain, num_boost_round=10000, - evals=[(dtrain, 'train'), (dvalid, 'valid')], + evals=[(dtrain, 'train'), (dvalid, 'valid')], early_stopping_rounds=50, verbose_eval=False, callbacks=[pruning_callback]) if bst.best_iteration >= 25: return bst.best_score @@ -63,7 +67,7 @@ params.update(study.best_trial.params) # Re-run training with the best hyperparameter combination print('Re-running the best trial... params = {}'.format(params)) bst = xgb.train(params, dtrain, num_boost_round=10000, - evals=[(dtrain, 'train'), (dvalid, 'valid')], + evals=[(dtrain, 'train'), (dvalid, 'valid')], early_stopping_rounds=50) # Run prediction on the validation set diff --git a/demo/aft_survival/aft_survival_viz_demo.py b/demo/aft_survival/aft_survival_viz_demo.py index fe622f9e2..beb0db40c 100644 --- a/demo/aft_survival/aft_survival_viz_demo.py +++ b/demo/aft_survival/aft_survival_viz_demo.py @@ -1,9 +1,10 @@ """ Visual demo for survival analysis (regression) with Accelerated Failure Time (AFT) model. +========================================================================================= -This demo uses 1D toy data and visualizes how XGBoost fits a tree ensemble. The ensemble model -starts out as a flat line and evolves into a step function in order to account for all ranged -labels. +This demo uses 1D toy data and visualizes how XGBoost fits a tree ensemble. The ensemble +model starts out as a flat line and evolves into a step function in order to account for +all ranged labels. """ import numpy as np import xgboost as xgb @@ -57,7 +58,7 @@ def plot_intermediate_model_callback(env): # the corresponding predicted label (y_pred) acc = np.sum(np.logical_and(y_pred >= y_lower, y_pred <= y_upper)/len(X) * 100) accuracy_history.append(acc) - + # Plot ranged labels as well as predictions by the model plt.subplot(5, 3, env.iteration + 1) plot_censored_labels(X, y_lower, y_upper) diff --git a/doc/conf.py b/doc/conf.py index 83ad3e3db..bb51a7f8f 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -93,9 +93,9 @@ extensions = [ sphinx_gallery_conf = { # path to your example scripts - "examples_dirs": ["../demo/guide-python", "../demo/dask"], + "examples_dirs": ["../demo/guide-python", "../demo/dask", "../demo/aft_survival"], # path to where to save gallery generated output - "gallery_dirs": ["python/examples", "python/dask-examples"], + "gallery_dirs": ["python/examples", "python/dask-examples", "python/survival-examples"], "matplotlib_animations": True, } diff --git a/doc/python/.gitignore b/doc/python/.gitignore index 843a492dd..bb0916d77 100644 --- a/doc/python/.gitignore +++ b/doc/python/.gitignore @@ -1,2 +1,3 @@ examples -dask-examples \ No newline at end of file +dask-examples +survival-examples \ No newline at end of file diff --git a/doc/python/index.rst b/doc/python/index.rst index cffc8a7fd..60608700b 100644 --- a/doc/python/index.rst +++ b/doc/python/index.rst @@ -15,3 +15,4 @@ Contents model examples/index dask-examples/index + survival-examples/index diff --git a/doc/tutorials/aft_survival_analysis.rst b/doc/tutorials/aft_survival_analysis.rst index adce5c3d0..4530c0749 100644 --- a/doc/tutorials/aft_survival_analysis.rst +++ b/doc/tutorials/aft_survival_analysis.rst @@ -98,7 +98,7 @@ Collect the lower bound numbers in one array (let's call it ``y_lower_bound``) a # 4-by-2 Data matrix X = np.array([[1, -1], [-1, 1], [0, 1], [1, 0]]) dtrain = xgb.DMatrix(X) - + # Associate ranged labels with the data matrix. # This example shows each kind of censored labels. # uncensored right left interval @@ -109,7 +109,7 @@ Collect the lower bound numbers in one array (let's call it ``y_lower_bound``) a .. code-block:: r :caption: R - + library(xgboost) # 4-by-2 Data matrix @@ -165,4 +165,4 @@ Currently, you can choose from three probability distributions for ``aft_loss_di ``extreme`` :math:`e^z e^{-\exp{z}}` ========================= =========================================== -Note that it is not yet possible to set the ranged label using the scikit-learn interface (e.g. :class:`xgboost.XGBRegressor`). For now, you should use :class:`xgboost.train` with :class:`xgboost.DMatrix`. +Note that it is not yet possible to set the ranged label using the scikit-learn interface (e.g. :class:`xgboost.XGBRegressor`). For now, you should use :class:`xgboost.train` with :class:`xgboost.DMatrix`. For a collection of Python examples, see :doc:`/python/survival-examples/index`