Add rmsle metric and reg:squaredlogerror objective (#4541)
This commit is contained in:
@@ -153,6 +153,19 @@ struct EvalRowRMSE {
|
||||
}
|
||||
};
|
||||
|
||||
struct EvalRowRMSLE {
|
||||
char const* Name() const {
|
||||
return "rmsle";
|
||||
}
|
||||
XGBOOST_DEVICE bst_float EvalRow(bst_float label, bst_float pred) const {
|
||||
bst_float diff = std::log1p(label) - std::log1p(pred);
|
||||
return diff * diff;
|
||||
}
|
||||
static bst_float GetFinal(bst_float esum, bst_float wsum) {
|
||||
return std::sqrt(esum / wsum);
|
||||
}
|
||||
};
|
||||
|
||||
struct EvalRowMAE {
|
||||
const char *Name() const {
|
||||
return "mae";
|
||||
@@ -349,6 +362,10 @@ XGBOOST_REGISTER_METRIC(RMSE, "rmse")
|
||||
.describe("Rooted mean square error.")
|
||||
.set_body([](const char* param) { return new EvalEWiseBase<EvalRowRMSE>(); });
|
||||
|
||||
XGBOOST_REGISTER_METRIC(RMSLE, "rmsle")
|
||||
.describe("Rooted mean square log error.")
|
||||
.set_body([](const char* param) { return new EvalEWiseBase<EvalRowRMSLE>(); });
|
||||
|
||||
XGBOOST_REGISTER_METRIC(MAE, "mae")
|
||||
.describe("Mean absolute error.")
|
||||
.set_body([](const char* param) { return new EvalEWiseBase<EvalRowMAE>(); });
|
||||
|
||||
Reference in New Issue
Block a user