diff --git a/demo/data/gen_autoclaims.R b/demo/data/gen_autoclaims.R new file mode 100644 index 000000000..5465db09c --- /dev/null +++ b/demo/data/gen_autoclaims.R @@ -0,0 +1,18 @@ +site <- 'http://cran.r-project.org' +if (!require('dummies')) + install.packages('dummies', repos=site) +if (!require('insuranceData')) + install.packages('insuranceData', repos=site) + +library(dummies) +library(insuranceData) + +data(AutoClaims) +data = AutoClaims + +data$STATE = as.factor(data$STATE) +data$CLASS = as.factor(data$CLASS) +data$GENDER = as.factor(data$GENDER) + +data.dummy <- dummy.data.frame(data, dummy.class='factor', omit.constants=T); +write.table(data.dummy, 'autoclaims.csv', sep=',', row.names=F, col.names=F, quote=F) diff --git a/demo/guide-python/gamma_regression.py b/demo/guide-python/gamma_regression.py new file mode 100755 index 000000000..faf58c2ad --- /dev/null +++ b/demo/guide-python/gamma_regression.py @@ -0,0 +1,25 @@ +#!/usr/bin/python +import xgboost as xgb +import numpy as np + +# this script demonstrates how to fit gamma regression model (with log link function) +# in xgboost, before running the demo you need to generate the autoclaims dataset +# by running gen_autoclaims.R located in xgboost/demo/data. + +data = np.genfromtxt('../data/autoclaims.csv', delimiter=',') +dtrain = xgb.DMatrix(data[0:4741, 0:34], data[0:4741, 34]) +dtest = xgb.DMatrix(data[4741:6773, 0:34], data[4741:6773, 34]) + +# for gamma regression, we need to set the objective to 'reg:gamma', it also suggests +# to set the base_score to a value between 1 to 5 if the number of iteration is small +param = {'silent':1, 'objective':'reg:gamma', 'booster':'gbtree', 'base_score':3} + +# the rest of settings are the same +watchlist = [(dtest,'eval'), (dtrain,'train')] +num_round = 30 + +# training and evaluation +bst = xgb.train(param, dtrain, num_round, watchlist) +preds = bst.predict(dtest) +labels = dtest.get_label() +print ('test deviance=%f' % (2 * np.sum((labels - preds) / preds - np.log(labels) + np.log(preds)))) diff --git a/doc/parameter.md b/doc/parameter.md index 644de6076..f3bccd001 100644 --- a/doc/parameter.md +++ b/doc/parameter.md @@ -119,6 +119,7 @@ Specify the learning task and the corresponding learning objective. The objectiv - "multi:softmax" --set XGBoost to do multiclass classification using the softmax objective, you also need to set num_class(number of classes) - "multi:softprob" --same as softmax, but output a vector of ndata * nclass, which can be further reshaped to ndata, nclass matrix. The result contains predicted probability of each data point belonging to each class. - "rank:pairwise" --set XGBoost to do ranking task by minimizing the pairwise loss + - "reg:gamma" --gamma regression for severity data, output mean of gamma distribution * base_score [ default=0.5 ] - the initial prediction score of all instances, global bias - for sufficent number of iterations, changing this value will not have too much effect. diff --git a/src/metric/elementwise_metric.cc b/src/metric/elementwise_metric.cc index 84baf671b..db2869efb 100644 --- a/src/metric/elementwise_metric.cc +++ b/src/metric/elementwise_metric.cc @@ -155,5 +155,87 @@ XGBOOST_REGISTER_METRIC(PossionNegLoglik, "poisson-nloglik") .describe("Negative loglikelihood for poisson regression.") .set_body([](const char* param) { return new EvalPoissionNegLogLik(); }); +/*! + * \brief base class of element-wise evaluation + * with additonal dispersion parameter + * \tparam Derived the name of subclass + */ +template +struct EvalEWiseBase2 : public Metric { + float Eval(const std::vector& preds, + const MetaInfo& info, + bool distributed) const override { + CHECK_NE(info.labels.size(), 0) << "label set cannot be empty"; + CHECK_EQ(preds.size(), info.labels.size()) + << "label and prediction size not match, " + << "hint: use merror or mlogloss for multi-class classification"; + const omp_ulong ndata = static_cast(info.labels.size()); + + // Computer dispersion + double sum = 0.0, wsum = 0.0; + #pragma omp parallel for schedule(static) + for (omp_ulong i = 0; i < ndata; ++i) { + const float wt = info.GetWeight(i); + sum += static_cast(this)->EvalDispersion(info.labels[i], preds[i]) * wt; + wsum += wt; + } + double dat[2]; dat[0] = sum, dat[1] = wsum; + if (distributed) { + rabit::Allreduce(dat, 2); + } + double dispersion = dat[0] / (dat[1] - info.num_col); + + // Computer metric + sum = 0.0, wsum = 0.0; + #pragma omp parallel for reduction(+: sum, wsum) schedule(static) + for (omp_ulong i = 0; i < ndata; ++i) { + const float wt = info.GetWeight(i); + sum += static_cast(this)->EvalRow(info.labels[i], preds[i], dispersion) * wt; + wsum += wt; + } + dat[0] = sum, dat[1] = wsum; + if (distributed) { + rabit::Allreduce(dat, 2); + } + return Derived::GetFinal(dat[0], dat[1]); + } + /*! + * \brief to be implemented by subclass, + * get evaluation result from one row + * \param label label of current instance + * \param pred prediction value of current instance + */ + inline float EvalRow(float label, float pred, float dispersion) const; + /*! + * \brief to be overridden by subclass, final transformation + * \param esum the sum statistics returned by EvalRow + * \param wsum sum of weight + */ + inline static float GetFinal(float esum, float wsum) { + return esum / wsum; + } + inline float EvalDispersion(float label, float pred) const; +}; + +struct EvalGammaNegLogLik : public EvalEWiseBase2 { + const char *Name() const override { + return "gamma-nloglik"; + } + inline float EvalRow(float y, float py, float psi) const { + double theta = -1. / py; + double a = psi; + double b = -std::log(-theta); + double c = 1. / psi * std::log(y/psi) - std::log(y) - common::LogGamma(1. / psi); + return -((y * theta - b) / a + c); + } + inline float EvalDispersion(float y, float py) const { + return ((y - py) * (y - py)) / (py * py); + } +}; + +XGBOOST_REGISTER_METRIC(GammaNegLoglik, "gamma-nloglik") +.describe("Negative loglikelihood for gamma regression.") +.set_body([](const char* param) { return new EvalGammaNegLogLik(); }); + } // namespace metric } // namespace xgboost diff --git a/src/objective/regression_obj.cc b/src/objective/regression_obj.cc index 6eb0f0a78..743ff083f 100644 --- a/src/objective/regression_obj.cc +++ b/src/objective/regression_obj.cc @@ -217,5 +217,60 @@ DMLC_REGISTER_PARAMETER(PoissonRegressionParam); XGBOOST_REGISTER_OBJECTIVE(PoissonRegression, "count:poisson") .describe("Possion regression for count data.") .set_body([]() { return new PoissonRegression(); }); + +// gamma regression +class GammaRegression : public ObjFunction { + public: + // declare functions + void Configure(const std::vector >& args) override { + } + + void GetGradient(const std::vector &preds, + const MetaInfo &info, + int iter, + std::vector *out_gpair) override { + CHECK_NE(info.labels.size(), 0) << "label set cannot be empty"; + CHECK_EQ(preds.size(), info.labels.size()) << "labels are not correctly provided"; + out_gpair->resize(preds.size()); + // check if label in range + bool label_correct = true; + // start calculating gradient + const omp_ulong ndata = static_cast(preds.size()); // NOLINT(*) + #pragma omp parallel for schedule(static) + for (omp_ulong i = 0; i < ndata; ++i) { // NOLINT(*) + float p = preds[i]; + float w = info.GetWeight(i); + float y = info.labels[i]; + if (y >= 0.0f) { + out_gpair->at(i) = bst_gpair((1 - y / std::exp(p)) * w, y / std::exp(p) * w); + } else { + label_correct = false; + } + } + CHECK(label_correct) << "GammaRegression: label must be positive"; + } + void PredTransform(std::vector *io_preds) override { + std::vector &preds = *io_preds; + const long ndata = static_cast(preds.size()); // NOLINT(*) + #pragma omp parallel for schedule(static) + for (long j = 0; j < ndata; ++j) { // NOLINT(*) + preds[j] = std::exp(preds[j]); + } + } + void EvalTransform(std::vector *io_preds) override { + PredTransform(io_preds); + } + float ProbToMargin(float base_score) const override { + return std::log(base_score); + } + const char* DefaultEvalMetric(void) const override { + return "gamma-nloglik"; + } +}; + +// register the ojective functions +XGBOOST_REGISTER_OBJECTIVE(GammaRegression, "reg:gamma") +.describe("Gamma regression for severity data.") +.set_body([]() { return new GammaRegression(); }); } // namespace obj } // namespace xgboost