Init estimation for regression. (#8272)
This commit is contained in:
@@ -119,7 +119,7 @@ class RabitCommunicator : public Communicator {
|
||||
}
|
||||
|
||||
template <typename DType, std::enable_if_t<std::is_floating_point<DType>::value> * = nullptr>
|
||||
void DoBitwiseAllReduce(void *send_receive_buffer, std::size_t count, Operation op) {
|
||||
void DoBitwiseAllReduce(void *, std::size_t, Operation) {
|
||||
LOG(FATAL) << "Floating point types do not support bitwise operations.";
|
||||
}
|
||||
|
||||
|
||||
@@ -684,7 +684,7 @@ void MetaInfo::Extend(MetaInfo const& that, bool accumulate_rows, bool check_col
|
||||
}
|
||||
}
|
||||
|
||||
void MetaInfo::Validate(int32_t device) const {
|
||||
void MetaInfo::Validate(std::int32_t device) const {
|
||||
if (group_ptr_.size() != 0 && weights_.Size() != 0) {
|
||||
CHECK_EQ(group_ptr_.size(), weights_.Size() + 1)
|
||||
<< "Size of weights must equal to number of groups when ranking "
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
#include "../common/hist_util.h"
|
||||
#include "../common/numeric.h"
|
||||
#include "../common/threading_utils.h"
|
||||
#include "../common/transform_iterator.h" // MakeIndexTransformIter
|
||||
#include "../common/transform_iterator.h" // common::MakeIndexTransformIter
|
||||
#include "adapter.h"
|
||||
#include "proxy_dmatrix.h"
|
||||
#include "xgboost/base.h"
|
||||
|
||||
@@ -190,6 +190,32 @@ struct LearnerModelParamLegacy : public dmlc::Parameter<LearnerModelParamLegacy>
|
||||
}
|
||||
return dmlc::Parameter<LearnerModelParamLegacy>::UpdateAllowUnknown(kwargs);
|
||||
}
|
||||
// sanity check
|
||||
void Validate() {
|
||||
if (!collective::IsDistributed()) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::array<std::int32_t, 6> data;
|
||||
std::size_t pos{0};
|
||||
std::memcpy(data.data() + pos, &base_score, sizeof(base_score));
|
||||
pos += 1;
|
||||
std::memcpy(data.data() + pos, &num_feature, sizeof(num_feature));
|
||||
pos += 1;
|
||||
std::memcpy(data.data() + pos, &num_class, sizeof(num_class));
|
||||
pos += 1;
|
||||
std::memcpy(data.data() + pos, &num_target, sizeof(num_target));
|
||||
pos += 1;
|
||||
std::memcpy(data.data() + pos, &major_version, sizeof(major_version));
|
||||
pos += 1;
|
||||
std::memcpy(data.data() + pos, &minor_version, sizeof(minor_version));
|
||||
|
||||
std::array<std::int32_t, 6> sync;
|
||||
std::copy(data.cbegin(), data.cend(), sync.begin());
|
||||
collective::Broadcast(sync.data(), sync.size(), 0);
|
||||
CHECK(std::equal(data.cbegin(), data.cend(), sync.cbegin()))
|
||||
<< "Different model parameter across workers.";
|
||||
}
|
||||
|
||||
// declare parameters
|
||||
DMLC_DECLARE_PARAMETER(LearnerModelParamLegacy) {
|
||||
@@ -391,6 +417,7 @@ class LearnerConfiguration : public Learner {
|
||||
}
|
||||
// Update the shared model parameter
|
||||
this->ConfigureModelParamWithoutBaseScore();
|
||||
mparam_.Validate();
|
||||
}
|
||||
CHECK(!std::isnan(mparam_.base_score));
|
||||
CHECK(!std::isinf(mparam_.base_score));
|
||||
|
||||
@@ -20,6 +20,7 @@
|
||||
#include "../common/stats.h"
|
||||
#include "../common/threading_utils.h"
|
||||
#include "../common/transform.h"
|
||||
#include "../tree/fit_stump.h" // FitStump
|
||||
#include "./regression_loss.h"
|
||||
#include "adaptive.h"
|
||||
#include "xgboost/base.h"
|
||||
@@ -53,6 +54,31 @@ void CheckRegInputs(MetaInfo const& info, HostDeviceVector<bst_float> const& pre
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
class RegInitEstimation : public ObjFunction {
|
||||
void InitEstimation(MetaInfo const& info, linalg::Tensor<float, 1>* base_score) const override {
|
||||
CheckInitInputs(info);
|
||||
// Avoid altering any state in child objective.
|
||||
HostDeviceVector<float> dummy_predt(info.labels.Size(), 0.0f, this->ctx_->gpu_id);
|
||||
HostDeviceVector<GradientPair> gpair(info.labels.Size(), GradientPair{}, this->ctx_->gpu_id);
|
||||
|
||||
Json config{Object{}};
|
||||
this->SaveConfig(&config);
|
||||
|
||||
std::unique_ptr<ObjFunction> new_obj{
|
||||
ObjFunction::Create(get<String const>(config["name"]), this->ctx_)};
|
||||
new_obj->LoadConfig(config);
|
||||
new_obj->GetGradient(dummy_predt, info, 0, &gpair);
|
||||
bst_target_t n_targets = this->Targets(info);
|
||||
linalg::Vector<float> leaf_weight;
|
||||
tree::FitStump(this->ctx_, gpair, n_targets, &leaf_weight);
|
||||
|
||||
// workaround, we don't support multi-target due to binary model serialization for
|
||||
// base margin.
|
||||
common::Mean(this->ctx_, leaf_weight, base_score);
|
||||
this->PredTransform(base_score->Data());
|
||||
}
|
||||
};
|
||||
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
DMLC_REGISTRY_FILE_TAG(regression_obj_gpu);
|
||||
#endif // defined(XGBOOST_USE_CUDA)
|
||||
@@ -67,7 +93,7 @@ struct RegLossParam : public XGBoostParameter<RegLossParam> {
|
||||
};
|
||||
|
||||
template<typename Loss>
|
||||
class RegLossObj : public ObjFunction {
|
||||
class RegLossObj : public RegInitEstimation {
|
||||
protected:
|
||||
HostDeviceVector<float> additional_input_;
|
||||
|
||||
@@ -214,7 +240,7 @@ XGBOOST_REGISTER_OBJECTIVE(LinearRegression, "reg:linear")
|
||||
return new RegLossObj<LinearSquareLoss>(); });
|
||||
// End deprecated
|
||||
|
||||
class PseudoHuberRegression : public ObjFunction {
|
||||
class PseudoHuberRegression : public RegInitEstimation {
|
||||
PesudoHuberParam param_;
|
||||
|
||||
public:
|
||||
@@ -289,7 +315,7 @@ struct PoissonRegressionParam : public XGBoostParameter<PoissonRegressionParam>
|
||||
};
|
||||
|
||||
// poisson regression for count
|
||||
class PoissonRegression : public ObjFunction {
|
||||
class PoissonRegression : public RegInitEstimation {
|
||||
public:
|
||||
// declare functions
|
||||
void Configure(const std::vector<std::pair<std::string, std::string> >& args) override {
|
||||
@@ -384,7 +410,7 @@ XGBOOST_REGISTER_OBJECTIVE(PoissonRegression, "count:poisson")
|
||||
|
||||
|
||||
// cox regression for survival data (negative values mean they are censored)
|
||||
class CoxRegression : public ObjFunction {
|
||||
class CoxRegression : public RegInitEstimation {
|
||||
public:
|
||||
void Configure(Args const&) override {}
|
||||
ObjInfo Task() const override { return ObjInfo::kRegression; }
|
||||
@@ -481,7 +507,7 @@ XGBOOST_REGISTER_OBJECTIVE(CoxRegression, "survival:cox")
|
||||
.set_body([]() { return new CoxRegression(); });
|
||||
|
||||
// gamma regression
|
||||
class GammaRegression : public ObjFunction {
|
||||
class GammaRegression : public RegInitEstimation {
|
||||
public:
|
||||
void Configure(Args const&) override {}
|
||||
ObjInfo Task() const override { return ObjInfo::kRegression; }
|
||||
@@ -572,7 +598,7 @@ struct TweedieRegressionParam : public XGBoostParameter<TweedieRegressionParam>
|
||||
};
|
||||
|
||||
// tweedie regression
|
||||
class TweedieRegression : public ObjFunction {
|
||||
class TweedieRegression : public RegInitEstimation {
|
||||
public:
|
||||
// declare functions
|
||||
void Configure(const std::vector<std::pair<std::string, std::string> >& args) override {
|
||||
|
||||
Reference in New Issue
Block a user