add poisson regression
This commit is contained in:
@@ -114,7 +114,7 @@ struct LossType {
|
||||
};
|
||||
|
||||
/*! \brief objective function that only need to */
|
||||
class RegLossObj : public IObjFunction{
|
||||
class RegLossObj : public IObjFunction {
|
||||
public:
|
||||
explicit RegLossObj(int loss_type) {
|
||||
loss.loss_type = loss_type;
|
||||
@@ -173,6 +173,72 @@ class RegLossObj : public IObjFunction{
|
||||
LossType loss;
|
||||
};
|
||||
|
||||
// poisson regression for count
|
||||
class PoissonRegression : public IObjFunction {
|
||||
public:
|
||||
explicit PoissonRegression(void) {
|
||||
max_delta_step = 0.0f;
|
||||
}
|
||||
virtual ~PoissonRegression(void) {}
|
||||
|
||||
virtual void SetParam(const char *name, const char *val) {
|
||||
using namespace std;
|
||||
if (!strcmp( "max_delta_step", name )) {
|
||||
max_delta_step = static_cast<float>(atof(val));
|
||||
}
|
||||
}
|
||||
virtual void GetGradient(const std::vector<float> &preds,
|
||||
const MetaInfo &info,
|
||||
int iter,
|
||||
std::vector<bst_gpair> *out_gpair) {
|
||||
utils::Check(max_delta_step != 0.0f,
|
||||
"PoissonRegression: need to set max_delta_step");
|
||||
utils::Check(info.labels.size() != 0, "label set cannot be empty");
|
||||
utils::Check(preds.size() == info.labels.size(),
|
||||
"labels are not correctly provided");
|
||||
std::vector<bst_gpair> &gpair = *out_gpair;
|
||||
gpair.resize(preds.size());
|
||||
// check if label in range
|
||||
bool label_correct = true;
|
||||
// start calculating gradient
|
||||
const long ndata = static_cast<bst_omp_uint>(preds.size());
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (long i = 0; i < ndata; ++i) {
|
||||
float p = preds[i];
|
||||
float w = info.GetWeight(i);
|
||||
float y = info.labels[i];
|
||||
if (y >= 0.0f) {
|
||||
gpair[i] = bst_gpair((std::exp(p) - y) * w,
|
||||
std::exp(p + max_delta_step) * w);
|
||||
} else {
|
||||
label_correct = false;
|
||||
}
|
||||
}
|
||||
utils::Check(label_correct,
|
||||
"PoissonRegression: label must be nonnegative");
|
||||
}
|
||||
virtual void PredTransform(std::vector<float> *io_preds) {
|
||||
std::vector<float> &preds = *io_preds;
|
||||
const long ndata = static_cast<long>(preds.size());
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (long j = 0; j < ndata; ++j) {
|
||||
preds[j] = std::exp(preds[j]);
|
||||
}
|
||||
}
|
||||
virtual void EvalTransform(std::vector<float> *io_preds) {
|
||||
PredTransform(io_preds);
|
||||
}
|
||||
virtual float ProbToMargin(float base_score) const {
|
||||
return std::log(base_score);
|
||||
}
|
||||
virtual const char* DefaultEvalMetric(void) const {
|
||||
return "poisson-nloglik";
|
||||
}
|
||||
|
||||
private:
|
||||
float max_delta_step;
|
||||
};
|
||||
|
||||
// softmax multi-class classification
|
||||
class SoftmaxMultiClassObj : public IObjFunction {
|
||||
public:
|
||||
|
||||
Reference in New Issue
Block a user