diff --git a/src/metric/elementwise_metric.cu b/src/metric/elementwise_metric.cu index b8322d647..25492bf2c 100644 --- a/src/metric/elementwise_metric.cu +++ b/src/metric/elementwise_metric.cu @@ -304,13 +304,17 @@ struct EvalGammaNLogLik { } XGBOOST_DEVICE bst_float EvalRow(bst_float y, bst_float py) const { - const bst_float eps = 1e-16f; - if (y < eps) y = eps; - bst_float psi = 1.0; + py = std::max(py, 1e-6f); + // hardcoded dispersion. + float constexpr kPsi = 1.0; bst_float theta = -1. / py; - bst_float a = psi; - bst_float b = -std::log(-theta); - bst_float c = 1. / psi * std::log(y/psi) - std::log(y) - common::LogGamma(1. / psi); + bst_float a = kPsi; + // b = -std::log(-theta); + float b = 1.0f; + // c = 1. / kPsi * std::log(y/kPsi) - std::log(y) - common::LogGamma(1. / kPsi); + // = 1.0f * std::log(y) - std::log(y) - 0 = 0 + float c = 0; + // general form for exponential family. return -((y * theta - b) / a + c); } static bst_float GetFinal(bst_float esum, bst_float wsum) { diff --git a/src/objective/regression_obj.cu b/src/objective/regression_obj.cu index a5bdc47ac..89bbb5081 100644 --- a/src/objective/regression_obj.cu +++ b/src/objective/regression_obj.cu @@ -404,7 +404,7 @@ class GammaRegression : public ObjFunction { bst_float p = _preds[_idx]; bst_float w = is_null_weight ? 1.0f : _weights[_idx]; bst_float y = _labels[_idx]; - if (y < 0.0f) { + if (y <= 0.0f) { _label_correct[0] = 0; } _out_gpair[_idx] = GradientPair((1 - y / expf(p)) * w, y / expf(p) * w); @@ -416,7 +416,7 @@ class GammaRegression : public ObjFunction { std::vector& label_correct_h = label_correct_.HostVector(); for (auto const flag : label_correct_h) { if (flag == 0) { - LOG(FATAL) << "GammaRegression: label must be nonnegative"; + LOG(FATAL) << "GammaRegression: label must be positive."; } } } diff --git a/tests/cpp/objective/test_regression_obj.cc b/tests/cpp/objective/test_regression_obj.cc index 2043bcdde..112670269 100644 --- a/tests/cpp/objective/test_regression_obj.cc +++ b/tests/cpp/objective/test_regression_obj.cc @@ -205,16 +205,16 @@ TEST(Objective, DeclareUnifiedTest(GammaRegressionGPair)) { obj->Configure(args); CheckObjFunction(obj, {0, 0.1f, 0.9f, 1, 0, 0.1f, 0.9f, 1}, - {0, 0, 0, 0, 1, 1, 1, 1}, - {1, 1, 1, 1, 1, 1, 1, 1}, - {1, 1, 1, 1, 0, 0.09f, 0.59f, 0.63f}, - {0, 0, 0, 0, 1, 0.90f, 0.40f, 0.36f}); + {2, 2, 2, 2, 1, 1, 1, 1}, + {1, 1, 1, 1, 1, 1, 1, 1}, + {-1, -0.809, 0.187, 0.264, 0, 0.09f, 0.59f, 0.63f}, + {2, 1.809, 0.813, 0.735, 1, 0.90f, 0.40f, 0.36f}); CheckObjFunction(obj, {0, 0.1f, 0.9f, 1, 0, 0.1f, 0.9f, 1}, - {0, 0, 0, 0, 1, 1, 1, 1}, + {2, 2, 2, 2, 1, 1, 1, 1}, {}, // Empty weight - {1, 1, 1, 1, 0, 0.09f, 0.59f, 0.63f}, - {0, 0, 0, 0, 1, 0.90f, 0.40f, 0.36f}); + {-1, -0.809, 0.187, 0.264, 0, 0.09f, 0.59f, 0.63f}, + {2, 1.809, 0.813, 0.735, 1, 0.90f, 0.40f, 0.36f}); } TEST(Objective, DeclareUnifiedTest(GammaRegressionBasic)) { @@ -228,7 +228,9 @@ TEST(Objective, DeclareUnifiedTest(GammaRegressionBasic)) { CheckConfigReload(obj, "reg:gamma"); // test label validation - EXPECT_ANY_THROW(CheckObjFunction(obj, {0}, {-1}, {1}, {0}, {0})) + EXPECT_ANY_THROW(CheckObjFunction(obj, {0}, {0}, {1}, {0}, {0})) + << "Expected error when label = 0 for GammaRegression"; + EXPECT_ANY_THROW(CheckObjFunction(obj, {-1}, {-1}, {1}, {-1}, {-3})) << "Expected error when label < 0 for GammaRegression"; // test ProbToMargin