Config for linear updaters. (#5222)
This commit is contained in:
@@ -102,11 +102,18 @@ class GBLinear : public GradientBooster {
|
||||
void LoadConfig(Json const& in) override {
|
||||
CHECK_EQ(get<String>(in["name"]), "gblinear");
|
||||
fromJson(in["gblinear_train_param"], ¶m_);
|
||||
updater_.reset(LinearUpdater::Create(param_.updater, generic_param_));
|
||||
this->updater_->LoadConfig(in["updater"]);
|
||||
}
|
||||
void SaveConfig(Json* p_out) const override {
|
||||
auto& out = *p_out;
|
||||
out["name"] = String{"gblinear"};
|
||||
out["gblinear_train_param"] = toJson(param_);
|
||||
|
||||
out["updater"] = Object();
|
||||
auto& j_updater = out["updater"];
|
||||
CHECK(this->updater_);
|
||||
this->updater_->SaveConfig(&j_updater);
|
||||
}
|
||||
|
||||
void DoBoost(DMatrix *p_fmat,
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#include "./param.h"
|
||||
#include "../common/timer.h"
|
||||
#include "coordinate_common.h"
|
||||
#include "xgboost/json.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace linear {
|
||||
@@ -32,6 +33,18 @@ class CoordinateUpdater : public LinearUpdater {
|
||||
selector_.reset(FeatureSelector::Create(tparam_.feature_selector));
|
||||
monitor_.Init("CoordinateUpdater");
|
||||
}
|
||||
|
||||
void LoadConfig(Json const& in) override {
|
||||
auto const& config = get<Object const>(in);
|
||||
fromJson(config.at("linear_train_param"), &tparam_);
|
||||
fromJson(config.at("coordinate_param"), &cparam_);
|
||||
}
|
||||
void SaveConfig(Json* p_out) const override {
|
||||
auto& out = *p_out;
|
||||
out["linear_train_param"] = toJson(tparam_);
|
||||
out["coordinate_param"] = toJson(cparam_);
|
||||
}
|
||||
|
||||
void Update(HostDeviceVector<GradientPair> *in_gpair, DMatrix *p_fmat,
|
||||
gbm::GBLinearModel *model, double sum_instance_weight) override {
|
||||
tparam_.DenormalizePenalties(sum_instance_weight);
|
||||
|
||||
@@ -37,10 +37,22 @@ class GPUCoordinateUpdater : public LinearUpdater { // NOLINT
|
||||
// set training parameter
|
||||
void Configure(Args const& args) override {
|
||||
tparam_.UpdateAllowUnknown(args);
|
||||
coord_param_.UpdateAllowUnknown(args);
|
||||
selector_.reset(FeatureSelector::Create(tparam_.feature_selector));
|
||||
monitor_.Init("GPUCoordinateUpdater");
|
||||
}
|
||||
|
||||
void LoadConfig(Json const& in) override {
|
||||
auto const& config = get<Object const>(in);
|
||||
fromJson(config.at("linear_train_param"), &tparam_);
|
||||
fromJson(config.at("coordinate_param"), &coord_param_);
|
||||
}
|
||||
void SaveConfig(Json* p_out) const override {
|
||||
auto& out = *p_out;
|
||||
out["linear_train_param"] = toJson(tparam_);
|
||||
out["coordinate_param"] = toJson(coord_param_);
|
||||
}
|
||||
|
||||
void LazyInitDevice(DMatrix *p_fmat, const LearnerModelParam &model_param) {
|
||||
if (learner_param_->gpu_id < 0) return;
|
||||
|
||||
|
||||
@@ -23,6 +23,15 @@ class ShotgunUpdater : public LinearUpdater {
|
||||
}
|
||||
selector_.reset(FeatureSelector::Create(param_.feature_selector));
|
||||
}
|
||||
void LoadConfig(Json const& in) override {
|
||||
auto const& config = get<Object const>(in);
|
||||
fromJson(config.at("linear_train_param"), ¶m_);
|
||||
}
|
||||
void SaveConfig(Json* p_out) const override {
|
||||
auto& out = *p_out;
|
||||
out["linear_train_param"] = toJson(param_);
|
||||
}
|
||||
|
||||
void Update(HostDeviceVector<GradientPair> *in_gpair, DMatrix *p_fmat,
|
||||
gbm::GBLinearModel *model, double sum_instance_weight) override {
|
||||
auto &gpair = in_gpair->HostVector();
|
||||
|
||||
Reference in New Issue
Block a user