Tweedie Regression Post-Rebase (#1737)
* add support for tweedie regression * added back readme line that was accidentally deleted * fixed linting errors * add support for tweedie regression * added back readme line that was accidentally deleted * fixed linting errors * rebased with upstream master and added R example * changed parameter name to tweedie_variance_power * linting error fix * refactored tweedie-nloglik metric to be more like the other parameterized metrics * added upper and lower bound check to tweedie metric * add support for tweedie regression * added back readme line that was accidentally deleted * fixed linting errors * added upper and lower bound check to tweedie metric * added back readme line that was accidentally deleted * rebased with upstream master and added R example * rebased again on top of upstream master * linting error fix * added upper and lower bound check to tweedie metric * rebased with master * lint fix * removed whitespace at end of line 186 - elementwise_metric.cc
This commit is contained in:
committed by
Tianqi Chen
parent
52b9867be5
commit
2ad0948444
@@ -163,6 +163,30 @@ struct EvalGammaNLogLik: public EvalEWiseBase<EvalGammaNLogLik> {
|
||||
}
|
||||
};
|
||||
|
||||
struct EvalTweedieNLogLik: public EvalEWiseBase<EvalTweedieNLogLik> {
|
||||
explicit EvalTweedieNLogLik(const char* param) {
|
||||
CHECK(param != nullptr)
|
||||
<< "tweedie-nloglik must be in format tweedie-nloglik@rho";
|
||||
rho_ = atof(param);
|
||||
CHECK(rho_ < 2 && rho_ >= 1)
|
||||
<< "tweedie variance power must be in interval [1, 2)";
|
||||
std::ostringstream os;
|
||||
os << "tweedie-nloglik@" << rho_;
|
||||
name_ = os.str();
|
||||
}
|
||||
const char *Name() const override {
|
||||
return name_.c_str();
|
||||
}
|
||||
inline float EvalRow(float y, float p) const {
|
||||
float a = y * std::exp((1 - rho_) * std::log(p)) / (1 - rho_);
|
||||
float b = std::exp((2 - rho_) * std::log(p)) / (2 - rho_);
|
||||
return -a + b;
|
||||
}
|
||||
protected:
|
||||
std::string name_;
|
||||
float rho_;
|
||||
};
|
||||
|
||||
XGBOOST_REGISTER_METRIC(RMSE, "rmse")
|
||||
.describe("Rooted mean square error.")
|
||||
.set_body([](const char* param) { return new EvalRMSE(); });
|
||||
@@ -191,5 +215,11 @@ XGBOOST_REGISTER_METRIC(GammaNLogLik, "gamma-nloglik")
|
||||
.describe("Negative log-likelihood for gamma regression.")
|
||||
.set_body([](const char* param) { return new EvalGammaNLogLik(); });
|
||||
|
||||
XGBOOST_REGISTER_METRIC(TweedieNLogLik, "tweedie-nloglik")
|
||||
.describe("tweedie-nloglik@rho for tweedie regression.")
|
||||
.set_body([](const char* param) {
|
||||
return new EvalTweedieNLogLik(param);
|
||||
});
|
||||
|
||||
} // namespace metric
|
||||
} // namespace xgboost
|
||||
|
||||
Reference in New Issue
Block a user