Fix gamma deviance (#6761)

This commit is contained in:
Jiaming Yuan
2021-03-20 01:56:17 +08:00
committed by GitHub
parent c2b6b80600
commit 23b4165a6b
2 changed files with 37 additions and 8 deletions

View File

@@ -274,18 +274,27 @@ struct EvalPoissonNegLogLik {
}
};
/**
* Gamma deviance
*
* Expected input:
* label >= 0
* predt >= 0
*/
struct EvalGammaDeviance {
const char *Name() const {
return "gamma-deviance";
const char *Name() const { return "gamma-deviance"; }
XGBOOST_DEVICE bst_float EvalRow(bst_float label, bst_float predt) const {
predt += kRtEps;
label += kRtEps;
return std::log(predt / label) + label / predt - 1;
}
XGBOOST_DEVICE bst_float EvalRow(bst_float label, bst_float pred) const {
bst_float epsilon = 1.0e-9;
bst_float tmp = label / (pred + epsilon);
return tmp - std::log(tmp) - 1;
}
static bst_float GetFinal(bst_float esum, bst_float wsum) {
return 2 * esum;
if (wsum <= 0) {
wsum = kRtEps;
}
return 2 * esum / wsum;
}
};