Tweedie Regression Post-Rebase (#1737)

* add support for tweedie regression

* added back readme line that was accidentally deleted

* fixed linting errors

* add support for tweedie regression

* added back readme line that was accidentally deleted

* fixed linting errors

* rebased with upstream master and added R example

* changed parameter name to tweedie_variance_power

* linting error fix

* refactored tweedie-nloglik metric to be more like the other parameterized metrics

* added upper and lower bound check to tweedie metric

* add support for tweedie regression

* added back readme line that was accidentally deleted

* fixed linting errors

* added upper and lower bound check to tweedie metric

* added back readme line that was accidentally deleted

* rebased with upstream master and added R example

* rebased again on top of upstream master

* linting error fix

* added upper and lower bound check to tweedie metric

* rebased with master

* lint fix

* removed whitespace at end of line 186 - elementwise_metric.cc
This commit is contained in:
Tony DiFranco
2016-11-05 20:02:32 -04:00
committed by Tianqi Chen
parent 52b9867be5
commit 2ad0948444
4 changed files with 156 additions and 0 deletions

View File

@@ -163,6 +163,30 @@ struct EvalGammaNLogLik: public EvalEWiseBase<EvalGammaNLogLik> {
}
};
struct EvalTweedieNLogLik: public EvalEWiseBase<EvalTweedieNLogLik> {
explicit EvalTweedieNLogLik(const char* param) {
CHECK(param != nullptr)
<< "tweedie-nloglik must be in format tweedie-nloglik@rho";
rho_ = atof(param);
CHECK(rho_ < 2 && rho_ >= 1)
<< "tweedie variance power must be in interval [1, 2)";
std::ostringstream os;
os << "tweedie-nloglik@" << rho_;
name_ = os.str();
}
const char *Name() const override {
return name_.c_str();
}
inline float EvalRow(float y, float p) const {
float a = y * std::exp((1 - rho_) * std::log(p)) / (1 - rho_);
float b = std::exp((2 - rho_) * std::log(p)) / (2 - rho_);
return -a + b;
}
protected:
std::string name_;
float rho_;
};
XGBOOST_REGISTER_METRIC(RMSE, "rmse")
.describe("Rooted mean square error.")
.set_body([](const char* param) { return new EvalRMSE(); });
@@ -191,5 +215,11 @@ XGBOOST_REGISTER_METRIC(GammaNLogLik, "gamma-nloglik")
.describe("Negative log-likelihood for gamma regression.")
.set_body([](const char* param) { return new EvalGammaNLogLik(); });
XGBOOST_REGISTER_METRIC(TweedieNLogLik, "tweedie-nloglik")
.describe("tweedie-nloglik@rho for tweedie regression.")
.set_body([](const char* param) {
return new EvalTweedieNLogLik(param);
});
} // namespace metric
} // namespace xgboost