/** * Copyright 2022-2023 by XGBoost contributors */ #include "init_estimation.h" #include // unique_ptr #include "../common/stats.h" // Mean #include "../tree/fit_stump.h" // FitStump #include "xgboost/base.h" // GradientPair #include "xgboost/data.h" // MetaInfo #include "xgboost/host_device_vector.h" // HostDeviceVector #include "xgboost/json.h" // Json #include "xgboost/linalg.h" // Tensor,Vector #include "xgboost/task.h" // ObjInfo namespace xgboost::obj { void FitIntercept::InitEstimation(MetaInfo const& info, linalg::Vector* base_score) const { if (this->Task().task == ObjInfo::kRegression) { CheckInitInputs(info); } // Avoid altering any state in child objective. HostDeviceVector dummy_predt(info.labels.Size(), 0.0f, this->ctx_->Device()); linalg::Matrix gpair(info.labels.Shape(), this->ctx_->Device()); Json config{Object{}}; this->SaveConfig(&config); std::unique_ptr new_obj{ ObjFunction::Create(get(config["name"]), this->ctx_)}; new_obj->LoadConfig(config); new_obj->GetGradient(dummy_predt, info, 0, &gpair); bst_target_t n_targets = this->Targets(info); linalg::Vector leaf_weight; tree::FitStump(this->ctx_, info, gpair, n_targets, &leaf_weight); // workaround, we don't support multi-target due to binary model serialization for // base margin. common::Mean(this->ctx_, leaf_weight, base_score); this->PredTransform(base_score->Data()); } } // namespace xgboost::obj