Fix model saving for 'count:possion': max_delta_step as Booster attribute (#3515)
* Save max_delta_step as an extra attribute of Booster Fixes #3509 and #3026, where `max_delta_step` parameter gets lost during serialization. * fix lint * Use camel case for global constant * disable local variable case in clang-tidy
This commit is contained in:
parent
cc6a5a3666
commit
8a5209c55e
@ -5,7 +5,6 @@ CheckOptions:
|
|||||||
- { key: readability-identifier-naming.TypeAliasCase, value: CamelCase }
|
- { key: readability-identifier-naming.TypeAliasCase, value: CamelCase }
|
||||||
- { key: readability-identifier-naming.TypedefCase, value: CamelCase }
|
- { key: readability-identifier-naming.TypedefCase, value: CamelCase }
|
||||||
- { key: readability-identifier-naming.TypeTemplateParameterCase, value: CamelCase }
|
- { key: readability-identifier-naming.TypeTemplateParameterCase, value: CamelCase }
|
||||||
- { key: readability-identifier-naming.LocalVariableCase, value: lower_case }
|
|
||||||
- { key: readability-identifier-naming.MemberCase, value: lower_case }
|
- { key: readability-identifier-naming.MemberCase, value: lower_case }
|
||||||
- { key: readability-identifier-naming.PrivateMemberSuffix, value: '_' }
|
- { key: readability-identifier-naming.PrivateMemberSuffix, value: '_' }
|
||||||
- { key: readability-identifier-naming.ProtectedMemberSuffix, value: '_' }
|
- { key: readability-identifier-naming.ProtectedMemberSuffix, value: '_' }
|
||||||
|
|||||||
@ -21,6 +21,11 @@
|
|||||||
#include "./common/random.h"
|
#include "./common/random.h"
|
||||||
#include "common/timer.h"
|
#include "common/timer.h"
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
const char* kMaxDeltaStepDefaultValue = "0.7";
|
||||||
|
|
||||||
|
} // anonymous namespace
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
// implementation of base learner.
|
// implementation of base learner.
|
||||||
@ -222,7 +227,7 @@ class LearnerImpl : public Learner {
|
|||||||
|
|
||||||
if (cfg_.count("max_delta_step") == 0 && cfg_.count("objective") != 0 &&
|
if (cfg_.count("max_delta_step") == 0 && cfg_.count("objective") != 0 &&
|
||||||
cfg_["objective"] == "count:poisson") {
|
cfg_["objective"] == "count:poisson") {
|
||||||
cfg_["max_delta_step"] = "0.7";
|
cfg_["max_delta_step"] = kMaxDeltaStepDefaultValue;
|
||||||
}
|
}
|
||||||
|
|
||||||
ConfigureUpdaters();
|
ConfigureUpdaters();
|
||||||
@ -321,21 +326,41 @@ class LearnerImpl : public Learner {
|
|||||||
|
|
||||||
// rabit save model to rabit checkpoint
|
// rabit save model to rabit checkpoint
|
||||||
void Save(dmlc::Stream* fo) const override {
|
void Save(dmlc::Stream* fo) const override {
|
||||||
fo->Write(&mparam_, sizeof(LearnerModelParam));
|
LearnerModelParam mparam = mparam_; // make a copy to potentially modify
|
||||||
|
std::vector<std::pair<std::string, std::string> > extra_attr;
|
||||||
|
// extra attributed to be added just before saving
|
||||||
|
|
||||||
|
if (name_obj_ == "count:poisson") {
|
||||||
|
auto it = cfg_.find("max_delta_step");
|
||||||
|
if (it != cfg_.end()) {
|
||||||
|
// write `max_delta_step` parameter as extra attribute of booster
|
||||||
|
mparam.contain_extra_attrs = 1;
|
||||||
|
extra_attr.emplace_back("count_poisson_max_delta_step", it->second);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fo->Write(&mparam, sizeof(LearnerModelParam));
|
||||||
fo->Write(name_obj_);
|
fo->Write(name_obj_);
|
||||||
fo->Write(name_gbm_);
|
fo->Write(name_gbm_);
|
||||||
gbm_->Save(fo);
|
gbm_->Save(fo);
|
||||||
if (mparam_.contain_extra_attrs != 0) {
|
if (mparam.contain_extra_attrs != 0) {
|
||||||
std::vector<std::pair<std::string, std::string> > attr(
|
std::vector<std::pair<std::string, std::string> > attr(
|
||||||
attributes_.begin(), attributes_.end());
|
attributes_.begin(), attributes_.end());
|
||||||
|
attr.insert(attr.end(), extra_attr.begin(), extra_attr.end());
|
||||||
fo->Write(attr);
|
fo->Write(attr);
|
||||||
}
|
}
|
||||||
if (name_obj_ == "count:poisson") {
|
if (name_obj_ == "count:poisson") {
|
||||||
auto it =
|
auto it = cfg_.find("max_delta_step");
|
||||||
cfg_.find("max_delta_step");
|
if (it != cfg_.end()) {
|
||||||
if (it != cfg_.end()) fo->Write(it->second);
|
fo->Write(it->second);
|
||||||
|
} else {
|
||||||
|
// recover value of max_delta_step from extra attributes
|
||||||
|
auto it2 = attributes_.find("count_poisson_max_delta_step");
|
||||||
|
const std::string max_delta_step
|
||||||
|
= (it2 != attributes_.end()) ? it2->second : kMaxDeltaStepDefaultValue;
|
||||||
|
fo->Write(max_delta_step);
|
||||||
}
|
}
|
||||||
if (mparam_.contain_eval_metrics != 0) {
|
}
|
||||||
|
if (mparam.contain_eval_metrics != 0) {
|
||||||
std::vector<std::string> metr;
|
std::vector<std::string> metr;
|
||||||
for (auto& ev : metrics_) {
|
for (auto& ev : metrics_) {
|
||||||
metr.emplace_back(ev->Name());
|
metr.emplace_back(ev->Name());
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user