Implement robust regularization in 'survival:aft' objective (#5473)
* Robust regularization of AFT gradient and hessian * Fix AFT doc; expose it to tutorial TOC * Apply robust regularization to uncensored case too * Revise unit test slightly * Fix lint * Update test_survival.py * Use GradientPairPrecise * Remove unused variables
This commit is contained in:
parent
939973630d
commit
5fc5ec539d
@ -51,4 +51,4 @@ print(df)
|
||||
print(df[np.isinf(df['Label (upper bound)'])])
|
||||
|
||||
# Save trained model
|
||||
bst.save_model('aft_model.json')
|
||||
bst.save_model('aft_model.json')
|
||||
|
||||
@ -75,4 +75,4 @@ print(df)
|
||||
print(df[np.isinf(df['Label (upper bound)'])])
|
||||
|
||||
# Save trained model
|
||||
bst.save_model('aft_best_model.json')
|
||||
bst.save_model('aft_best_model.json')
|
||||
|
||||
@ -68,7 +68,7 @@ Note that this model is a generalized form of a linear regression model :math:`Y
|
||||
|
||||
\ln{Y} = \mathcal{T}(\mathbf{x}) + \sigma Z
|
||||
|
||||
where :math:`\mathcal{T}(\mathbf{x})` represents the output from a decision tree ensemble, given input :math:`\mathbf{x}`. Since :math:`Z` is a random variable, we have a likelihood defined for the expression :math:`\ln{Y} = \mathcal{T}(\mathbf{x}) + \sigma Z`. So the goal for XGBoost is to maximize the (log) likelihood by fitting a good tree ensemble :math:`\mathbf{x}`.
|
||||
where :math:`\mathcal{T}(\mathbf{x})` represents the output from a decision tree ensemble, given input :math:`\mathbf{x}`. Since :math:`Z` is a random variable, we have a likelihood defined for the expression :math:`\ln{Y} = \mathcal{T}(\mathbf{x}) + \sigma Z`. So the goal for XGBoost is to maximize the (log) likelihood by fitting a good tree ensemble :math:`\mathcal{T}(\mathbf{x})`.
|
||||
|
||||
**********
|
||||
How to use
|
||||
|
||||
@ -18,6 +18,7 @@ See `Awesome XGBoost <https://github.com/dmlc/xgboost/tree/master/demo>`_ for mo
|
||||
monotonic
|
||||
rf
|
||||
feature_interaction_constraint
|
||||
aft_survival_analysis
|
||||
input_format
|
||||
param_tuning
|
||||
external_memory
|
||||
|
||||
@ -18,6 +18,106 @@
|
||||
https://github.com/avinashbarnwal/GSOC-2019/blob/master/doc/Accelerated_Failure_Time.pdf
|
||||
*/
|
||||
|
||||
namespace {
|
||||
|
||||
// Allowable range for gradient and hessian. Used for regularization
|
||||
constexpr double kMinGradient = -15.0;
|
||||
constexpr double kMaxGradient = 15.0;
|
||||
constexpr double kMinHessian = 1e-16; // Ensure that no data point gets zero hessian
|
||||
constexpr double kMaxHessian = 15.0;
|
||||
|
||||
constexpr double kEps = 1e-12; // A denomitor in a fraction should not be too small
|
||||
|
||||
// Clip (limit) x to fit range [x_min, x_max].
|
||||
// If x < x_min, return x_min; if x > x_max, return x_max; if x_min <= x <= x_max, return x.
|
||||
// This function assumes x_min < x_max; behavior is undefined if this assumption does not hold.
|
||||
inline double Clip(double x, double x_min, double x_max) {
|
||||
if (x < x_min) {
|
||||
return x_min;
|
||||
}
|
||||
if (x > x_max) {
|
||||
return x_max;
|
||||
}
|
||||
return x;
|
||||
}
|
||||
|
||||
using xgboost::common::ProbabilityDistributionType;
|
||||
|
||||
enum class CensoringType : uint8_t {
|
||||
kUncensored, kRightCensored, kLeftCensored, kIntervalCensored
|
||||
};
|
||||
|
||||
using xgboost::GradientPairPrecise;
|
||||
|
||||
inline GradientPairPrecise GetLimitAtInfPred(ProbabilityDistributionType dist_type,
|
||||
CensoringType censor_type,
|
||||
double sign, double sigma) {
|
||||
switch (censor_type) {
|
||||
case CensoringType::kUncensored:
|
||||
switch (dist_type) {
|
||||
case ProbabilityDistributionType::kNormal:
|
||||
return sign ? GradientPairPrecise{ kMinGradient, 1.0 / (sigma * sigma) }
|
||||
: GradientPairPrecise{ kMaxGradient, 1.0 / (sigma * sigma) };
|
||||
case ProbabilityDistributionType::kLogistic:
|
||||
return sign ? GradientPairPrecise{ -1.0 / sigma, kMinHessian }
|
||||
: GradientPairPrecise{ 1.0 / sigma, kMinHessian };
|
||||
case ProbabilityDistributionType::kExtreme:
|
||||
return sign ? GradientPairPrecise{ kMinGradient, kMaxHessian }
|
||||
: GradientPairPrecise{ 1.0 / sigma, kMinHessian };
|
||||
default:
|
||||
LOG(FATAL) << "Unknown distribution type";
|
||||
}
|
||||
case CensoringType::kRightCensored:
|
||||
switch (dist_type) {
|
||||
case ProbabilityDistributionType::kNormal:
|
||||
return sign ? GradientPairPrecise{ kMinGradient, 1.0 / (sigma * sigma) }
|
||||
: GradientPairPrecise{ 0.0, kMinHessian };
|
||||
case ProbabilityDistributionType::kLogistic:
|
||||
return sign ? GradientPairPrecise{ -1.0 / sigma, kMinHessian }
|
||||
: GradientPairPrecise{ 0.0, kMinHessian };
|
||||
case ProbabilityDistributionType::kExtreme:
|
||||
return sign ? GradientPairPrecise{ kMinGradient, kMaxHessian }
|
||||
: GradientPairPrecise{ 0.0, kMinHessian };
|
||||
default:
|
||||
LOG(FATAL) << "Unknown distribution type";
|
||||
}
|
||||
case CensoringType::kLeftCensored:
|
||||
switch (dist_type) {
|
||||
case ProbabilityDistributionType::kNormal:
|
||||
return sign ? GradientPairPrecise{ 0.0, kMinHessian }
|
||||
: GradientPairPrecise{ kMaxGradient, 1.0 / (sigma * sigma) };
|
||||
case ProbabilityDistributionType::kLogistic:
|
||||
return sign ? GradientPairPrecise{ 0.0, kMinHessian }
|
||||
: GradientPairPrecise{ 1.0 / sigma, kMinHessian };
|
||||
case ProbabilityDistributionType::kExtreme:
|
||||
return sign ? GradientPairPrecise{ 0.0, kMinHessian }
|
||||
: GradientPairPrecise{ 1.0 / sigma, kMinHessian };
|
||||
default:
|
||||
LOG(FATAL) << "Unknown distribution type";
|
||||
}
|
||||
case CensoringType::kIntervalCensored:
|
||||
switch (dist_type) {
|
||||
case ProbabilityDistributionType::kNormal:
|
||||
return sign ? GradientPairPrecise{ kMinGradient, 1.0 / (sigma * sigma) }
|
||||
: GradientPairPrecise{ kMaxGradient, 1.0 / (sigma * sigma) };
|
||||
case ProbabilityDistributionType::kLogistic:
|
||||
return sign ? GradientPairPrecise{ -1.0 / sigma, kMinHessian }
|
||||
: GradientPairPrecise{ 1.0 / sigma, kMinHessian };
|
||||
case ProbabilityDistributionType::kExtreme:
|
||||
return sign ? GradientPairPrecise{ kMinGradient, kMaxHessian }
|
||||
: GradientPairPrecise{ 1.0 / sigma, kMinHessian };
|
||||
default:
|
||||
LOG(FATAL) << "Unknown distribution type";
|
||||
}
|
||||
default:
|
||||
LOG(FATAL) << "Unknown censoring type";
|
||||
}
|
||||
|
||||
return { 0.0, 0.0 };
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
|
||||
@ -26,14 +126,14 @@ DMLC_REGISTER_PARAMETER(AFTParam);
|
||||
double AFTLoss::Loss(double y_lower, double y_upper, double y_pred, double sigma) {
|
||||
const double log_y_lower = std::log(y_lower);
|
||||
const double log_y_upper = std::log(y_upper);
|
||||
const double eps = 1e-12;
|
||||
|
||||
double cost;
|
||||
|
||||
if (y_lower == y_upper) { // uncensored
|
||||
const double z = (log_y_lower - y_pred) / sigma;
|
||||
const double pdf = dist_->PDF(z);
|
||||
// Regularize the denominator with eps, to avoid INF or NAN
|
||||
cost = -std::log(std::max(pdf / (sigma * y_lower), eps));
|
||||
cost = -std::log(std::max(pdf / (sigma * y_lower), kEps));
|
||||
} else { // censored; now check what type of censorship we have
|
||||
double z_u, z_l, cdf_u, cdf_l;
|
||||
if (std::isinf(y_upper)) { // right-censored
|
||||
@ -49,7 +149,7 @@ double AFTLoss::Loss(double y_lower, double y_upper, double y_pred, double sigma
|
||||
cdf_l = dist_->CDF(z_l);
|
||||
}
|
||||
// Regularize the denominator with eps, to avoid INF or NAN
|
||||
cost = -std::log(std::max(cdf_u - cdf_l, eps));
|
||||
cost = -std::log(std::max(cdf_u - cdf_l, kEps));
|
||||
}
|
||||
|
||||
return cost;
|
||||
@ -58,20 +158,25 @@ double AFTLoss::Loss(double y_lower, double y_upper, double y_pred, double sigma
|
||||
double AFTLoss::Gradient(double y_lower, double y_upper, double y_pred, double sigma) {
|
||||
const double log_y_lower = std::log(y_lower);
|
||||
const double log_y_upper = std::log(y_upper);
|
||||
double gradient;
|
||||
const double eps = 1e-12;
|
||||
double numerator, denominator, gradient; // numerator and denominator of gradient
|
||||
CensoringType censor_type;
|
||||
bool z_sign; // sign of z-score
|
||||
|
||||
if (y_lower == y_upper) { // uncensored
|
||||
const double z = (log_y_lower - y_pred) / sigma;
|
||||
const double pdf = dist_->PDF(z);
|
||||
const double grad_pdf = dist_->GradPDF(z);
|
||||
// Regularize the denominator with eps, so that gradient doesn't get too big
|
||||
gradient = grad_pdf / (sigma * std::max(pdf, eps));
|
||||
censor_type = CensoringType::kUncensored;
|
||||
numerator = grad_pdf;
|
||||
denominator = sigma * pdf;
|
||||
z_sign = (z > 0);
|
||||
} else { // censored; now check what type of censorship we have
|
||||
double z_u, z_l, pdf_u, pdf_l, cdf_u, cdf_l;
|
||||
double z_u = 0.0, z_l = 0.0, pdf_u, pdf_l, cdf_u, cdf_l;
|
||||
censor_type = CensoringType::kIntervalCensored;
|
||||
if (std::isinf(y_upper)) { // right-censored
|
||||
pdf_u = 0;
|
||||
cdf_u = 1;
|
||||
censor_type = CensoringType::kRightCensored;
|
||||
} else { // interval-censored or left-censored
|
||||
z_u = (log_y_upper - y_pred) / sigma;
|
||||
pdf_u = dist_->PDF(z_u);
|
||||
@ -80,38 +185,48 @@ double AFTLoss::Gradient(double y_lower, double y_upper, double y_pred, double s
|
||||
if (std::isinf(y_lower)) { // left-censored
|
||||
pdf_l = 0;
|
||||
cdf_l = 0;
|
||||
censor_type = CensoringType::kLeftCensored;
|
||||
} else { // interval-censored or right-censored
|
||||
z_l = (log_y_lower - y_pred) / sigma;
|
||||
pdf_l = dist_->PDF(z_l);
|
||||
cdf_l = dist_->CDF(z_l);
|
||||
}
|
||||
// Regularize the denominator with eps, so that gradient doesn't get too big
|
||||
gradient = (pdf_u - pdf_l) / (sigma * std::max(cdf_u - cdf_l, eps));
|
||||
z_sign = (z_u > 0 || z_l > 0);
|
||||
numerator = pdf_u - pdf_l;
|
||||
denominator = sigma * (cdf_u - cdf_l);
|
||||
}
|
||||
gradient = numerator / denominator;
|
||||
if (denominator < kEps && (std::isnan(gradient) || std::isinf(gradient))) {
|
||||
gradient = GetLimitAtInfPred(dist_type_, censor_type, z_sign, sigma).GetGrad();
|
||||
}
|
||||
|
||||
return gradient;
|
||||
return Clip(gradient, kMinGradient, kMaxGradient);
|
||||
}
|
||||
|
||||
double AFTLoss::Hessian(double y_lower, double y_upper, double y_pred, double sigma) {
|
||||
const double log_y_lower = std::log(y_lower);
|
||||
const double log_y_upper = std::log(y_upper);
|
||||
const double eps = 1e-12;
|
||||
double hessian;
|
||||
double numerator, denominator, hessian; // numerator and denominator of hessian
|
||||
CensoringType censor_type;
|
||||
bool z_sign; // sign of z-score
|
||||
|
||||
if (y_lower == y_upper) { // uncensored
|
||||
const double z = (log_y_lower - y_pred) / sigma;
|
||||
const double pdf = dist_->PDF(z);
|
||||
const double grad_pdf = dist_->GradPDF(z);
|
||||
const double hess_pdf = dist_->HessPDF(z);
|
||||
// Regularize the denominator with eps, so that gradient doesn't get too big
|
||||
hessian = -(pdf * hess_pdf - std::pow(grad_pdf, 2))
|
||||
/ (std::pow(sigma, 2) * std::pow(std::max(pdf, eps), 2));
|
||||
censor_type = CensoringType::kUncensored;
|
||||
numerator = -(pdf * hess_pdf - grad_pdf * grad_pdf);
|
||||
denominator = sigma * sigma * pdf * pdf;
|
||||
z_sign = (z > 0);
|
||||
} else { // censored; now check what type of censorship we have
|
||||
double z_u, z_l, grad_pdf_u, grad_pdf_l, pdf_u, pdf_l, cdf_u, cdf_l;
|
||||
double z_u = 0.0, z_l = 0.0, grad_pdf_u, grad_pdf_l, pdf_u, pdf_l, cdf_u, cdf_l;
|
||||
censor_type = CensoringType::kIntervalCensored;
|
||||
if (std::isinf(y_upper)) { // right-censored
|
||||
pdf_u = 0;
|
||||
cdf_u = 1;
|
||||
grad_pdf_u = 0;
|
||||
censor_type = CensoringType::kRightCensored;
|
||||
} else { // interval-censored or left-censored
|
||||
z_u = (log_y_upper - y_pred) / sigma;
|
||||
pdf_u = dist_->PDF(z_u);
|
||||
@ -122,6 +237,7 @@ double AFTLoss::Hessian(double y_lower, double y_upper, double y_pred, double si
|
||||
pdf_l = 0;
|
||||
cdf_l = 0;
|
||||
grad_pdf_l = 0;
|
||||
censor_type = CensoringType::kLeftCensored;
|
||||
} else { // interval-censored or right-censored
|
||||
z_l = (log_y_lower - y_pred) / sigma;
|
||||
pdf_l = dist_->PDF(z_l);
|
||||
@ -131,15 +247,17 @@ double AFTLoss::Hessian(double y_lower, double y_upper, double y_pred, double si
|
||||
const double cdf_diff = cdf_u - cdf_l;
|
||||
const double pdf_diff = pdf_u - pdf_l;
|
||||
const double grad_diff = grad_pdf_u - grad_pdf_l;
|
||||
// Regularize the denominator with eps, so that gradient doesn't get too big
|
||||
const double cdf_diff_thresh = std::max(cdf_diff, eps);
|
||||
const double numerator = -(cdf_diff * grad_diff - pdf_diff * pdf_diff);
|
||||
const double sqrt_denominator = sigma * cdf_diff_thresh;
|
||||
const double denominator = sqrt_denominator * sqrt_denominator;
|
||||
hessian = numerator / denominator;
|
||||
const double sqrt_denominator = sigma * cdf_diff;
|
||||
z_sign = (z_u > 0 || z_l > 0);
|
||||
numerator = -(cdf_diff * grad_diff - pdf_diff * pdf_diff);
|
||||
denominator = sqrt_denominator * sqrt_denominator;
|
||||
}
|
||||
hessian = numerator / denominator;
|
||||
if (denominator < kEps && (std::isnan(hessian) || std::isinf(hessian))) {
|
||||
hessian = GetLimitAtInfPred(dist_type_, censor_type, z_sign, sigma).GetHess();
|
||||
}
|
||||
|
||||
return hessian;
|
||||
return Clip(hessian, kMinHessian, kMaxHessian);
|
||||
}
|
||||
|
||||
} // namespace common
|
||||
|
||||
@ -42,15 +42,16 @@ struct AFTParam : public XGBoostParameter<AFTParam> {
|
||||
class AFTLoss {
|
||||
private:
|
||||
std::unique_ptr<ProbabilityDistribution> dist_;
|
||||
ProbabilityDistributionType dist_type_;
|
||||
|
||||
public:
|
||||
/*!
|
||||
* \brief Constructor for AFT loss function
|
||||
* \param dist Choice of probability distribution for the noise term in AFT
|
||||
* \param dist_type Choice of probability distribution for the noise term in AFT
|
||||
*/
|
||||
explicit AFTLoss(ProbabilityDistributionType dist) {
|
||||
dist_.reset(ProbabilityDistribution::Create(dist));
|
||||
}
|
||||
explicit AFTLoss(ProbabilityDistributionType dist_type)
|
||||
: dist_(ProbabilityDistribution::Create(dist_type)),
|
||||
dist_type_(dist_type) {}
|
||||
|
||||
public:
|
||||
/*!
|
||||
|
||||
44
tests/cpp/common/test_survival_util.cc
Normal file
44
tests/cpp/common/test_survival_util.cc
Normal file
@ -0,0 +1,44 @@
|
||||
/*!
|
||||
* Copyright (c) by Contributors 2020
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "../../../src/common/survival_util.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
|
||||
inline static void RobustTestSuite(ProbabilityDistributionType dist_type,
|
||||
double y_lower, double y_upper, double sigma) {
|
||||
AFTLoss loss(dist_type);
|
||||
for (int i = 50; i >= -50; --i) {
|
||||
const double y_pred = std::pow(10.0, static_cast<double>(i));
|
||||
const double z = (std::log(y_lower) - std::log(y_pred)) / sigma;
|
||||
const double gradient = loss.Gradient(y_lower, y_upper, std::log(y_pred), sigma);
|
||||
const double hessian = loss.Hessian(y_lower, y_upper, std::log(y_pred), sigma);
|
||||
ASSERT_FALSE(std::isnan(gradient)) << "z = " << z << ", y \\in ["
|
||||
<< y_lower << ", " << y_upper << "], y_pred = " << y_pred
|
||||
<< ", dist = " << static_cast<int>(dist_type);
|
||||
ASSERT_FALSE(std::isinf(gradient)) << "z = " << z << ", y \\in ["
|
||||
<< y_lower << ", " << y_upper << "], y_pred = " << y_pred
|
||||
<< ", dist = " << static_cast<int>(dist_type);
|
||||
ASSERT_FALSE(std::isnan(hessian)) << "z = " << z << ", y \\in ["
|
||||
<< y_lower << ", " << y_upper << "], y_pred = " << y_pred
|
||||
<< ", dist = " << static_cast<int>(dist_type);
|
||||
ASSERT_FALSE(std::isinf(hessian)) << "z = " << z << ", y \\in ["
|
||||
<< y_lower << ", " << y_upper << "], y_pred = " << y_pred
|
||||
<< ", dist = " << static_cast<int>(dist_type);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(AFTLoss, RobustGradientPair) { // Ensure that INF and NAN don't show up in gradient pair
|
||||
RobustTestSuite(ProbabilityDistributionType::kNormal, 16.0, 200.0, 2.0);
|
||||
RobustTestSuite(ProbabilityDistributionType::kLogistic, 16.0, 200.0, 2.0);
|
||||
RobustTestSuite(ProbabilityDistributionType::kExtreme, 16.0, 200.0, 2.0);
|
||||
RobustTestSuite(ProbabilityDistributionType::kNormal, 100.0, 100.0, 2.0);
|
||||
RobustTestSuite(ProbabilityDistributionType::kLogistic, 100.0, 100.0, 2.0);
|
||||
RobustTestSuite(ProbabilityDistributionType::kExtreme, 100.0, 100.0, 2.0);
|
||||
}
|
||||
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
@ -93,10 +93,10 @@ TEST(Objective, AFTObjGPairUncensoredLabels) {
|
||||
{ 0.0384f, 0.0624f, 0.0997f, 0.1551f, 0.2316f, 0.3254f, 0.4200f, 0.4861f, 0.4962f, 0.4457f,
|
||||
0.3567f, 0.2601f, 0.1772f, 0.1152f, 0.0726f, 0.0449f, 0.0275f, 0.0167f, 0.0101f, 0.0061f });
|
||||
CheckGPairOverGridPoints(obj.get(), 100.0f, 100.0f, "extreme",
|
||||
{ -0.0000f, -29.0026f, -17.0031f, -9.8028f, -5.4822f, -2.8897f, -1.3340f, -0.4005f, 0.1596f,
|
||||
{ -15.0000f, -15.0000f, -15.0000f, -9.8028f, -5.4822f, -2.8897f, -1.3340f, -0.4005f, 0.1596f,
|
||||
0.4957f, 0.6974f, 0.8184f, 0.8910f, 0.9346f, 0.9608f, 0.9765f, 0.9859f, 0.9915f, 0.9949f,
|
||||
0.9969f },
|
||||
{ 0.0000f, 30.0026f, 18.0031f, 10.8028f, 6.4822f, 3.8897f, 2.3340f, 1.4005f, 0.8404f, 0.5043f,
|
||||
{ 15.0000f, 15.0000f, 15.0000f, 10.8028f, 6.4822f, 3.8897f, 2.3340f, 1.4005f, 0.8404f, 0.5043f,
|
||||
0.3026f, 0.1816f, 0.1090f, 0.0654f, 0.0392f, 0.0235f, 0.0141f, 0.0085f, 0.0051f, 0.0031f });
|
||||
}
|
||||
|
||||
@ -106,10 +106,9 @@ TEST(Objective, AFTObjGPairLeftCensoredLabels) {
|
||||
|
||||
CheckGPairOverGridPoints(obj.get(), -std::numeric_limits<float>::infinity(), 20.0f, "normal",
|
||||
{ 0.0285f, 0.0832f, 0.1951f, 0.3804f, 0.6403f, 0.9643f, 1.3379f, 1.7475f, 2.1828f, 2.6361f,
|
||||
3.1023f, 3.5779f, 4.0603f, 4.5479f, 5.0394f, 5.5340f, 6.0309f, 6.5298f, 7.0303f, 0.5072f },
|
||||
3.1023f, 3.5779f, 4.0603f, 4.5479f, 5.0394f, 5.5340f, 6.0309f, 6.5298f, 7.0303f, 7.5326f },
|
||||
{ 0.0663f, 0.1559f, 0.2881f, 0.4378f, 0.5762f, 0.6878f, 0.7707f, 0.8300f, 0.8719f, 0.9016f,
|
||||
0.9229f, 0.9385f, 0.9501f, 0.9588f, 0.9656f, 0.9709f, 0.9751f, 0.9785f, 0.9812f, 0.0045f },
|
||||
2e-4);
|
||||
0.9229f, 0.9385f, 0.9501f, 0.9588f, 0.9656f, 0.9709f, 0.9751f, 0.9785f, 0.9813f, 0.9877f });
|
||||
CheckGPairOverGridPoints(obj.get(), -std::numeric_limits<float>::infinity(), 20.0f, "logistic",
|
||||
{ 0.0909f, 0.1428f, 0.2174f, 0.3164f, 0.4355f, 0.5625f, 0.6818f, 0.7812f, 0.8561f, 0.9084f,
|
||||
0.9429f, 0.9650f, 0.9787f, 0.9871f, 0.9922f, 0.9953f, 0.9972f, 0.9983f, 0.9990f, 0.9994f },
|
||||
@ -139,10 +138,10 @@ TEST(Objective, AFTObjGPairRightCensoredLabels) {
|
||||
{ 0.0312f, 0.0499f, 0.0776f, 0.1158f, 0.1627f, 0.2100f, 0.2430f, 0.2481f, 0.2228f, 0.1783f,
|
||||
0.1300f, 0.0886f, 0.0576f, 0.0363f, 0.0225f, 0.0137f, 0.0083f, 0.0050f, 0.0030f, 0.0018f });
|
||||
CheckGPairOverGridPoints(obj.get(), 60.0f, std::numeric_limits<float>::infinity(), "extreme",
|
||||
{ -2.8073f, -18.0015f, -10.8018f, -6.4817f, -3.8893f, -2.3338f, -1.4004f, -0.8403f, -0.5042f,
|
||||
{ -15.0000f, -15.0000f, -10.8018f, -6.4817f, -3.8893f, -2.3338f, -1.4004f, -0.8403f, -0.5042f,
|
||||
-0.3026f, -0.1816f, -0.1089f, -0.0654f, -0.0392f, -0.0235f, -0.0141f, -0.0085f, -0.0051f,
|
||||
-0.0031f, -0.0018f },
|
||||
{ 0.2614f, 18.0015f, 10.8018f, 6.4817f, 3.8893f, 2.3338f, 1.4004f, 0.8403f, 0.5042f, 0.3026f,
|
||||
-0.0031f, -0.0018f },
|
||||
{ 15.0000f, 15.0000f, 10.8018f, 6.4817f, 3.8893f, 2.3338f, 1.4004f, 0.8403f, 0.5042f, 0.3026f,
|
||||
0.1816f, 0.1089f, 0.0654f, 0.0392f, 0.0235f, 0.0141f, 0.0085f, 0.0051f, 0.0031f, 0.0018f });
|
||||
}
|
||||
|
||||
|
||||
@ -85,6 +85,6 @@ def test_aft_survival_demo_data():
|
||||
# AFT metric (negative log likelihood) improve monotonically
|
||||
assert all(p >= q for p, q in zip(nloglik_rec[dist], nloglik_rec[dist][:1]))
|
||||
# For this data, normal distribution works the best
|
||||
assert nloglik_rec['normal'][-1] < 5.0
|
||||
assert nloglik_rec['logistic'][-1] > 5.0
|
||||
assert nloglik_rec['extreme'][-1] > 5.0
|
||||
assert nloglik_rec['normal'][-1] < 4.9
|
||||
assert nloglik_rec['logistic'][-1] > 4.9
|
||||
assert nloglik_rec['extreme'][-1] > 4.9
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user