[doc] Display survival demos in sphinx doc. [skip ci] (#8328)
This commit is contained in:
5
demo/aft_survival/README.rst
Normal file
5
demo/aft_survival/README.rst
Normal file
@@ -0,0 +1,5 @@
|
||||
Survival Analysis Walkthrough
|
||||
=============================
|
||||
|
||||
This is a collection of examples for using the XGBoost Python package for training
|
||||
survival models. For an introduction, see :doc:`/tutorials/aft_survival_analysis`
|
||||
@@ -1,6 +1,10 @@
|
||||
"""
|
||||
Demo for survival analysis (regression) using Accelerated Failure Time (AFT) model
|
||||
Demo for survival analysis (regression).
|
||||
========================================
|
||||
|
||||
Demo for survival analysis (regression). using Accelerated Failure Time (AFT) model.
|
||||
"""
|
||||
|
||||
import os
|
||||
from sklearn.model_selection import ShuffleSplit
|
||||
import pandas as pd
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
"""
|
||||
Demo for survival analysis (regression) using Accelerated Failure Time (AFT) model, using Optuna
|
||||
to tune hyperparameters
|
||||
Demo for survival analysis (regression) with Optuna.
|
||||
====================================================
|
||||
|
||||
Demo for survival analysis (regression) using Accelerated Failure Time (AFT) model,
|
||||
using Optuna to tune hyperparameters
|
||||
|
||||
"""
|
||||
from sklearn.model_selection import ShuffleSplit
|
||||
import pandas as pd
|
||||
@@ -45,7 +49,7 @@ def objective(trial):
|
||||
params.update(base_params)
|
||||
pruning_callback = optuna.integration.XGBoostPruningCallback(trial, 'valid-aft-nloglik')
|
||||
bst = xgb.train(params, dtrain, num_boost_round=10000,
|
||||
evals=[(dtrain, 'train'), (dvalid, 'valid')],
|
||||
evals=[(dtrain, 'train'), (dvalid, 'valid')],
|
||||
early_stopping_rounds=50, verbose_eval=False, callbacks=[pruning_callback])
|
||||
if bst.best_iteration >= 25:
|
||||
return bst.best_score
|
||||
@@ -63,7 +67,7 @@ params.update(study.best_trial.params)
|
||||
# Re-run training with the best hyperparameter combination
|
||||
print('Re-running the best trial... params = {}'.format(params))
|
||||
bst = xgb.train(params, dtrain, num_boost_round=10000,
|
||||
evals=[(dtrain, 'train'), (dvalid, 'valid')],
|
||||
evals=[(dtrain, 'train'), (dvalid, 'valid')],
|
||||
early_stopping_rounds=50)
|
||||
|
||||
# Run prediction on the validation set
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
"""
|
||||
Visual demo for survival analysis (regression) with Accelerated Failure Time (AFT) model.
|
||||
=========================================================================================
|
||||
|
||||
This demo uses 1D toy data and visualizes how XGBoost fits a tree ensemble. The ensemble model
|
||||
starts out as a flat line and evolves into a step function in order to account for all ranged
|
||||
labels.
|
||||
This demo uses 1D toy data and visualizes how XGBoost fits a tree ensemble. The ensemble
|
||||
model starts out as a flat line and evolves into a step function in order to account for
|
||||
all ranged labels.
|
||||
"""
|
||||
import numpy as np
|
||||
import xgboost as xgb
|
||||
@@ -57,7 +58,7 @@ def plot_intermediate_model_callback(env):
|
||||
# the corresponding predicted label (y_pred)
|
||||
acc = np.sum(np.logical_and(y_pred >= y_lower, y_pred <= y_upper)/len(X) * 100)
|
||||
accuracy_history.append(acc)
|
||||
|
||||
|
||||
# Plot ranged labels as well as predictions by the model
|
||||
plt.subplot(5, 3, env.iteration + 1)
|
||||
plot_censored_labels(X, y_lower, y_upper)
|
||||
|
||||
Reference in New Issue
Block a user