Add Model and Configurable interface. (#4945)
* Apply Configurable to objective functions. * Apply Model to Learner and Regtree, gbm. * Add Load/SaveConfig to objs. * Refactor obj tests to use smart pointer. * Dummy methods for Save/Load Model.
This commit is contained in:
@@ -34,6 +34,8 @@ struct LinearSquareLoss {
|
||||
static bst_float ProbToMargin(bst_float base_score) { return base_score; }
|
||||
static const char* LabelErrorMsg() { return ""; }
|
||||
static const char* DefaultEvalMetric() { return "rmse"; }
|
||||
|
||||
static const char* Name() { return "reg:squarederror"; }
|
||||
};
|
||||
|
||||
struct SquaredLogError {
|
||||
@@ -57,6 +59,8 @@ struct SquaredLogError {
|
||||
return "label must be greater than -1 for rmsle so that log(label + 1) can be valid.";
|
||||
}
|
||||
static const char* DefaultEvalMetric() { return "rmsle"; }
|
||||
|
||||
static const char* Name() { return "reg:squaredlogerror"; }
|
||||
};
|
||||
|
||||
// logistic loss for probability regression task
|
||||
@@ -83,18 +87,21 @@ struct LogisticRegression {
|
||||
}
|
||||
static bst_float ProbToMargin(bst_float base_score) {
|
||||
CHECK(base_score > 0.0f && base_score < 1.0f)
|
||||
<< "base_score must be in (0,1) for logistic loss";
|
||||
<< "base_score must be in (0,1) for logistic loss, got: " << base_score;
|
||||
return -logf(1.0f / base_score - 1.0f);
|
||||
}
|
||||
static const char* LabelErrorMsg() {
|
||||
return "label must be in [0,1] for logistic regression";
|
||||
}
|
||||
static const char* DefaultEvalMetric() { return "rmse"; }
|
||||
|
||||
static const char* Name() { return "reg:logistic"; }
|
||||
};
|
||||
|
||||
// logistic loss for binary classification task
|
||||
struct LogisticClassification : public LogisticRegression {
|
||||
static const char* DefaultEvalMetric() { return "error"; }
|
||||
static const char* Name() { return "binary:logistic"; }
|
||||
};
|
||||
|
||||
// logistic loss, but predict un-transformed margin
|
||||
@@ -125,6 +132,8 @@ struct LogisticRaw : public LogisticRegression {
|
||||
return std::max(predt * (T(1.0f) - predt), eps);
|
||||
}
|
||||
static const char* DefaultEvalMetric() { return "auc"; }
|
||||
|
||||
static const char* Name() { return "binary:logitraw"; }
|
||||
};
|
||||
|
||||
} // namespace obj
|
||||
|
||||
Reference in New Issue
Block a user