Model IO in JSON. (#5110)
This commit is contained in:
@@ -266,14 +266,61 @@ class LearnerImpl : public Learner {
|
||||
}
|
||||
}
|
||||
|
||||
void LoadModel(dmlc::Stream* fi) override {
|
||||
// They are the same right now until we can split up the saved parameter from model.
|
||||
this->Load(fi);
|
||||
void LoadModel(Json const& in) override {
|
||||
CHECK(IsA<Object>(in));
|
||||
Version::Load(in, false);
|
||||
auto const& learner = get<Object>(in["Learner"]);
|
||||
mparam_.FromJson(learner.at("learner_model_param"));
|
||||
|
||||
auto const& objective_fn = learner.at("objective");
|
||||
|
||||
std::string name = get<String>(objective_fn["name"]);
|
||||
tparam_.UpdateAllowUnknown(Args{{"objective", name}});
|
||||
obj_.reset(ObjFunction::Create(name, &generic_parameters_));
|
||||
obj_->LoadConfig(objective_fn);
|
||||
|
||||
auto const& gradient_booster = learner.at("gradient_booster");
|
||||
name = get<String>(gradient_booster["name"]);
|
||||
tparam_.UpdateAllowUnknown(Args{{"booster", name}});
|
||||
gbm_.reset(GradientBooster::Create(tparam_.booster,
|
||||
&generic_parameters_, &learner_model_param_,
|
||||
cache_));
|
||||
gbm_->LoadModel(gradient_booster);
|
||||
|
||||
learner_model_param_ = LearnerModelParam(mparam_,
|
||||
obj_->ProbToMargin(mparam_.base_score));
|
||||
|
||||
auto const& j_attributes = get<Object const>(learner.at("attributes"));
|
||||
attributes_.clear();
|
||||
for (auto const& kv : j_attributes) {
|
||||
attributes_[kv.first] = get<String const>(kv.second);
|
||||
}
|
||||
|
||||
this->need_configuration_ = true;
|
||||
}
|
||||
|
||||
void SaveModel(dmlc::Stream* fo) const override {
|
||||
// They are the same right now until we can split up the saved parameter from model.
|
||||
this->Save(fo);
|
||||
void SaveModel(Json* p_out) const override {
|
||||
CHECK(!this->need_configuration_) << "Call Configure before saving model.";
|
||||
|
||||
Version::Save(p_out);
|
||||
Json& out { *p_out };
|
||||
|
||||
out["Learner"] = Object();
|
||||
auto& learner = out["Learner"];
|
||||
|
||||
learner["learner_model_param"] = mparam_.ToJson();
|
||||
learner["gradient_booster"] = Object();
|
||||
auto& gradient_booster = learner["gradient_booster"];
|
||||
gbm_->SaveModel(&gradient_booster);
|
||||
|
||||
learner["objective"] = Object();
|
||||
auto& objective_fn = learner["objective"];
|
||||
obj_->SaveConfig(&objective_fn);
|
||||
|
||||
learner["attributes"] = Object();
|
||||
for (auto const& kv : attributes_) {
|
||||
learner["attributes"][kv.first] = String(kv.second);
|
||||
}
|
||||
}
|
||||
|
||||
void Load(dmlc::Stream* fi) override {
|
||||
@@ -747,7 +794,6 @@ class LearnerImpl : public Learner {
|
||||
LearnerTrainParam tparam_;
|
||||
// configurations
|
||||
std::map<std::string, std::string> cfg_;
|
||||
// FIXME(trivialfis): Legacy field used to store extra attributes into binary model.
|
||||
std::map<std::string, std::string> attributes_;
|
||||
std::vector<std::string> metric_names_;
|
||||
static std::string const kEvalMetric; // NOLINT
|
||||
|
||||
Reference in New Issue
Block a user