Add support for Gamma regression (#1258)
* Add support for Gamma regression * Use base_score to replace the lp_bias * Remove the lp_bias config block * Add a demo for running gamma regression in Python * Typo fix * Revise the description for objective * Add a script to generate the autoclaims dataset
This commit is contained in:
parent
f74e2439e0
commit
77d17f6264
18
demo/data/gen_autoclaims.R
Normal file
18
demo/data/gen_autoclaims.R
Normal file
@ -0,0 +1,18 @@
|
||||
site <- 'http://cran.r-project.org'
|
||||
if (!require('dummies'))
|
||||
install.packages('dummies', repos=site)
|
||||
if (!require('insuranceData'))
|
||||
install.packages('insuranceData', repos=site)
|
||||
|
||||
library(dummies)
|
||||
library(insuranceData)
|
||||
|
||||
data(AutoClaims)
|
||||
data = AutoClaims
|
||||
|
||||
data$STATE = as.factor(data$STATE)
|
||||
data$CLASS = as.factor(data$CLASS)
|
||||
data$GENDER = as.factor(data$GENDER)
|
||||
|
||||
data.dummy <- dummy.data.frame(data, dummy.class='factor', omit.constants=T);
|
||||
write.table(data.dummy, 'autoclaims.csv', sep=',', row.names=F, col.names=F, quote=F)
|
||||
25
demo/guide-python/gamma_regression.py
Executable file
25
demo/guide-python/gamma_regression.py
Executable file
@ -0,0 +1,25 @@
|
||||
#!/usr/bin/python
|
||||
import xgboost as xgb
|
||||
import numpy as np
|
||||
|
||||
# this script demonstrates how to fit gamma regression model (with log link function)
|
||||
# in xgboost, before running the demo you need to generate the autoclaims dataset
|
||||
# by running gen_autoclaims.R located in xgboost/demo/data.
|
||||
|
||||
data = np.genfromtxt('../data/autoclaims.csv', delimiter=',')
|
||||
dtrain = xgb.DMatrix(data[0:4741, 0:34], data[0:4741, 34])
|
||||
dtest = xgb.DMatrix(data[4741:6773, 0:34], data[4741:6773, 34])
|
||||
|
||||
# for gamma regression, we need to set the objective to 'reg:gamma', it also suggests
|
||||
# to set the base_score to a value between 1 to 5 if the number of iteration is small
|
||||
param = {'silent':1, 'objective':'reg:gamma', 'booster':'gbtree', 'base_score':3}
|
||||
|
||||
# the rest of settings are the same
|
||||
watchlist = [(dtest,'eval'), (dtrain,'train')]
|
||||
num_round = 30
|
||||
|
||||
# training and evaluation
|
||||
bst = xgb.train(param, dtrain, num_round, watchlist)
|
||||
preds = bst.predict(dtest)
|
||||
labels = dtest.get_label()
|
||||
print ('test deviance=%f' % (2 * np.sum((labels - preds) / preds - np.log(labels) + np.log(preds))))
|
||||
@ -119,6 +119,7 @@ Specify the learning task and the corresponding learning objective. The objectiv
|
||||
- "multi:softmax" --set XGBoost to do multiclass classification using the softmax objective, you also need to set num_class(number of classes)
|
||||
- "multi:softprob" --same as softmax, but output a vector of ndata * nclass, which can be further reshaped to ndata, nclass matrix. The result contains predicted probability of each data point belonging to each class.
|
||||
- "rank:pairwise" --set XGBoost to do ranking task by minimizing the pairwise loss
|
||||
- "reg:gamma" --gamma regression for severity data, output mean of gamma distribution
|
||||
* base_score [ default=0.5 ]
|
||||
- the initial prediction score of all instances, global bias
|
||||
- for sufficent number of iterations, changing this value will not have too much effect.
|
||||
|
||||
@ -155,5 +155,87 @@ XGBOOST_REGISTER_METRIC(PossionNegLoglik, "poisson-nloglik")
|
||||
.describe("Negative loglikelihood for poisson regression.")
|
||||
.set_body([](const char* param) { return new EvalPoissionNegLogLik(); });
|
||||
|
||||
/*!
|
||||
* \brief base class of element-wise evaluation
|
||||
* with additonal dispersion parameter
|
||||
* \tparam Derived the name of subclass
|
||||
*/
|
||||
template<typename Derived>
|
||||
struct EvalEWiseBase2 : public Metric {
|
||||
float Eval(const std::vector<float>& preds,
|
||||
const MetaInfo& info,
|
||||
bool distributed) const override {
|
||||
CHECK_NE(info.labels.size(), 0) << "label set cannot be empty";
|
||||
CHECK_EQ(preds.size(), info.labels.size())
|
||||
<< "label and prediction size not match, "
|
||||
<< "hint: use merror or mlogloss for multi-class classification";
|
||||
const omp_ulong ndata = static_cast<omp_ulong>(info.labels.size());
|
||||
|
||||
// Computer dispersion
|
||||
double sum = 0.0, wsum = 0.0;
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (omp_ulong i = 0; i < ndata; ++i) {
|
||||
const float wt = info.GetWeight(i);
|
||||
sum += static_cast<const Derived*>(this)->EvalDispersion(info.labels[i], preds[i]) * wt;
|
||||
wsum += wt;
|
||||
}
|
||||
double dat[2]; dat[0] = sum, dat[1] = wsum;
|
||||
if (distributed) {
|
||||
rabit::Allreduce<rabit::op::Sum>(dat, 2);
|
||||
}
|
||||
double dispersion = dat[0] / (dat[1] - info.num_col);
|
||||
|
||||
// Computer metric
|
||||
sum = 0.0, wsum = 0.0;
|
||||
#pragma omp parallel for reduction(+: sum, wsum) schedule(static)
|
||||
for (omp_ulong i = 0; i < ndata; ++i) {
|
||||
const float wt = info.GetWeight(i);
|
||||
sum += static_cast<const Derived*>(this)->EvalRow(info.labels[i], preds[i], dispersion) * wt;
|
||||
wsum += wt;
|
||||
}
|
||||
dat[0] = sum, dat[1] = wsum;
|
||||
if (distributed) {
|
||||
rabit::Allreduce<rabit::op::Sum>(dat, 2);
|
||||
}
|
||||
return Derived::GetFinal(dat[0], dat[1]);
|
||||
}
|
||||
/*!
|
||||
* \brief to be implemented by subclass,
|
||||
* get evaluation result from one row
|
||||
* \param label label of current instance
|
||||
* \param pred prediction value of current instance
|
||||
*/
|
||||
inline float EvalRow(float label, float pred, float dispersion) const;
|
||||
/*!
|
||||
* \brief to be overridden by subclass, final transformation
|
||||
* \param esum the sum statistics returned by EvalRow
|
||||
* \param wsum sum of weight
|
||||
*/
|
||||
inline static float GetFinal(float esum, float wsum) {
|
||||
return esum / wsum;
|
||||
}
|
||||
inline float EvalDispersion(float label, float pred) const;
|
||||
};
|
||||
|
||||
struct EvalGammaNegLogLik : public EvalEWiseBase2<EvalGammaNegLogLik> {
|
||||
const char *Name() const override {
|
||||
return "gamma-nloglik";
|
||||
}
|
||||
inline float EvalRow(float y, float py, float psi) const {
|
||||
double theta = -1. / py;
|
||||
double a = psi;
|
||||
double b = -std::log(-theta);
|
||||
double c = 1. / psi * std::log(y/psi) - std::log(y) - common::LogGamma(1. / psi);
|
||||
return -((y * theta - b) / a + c);
|
||||
}
|
||||
inline float EvalDispersion(float y, float py) const {
|
||||
return ((y - py) * (y - py)) / (py * py);
|
||||
}
|
||||
};
|
||||
|
||||
XGBOOST_REGISTER_METRIC(GammaNegLoglik, "gamma-nloglik")
|
||||
.describe("Negative loglikelihood for gamma regression.")
|
||||
.set_body([](const char* param) { return new EvalGammaNegLogLik(); });
|
||||
|
||||
} // namespace metric
|
||||
} // namespace xgboost
|
||||
|
||||
@ -217,5 +217,60 @@ DMLC_REGISTER_PARAMETER(PoissonRegressionParam);
|
||||
XGBOOST_REGISTER_OBJECTIVE(PoissonRegression, "count:poisson")
|
||||
.describe("Possion regression for count data.")
|
||||
.set_body([]() { return new PoissonRegression(); });
|
||||
|
||||
// gamma regression
|
||||
class GammaRegression : public ObjFunction {
|
||||
public:
|
||||
// declare functions
|
||||
void Configure(const std::vector<std::pair<std::string, std::string> >& args) override {
|
||||
}
|
||||
|
||||
void GetGradient(const std::vector<float> &preds,
|
||||
const MetaInfo &info,
|
||||
int iter,
|
||||
std::vector<bst_gpair> *out_gpair) override {
|
||||
CHECK_NE(info.labels.size(), 0) << "label set cannot be empty";
|
||||
CHECK_EQ(preds.size(), info.labels.size()) << "labels are not correctly provided";
|
||||
out_gpair->resize(preds.size());
|
||||
// check if label in range
|
||||
bool label_correct = true;
|
||||
// start calculating gradient
|
||||
const omp_ulong ndata = static_cast<omp_ulong>(preds.size()); // NOLINT(*)
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (omp_ulong i = 0; i < ndata; ++i) { // NOLINT(*)
|
||||
float p = preds[i];
|
||||
float w = info.GetWeight(i);
|
||||
float y = info.labels[i];
|
||||
if (y >= 0.0f) {
|
||||
out_gpair->at(i) = bst_gpair((1 - y / std::exp(p)) * w, y / std::exp(p) * w);
|
||||
} else {
|
||||
label_correct = false;
|
||||
}
|
||||
}
|
||||
CHECK(label_correct) << "GammaRegression: label must be positive";
|
||||
}
|
||||
void PredTransform(std::vector<float> *io_preds) override {
|
||||
std::vector<float> &preds = *io_preds;
|
||||
const long ndata = static_cast<long>(preds.size()); // NOLINT(*)
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (long j = 0; j < ndata; ++j) { // NOLINT(*)
|
||||
preds[j] = std::exp(preds[j]);
|
||||
}
|
||||
}
|
||||
void EvalTransform(std::vector<float> *io_preds) override {
|
||||
PredTransform(io_preds);
|
||||
}
|
||||
float ProbToMargin(float base_score) const override {
|
||||
return std::log(base_score);
|
||||
}
|
||||
const char* DefaultEvalMetric(void) const override {
|
||||
return "gamma-nloglik";
|
||||
}
|
||||
};
|
||||
|
||||
// register the ojective functions
|
||||
XGBOOST_REGISTER_OBJECTIVE(GammaRegression, "reg:gamma")
|
||||
.describe("Gamma regression for severity data.")
|
||||
.set_body([]() { return new GammaRegression(); });
|
||||
} // namespace obj
|
||||
} // namespace xgboost
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user