[doc] Display survival demos in sphinx doc. [skip ci] (#8328)

This commit is contained in:
Jiaming Yuan
2022-10-13 20:51:23 +08:00
committed by GitHub
parent 3ef1703553
commit 4633b476e9
8 changed files with 31 additions and 15 deletions

View File

@@ -1,6 +1,10 @@
"""
Demo for survival analysis (regression) using Accelerated Failure Time (AFT) model, using Optuna
to tune hyperparameters
Demo for survival analysis (regression) with Optuna.
====================================================
Demo for survival analysis (regression) using Accelerated Failure Time (AFT) model,
using Optuna to tune hyperparameters
"""
from sklearn.model_selection import ShuffleSplit
import pandas as pd
@@ -45,7 +49,7 @@ def objective(trial):
params.update(base_params)
pruning_callback = optuna.integration.XGBoostPruningCallback(trial, 'valid-aft-nloglik')
bst = xgb.train(params, dtrain, num_boost_round=10000,
evals=[(dtrain, 'train'), (dvalid, 'valid')],
evals=[(dtrain, 'train'), (dvalid, 'valid')],
early_stopping_rounds=50, verbose_eval=False, callbacks=[pruning_callback])
if bst.best_iteration >= 25:
return bst.best_score
@@ -63,7 +67,7 @@ params.update(study.best_trial.params)
# Re-run training with the best hyperparameter combination
print('Re-running the best trial... params = {}'.format(params))
bst = xgb.train(params, dtrain, num_boost_round=10000,
evals=[(dtrain, 'train'), (dvalid, 'valid')],
evals=[(dtrain, 'train'), (dvalid, 'valid')],
early_stopping_rounds=50)
# Run prediction on the validation set