Add rmsle metric and reg:squaredlogerror objective (#4541)

This commit is contained in:
Jiaming Yuan
2019-06-11 05:48:27 +08:00
committed by GitHub
parent 9683fd433e
commit 2f1319f273
7 changed files with 92 additions and 9 deletions

View File

@@ -153,6 +153,19 @@ struct EvalRowRMSE {
}
};
struct EvalRowRMSLE {
char const* Name() const {
return "rmsle";
}
XGBOOST_DEVICE bst_float EvalRow(bst_float label, bst_float pred) const {
bst_float diff = std::log1p(label) - std::log1p(pred);
return diff * diff;
}
static bst_float GetFinal(bst_float esum, bst_float wsum) {
return std::sqrt(esum / wsum);
}
};
struct EvalRowMAE {
const char *Name() const {
return "mae";
@@ -349,6 +362,10 @@ XGBOOST_REGISTER_METRIC(RMSE, "rmse")
.describe("Rooted mean square error.")
.set_body([](const char* param) { return new EvalEWiseBase<EvalRowRMSE>(); });
XGBOOST_REGISTER_METRIC(RMSLE, "rmsle")
.describe("Rooted mean square log error.")
.set_body([](const char* param) { return new EvalEWiseBase<EvalRowRMSLE>(); });
XGBOOST_REGISTER_METRIC(MAE, "mae")
.describe("Mean absolute error.")
.set_body([](const char* param) { return new EvalEWiseBase<EvalRowMAE>(); });