Verify strictly positive labels for gamma regression. (#6778)

Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
Jiaming Yuan
2021-03-25 11:46:52 +08:00
committed by GitHub
parent 794fd6a46b
commit 1d90577800
3 changed files with 22 additions and 16 deletions

View File

@@ -304,13 +304,17 @@ struct EvalGammaNLogLik {
}
XGBOOST_DEVICE bst_float EvalRow(bst_float y, bst_float py) const {
const bst_float eps = 1e-16f;
if (y < eps) y = eps;
bst_float psi = 1.0;
py = std::max(py, 1e-6f);
// hardcoded dispersion.
float constexpr kPsi = 1.0;
bst_float theta = -1. / py;
bst_float a = psi;
bst_float b = -std::log(-theta);
bst_float c = 1. / psi * std::log(y/psi) - std::log(y) - common::LogGamma(1. / psi);
bst_float a = kPsi;
// b = -std::log(-theta);
float b = 1.0f;
// c = 1. / kPsi * std::log(y/kPsi) - std::log(y) - common::LogGamma(1. / kPsi);
// = 1.0f * std::log(y) - std::log(y) - 0 = 0
float c = 0;
// general form for exponential family.
return -((y * theta - b) / a + c);
}
static bst_float GetFinal(bst_float esum, bst_float wsum) {

View File

@@ -404,7 +404,7 @@ class GammaRegression : public ObjFunction {
bst_float p = _preds[_idx];
bst_float w = is_null_weight ? 1.0f : _weights[_idx];
bst_float y = _labels[_idx];
if (y < 0.0f) {
if (y <= 0.0f) {
_label_correct[0] = 0;
}
_out_gpair[_idx] = GradientPair((1 - y / expf(p)) * w, y / expf(p) * w);
@@ -416,7 +416,7 @@ class GammaRegression : public ObjFunction {
std::vector<int>& label_correct_h = label_correct_.HostVector();
for (auto const flag : label_correct_h) {
if (flag == 0) {
LOG(FATAL) << "GammaRegression: label must be nonnegative";
LOG(FATAL) << "GammaRegression: label must be positive.";
}
}
}