Fix model saving for 'count:possion': max_delta_step as Booster attribute (#3515)

* Save max_delta_step as an extra attribute of Booster

Fixes #3509 and #3026, where `max_delta_step` parameter gets lost during serialization.

* fix lint

* Use camel case for global constant

* disable local variable case in clang-tidy
This commit is contained in:
Philip Hyunsu Cho 2018-07-27 09:55:54 -07:00 committed by GitHub
parent cc6a5a3666
commit 8a5209c55e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 32 additions and 8 deletions

View File

@ -5,7 +5,6 @@ CheckOptions:
- { key: readability-identifier-naming.TypeAliasCase, value: CamelCase } - { key: readability-identifier-naming.TypeAliasCase, value: CamelCase }
- { key: readability-identifier-naming.TypedefCase, value: CamelCase } - { key: readability-identifier-naming.TypedefCase, value: CamelCase }
- { key: readability-identifier-naming.TypeTemplateParameterCase, value: CamelCase } - { key: readability-identifier-naming.TypeTemplateParameterCase, value: CamelCase }
- { key: readability-identifier-naming.LocalVariableCase, value: lower_case }
- { key: readability-identifier-naming.MemberCase, value: lower_case } - { key: readability-identifier-naming.MemberCase, value: lower_case }
- { key: readability-identifier-naming.PrivateMemberSuffix, value: '_' } - { key: readability-identifier-naming.PrivateMemberSuffix, value: '_' }
- { key: readability-identifier-naming.ProtectedMemberSuffix, value: '_' } - { key: readability-identifier-naming.ProtectedMemberSuffix, value: '_' }

View File

@ -21,6 +21,11 @@
#include "./common/random.h" #include "./common/random.h"
#include "common/timer.h" #include "common/timer.h"
namespace {
const char* kMaxDeltaStepDefaultValue = "0.7";
} // anonymous namespace
namespace xgboost { namespace xgboost {
// implementation of base learner. // implementation of base learner.
@ -222,7 +227,7 @@ class LearnerImpl : public Learner {
if (cfg_.count("max_delta_step") == 0 && cfg_.count("objective") != 0 && if (cfg_.count("max_delta_step") == 0 && cfg_.count("objective") != 0 &&
cfg_["objective"] == "count:poisson") { cfg_["objective"] == "count:poisson") {
cfg_["max_delta_step"] = "0.7"; cfg_["max_delta_step"] = kMaxDeltaStepDefaultValue;
} }
ConfigureUpdaters(); ConfigureUpdaters();
@ -321,21 +326,41 @@ class LearnerImpl : public Learner {
// rabit save model to rabit checkpoint // rabit save model to rabit checkpoint
void Save(dmlc::Stream* fo) const override { void Save(dmlc::Stream* fo) const override {
fo->Write(&mparam_, sizeof(LearnerModelParam)); LearnerModelParam mparam = mparam_; // make a copy to potentially modify
std::vector<std::pair<std::string, std::string> > extra_attr;
// extra attributed to be added just before saving
if (name_obj_ == "count:poisson") {
auto it = cfg_.find("max_delta_step");
if (it != cfg_.end()) {
// write `max_delta_step` parameter as extra attribute of booster
mparam.contain_extra_attrs = 1;
extra_attr.emplace_back("count_poisson_max_delta_step", it->second);
}
}
fo->Write(&mparam, sizeof(LearnerModelParam));
fo->Write(name_obj_); fo->Write(name_obj_);
fo->Write(name_gbm_); fo->Write(name_gbm_);
gbm_->Save(fo); gbm_->Save(fo);
if (mparam_.contain_extra_attrs != 0) { if (mparam.contain_extra_attrs != 0) {
std::vector<std::pair<std::string, std::string> > attr( std::vector<std::pair<std::string, std::string> > attr(
attributes_.begin(), attributes_.end()); attributes_.begin(), attributes_.end());
attr.insert(attr.end(), extra_attr.begin(), extra_attr.end());
fo->Write(attr); fo->Write(attr);
} }
if (name_obj_ == "count:poisson") { if (name_obj_ == "count:poisson") {
auto it = auto it = cfg_.find("max_delta_step");
cfg_.find("max_delta_step"); if (it != cfg_.end()) {
if (it != cfg_.end()) fo->Write(it->second); fo->Write(it->second);
} else {
// recover value of max_delta_step from extra attributes
auto it2 = attributes_.find("count_poisson_max_delta_step");
const std::string max_delta_step
= (it2 != attributes_.end()) ? it2->second : kMaxDeltaStepDefaultValue;
fo->Write(max_delta_step);
}
} }
if (mparam_.contain_eval_metrics != 0) { if (mparam.contain_eval_metrics != 0) {
std::vector<std::string> metr; std::vector<std::string> metr;
for (auto& ev : metrics_) { for (auto& ev : metrics_) {
metr.emplace_back(ev->Name()); metr.emplace_back(ev->Name());