Model IO in JSON. (#5110)

This commit is contained in:
Jiaming Yuan
2019-12-11 11:20:40 +08:00
committed by GitHub
parent c7cc657a4d
commit 208ab3b1ff
25 changed files with 667 additions and 165 deletions

View File

@@ -266,14 +266,61 @@ class LearnerImpl : public Learner {
}
}
void LoadModel(dmlc::Stream* fi) override {
// They are the same right now until we can split up the saved parameter from model.
this->Load(fi);
void LoadModel(Json const& in) override {
CHECK(IsA<Object>(in));
Version::Load(in, false);
auto const& learner = get<Object>(in["Learner"]);
mparam_.FromJson(learner.at("learner_model_param"));
auto const& objective_fn = learner.at("objective");
std::string name = get<String>(objective_fn["name"]);
tparam_.UpdateAllowUnknown(Args{{"objective", name}});
obj_.reset(ObjFunction::Create(name, &generic_parameters_));
obj_->LoadConfig(objective_fn);
auto const& gradient_booster = learner.at("gradient_booster");
name = get<String>(gradient_booster["name"]);
tparam_.UpdateAllowUnknown(Args{{"booster", name}});
gbm_.reset(GradientBooster::Create(tparam_.booster,
&generic_parameters_, &learner_model_param_,
cache_));
gbm_->LoadModel(gradient_booster);
learner_model_param_ = LearnerModelParam(mparam_,
obj_->ProbToMargin(mparam_.base_score));
auto const& j_attributes = get<Object const>(learner.at("attributes"));
attributes_.clear();
for (auto const& kv : j_attributes) {
attributes_[kv.first] = get<String const>(kv.second);
}
this->need_configuration_ = true;
}
void SaveModel(dmlc::Stream* fo) const override {
// They are the same right now until we can split up the saved parameter from model.
this->Save(fo);
void SaveModel(Json* p_out) const override {
CHECK(!this->need_configuration_) << "Call Configure before saving model.";
Version::Save(p_out);
Json& out { *p_out };
out["Learner"] = Object();
auto& learner = out["Learner"];
learner["learner_model_param"] = mparam_.ToJson();
learner["gradient_booster"] = Object();
auto& gradient_booster = learner["gradient_booster"];
gbm_->SaveModel(&gradient_booster);
learner["objective"] = Object();
auto& objective_fn = learner["objective"];
obj_->SaveConfig(&objective_fn);
learner["attributes"] = Object();
for (auto const& kv : attributes_) {
learner["attributes"][kv.first] = String(kv.second);
}
}
void Load(dmlc::Stream* fi) override {
@@ -747,7 +794,6 @@ class LearnerImpl : public Learner {
LearnerTrainParam tparam_;
// configurations
std::map<std::string, std::string> cfg_;
// FIXME(trivialfis): Legacy field used to store extra attributes into binary model.
std::map<std::string, std::string> attributes_;
std::vector<std::string> metric_names_;
static std::string const kEvalMetric; // NOLINT