Pseudo-huber loss metric added (#5647)

- Add pseudo huber loss objective.
- Add pseudo huber loss metric.

Co-authored-by: Reetz <s02reetz@iavgroup.local>
This commit is contained in:
LionOrCatThatIsTheQuestion
2020-05-18 15:08:07 +02:00
committed by GitHub
parent 535479e69f
commit 83981a9ce3
6 changed files with 89 additions and 0 deletions

View File

@@ -98,6 +98,37 @@ struct LogisticRegression {
static const char* Name() { return "reg:logistic"; }
};
struct PseudoHuberError {
XGBOOST_DEVICE static bst_float PredTransform(bst_float x) {
return x;
}
XGBOOST_DEVICE static bool CheckLabel(bst_float label) {
return true;
}
XGBOOST_DEVICE static bst_float FirstOrderGradient(bst_float predt, bst_float label) {
const float z = predt - label;
const float scale_sqrt = std::sqrt(1 + std::pow(z, 2));
return z/scale_sqrt;
}
XGBOOST_DEVICE static bst_float SecondOrderGradient(bst_float predt, bst_float label) {
const float scale = 1 + std::pow(predt - label, 2);
const float scale_sqrt = std::sqrt(scale);
return 1/(scale*scale_sqrt);
}
static bst_float ProbToMargin(bst_float base_score) {
return base_score;
}
static const char* LabelErrorMsg() {
return "";
}
static const char* DefaultEvalMetric() {
return "mphe";
}
static const char* Name() {
return "reg:pseudohubererror";
}
};
// logistic loss for binary classification task
struct LogisticClassification : public LogisticRegression {
static const char* DefaultEvalMetric() { return "error"; }

View File

@@ -152,6 +152,10 @@ XGBOOST_REGISTER_OBJECTIVE(LogisticRegression, LogisticRegression::Name())
.describe("Logistic regression for probability regression task.")
.set_body([]() { return new RegLossObj<LogisticRegression>(); });
XGBOOST_REGISTER_OBJECTIVE(PseudoHuberError, PseudoHuberError::Name())
.describe("Regression Pseudo Huber error.")
.set_body([]() { return new RegLossObj<PseudoHuberError>(); });
XGBOOST_REGISTER_OBJECTIVE(LogisticClassification, LogisticClassification::Name())
.describe("Logistic regression for binary classification task.")
.set_body([]() { return new RegLossObj<LogisticClassification>(); });