Add Model and Configurable interface. (#4945)

* Apply Configurable to objective functions.
* Apply Model to Learner and Regtree, gbm.
* Add Load/SaveConfig to objs.
* Refactor obj tests to use smart pointer.
* Dummy methods for Save/Load Model.
This commit is contained in:
Jiaming Yuan
2019-10-18 01:56:02 -04:00
committed by GitHub
parent 9fc681001a
commit ae536756ae
31 changed files with 521 additions and 187 deletions

View File

@@ -34,6 +34,8 @@ struct LinearSquareLoss {
static bst_float ProbToMargin(bst_float base_score) { return base_score; }
static const char* LabelErrorMsg() { return ""; }
static const char* DefaultEvalMetric() { return "rmse"; }
static const char* Name() { return "reg:squarederror"; }
};
struct SquaredLogError {
@@ -57,6 +59,8 @@ struct SquaredLogError {
return "label must be greater than -1 for rmsle so that log(label + 1) can be valid.";
}
static const char* DefaultEvalMetric() { return "rmsle"; }
static const char* Name() { return "reg:squaredlogerror"; }
};
// logistic loss for probability regression task
@@ -83,18 +87,21 @@ struct LogisticRegression {
}
static bst_float ProbToMargin(bst_float base_score) {
CHECK(base_score > 0.0f && base_score < 1.0f)
<< "base_score must be in (0,1) for logistic loss";
<< "base_score must be in (0,1) for logistic loss, got: " << base_score;
return -logf(1.0f / base_score - 1.0f);
}
static const char* LabelErrorMsg() {
return "label must be in [0,1] for logistic regression";
}
static const char* DefaultEvalMetric() { return "rmse"; }
static const char* Name() { return "reg:logistic"; }
};
// logistic loss for binary classification task
struct LogisticClassification : public LogisticRegression {
static const char* DefaultEvalMetric() { return "error"; }
static const char* Name() { return "binary:logistic"; }
};
// logistic loss, but predict un-transformed margin
@@ -125,6 +132,8 @@ struct LogisticRaw : public LogisticRegression {
return std::max(predt * (T(1.0f) - predt), eps);
}
static const char* DefaultEvalMetric() { return "auc"; }
static const char* Name() { return "binary:logitraw"; }
};
} // namespace obj