Tweedie Regression Post-Rebase (#1737)
* add support for tweedie regression * added back readme line that was accidentally deleted * fixed linting errors * add support for tweedie regression * added back readme line that was accidentally deleted * fixed linting errors * rebased with upstream master and added R example * changed parameter name to tweedie_variance_power * linting error fix * refactored tweedie-nloglik metric to be more like the other parameterized metrics * added upper and lower bound check to tweedie metric * add support for tweedie regression * added back readme line that was accidentally deleted * fixed linting errors * added upper and lower bound check to tweedie metric * added back readme line that was accidentally deleted * rebased with upstream master and added R example * rebased again on top of upstream master * linting error fix * added upper and lower bound check to tweedie metric * rebased with master * lint fix * removed whitespace at end of line 186 - elementwise_metric.cc
This commit is contained in:
committed by
Tianqi Chen
parent
52b9867be5
commit
2ad0948444
@@ -272,5 +272,75 @@ class GammaRegression : public ObjFunction {
|
||||
XGBOOST_REGISTER_OBJECTIVE(GammaRegression, "reg:gamma")
|
||||
.describe("Gamma regression for severity data.")
|
||||
.set_body([]() { return new GammaRegression(); });
|
||||
|
||||
// declare parameter
|
||||
struct TweedieRegressionParam : public dmlc::Parameter<TweedieRegressionParam> {
|
||||
float tweedie_variance_power;
|
||||
DMLC_DECLARE_PARAMETER(TweedieRegressionParam) {
|
||||
DMLC_DECLARE_FIELD(tweedie_variance_power).set_range(1.0f, 2.0f).set_default(1.5f)
|
||||
.describe("Tweedie variance power. Must be between in range [1, 2).");
|
||||
}
|
||||
};
|
||||
|
||||
// tweedie regression
|
||||
class TweedieRegression : public ObjFunction {
|
||||
public:
|
||||
// declare functions
|
||||
void Configure(const std::vector<std::pair<std::string, std::string> >& args) override {
|
||||
param_.InitAllowUnknown(args);
|
||||
}
|
||||
|
||||
void GetGradient(const std::vector<float> &preds,
|
||||
const MetaInfo &info,
|
||||
int iter,
|
||||
std::vector<bst_gpair> *out_gpair) override {
|
||||
CHECK_NE(info.labels.size(), 0) << "label set cannot be empty";
|
||||
CHECK_EQ(preds.size(), info.labels.size()) << "labels are not correctly provided";
|
||||
out_gpair->resize(preds.size());
|
||||
// check if label in range
|
||||
bool label_correct = true;
|
||||
// start calculating gradient
|
||||
const omp_ulong ndata = static_cast<omp_ulong>(preds.size()); // NOLINT(*)
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (omp_ulong i = 0; i < ndata; ++i) { // NOLINT(*)
|
||||
float p = preds[i];
|
||||
float w = info.GetWeight(i);
|
||||
float y = info.labels[i];
|
||||
float rho = param_.tweedie_variance_power;
|
||||
if (y >= 0.0f) {
|
||||
float grad = -y * std::exp((1 - rho) * p) + std::exp((2 - rho) * p);
|
||||
float hess = -y * (1 - rho) * std::exp((1 - rho) * p) + (2 - rho) * std::exp((2 - rho) * p);
|
||||
out_gpair->at(i) = bst_gpair(grad * w, hess * w);
|
||||
} else {
|
||||
label_correct = false;
|
||||
}
|
||||
}
|
||||
CHECK(label_correct) << "TweedieRegression: label must be nonnegative";
|
||||
}
|
||||
void PredTransform(std::vector<float> *io_preds) override {
|
||||
std::vector<float> &preds = *io_preds;
|
||||
const long ndata = static_cast<long>(preds.size()); // NOLINT(*)
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (long j = 0; j < ndata; ++j) { // NOLINT(*)
|
||||
preds[j] = std::exp(preds[j]);
|
||||
}
|
||||
}
|
||||
const char* DefaultEvalMetric(void) const override {
|
||||
std::ostringstream os;
|
||||
os << "tweedie-nloglik@" << param_.tweedie_variance_power;
|
||||
std::string metric = os.str();
|
||||
return metric.c_str();
|
||||
}
|
||||
|
||||
private:
|
||||
TweedieRegressionParam param_;
|
||||
};
|
||||
|
||||
// register the ojective functions
|
||||
DMLC_REGISTER_PARAMETER(TweedieRegressionParam);
|
||||
|
||||
XGBOOST_REGISTER_OBJECTIVE(TweedieRegression, "reg:tweedie")
|
||||
.describe("Tweedie regression for insurance data.")
|
||||
.set_body([]() { return new TweedieRegression(); });
|
||||
} // namespace obj
|
||||
} // namespace xgboost
|
||||
|
||||
Reference in New Issue
Block a user