xgboost/tests/cpp/common/test_probability_distribution.cc
Avinash Barnwal dcf439932a
Add Accelerated Failure Time loss for survival analysis task (#4763)
* [WIP] Add lower and upper bounds on the label for survival analysis

* Update test MetaInfo.SaveLoadBinary to account for extra two fields

* Don't clear qids_ for version 2 of MetaInfo

* Add SetInfo() and GetInfo() method for lower and upper bounds

* changes to aft

* Add parameter class for AFT; use enum's to represent distribution and event type

* Add AFT metric

* changes to neg grad to grad

* changes to binomial loss

* changes to overflow

* changes to eps

* changes to code refactoring

* changes to code refactoring

* changes to code refactoring

* Re-factor survival analysis

* Remove aft namespace

* Move function bodies out of AFTNormal and AFTLogistic, to reduce clutter

* Move function bodies out of AFTLoss, to reduce clutter

* Use smart pointer to store AFTDistribution and AFTLoss

* Rename AFTNoiseDistribution enum to AFTDistributionType for clarity

The enum class was not a distribution itself but a distribution type

* Add AFTDistribution::Create() method for convenience

* changes to extreme distribution

* changes to extreme distribution

* changes to extreme

* changes to extreme distribution

* changes to left censored

* deleted cout

* changes to x,mu and sd and code refactoring

* changes to print

* changes to hessian formula in censored and uncensored

* changes to variable names and pow

* changes to Logistic Pdf

* changes to parameter

* Expose lower and upper bound labels to R package

* Use example weights; normalize log likelihood metric

* changes to CHECK

* changes to logistic hessian to standard formula

* changes to logistic formula

* Comply with coding style guideline

* Revert back Rabit submodule

* Revert dmlc-core submodule

* Comply with coding style guideline (clang-tidy)

* Fix an error in AFTLoss::Gradient()

* Add missing files to amalgamation

* Address @RAMitchell's comment: minimize future change in MetaInfo interface

* Fix lint

* Fix compilation error on 32-bit target, when size_t == bst_uint

* Allocate sufficient memory to hold extra label info

* Use OpenMP to speed up

* Fix compilation on Windows

* Address reviewer's feedback

* Add unit tests for probability distributions

* Make Metric subclass of Configurable

* Address reviewer's feedback: Configure() AFT metric

* Add a dummy test for AFT metric configuration

* Complete AFT configuration test; remove debugging print

* Rename AFT parameters

* Clarify test comment

* Add a dummy test for AFT loss for uncensored case

* Fix a bug in AFT loss for uncensored labels

* Complete unit test for AFT loss metric

* Simplify unit tests for AFT metric

* Add unit test to verify aggregate output from AFT metric

* Use EXPECT_* instead of ASSERT_*, so that we run all unit tests

* Use aft_loss_param when serializing AFTObj

This is to be consistent with AFT metric

* Add unit tests for AFT Objective

* Fix OpenMP bug; clarify semantics for shared variables used in OpenMP loops

* Add comments

* Remove AFT prefix from probability distribution; put probability distribution in separate source file

* Add comments

* Define kPI and kEulerMascheroni in probability_distribution.h

* Add probability_distribution.cc to amalgamation

* Remove unnecessary diff

* Address reviewer's feedback: define variables where they're used

* Eliminate all INFs and NANs from AFT loss and gradient

* Add demo

* Add tutorial

* Fix lint

* Use 'survival:aft' to be consistent with 'survival:cox'

* Move sample data to demo/data

* Add visual demo with 1D toy data

* Add Python tests

Co-authored-by: Philip Cho <chohyu01@cs.washington.edu>
2020-03-25 13:52:51 -07:00

122 lines
5.4 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/*!
* Copyright (c) by Contributors 2020
*/
#include <gtest/gtest.h>
#include <memory>
#include <cmath>
#include "xgboost/logging.h"
#include "../../../src/common/probability_distribution.h"
namespace xgboost {
namespace common {
TEST(ProbabilityDistribution, DistributionGeneric) {
// Assert d/dx CDF = PDF, d/dx PDF = GradPDF, d/dx GradPDF = HessPDF
// Do this for every distribution type
for (auto type : {ProbabilityDistributionType::kNormal, ProbabilityDistributionType::kLogistic,
ProbabilityDistributionType::kExtreme}) {
std::unique_ptr<ProbabilityDistribution> dist{ ProbabilityDistribution::Create(type) };
double integral_of_pdf = dist->CDF(-2.0);
double integral_of_grad_pdf = dist->PDF(-2.0);
double integral_of_hess_pdf = dist->GradPDF(-2.0);
// Perform numerical differentiation and integration
// Enumerate 4000 grid points in range [-2, 2]
for (int i = 0; i <= 4000; ++i) {
const double x = static_cast<double>(i) / 1000.0 - 2.0;
// Numerical differentiation (p. 246, Numerical Analysis 2nd ed. by Timothy Sauer)
EXPECT_NEAR((dist->CDF(x + 1e-5) - dist->CDF(x - 1e-5)) / 2e-5, dist->PDF(x), 6e-11);
EXPECT_NEAR((dist->PDF(x + 1e-5) - dist->PDF(x - 1e-5)) / 2e-5, dist->GradPDF(x), 6e-11);
EXPECT_NEAR((dist->GradPDF(x + 1e-5) - dist->GradPDF(x - 1e-5)) / 2e-5,
dist->HessPDF(x), 6e-11);
// Numerical integration using Trapezoid Rule (p. 257, Sauer)
integral_of_pdf += 5e-4 * (dist->PDF(x - 1e-3) + dist->PDF(x));
integral_of_grad_pdf += 5e-4 * (dist->GradPDF(x - 1e-3) + dist->GradPDF(x));
integral_of_hess_pdf += 5e-4 * (dist->HessPDF(x - 1e-3) + dist->HessPDF(x));
EXPECT_NEAR(integral_of_pdf, dist->CDF(x), 2e-4);
EXPECT_NEAR(integral_of_grad_pdf, dist->PDF(x), 2e-4);
EXPECT_NEAR(integral_of_hess_pdf, dist->GradPDF(x), 2e-4);
}
}
}
TEST(ProbabilityDistribution, NormalDist) {
std::unique_ptr<ProbabilityDistribution> dist{
ProbabilityDistribution::Create(ProbabilityDistributionType::kNormal)
};
// "Three-sigma rule" (https://en.wikipedia.org/wiki/689599.7_rule)
// 68% of values are within 1 standard deviation away from the mean
// 95% of values are within 2 standard deviation away from the mean
// 99.7% of values are within 3 standard deviation away from the mean
EXPECT_NEAR(dist->CDF(0.5) - dist->CDF(-0.5), 0.3829, 0.00005);
EXPECT_NEAR(dist->CDF(1.0) - dist->CDF(-1.0), 0.6827, 0.00005);
EXPECT_NEAR(dist->CDF(1.5) - dist->CDF(-1.5), 0.8664, 0.00005);
EXPECT_NEAR(dist->CDF(2.0) - dist->CDF(-2.0), 0.9545, 0.00005);
EXPECT_NEAR(dist->CDF(2.5) - dist->CDF(-2.5), 0.9876, 0.00005);
EXPECT_NEAR(dist->CDF(3.0) - dist->CDF(-3.0), 0.9973, 0.00005);
EXPECT_NEAR(dist->CDF(3.5) - dist->CDF(-3.5), 0.9995, 0.00005);
EXPECT_NEAR(dist->CDF(4.0) - dist->CDF(-4.0), 0.9999, 0.00005);
}
TEST(ProbabilityDistribution, LogisticDist) {
std::unique_ptr<ProbabilityDistribution> dist{
ProbabilityDistribution::Create(ProbabilityDistributionType::kLogistic)
};
/**
* Enforce known properties of the logistic distribution.
* (https://en.wikipedia.org/wiki/Logistic_distribution)
**/
// Enumerate 4000 grid points in range [-2, 2]
for (int i = 0; i <= 4000; ++i) {
const double x = static_cast<double>(i) / 1000.0 - 2.0;
// PDF = 1/4 * sech(x/2)**2
const double sech_x = 1.0 / std::cosh(x * 0.5); // hyperbolic secant at x/2
EXPECT_NEAR(0.25 * sech_x * sech_x, dist->PDF(x), 1e-15);
// CDF = 1/2 + 1/2 * tanh(x/2)
EXPECT_NEAR(0.5 + 0.5 * std::tanh(x * 0.5), dist->CDF(x), 1e-15);
}
}
TEST(ProbabilityDistribution, ExtremeDist) {
std::unique_ptr<ProbabilityDistribution> dist{
ProbabilityDistribution::Create(ProbabilityDistributionType::kExtreme)
};
/**
* Enforce known properties of the extreme distribution (also known as Gumbel distribution).
* The mean is the negative of the Euler-Mascheroni constant.
* The variance is 1/6 * pi**2. (https://mathworld.wolfram.com/GumbelDistribution.html)
**/
// Enumerate 25000 grid points in range [-20, 5].
// Compute the mean (expected value) of the distribution using numerical integration.
// Nearly all mass of the extreme distribution is concentrated between -20 and 5,
// so numerically integrating x*PDF(x) over [-20, 5] gives good estimate of the mean.
double mean = 0.0;
for (int i = 0; i <= 25000; ++i) {
const double x = static_cast<double>(i) / 1000.0 - 20.0;
// Numerical integration using Trapezoid Rule (p. 257, Sauer)
mean += 5e-4 * ((x - 1e-3) * dist->PDF(x - 1e-3) + x * dist->PDF(x));
}
EXPECT_NEAR(mean, -probability_constant::kEulerMascheroni, 1e-7);
// Enumerate 25000 grid points in range [-20, 5].
// Compute the variance of the distribution using numerical integration.
// Nearly all mass of the extreme distribution is concentrated between -20 and 5,
// so numerically integrating (x-mean)*PDF(x) over [-20, 5] gives good estimate of the variance.
double variance = 0.0;
for (int i = 0; i <= 25000; ++i) {
const double x = static_cast<double>(i) / 1000.0 - 20.0;
// Numerical integration using Trapezoid Rule (p. 257, Sauer)
variance += 5e-4 * ((x - 1e-3 - mean) * (x - 1e-3 - mean) * dist->PDF(x - 1e-3)
+ (x - mean) * (x - mean) * dist->PDF(x));
}
EXPECT_NEAR(variance, probability_constant::kPI * probability_constant::kPI / 6.0, 1e-6);
}
} // namespace common
} // namespace xgboost