Change default metric for gamma regression to deviance. (#9757)

* Change default metric for gamma regression to deviance.

- Cleanup the gamma implementation.
- Use deviance instead since the objective is derived from deviance.
This commit is contained in:
Jiaming Yuan 2023-11-22 21:17:48 +08:00 committed by GitHub
parent 0715ab3c10
commit 1877cb8e83
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 21 additions and 86 deletions

View File

@ -13,9 +13,7 @@
#include "xgboost/logging.h"
#include "xgboost/task.h" // ObjInfo
namespace xgboost {
namespace obj {
// common regressions
namespace xgboost::obj {
// linear regression
struct LinearSquareLoss {
XGBOOST_DEVICE static bst_float PredTransform(bst_float x) { return x; }
@ -106,7 +104,21 @@ struct LogisticRaw : public LogisticRegression {
static ObjInfo Info() { return ObjInfo::kRegression; }
};
} // namespace obj
} // namespace xgboost
// gamma deviance loss.
class GammaDeviance {
public:
XGBOOST_DEVICE static float PredTransform(float x) { return std::exp(x); }
XGBOOST_DEVICE static float ProbToMargin(float x) { return std::log(x); }
XGBOOST_DEVICE static float FirstOrderGradient(float p, float y) {
return 1.0f - y / p;
}
XGBOOST_DEVICE static float SecondOrderGradient(float p, float y) { return y / p; }
static ObjInfo Info() { return ObjInfo::kRegression; }
static const char* Name() { return "reg:gamma"; }
static const char* DefaultEvalMetric() { return "gamma-deviance"; }
XGBOOST_DEVICE static bool CheckLabel(float x) { return x > 0.0f; }
static const char* LabelErrorMsg() { return "label must be positive for gamma regression."; }
};
} // namespace xgboost::obj
#endif // XGBOOST_OBJECTIVE_REGRESSION_LOSS_H_

View File

@ -221,6 +221,10 @@ XGBOOST_REGISTER_OBJECTIVE(LogisticRaw, LogisticRaw::Name())
"before logistic transformation.")
.set_body([]() { return new RegLossObj<LogisticRaw>(); });
XGBOOST_REGISTER_OBJECTIVE(GammaRegression, GammaDeviance::Name())
.describe("Gamma regression using the gamma deviance loss with log link.")
.set_body([]() { return new RegLossObj<GammaDeviance>(); });
// Deprecated functions
XGBOOST_REGISTER_OBJECTIVE(LinearRegression, "reg:linear")
.describe("Regression with squared error.")
@ -501,87 +505,6 @@ XGBOOST_REGISTER_OBJECTIVE(CoxRegression, "survival:cox")
.describe("Cox regression for censored survival data (negative labels are considered censored).")
.set_body([]() { return new CoxRegression(); });
// gamma regression
class GammaRegression : public FitIntercept {
public:
void Configure(Args const&) override {}
[[nodiscard]] ObjInfo Task() const override { return ObjInfo::kRegression; }
void GetGradient(const HostDeviceVector<bst_float>& preds, const MetaInfo& info, std::int32_t,
linalg::Matrix<GradientPair>* out_gpair) override {
CHECK_NE(info.labels.Size(), 0U) << "label set cannot be empty";
CHECK_EQ(preds.Size(), info.labels.Size()) << "labels are not correctly provided";
const size_t ndata = preds.Size();
auto device = ctx_->Device();
out_gpair->SetDevice(ctx_->Device());
out_gpair->Reshape(info.num_row_, this->Targets(info));
label_correct_.Resize(1);
label_correct_.Fill(1);
const bool is_null_weight = info.weights_.Size() == 0;
if (!is_null_weight) {
CHECK_EQ(info.weights_.Size(), ndata)
<< "Number of weights should be equal to number of data points.";
}
common::Transform<>::Init(
[=] XGBOOST_DEVICE(size_t _idx,
common::Span<int> _label_correct,
common::Span<GradientPair> _out_gpair,
common::Span<const bst_float> _preds,
common::Span<const bst_float> _labels,
common::Span<const bst_float> _weights) {
bst_float p = _preds[_idx];
bst_float w = is_null_weight ? 1.0f : _weights[_idx];
bst_float y = _labels[_idx];
if (y <= 0.0f) {
_label_correct[0] = 0;
}
_out_gpair[_idx] = GradientPair((1 - y / expf(p)) * w, y / expf(p) * w);
},
common::Range{0, static_cast<int64_t>(ndata)}, this->ctx_->Threads(), device).Eval(
&label_correct_, out_gpair->Data(), &preds, info.labels.Data(), &info.weights_);
// copy "label correct" flags back to host
std::vector<int>& label_correct_h = label_correct_.HostVector();
for (auto const flag : label_correct_h) {
if (flag == 0) {
LOG(FATAL) << "GammaRegression: label must be positive.";
}
}
}
void PredTransform(HostDeviceVector<bst_float> *io_preds) const override {
common::Transform<>::Init(
[] XGBOOST_DEVICE(size_t _idx, common::Span<bst_float> _preds) {
_preds[_idx] = expf(_preds[_idx]);
},
common::Range{0, static_cast<int64_t>(io_preds->Size())}, this->ctx_->Threads(),
io_preds->Device())
.Eval(io_preds);
}
void EvalTransform(HostDeviceVector<bst_float> *io_preds) override {
PredTransform(io_preds);
}
[[nodiscard]] float ProbToMargin(bst_float base_score) const override {
return std::log(base_score);
}
[[nodiscard]] const char* DefaultEvalMetric() const override {
return "gamma-nloglik";
}
void SaveConfig(Json* p_out) const override {
auto& out = *p_out;
out["name"] = String("reg:gamma");
}
void LoadConfig(Json const&) override {}
private:
HostDeviceVector<int> label_correct_;
};
// register the objective functions
XGBOOST_REGISTER_OBJECTIVE(GammaRegression, "reg:gamma")
.describe("Gamma regression for severity data.")
.set_body([]() { return new GammaRegression(); });
// declare parameter
struct TweedieRegressionParam : public XGBoostParameter<TweedieRegressionParam> {