Fix gamma deviance (#6761)
This commit is contained in:
@@ -274,18 +274,27 @@ struct EvalPoissonNegLogLik {
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Gamma deviance
|
||||
*
|
||||
* Expected input:
|
||||
* label >= 0
|
||||
* predt >= 0
|
||||
*/
|
||||
struct EvalGammaDeviance {
|
||||
const char *Name() const {
|
||||
return "gamma-deviance";
|
||||
const char *Name() const { return "gamma-deviance"; }
|
||||
|
||||
XGBOOST_DEVICE bst_float EvalRow(bst_float label, bst_float predt) const {
|
||||
predt += kRtEps;
|
||||
label += kRtEps;
|
||||
return std::log(predt / label) + label / predt - 1;
|
||||
}
|
||||
|
||||
XGBOOST_DEVICE bst_float EvalRow(bst_float label, bst_float pred) const {
|
||||
bst_float epsilon = 1.0e-9;
|
||||
bst_float tmp = label / (pred + epsilon);
|
||||
return tmp - std::log(tmp) - 1;
|
||||
}
|
||||
static bst_float GetFinal(bst_float esum, bst_float wsum) {
|
||||
return 2 * esum;
|
||||
if (wsum <= 0) {
|
||||
wsum = kRtEps;
|
||||
}
|
||||
return 2 * esum / wsum;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user