Init estimation for regression. (#8272)

This commit is contained in:
Jiaming Yuan
2023-01-11 02:04:56 +08:00
committed by GitHub
parent 1b58d81315
commit badeff1d74
29 changed files with 466 additions and 132 deletions

View File

@@ -20,6 +20,7 @@
#include "../common/stats.h"
#include "../common/threading_utils.h"
#include "../common/transform.h"
#include "../tree/fit_stump.h" // FitStump
#include "./regression_loss.h"
#include "adaptive.h"
#include "xgboost/base.h"
@@ -53,6 +54,31 @@ void CheckRegInputs(MetaInfo const& info, HostDeviceVector<bst_float> const& pre
}
} // anonymous namespace
class RegInitEstimation : public ObjFunction {
void InitEstimation(MetaInfo const& info, linalg::Tensor<float, 1>* base_score) const override {
CheckInitInputs(info);
// Avoid altering any state in child objective.
HostDeviceVector<float> dummy_predt(info.labels.Size(), 0.0f, this->ctx_->gpu_id);
HostDeviceVector<GradientPair> gpair(info.labels.Size(), GradientPair{}, this->ctx_->gpu_id);
Json config{Object{}};
this->SaveConfig(&config);
std::unique_ptr<ObjFunction> new_obj{
ObjFunction::Create(get<String const>(config["name"]), this->ctx_)};
new_obj->LoadConfig(config);
new_obj->GetGradient(dummy_predt, info, 0, &gpair);
bst_target_t n_targets = this->Targets(info);
linalg::Vector<float> leaf_weight;
tree::FitStump(this->ctx_, gpair, n_targets, &leaf_weight);
// workaround, we don't support multi-target due to binary model serialization for
// base margin.
common::Mean(this->ctx_, leaf_weight, base_score);
this->PredTransform(base_score->Data());
}
};
#if defined(XGBOOST_USE_CUDA)
DMLC_REGISTRY_FILE_TAG(regression_obj_gpu);
#endif // defined(XGBOOST_USE_CUDA)
@@ -67,7 +93,7 @@ struct RegLossParam : public XGBoostParameter<RegLossParam> {
};
template<typename Loss>
class RegLossObj : public ObjFunction {
class RegLossObj : public RegInitEstimation {
protected:
HostDeviceVector<float> additional_input_;
@@ -214,7 +240,7 @@ XGBOOST_REGISTER_OBJECTIVE(LinearRegression, "reg:linear")
return new RegLossObj<LinearSquareLoss>(); });
// End deprecated
class PseudoHuberRegression : public ObjFunction {
class PseudoHuberRegression : public RegInitEstimation {
PesudoHuberParam param_;
public:
@@ -289,7 +315,7 @@ struct PoissonRegressionParam : public XGBoostParameter<PoissonRegressionParam>
};
// poisson regression for count
class PoissonRegression : public ObjFunction {
class PoissonRegression : public RegInitEstimation {
public:
// declare functions
void Configure(const std::vector<std::pair<std::string, std::string> >& args) override {
@@ -384,7 +410,7 @@ XGBOOST_REGISTER_OBJECTIVE(PoissonRegression, "count:poisson")
// cox regression for survival data (negative values mean they are censored)
class CoxRegression : public ObjFunction {
class CoxRegression : public RegInitEstimation {
public:
void Configure(Args const&) override {}
ObjInfo Task() const override { return ObjInfo::kRegression; }
@@ -481,7 +507,7 @@ XGBOOST_REGISTER_OBJECTIVE(CoxRegression, "survival:cox")
.set_body([]() { return new CoxRegression(); });
// gamma regression
class GammaRegression : public ObjFunction {
class GammaRegression : public RegInitEstimation {
public:
void Configure(Args const&) override {}
ObjInfo Task() const override { return ObjInfo::kRegression; }
@@ -572,7 +598,7 @@ struct TweedieRegressionParam : public XGBoostParameter<TweedieRegressionParam>
};
// tweedie regression
class TweedieRegression : public ObjFunction {
class TweedieRegression : public RegInitEstimation {
public:
// declare functions
void Configure(const std::vector<std::pair<std::string, std::string> >& args) override {