Add Accelerated Failure Time loss for survival analysis task (#4763)
* [WIP] Add lower and upper bounds on the label for survival analysis * Update test MetaInfo.SaveLoadBinary to account for extra two fields * Don't clear qids_ for version 2 of MetaInfo * Add SetInfo() and GetInfo() method for lower and upper bounds * changes to aft * Add parameter class for AFT; use enum's to represent distribution and event type * Add AFT metric * changes to neg grad to grad * changes to binomial loss * changes to overflow * changes to eps * changes to code refactoring * changes to code refactoring * changes to code refactoring * Re-factor survival analysis * Remove aft namespace * Move function bodies out of AFTNormal and AFTLogistic, to reduce clutter * Move function bodies out of AFTLoss, to reduce clutter * Use smart pointer to store AFTDistribution and AFTLoss * Rename AFTNoiseDistribution enum to AFTDistributionType for clarity The enum class was not a distribution itself but a distribution type * Add AFTDistribution::Create() method for convenience * changes to extreme distribution * changes to extreme distribution * changes to extreme * changes to extreme distribution * changes to left censored * deleted cout * changes to x,mu and sd and code refactoring * changes to print * changes to hessian formula in censored and uncensored * changes to variable names and pow * changes to Logistic Pdf * changes to parameter * Expose lower and upper bound labels to R package * Use example weights; normalize log likelihood metric * changes to CHECK * changes to logistic hessian to standard formula * changes to logistic formula * Comply with coding style guideline * Revert back Rabit submodule * Revert dmlc-core submodule * Comply with coding style guideline (clang-tidy) * Fix an error in AFTLoss::Gradient() * Add missing files to amalgamation * Address @RAMitchell's comment: minimize future change in MetaInfo interface * Fix lint * Fix compilation error on 32-bit target, when size_t == bst_uint * Allocate sufficient memory to hold extra label info * Use OpenMP to speed up * Fix compilation on Windows * Address reviewer's feedback * Add unit tests for probability distributions * Make Metric subclass of Configurable * Address reviewer's feedback: Configure() AFT metric * Add a dummy test for AFT metric configuration * Complete AFT configuration test; remove debugging print * Rename AFT parameters * Clarify test comment * Add a dummy test for AFT loss for uncensored case * Fix a bug in AFT loss for uncensored labels * Complete unit test for AFT loss metric * Simplify unit tests for AFT metric * Add unit test to verify aggregate output from AFT metric * Use EXPECT_* instead of ASSERT_*, so that we run all unit tests * Use aft_loss_param when serializing AFTObj This is to be consistent with AFT metric * Add unit tests for AFT Objective * Fix OpenMP bug; clarify semantics for shared variables used in OpenMP loops * Add comments * Remove AFT prefix from probability distribution; put probability distribution in separate source file * Add comments * Define kPI and kEulerMascheroni in probability_distribution.h * Add probability_distribution.cc to amalgamation * Remove unnecessary diff * Address reviewer's feedback: define variables where they're used * Eliminate all INFs and NANs from AFT loss and gradient * Add demo * Add tutorial * Fix lint * Use 'survival:aft' to be consistent with 'survival:cox' * Move sample data to demo/data * Add visual demo with 1D toy data * Add Python tests Co-authored-by: Philip Cho <chohyu01@cs.washington.edu>
This commit is contained in:
138
demo/data/veterans_lung_cancer.csv
Normal file
138
demo/data/veterans_lung_cancer.csv
Normal file
@@ -0,0 +1,138 @@
|
||||
Survival_label_lower_bound,Survival_label_upper_bound,Age_in_years,Karnofsky_score,Months_from_Diagnosis,Celltype=adeno,Celltype=large,Celltype=smallcell,Celltype=squamous,Prior_therapy=no,Prior_therapy=yes,Treatment=standard,Treatment=test
|
||||
72.0,72.0,69.0,60.0,7.0,0,0,0,1,1,0,1,0
|
||||
411.0,411.0,64.0,70.0,5.0,0,0,0,1,0,1,1,0
|
||||
228.0,228.0,38.0,60.0,3.0,0,0,0,1,1,0,1,0
|
||||
126.0,126.0,63.0,60.0,9.0,0,0,0,1,0,1,1,0
|
||||
118.0,118.0,65.0,70.0,11.0,0,0,0,1,0,1,1,0
|
||||
10.0,10.0,49.0,20.0,5.0,0,0,0,1,1,0,1,0
|
||||
82.0,82.0,69.0,40.0,10.0,0,0,0,1,0,1,1,0
|
||||
110.0,110.0,68.0,80.0,29.0,0,0,0,1,1,0,1,0
|
||||
314.0,314.0,43.0,50.0,18.0,0,0,0,1,1,0,1,0
|
||||
100.0,inf,70.0,70.0,6.0,0,0,0,1,1,0,1,0
|
||||
42.0,42.0,81.0,60.0,4.0,0,0,0,1,1,0,1,0
|
||||
8.0,8.0,63.0,40.0,58.0,0,0,0,1,0,1,1,0
|
||||
144.0,144.0,63.0,30.0,4.0,0,0,0,1,1,0,1,0
|
||||
25.0,inf,52.0,80.0,9.0,0,0,0,1,0,1,1,0
|
||||
11.0,11.0,48.0,70.0,11.0,0,0,0,1,0,1,1,0
|
||||
30.0,30.0,61.0,60.0,3.0,0,0,1,0,1,0,1,0
|
||||
384.0,384.0,42.0,60.0,9.0,0,0,1,0,1,0,1,0
|
||||
4.0,4.0,35.0,40.0,2.0,0,0,1,0,1,0,1,0
|
||||
54.0,54.0,63.0,80.0,4.0,0,0,1,0,0,1,1,0
|
||||
13.0,13.0,56.0,60.0,4.0,0,0,1,0,1,0,1,0
|
||||
123.0,inf,55.0,40.0,3.0,0,0,1,0,1,0,1,0
|
||||
97.0,inf,67.0,60.0,5.0,0,0,1,0,1,0,1,0
|
||||
153.0,153.0,63.0,60.0,14.0,0,0,1,0,0,1,1,0
|
||||
59.0,59.0,65.0,30.0,2.0,0,0,1,0,1,0,1,0
|
||||
117.0,117.0,46.0,80.0,3.0,0,0,1,0,1,0,1,0
|
||||
16.0,16.0,53.0,30.0,4.0,0,0,1,0,0,1,1,0
|
||||
151.0,151.0,69.0,50.0,12.0,0,0,1,0,1,0,1,0
|
||||
22.0,22.0,68.0,60.0,4.0,0,0,1,0,1,0,1,0
|
||||
56.0,56.0,43.0,80.0,12.0,0,0,1,0,0,1,1,0
|
||||
21.0,21.0,55.0,40.0,2.0,0,0,1,0,0,1,1,0
|
||||
18.0,18.0,42.0,20.0,15.0,0,0,1,0,1,0,1,0
|
||||
139.0,139.0,64.0,80.0,2.0,0,0,1,0,1,0,1,0
|
||||
20.0,20.0,65.0,30.0,5.0,0,0,1,0,1,0,1,0
|
||||
31.0,31.0,65.0,75.0,3.0,0,0,1,0,1,0,1,0
|
||||
52.0,52.0,55.0,70.0,2.0,0,0,1,0,1,0,1,0
|
||||
287.0,287.0,66.0,60.0,25.0,0,0,1,0,0,1,1,0
|
||||
18.0,18.0,60.0,30.0,4.0,0,0,1,0,1,0,1,0
|
||||
51.0,51.0,67.0,60.0,1.0,0,0,1,0,1,0,1,0
|
||||
122.0,122.0,53.0,80.0,28.0,0,0,1,0,1,0,1,0
|
||||
27.0,27.0,62.0,60.0,8.0,0,0,1,0,1,0,1,0
|
||||
54.0,54.0,67.0,70.0,1.0,0,0,1,0,1,0,1,0
|
||||
7.0,7.0,72.0,50.0,7.0,0,0,1,0,1,0,1,0
|
||||
63.0,63.0,48.0,50.0,11.0,0,0,1,0,1,0,1,0
|
||||
392.0,392.0,68.0,40.0,4.0,0,0,1,0,1,0,1,0
|
||||
10.0,10.0,67.0,40.0,23.0,0,0,1,0,0,1,1,0
|
||||
8.0,8.0,61.0,20.0,19.0,1,0,0,0,0,1,1,0
|
||||
92.0,92.0,60.0,70.0,10.0,1,0,0,0,1,0,1,0
|
||||
35.0,35.0,62.0,40.0,6.0,1,0,0,0,1,0,1,0
|
||||
117.0,117.0,38.0,80.0,2.0,1,0,0,0,1,0,1,0
|
||||
132.0,132.0,50.0,80.0,5.0,1,0,0,0,1,0,1,0
|
||||
12.0,12.0,63.0,50.0,4.0,1,0,0,0,0,1,1,0
|
||||
162.0,162.0,64.0,80.0,5.0,1,0,0,0,1,0,1,0
|
||||
3.0,3.0,43.0,30.0,3.0,1,0,0,0,1,0,1,0
|
||||
95.0,95.0,34.0,80.0,4.0,1,0,0,0,1,0,1,0
|
||||
177.0,177.0,66.0,50.0,16.0,0,1,0,0,0,1,1,0
|
||||
162.0,162.0,62.0,80.0,5.0,0,1,0,0,1,0,1,0
|
||||
216.0,216.0,52.0,50.0,15.0,0,1,0,0,1,0,1,0
|
||||
553.0,553.0,47.0,70.0,2.0,0,1,0,0,1,0,1,0
|
||||
278.0,278.0,63.0,60.0,12.0,0,1,0,0,1,0,1,0
|
||||
12.0,12.0,68.0,40.0,12.0,0,1,0,0,0,1,1,0
|
||||
260.0,260.0,45.0,80.0,5.0,0,1,0,0,1,0,1,0
|
||||
200.0,200.0,41.0,80.0,12.0,0,1,0,0,0,1,1,0
|
||||
156.0,156.0,66.0,70.0,2.0,0,1,0,0,1,0,1,0
|
||||
182.0,inf,62.0,90.0,2.0,0,1,0,0,1,0,1,0
|
||||
143.0,143.0,60.0,90.0,8.0,0,1,0,0,1,0,1,0
|
||||
105.0,105.0,66.0,80.0,11.0,0,1,0,0,1,0,1,0
|
||||
103.0,103.0,38.0,80.0,5.0,0,1,0,0,1,0,1,0
|
||||
250.0,250.0,53.0,70.0,8.0,0,1,0,0,0,1,1,0
|
||||
100.0,100.0,37.0,60.0,13.0,0,1,0,0,0,1,1,0
|
||||
999.0,999.0,54.0,90.0,12.0,0,0,0,1,0,1,0,1
|
||||
112.0,112.0,60.0,80.0,6.0,0,0,0,1,1,0,0,1
|
||||
87.0,inf,48.0,80.0,3.0,0,0,0,1,1,0,0,1
|
||||
231.0,inf,52.0,50.0,8.0,0,0,0,1,0,1,0,1
|
||||
242.0,242.0,70.0,50.0,1.0,0,0,0,1,1,0,0,1
|
||||
991.0,991.0,50.0,70.0,7.0,0,0,0,1,0,1,0,1
|
||||
111.0,111.0,62.0,70.0,3.0,0,0,0,1,1,0,0,1
|
||||
1.0,1.0,65.0,20.0,21.0,0,0,0,1,0,1,0,1
|
||||
587.0,587.0,58.0,60.0,3.0,0,0,0,1,1,0,0,1
|
||||
389.0,389.0,62.0,90.0,2.0,0,0,0,1,1,0,0,1
|
||||
33.0,33.0,64.0,30.0,6.0,0,0,0,1,1,0,0,1
|
||||
25.0,25.0,63.0,20.0,36.0,0,0,0,1,1,0,0,1
|
||||
357.0,357.0,58.0,70.0,13.0,0,0,0,1,1,0,0,1
|
||||
467.0,467.0,64.0,90.0,2.0,0,0,0,1,1,0,0,1
|
||||
201.0,201.0,52.0,80.0,28.0,0,0,0,1,0,1,0,1
|
||||
1.0,1.0,35.0,50.0,7.0,0,0,0,1,1,0,0,1
|
||||
30.0,30.0,63.0,70.0,11.0,0,0,0,1,1,0,0,1
|
||||
44.0,44.0,70.0,60.0,13.0,0,0,0,1,0,1,0,1
|
||||
283.0,283.0,51.0,90.0,2.0,0,0,0,1,1,0,0,1
|
||||
15.0,15.0,40.0,50.0,13.0,0,0,0,1,0,1,0,1
|
||||
25.0,25.0,69.0,30.0,2.0,0,0,1,0,1,0,0,1
|
||||
103.0,inf,36.0,70.0,22.0,0,0,1,0,0,1,0,1
|
||||
21.0,21.0,71.0,20.0,4.0,0,0,1,0,1,0,0,1
|
||||
13.0,13.0,62.0,30.0,2.0,0,0,1,0,1,0,0,1
|
||||
87.0,87.0,60.0,60.0,2.0,0,0,1,0,1,0,0,1
|
||||
2.0,2.0,44.0,40.0,36.0,0,0,1,0,0,1,0,1
|
||||
20.0,20.0,54.0,30.0,9.0,0,0,1,0,0,1,0,1
|
||||
7.0,7.0,66.0,20.0,11.0,0,0,1,0,1,0,0,1
|
||||
24.0,24.0,49.0,60.0,8.0,0,0,1,0,1,0,0,1
|
||||
99.0,99.0,72.0,70.0,3.0,0,0,1,0,1,0,0,1
|
||||
8.0,8.0,68.0,80.0,2.0,0,0,1,0,1,0,0,1
|
||||
99.0,99.0,62.0,85.0,4.0,0,0,1,0,1,0,0,1
|
||||
61.0,61.0,71.0,70.0,2.0,0,0,1,0,1,0,0,1
|
||||
25.0,25.0,70.0,70.0,2.0,0,0,1,0,1,0,0,1
|
||||
95.0,95.0,61.0,70.0,1.0,0,0,1,0,1,0,0,1
|
||||
80.0,80.0,71.0,50.0,17.0,0,0,1,0,1,0,0,1
|
||||
51.0,51.0,59.0,30.0,87.0,0,0,1,0,0,1,0,1
|
||||
29.0,29.0,67.0,40.0,8.0,0,0,1,0,1,0,0,1
|
||||
24.0,24.0,60.0,40.0,2.0,1,0,0,0,1,0,0,1
|
||||
18.0,18.0,69.0,40.0,5.0,1,0,0,0,0,1,0,1
|
||||
83.0,inf,57.0,99.0,3.0,1,0,0,0,1,0,0,1
|
||||
31.0,31.0,39.0,80.0,3.0,1,0,0,0,1,0,0,1
|
||||
51.0,51.0,62.0,60.0,5.0,1,0,0,0,1,0,0,1
|
||||
90.0,90.0,50.0,60.0,22.0,1,0,0,0,0,1,0,1
|
||||
52.0,52.0,43.0,60.0,3.0,1,0,0,0,1,0,0,1
|
||||
73.0,73.0,70.0,60.0,3.0,1,0,0,0,1,0,0,1
|
||||
8.0,8.0,66.0,50.0,5.0,1,0,0,0,1,0,0,1
|
||||
36.0,36.0,61.0,70.0,8.0,1,0,0,0,1,0,0,1
|
||||
48.0,48.0,81.0,10.0,4.0,1,0,0,0,1,0,0,1
|
||||
7.0,7.0,58.0,40.0,4.0,1,0,0,0,1,0,0,1
|
||||
140.0,140.0,63.0,70.0,3.0,1,0,0,0,1,0,0,1
|
||||
186.0,186.0,60.0,90.0,3.0,1,0,0,0,1,0,0,1
|
||||
84.0,84.0,62.0,80.0,4.0,1,0,0,0,0,1,0,1
|
||||
19.0,19.0,42.0,50.0,10.0,1,0,0,0,1,0,0,1
|
||||
45.0,45.0,69.0,40.0,3.0,1,0,0,0,1,0,0,1
|
||||
80.0,80.0,63.0,40.0,4.0,1,0,0,0,1,0,0,1
|
||||
52.0,52.0,45.0,60.0,4.0,0,1,0,0,1,0,0,1
|
||||
164.0,164.0,68.0,70.0,15.0,0,1,0,0,0,1,0,1
|
||||
19.0,19.0,39.0,30.0,4.0,0,1,0,0,0,1,0,1
|
||||
53.0,53.0,66.0,60.0,12.0,0,1,0,0,1,0,0,1
|
||||
15.0,15.0,63.0,30.0,5.0,0,1,0,0,1,0,0,1
|
||||
43.0,43.0,49.0,60.0,11.0,0,1,0,0,0,1,0,1
|
||||
340.0,340.0,64.0,80.0,10.0,0,1,0,0,0,1,0,1
|
||||
133.0,133.0,65.0,75.0,1.0,0,1,0,0,1,0,0,1
|
||||
111.0,111.0,64.0,60.0,5.0,0,1,0,0,1,0,0,1
|
||||
231.0,231.0,67.0,70.0,18.0,0,1,0,0,0,1,0,1
|
||||
378.0,378.0,65.0,80.0,4.0,0,1,0,0,1,0,0,1
|
||||
49.0,49.0,37.0,30.0,3.0,0,1,0,0,1,0,0,1
|
||||
|
Reference in New Issue
Block a user