From 8a5209c55eabf80264f3036227990b237f51513a Mon Sep 17 00:00:00 2001 From: Philip Hyunsu Cho Date: Fri, 27 Jul 2018 09:55:54 -0700 Subject: [PATCH] Fix model saving for 'count:possion': max_delta_step as Booster attribute (#3515) * Save max_delta_step as an extra attribute of Booster Fixes #3509 and #3026, where `max_delta_step` parameter gets lost during serialization. * fix lint * Use camel case for global constant * disable local variable case in clang-tidy --- .clang-tidy | 1 - src/learner.cc | 39 ++++++++++++++++++++++++++++++++------- 2 files changed, 32 insertions(+), 8 deletions(-) diff --git a/.clang-tidy b/.clang-tidy index 6000c08a6..959b0a438 100644 --- a/.clang-tidy +++ b/.clang-tidy @@ -5,7 +5,6 @@ CheckOptions: - { key: readability-identifier-naming.TypeAliasCase, value: CamelCase } - { key: readability-identifier-naming.TypedefCase, value: CamelCase } - { key: readability-identifier-naming.TypeTemplateParameterCase, value: CamelCase } - - { key: readability-identifier-naming.LocalVariableCase, value: lower_case } - { key: readability-identifier-naming.MemberCase, value: lower_case } - { key: readability-identifier-naming.PrivateMemberSuffix, value: '_' } - { key: readability-identifier-naming.ProtectedMemberSuffix, value: '_' } diff --git a/src/learner.cc b/src/learner.cc index abe7c6d89..57c361b6f 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -21,6 +21,11 @@ #include "./common/random.h" #include "common/timer.h" +namespace { + +const char* kMaxDeltaStepDefaultValue = "0.7"; + +} // anonymous namespace namespace xgboost { // implementation of base learner. @@ -222,7 +227,7 @@ class LearnerImpl : public Learner { if (cfg_.count("max_delta_step") == 0 && cfg_.count("objective") != 0 && cfg_["objective"] == "count:poisson") { - cfg_["max_delta_step"] = "0.7"; + cfg_["max_delta_step"] = kMaxDeltaStepDefaultValue; } ConfigureUpdaters(); @@ -321,21 +326,41 @@ class LearnerImpl : public Learner { // rabit save model to rabit checkpoint void Save(dmlc::Stream* fo) const override { - fo->Write(&mparam_, sizeof(LearnerModelParam)); + LearnerModelParam mparam = mparam_; // make a copy to potentially modify + std::vector > extra_attr; + // extra attributed to be added just before saving + + if (name_obj_ == "count:poisson") { + auto it = cfg_.find("max_delta_step"); + if (it != cfg_.end()) { + // write `max_delta_step` parameter as extra attribute of booster + mparam.contain_extra_attrs = 1; + extra_attr.emplace_back("count_poisson_max_delta_step", it->second); + } + } + fo->Write(&mparam, sizeof(LearnerModelParam)); fo->Write(name_obj_); fo->Write(name_gbm_); gbm_->Save(fo); - if (mparam_.contain_extra_attrs != 0) { + if (mparam.contain_extra_attrs != 0) { std::vector > attr( attributes_.begin(), attributes_.end()); + attr.insert(attr.end(), extra_attr.begin(), extra_attr.end()); fo->Write(attr); } if (name_obj_ == "count:poisson") { - auto it = - cfg_.find("max_delta_step"); - if (it != cfg_.end()) fo->Write(it->second); + auto it = cfg_.find("max_delta_step"); + if (it != cfg_.end()) { + fo->Write(it->second); + } else { + // recover value of max_delta_step from extra attributes + auto it2 = attributes_.find("count_poisson_max_delta_step"); + const std::string max_delta_step + = (it2 != attributes_.end()) ? it2->second : kMaxDeltaStepDefaultValue; + fo->Write(max_delta_step); + } } - if (mparam_.contain_eval_metrics != 0) { + if (mparam.contain_eval_metrics != 0) { std::vector metr; for (auto& ev : metrics_) { metr.emplace_back(ev->Name());