Config for linear updaters. (#5222)

This commit is contained in:
Jiaming Yuan
2020-01-25 11:26:46 +08:00
committed by GitHub
parent 40680368cf
commit 3eb1279bbf
14 changed files with 110 additions and 11 deletions

View File

@@ -7,6 +7,7 @@
#include "./param.h"
#include "../common/timer.h"
#include "coordinate_common.h"
#include "xgboost/json.h"
namespace xgboost {
namespace linear {
@@ -32,6 +33,18 @@ class CoordinateUpdater : public LinearUpdater {
selector_.reset(FeatureSelector::Create(tparam_.feature_selector));
monitor_.Init("CoordinateUpdater");
}
void LoadConfig(Json const& in) override {
auto const& config = get<Object const>(in);
fromJson(config.at("linear_train_param"), &tparam_);
fromJson(config.at("coordinate_param"), &cparam_);
}
void SaveConfig(Json* p_out) const override {
auto& out = *p_out;
out["linear_train_param"] = toJson(tparam_);
out["coordinate_param"] = toJson(cparam_);
}
void Update(HostDeviceVector<GradientPair> *in_gpair, DMatrix *p_fmat,
gbm::GBLinearModel *model, double sum_instance_weight) override {
tparam_.DenormalizePenalties(sum_instance_weight);

View File

@@ -37,10 +37,22 @@ class GPUCoordinateUpdater : public LinearUpdater { // NOLINT
// set training parameter
void Configure(Args const& args) override {
tparam_.UpdateAllowUnknown(args);
coord_param_.UpdateAllowUnknown(args);
selector_.reset(FeatureSelector::Create(tparam_.feature_selector));
monitor_.Init("GPUCoordinateUpdater");
}
void LoadConfig(Json const& in) override {
auto const& config = get<Object const>(in);
fromJson(config.at("linear_train_param"), &tparam_);
fromJson(config.at("coordinate_param"), &coord_param_);
}
void SaveConfig(Json* p_out) const override {
auto& out = *p_out;
out["linear_train_param"] = toJson(tparam_);
out["coordinate_param"] = toJson(coord_param_);
}
void LazyInitDevice(DMatrix *p_fmat, const LearnerModelParam &model_param) {
if (learner_param_->gpu_id < 0) return;

View File

@@ -23,6 +23,15 @@ class ShotgunUpdater : public LinearUpdater {
}
selector_.reset(FeatureSelector::Create(param_.feature_selector));
}
void LoadConfig(Json const& in) override {
auto const& config = get<Object const>(in);
fromJson(config.at("linear_train_param"), &param_);
}
void SaveConfig(Json* p_out) const override {
auto& out = *p_out;
out["linear_train_param"] = toJson(param_);
}
void Update(HostDeviceVector<GradientPair> *in_gpair, DMatrix *p_fmat,
gbm::GBLinearModel *model, double sum_instance_weight) override {
auto &gpair = in_gpair->HostVector();