Extract fit intercept. (#8793)
This commit is contained in:
parent
594371e35b
commit
c7c485d052
@ -36,6 +36,7 @@ OBJECTS= \
|
|||||||
$(PKGROOT)/src/objective/hinge.o \
|
$(PKGROOT)/src/objective/hinge.o \
|
||||||
$(PKGROOT)/src/objective/aft_obj.o \
|
$(PKGROOT)/src/objective/aft_obj.o \
|
||||||
$(PKGROOT)/src/objective/adaptive.o \
|
$(PKGROOT)/src/objective/adaptive.o \
|
||||||
|
$(PKGROOT)/src/objective/init_estimation.o \
|
||||||
$(PKGROOT)/src/gbm/gbm.o \
|
$(PKGROOT)/src/gbm/gbm.o \
|
||||||
$(PKGROOT)/src/gbm/gbtree.o \
|
$(PKGROOT)/src/gbm/gbtree.o \
|
||||||
$(PKGROOT)/src/gbm/gbtree_model.o \
|
$(PKGROOT)/src/gbm/gbtree_model.o \
|
||||||
|
|||||||
@ -36,6 +36,7 @@ OBJECTS= \
|
|||||||
$(PKGROOT)/src/objective/hinge.o \
|
$(PKGROOT)/src/objective/hinge.o \
|
||||||
$(PKGROOT)/src/objective/aft_obj.o \
|
$(PKGROOT)/src/objective/aft_obj.o \
|
||||||
$(PKGROOT)/src/objective/adaptive.o \
|
$(PKGROOT)/src/objective/adaptive.o \
|
||||||
|
$(PKGROOT)/src/objective/init_estimation.o \
|
||||||
$(PKGROOT)/src/gbm/gbm.o \
|
$(PKGROOT)/src/gbm/gbm.o \
|
||||||
$(PKGROOT)/src/gbm/gbtree.o \
|
$(PKGROOT)/src/gbm/gbtree.o \
|
||||||
$(PKGROOT)/src/gbm/gbtree_model.o \
|
$(PKGROOT)/src/gbm/gbtree_model.o \
|
||||||
|
|||||||
39
src/objective/init_estimation.cc
Normal file
39
src/objective/init_estimation.cc
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
#include "init_estimation.h"
|
||||||
|
|
||||||
|
#include "../common/stats.h" // Mean
|
||||||
|
#include "../tree/fit_stump.h" // FitStump
|
||||||
|
#include "xgboost/base.h" // GradientPair
|
||||||
|
#include "xgboost/data.h" // MetaInfo
|
||||||
|
#include "xgboost/host_device_vector.h" // HostDeviceVector
|
||||||
|
#include "xgboost/json.h" // Json
|
||||||
|
#include "xgboost/linalg.h" // Tensor,Vector
|
||||||
|
#include "xgboost/task.h" // ObjInfo
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
namespace obj {
|
||||||
|
void FitIntercept::InitEstimation(MetaInfo const& info, linalg::Vector<float>* base_score) const {
|
||||||
|
if (this->Task().task == ObjInfo::kRegression) {
|
||||||
|
CheckInitInputs(info);
|
||||||
|
}
|
||||||
|
// Avoid altering any state in child objective.
|
||||||
|
HostDeviceVector<float> dummy_predt(info.labels.Size(), 0.0f, this->ctx_->gpu_id);
|
||||||
|
HostDeviceVector<GradientPair> gpair(info.labels.Size(), GradientPair{}, this->ctx_->gpu_id);
|
||||||
|
|
||||||
|
Json config{Object{}};
|
||||||
|
this->SaveConfig(&config);
|
||||||
|
|
||||||
|
std::unique_ptr<ObjFunction> new_obj{
|
||||||
|
ObjFunction::Create(get<String const>(config["name"]), this->ctx_)};
|
||||||
|
new_obj->LoadConfig(config);
|
||||||
|
new_obj->GetGradient(dummy_predt, info, 0, &gpair);
|
||||||
|
bst_target_t n_targets = this->Targets(info);
|
||||||
|
linalg::Vector<float> leaf_weight;
|
||||||
|
tree::FitStump(this->ctx_, gpair, n_targets, &leaf_weight);
|
||||||
|
|
||||||
|
// workaround, we don't support multi-target due to binary model serialization for
|
||||||
|
// base margin.
|
||||||
|
common::Mean(this->ctx_, leaf_weight, base_score);
|
||||||
|
this->PredTransform(base_score->Data());
|
||||||
|
}
|
||||||
|
} // namespace obj
|
||||||
|
} // namespace xgboost
|
||||||
19
src/objective/init_estimation.h
Normal file
19
src/objective/init_estimation.h
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
#include "xgboost/data.h" // MetaInfo
|
||||||
|
#include "xgboost/linalg.h" // Tensor
|
||||||
|
#include "xgboost/objective.h" // ObjFunction
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
namespace obj {
|
||||||
|
class FitIntercept : public ObjFunction {
|
||||||
|
void InitEstimation(MetaInfo const& info, linalg::Vector<float>* base_score) const override;
|
||||||
|
};
|
||||||
|
|
||||||
|
inline void CheckInitInputs(MetaInfo const& info) {
|
||||||
|
CHECK_EQ(info.labels.Shape(0), info.num_row_) << "Invalid shape of labels.";
|
||||||
|
if (!info.weights_.Empty()) {
|
||||||
|
CHECK_EQ(info.weights_.Size(), info.num_row_)
|
||||||
|
<< "Number of weights should be equal to number of data points.";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace obj
|
||||||
|
} // namespace xgboost
|
||||||
@ -20,11 +20,11 @@
|
|||||||
#include "../common/stats.h"
|
#include "../common/stats.h"
|
||||||
#include "../common/threading_utils.h"
|
#include "../common/threading_utils.h"
|
||||||
#include "../common/transform.h"
|
#include "../common/transform.h"
|
||||||
#include "../tree/fit_stump.h" // FitStump
|
|
||||||
#include "./regression_loss.h"
|
#include "./regression_loss.h"
|
||||||
#include "adaptive.h"
|
#include "adaptive.h"
|
||||||
|
#include "init_estimation.h" // FitIntercept
|
||||||
#include "xgboost/base.h"
|
#include "xgboost/base.h"
|
||||||
#include "xgboost/context.h"
|
#include "xgboost/context.h" // Context
|
||||||
#include "xgboost/data.h" // MetaInfo
|
#include "xgboost/data.h" // MetaInfo
|
||||||
#include "xgboost/host_device_vector.h"
|
#include "xgboost/host_device_vector.h"
|
||||||
#include "xgboost/json.h"
|
#include "xgboost/json.h"
|
||||||
@ -43,45 +43,12 @@
|
|||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace obj {
|
namespace obj {
|
||||||
namespace {
|
namespace {
|
||||||
void CheckInitInputs(MetaInfo const& info) {
|
|
||||||
CHECK_EQ(info.labels.Shape(0), info.num_row_) << "Invalid shape of labels.";
|
|
||||||
if (!info.weights_.Empty()) {
|
|
||||||
CHECK_EQ(info.weights_.Size(), info.num_row_)
|
|
||||||
<< "Number of weights should be equal to number of data points.";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void CheckRegInputs(MetaInfo const& info, HostDeviceVector<bst_float> const& preds) {
|
void CheckRegInputs(MetaInfo const& info, HostDeviceVector<bst_float> const& preds) {
|
||||||
CheckInitInputs(info);
|
CheckInitInputs(info);
|
||||||
CHECK_EQ(info.labels.Size(), preds.Size()) << "Invalid shape of labels.";
|
CHECK_EQ(info.labels.Size(), preds.Size()) << "Invalid shape of labels.";
|
||||||
}
|
}
|
||||||
} // anonymous namespace
|
} // anonymous namespace
|
||||||
|
|
||||||
class RegInitEstimation : public ObjFunction {
|
|
||||||
void InitEstimation(MetaInfo const& info, linalg::Tensor<float, 1>* base_score) const override {
|
|
||||||
CheckInitInputs(info);
|
|
||||||
// Avoid altering any state in child objective.
|
|
||||||
HostDeviceVector<float> dummy_predt(info.labels.Size(), 0.0f, this->ctx_->gpu_id);
|
|
||||||
HostDeviceVector<GradientPair> gpair(info.labels.Size(), GradientPair{}, this->ctx_->gpu_id);
|
|
||||||
|
|
||||||
Json config{Object{}};
|
|
||||||
this->SaveConfig(&config);
|
|
||||||
|
|
||||||
std::unique_ptr<ObjFunction> new_obj{
|
|
||||||
ObjFunction::Create(get<String const>(config["name"]), this->ctx_)};
|
|
||||||
new_obj->LoadConfig(config);
|
|
||||||
new_obj->GetGradient(dummy_predt, info, 0, &gpair);
|
|
||||||
bst_target_t n_targets = this->Targets(info);
|
|
||||||
linalg::Vector<float> leaf_weight;
|
|
||||||
tree::FitStump(this->ctx_, gpair, n_targets, &leaf_weight);
|
|
||||||
|
|
||||||
// workaround, we don't support multi-target due to binary model serialization for
|
|
||||||
// base margin.
|
|
||||||
common::Mean(this->ctx_, leaf_weight, base_score);
|
|
||||||
this->PredTransform(base_score->Data());
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
#if defined(XGBOOST_USE_CUDA)
|
#if defined(XGBOOST_USE_CUDA)
|
||||||
DMLC_REGISTRY_FILE_TAG(regression_obj_gpu);
|
DMLC_REGISTRY_FILE_TAG(regression_obj_gpu);
|
||||||
#endif // defined(XGBOOST_USE_CUDA)
|
#endif // defined(XGBOOST_USE_CUDA)
|
||||||
@ -96,7 +63,7 @@ struct RegLossParam : public XGBoostParameter<RegLossParam> {
|
|||||||
};
|
};
|
||||||
|
|
||||||
template<typename Loss>
|
template<typename Loss>
|
||||||
class RegLossObj : public RegInitEstimation {
|
class RegLossObj : public FitIntercept {
|
||||||
protected:
|
protected:
|
||||||
HostDeviceVector<float> additional_input_;
|
HostDeviceVector<float> additional_input_;
|
||||||
|
|
||||||
@ -243,7 +210,7 @@ XGBOOST_REGISTER_OBJECTIVE(LinearRegression, "reg:linear")
|
|||||||
return new RegLossObj<LinearSquareLoss>(); });
|
return new RegLossObj<LinearSquareLoss>(); });
|
||||||
// End deprecated
|
// End deprecated
|
||||||
|
|
||||||
class PseudoHuberRegression : public RegInitEstimation {
|
class PseudoHuberRegression : public FitIntercept {
|
||||||
PesudoHuberParam param_;
|
PesudoHuberParam param_;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
@ -318,7 +285,7 @@ struct PoissonRegressionParam : public XGBoostParameter<PoissonRegressionParam>
|
|||||||
};
|
};
|
||||||
|
|
||||||
// poisson regression for count
|
// poisson regression for count
|
||||||
class PoissonRegression : public RegInitEstimation {
|
class PoissonRegression : public FitIntercept {
|
||||||
public:
|
public:
|
||||||
// declare functions
|
// declare functions
|
||||||
void Configure(const std::vector<std::pair<std::string, std::string> >& args) override {
|
void Configure(const std::vector<std::pair<std::string, std::string> >& args) override {
|
||||||
@ -413,7 +380,7 @@ XGBOOST_REGISTER_OBJECTIVE(PoissonRegression, "count:poisson")
|
|||||||
|
|
||||||
|
|
||||||
// cox regression for survival data (negative values mean they are censored)
|
// cox regression for survival data (negative values mean they are censored)
|
||||||
class CoxRegression : public RegInitEstimation {
|
class CoxRegression : public FitIntercept {
|
||||||
public:
|
public:
|
||||||
void Configure(Args const&) override {}
|
void Configure(Args const&) override {}
|
||||||
ObjInfo Task() const override { return ObjInfo::kRegression; }
|
ObjInfo Task() const override { return ObjInfo::kRegression; }
|
||||||
@ -510,7 +477,7 @@ XGBOOST_REGISTER_OBJECTIVE(CoxRegression, "survival:cox")
|
|||||||
.set_body([]() { return new CoxRegression(); });
|
.set_body([]() { return new CoxRegression(); });
|
||||||
|
|
||||||
// gamma regression
|
// gamma regression
|
||||||
class GammaRegression : public RegInitEstimation {
|
class GammaRegression : public FitIntercept {
|
||||||
public:
|
public:
|
||||||
void Configure(Args const&) override {}
|
void Configure(Args const&) override {}
|
||||||
ObjInfo Task() const override { return ObjInfo::kRegression; }
|
ObjInfo Task() const override { return ObjInfo::kRegression; }
|
||||||
@ -601,7 +568,7 @@ struct TweedieRegressionParam : public XGBoostParameter<TweedieRegressionParam>
|
|||||||
};
|
};
|
||||||
|
|
||||||
// tweedie regression
|
// tweedie regression
|
||||||
class TweedieRegression : public RegInitEstimation {
|
class TweedieRegression : public FitIntercept {
|
||||||
public:
|
public:
|
||||||
// declare functions
|
// declare functions
|
||||||
void Configure(const std::vector<std::pair<std::string, std::string> >& args) override {
|
void Configure(const std::vector<std::pair<std::string, std::string> >& args) override {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user