Extract fit intercept. (#8793)

This commit is contained in:
Jiaming Yuan 2023-02-15 22:41:31 +08:00 committed by GitHub
parent 594371e35b
commit c7c485d052
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 69 additions and 42 deletions

View File

@ -36,6 +36,7 @@ OBJECTS= \
$(PKGROOT)/src/objective/hinge.o \ $(PKGROOT)/src/objective/hinge.o \
$(PKGROOT)/src/objective/aft_obj.o \ $(PKGROOT)/src/objective/aft_obj.o \
$(PKGROOT)/src/objective/adaptive.o \ $(PKGROOT)/src/objective/adaptive.o \
$(PKGROOT)/src/objective/init_estimation.o \
$(PKGROOT)/src/gbm/gbm.o \ $(PKGROOT)/src/gbm/gbm.o \
$(PKGROOT)/src/gbm/gbtree.o \ $(PKGROOT)/src/gbm/gbtree.o \
$(PKGROOT)/src/gbm/gbtree_model.o \ $(PKGROOT)/src/gbm/gbtree_model.o \

View File

@ -36,6 +36,7 @@ OBJECTS= \
$(PKGROOT)/src/objective/hinge.o \ $(PKGROOT)/src/objective/hinge.o \
$(PKGROOT)/src/objective/aft_obj.o \ $(PKGROOT)/src/objective/aft_obj.o \
$(PKGROOT)/src/objective/adaptive.o \ $(PKGROOT)/src/objective/adaptive.o \
$(PKGROOT)/src/objective/init_estimation.o \
$(PKGROOT)/src/gbm/gbm.o \ $(PKGROOT)/src/gbm/gbm.o \
$(PKGROOT)/src/gbm/gbtree.o \ $(PKGROOT)/src/gbm/gbtree.o \
$(PKGROOT)/src/gbm/gbtree_model.o \ $(PKGROOT)/src/gbm/gbtree_model.o \

View File

@ -0,0 +1,39 @@
#include "init_estimation.h"
#include "../common/stats.h" // Mean
#include "../tree/fit_stump.h" // FitStump
#include "xgboost/base.h" // GradientPair
#include "xgboost/data.h" // MetaInfo
#include "xgboost/host_device_vector.h" // HostDeviceVector
#include "xgboost/json.h" // Json
#include "xgboost/linalg.h" // Tensor,Vector
#include "xgboost/task.h" // ObjInfo
namespace xgboost {
namespace obj {
void FitIntercept::InitEstimation(MetaInfo const& info, linalg::Vector<float>* base_score) const {
if (this->Task().task == ObjInfo::kRegression) {
CheckInitInputs(info);
}
// Avoid altering any state in child objective.
HostDeviceVector<float> dummy_predt(info.labels.Size(), 0.0f, this->ctx_->gpu_id);
HostDeviceVector<GradientPair> gpair(info.labels.Size(), GradientPair{}, this->ctx_->gpu_id);
Json config{Object{}};
this->SaveConfig(&config);
std::unique_ptr<ObjFunction> new_obj{
ObjFunction::Create(get<String const>(config["name"]), this->ctx_)};
new_obj->LoadConfig(config);
new_obj->GetGradient(dummy_predt, info, 0, &gpair);
bst_target_t n_targets = this->Targets(info);
linalg::Vector<float> leaf_weight;
tree::FitStump(this->ctx_, gpair, n_targets, &leaf_weight);
// workaround, we don't support multi-target due to binary model serialization for
// base margin.
common::Mean(this->ctx_, leaf_weight, base_score);
this->PredTransform(base_score->Data());
}
} // namespace obj
} // namespace xgboost

View File

@ -0,0 +1,19 @@
#include "xgboost/data.h" // MetaInfo
#include "xgboost/linalg.h" // Tensor
#include "xgboost/objective.h" // ObjFunction
namespace xgboost {
namespace obj {
class FitIntercept : public ObjFunction {
void InitEstimation(MetaInfo const& info, linalg::Vector<float>* base_score) const override;
};
inline void CheckInitInputs(MetaInfo const& info) {
CHECK_EQ(info.labels.Shape(0), info.num_row_) << "Invalid shape of labels.";
if (!info.weights_.Empty()) {
CHECK_EQ(info.weights_.Size(), info.num_row_)
<< "Number of weights should be equal to number of data points.";
}
}
} // namespace obj
} // namespace xgboost

View File

@ -20,12 +20,12 @@
#include "../common/stats.h" #include "../common/stats.h"
#include "../common/threading_utils.h" #include "../common/threading_utils.h"
#include "../common/transform.h" #include "../common/transform.h"
#include "../tree/fit_stump.h" // FitStump
#include "./regression_loss.h" #include "./regression_loss.h"
#include "adaptive.h" #include "adaptive.h"
#include "init_estimation.h" // FitIntercept
#include "xgboost/base.h" #include "xgboost/base.h"
#include "xgboost/context.h" #include "xgboost/context.h" // Context
#include "xgboost/data.h" // MetaInfo #include "xgboost/data.h" // MetaInfo
#include "xgboost/host_device_vector.h" #include "xgboost/host_device_vector.h"
#include "xgboost/json.h" #include "xgboost/json.h"
#include "xgboost/linalg.h" #include "xgboost/linalg.h"
@ -43,45 +43,12 @@
namespace xgboost { namespace xgboost {
namespace obj { namespace obj {
namespace { namespace {
void CheckInitInputs(MetaInfo const& info) {
CHECK_EQ(info.labels.Shape(0), info.num_row_) << "Invalid shape of labels.";
if (!info.weights_.Empty()) {
CHECK_EQ(info.weights_.Size(), info.num_row_)
<< "Number of weights should be equal to number of data points.";
}
}
void CheckRegInputs(MetaInfo const& info, HostDeviceVector<bst_float> const& preds) { void CheckRegInputs(MetaInfo const& info, HostDeviceVector<bst_float> const& preds) {
CheckInitInputs(info); CheckInitInputs(info);
CHECK_EQ(info.labels.Size(), preds.Size()) << "Invalid shape of labels."; CHECK_EQ(info.labels.Size(), preds.Size()) << "Invalid shape of labels.";
} }
} // anonymous namespace } // anonymous namespace
class RegInitEstimation : public ObjFunction {
void InitEstimation(MetaInfo const& info, linalg::Tensor<float, 1>* base_score) const override {
CheckInitInputs(info);
// Avoid altering any state in child objective.
HostDeviceVector<float> dummy_predt(info.labels.Size(), 0.0f, this->ctx_->gpu_id);
HostDeviceVector<GradientPair> gpair(info.labels.Size(), GradientPair{}, this->ctx_->gpu_id);
Json config{Object{}};
this->SaveConfig(&config);
std::unique_ptr<ObjFunction> new_obj{
ObjFunction::Create(get<String const>(config["name"]), this->ctx_)};
new_obj->LoadConfig(config);
new_obj->GetGradient(dummy_predt, info, 0, &gpair);
bst_target_t n_targets = this->Targets(info);
linalg::Vector<float> leaf_weight;
tree::FitStump(this->ctx_, gpair, n_targets, &leaf_weight);
// workaround, we don't support multi-target due to binary model serialization for
// base margin.
common::Mean(this->ctx_, leaf_weight, base_score);
this->PredTransform(base_score->Data());
}
};
#if defined(XGBOOST_USE_CUDA) #if defined(XGBOOST_USE_CUDA)
DMLC_REGISTRY_FILE_TAG(regression_obj_gpu); DMLC_REGISTRY_FILE_TAG(regression_obj_gpu);
#endif // defined(XGBOOST_USE_CUDA) #endif // defined(XGBOOST_USE_CUDA)
@ -96,7 +63,7 @@ struct RegLossParam : public XGBoostParameter<RegLossParam> {
}; };
template<typename Loss> template<typename Loss>
class RegLossObj : public RegInitEstimation { class RegLossObj : public FitIntercept {
protected: protected:
HostDeviceVector<float> additional_input_; HostDeviceVector<float> additional_input_;
@ -243,7 +210,7 @@ XGBOOST_REGISTER_OBJECTIVE(LinearRegression, "reg:linear")
return new RegLossObj<LinearSquareLoss>(); }); return new RegLossObj<LinearSquareLoss>(); });
// End deprecated // End deprecated
class PseudoHuberRegression : public RegInitEstimation { class PseudoHuberRegression : public FitIntercept {
PesudoHuberParam param_; PesudoHuberParam param_;
public: public:
@ -318,7 +285,7 @@ struct PoissonRegressionParam : public XGBoostParameter<PoissonRegressionParam>
}; };
// poisson regression for count // poisson regression for count
class PoissonRegression : public RegInitEstimation { class PoissonRegression : public FitIntercept {
public: public:
// declare functions // declare functions
void Configure(const std::vector<std::pair<std::string, std::string> >& args) override { void Configure(const std::vector<std::pair<std::string, std::string> >& args) override {
@ -413,7 +380,7 @@ XGBOOST_REGISTER_OBJECTIVE(PoissonRegression, "count:poisson")
// cox regression for survival data (negative values mean they are censored) // cox regression for survival data (negative values mean they are censored)
class CoxRegression : public RegInitEstimation { class CoxRegression : public FitIntercept {
public: public:
void Configure(Args const&) override {} void Configure(Args const&) override {}
ObjInfo Task() const override { return ObjInfo::kRegression; } ObjInfo Task() const override { return ObjInfo::kRegression; }
@ -510,7 +477,7 @@ XGBOOST_REGISTER_OBJECTIVE(CoxRegression, "survival:cox")
.set_body([]() { return new CoxRegression(); }); .set_body([]() { return new CoxRegression(); });
// gamma regression // gamma regression
class GammaRegression : public RegInitEstimation { class GammaRegression : public FitIntercept {
public: public:
void Configure(Args const&) override {} void Configure(Args const&) override {}
ObjInfo Task() const override { return ObjInfo::kRegression; } ObjInfo Task() const override { return ObjInfo::kRegression; }
@ -601,7 +568,7 @@ struct TweedieRegressionParam : public XGBoostParameter<TweedieRegressionParam>
}; };
// tweedie regression // tweedie regression
class TweedieRegression : public RegInitEstimation { class TweedieRegression : public FitIntercept {
public: public:
// declare functions // declare functions
void Configure(const std::vector<std::pair<std::string, std::string> >& args) override { void Configure(const std::vector<std::pair<std::string, std::string> >& args) override {