Verify strictly positive labels for gamma regression. (#6778)
Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
parent
794fd6a46b
commit
1d90577800
@ -304,13 +304,17 @@ struct EvalGammaNLogLik {
|
||||
}
|
||||
|
||||
XGBOOST_DEVICE bst_float EvalRow(bst_float y, bst_float py) const {
|
||||
const bst_float eps = 1e-16f;
|
||||
if (y < eps) y = eps;
|
||||
bst_float psi = 1.0;
|
||||
py = std::max(py, 1e-6f);
|
||||
// hardcoded dispersion.
|
||||
float constexpr kPsi = 1.0;
|
||||
bst_float theta = -1. / py;
|
||||
bst_float a = psi;
|
||||
bst_float b = -std::log(-theta);
|
||||
bst_float c = 1. / psi * std::log(y/psi) - std::log(y) - common::LogGamma(1. / psi);
|
||||
bst_float a = kPsi;
|
||||
// b = -std::log(-theta);
|
||||
float b = 1.0f;
|
||||
// c = 1. / kPsi * std::log(y/kPsi) - std::log(y) - common::LogGamma(1. / kPsi);
|
||||
// = 1.0f * std::log(y) - std::log(y) - 0 = 0
|
||||
float c = 0;
|
||||
// general form for exponential family.
|
||||
return -((y * theta - b) / a + c);
|
||||
}
|
||||
static bst_float GetFinal(bst_float esum, bst_float wsum) {
|
||||
|
||||
@ -404,7 +404,7 @@ class GammaRegression : public ObjFunction {
|
||||
bst_float p = _preds[_idx];
|
||||
bst_float w = is_null_weight ? 1.0f : _weights[_idx];
|
||||
bst_float y = _labels[_idx];
|
||||
if (y < 0.0f) {
|
||||
if (y <= 0.0f) {
|
||||
_label_correct[0] = 0;
|
||||
}
|
||||
_out_gpair[_idx] = GradientPair((1 - y / expf(p)) * w, y / expf(p) * w);
|
||||
@ -416,7 +416,7 @@ class GammaRegression : public ObjFunction {
|
||||
std::vector<int>& label_correct_h = label_correct_.HostVector();
|
||||
for (auto const flag : label_correct_h) {
|
||||
if (flag == 0) {
|
||||
LOG(FATAL) << "GammaRegression: label must be nonnegative";
|
||||
LOG(FATAL) << "GammaRegression: label must be positive.";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -205,16 +205,16 @@ TEST(Objective, DeclareUnifiedTest(GammaRegressionGPair)) {
|
||||
obj->Configure(args);
|
||||
CheckObjFunction(obj,
|
||||
{0, 0.1f, 0.9f, 1, 0, 0.1f, 0.9f, 1},
|
||||
{0, 0, 0, 0, 1, 1, 1, 1},
|
||||
{2, 2, 2, 2, 1, 1, 1, 1},
|
||||
{1, 1, 1, 1, 1, 1, 1, 1},
|
||||
{1, 1, 1, 1, 0, 0.09f, 0.59f, 0.63f},
|
||||
{0, 0, 0, 0, 1, 0.90f, 0.40f, 0.36f});
|
||||
{-1, -0.809, 0.187, 0.264, 0, 0.09f, 0.59f, 0.63f},
|
||||
{2, 1.809, 0.813, 0.735, 1, 0.90f, 0.40f, 0.36f});
|
||||
CheckObjFunction(obj,
|
||||
{0, 0.1f, 0.9f, 1, 0, 0.1f, 0.9f, 1},
|
||||
{0, 0, 0, 0, 1, 1, 1, 1},
|
||||
{2, 2, 2, 2, 1, 1, 1, 1},
|
||||
{}, // Empty weight
|
||||
{1, 1, 1, 1, 0, 0.09f, 0.59f, 0.63f},
|
||||
{0, 0, 0, 0, 1, 0.90f, 0.40f, 0.36f});
|
||||
{-1, -0.809, 0.187, 0.264, 0, 0.09f, 0.59f, 0.63f},
|
||||
{2, 1.809, 0.813, 0.735, 1, 0.90f, 0.40f, 0.36f});
|
||||
}
|
||||
|
||||
TEST(Objective, DeclareUnifiedTest(GammaRegressionBasic)) {
|
||||
@ -228,7 +228,9 @@ TEST(Objective, DeclareUnifiedTest(GammaRegressionBasic)) {
|
||||
CheckConfigReload(obj, "reg:gamma");
|
||||
|
||||
// test label validation
|
||||
EXPECT_ANY_THROW(CheckObjFunction(obj, {0}, {-1}, {1}, {0}, {0}))
|
||||
EXPECT_ANY_THROW(CheckObjFunction(obj, {0}, {0}, {1}, {0}, {0}))
|
||||
<< "Expected error when label = 0 for GammaRegression";
|
||||
EXPECT_ANY_THROW(CheckObjFunction(obj, {-1}, {-1}, {1}, {-1}, {-3}))
|
||||
<< "Expected error when label < 0 for GammaRegression";
|
||||
|
||||
// test ProbToMargin
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user