add with pbuffer info to model, allow xgb model to be saved in a more memory compact way

This commit is contained in:
tqchen
2015-05-06 15:43:15 -07:00
parent 3b4697786e
commit 7f7947f31c
4 changed files with 34 additions and 20 deletions

View File

@@ -157,11 +157,9 @@ class BoostLearner : public rabit::Serializable {
/*!
* \brief load model from stream
* \param fi input stream
* \param with_pbuffer whether to load with predict buffer
* \param calc_num_feature whether call InitTrainer with calc_num_feature
*/
inline void LoadModel(utils::IStream &fi,
bool with_pbuffer = true,
bool calc_num_feature = true) {
utils::Check(fi.Read(&mparam, sizeof(ModelParam)) != 0,
"BoostLearner: wrong model format");
@@ -189,15 +187,15 @@ class BoostLearner : public rabit::Serializable {
char tmp[32];
utils::SPrintf(tmp, sizeof(tmp), "%u", mparam.num_class);
obj_->SetParam("num_class", tmp);
gbm_->LoadModel(fi, with_pbuffer);
if (!with_pbuffer || distributed_mode == 2) {
gbm_->LoadModel(fi, mparam.saved_with_pbuffer != 0);
if (mparam.saved_with_pbuffer == 0) {
gbm_->ResetPredBuffer(pred_buffer_size);
}
}
// rabit load model from rabit checkpoint
virtual void Load(rabit::Stream *fi) {
// for row split, we should not keep pbuffer
this->LoadModel(*fi, distributed_mode != 2, false);
this->LoadModel(*fi, false);
}
// rabit save model to rabit checkpoint
virtual void Save(rabit::Stream *fo) const {
@@ -218,18 +216,20 @@ class BoostLearner : public rabit::Serializable {
if (header == "bs64") {
utils::Base64InStream bsin(fi);
bsin.InitPosition();
this->LoadModel(bsin);
this->LoadModel(bsin, true);
} else if (header == "binf") {
this->LoadModel(*fi);
this->LoadModel(*fi, true);
} else {
delete fi;
fi = utils::IStream::Create(fname, "r");
this->LoadModel(*fi);
this->LoadModel(*fi, true);
}
delete fi;
}
inline void SaveModel(utils::IStream &fo, bool with_pbuffer = true) const {
fo.Write(&mparam, sizeof(ModelParam));
inline void SaveModel(utils::IStream &fo, bool with_pbuffer) const {
ModelParam p = mparam;
p.saved_with_pbuffer = static_cast<int>(with_pbuffer);
fo.Write(&p, sizeof(ModelParam));
fo.Write(name_obj_);
fo.Write(name_gbm_);
gbm_->SaveModel(fo, with_pbuffer);
@@ -237,17 +237,18 @@ class BoostLearner : public rabit::Serializable {
/*!
* \brief save model into file
* \param fname file name
* \param with_pbuffer whether save pbuffer together
*/
inline void SaveModel(const char *fname) const {
inline void SaveModel(const char *fname, bool with_pbuffer) const {
utils::IStream *fo = utils::IStream::Create(fname, "w");
if (save_base64 != 0 || !strcmp(fname, "stdout")) {
fo->Write("bs64\t", 5);
utils::Base64OutStream bout(fo);
this->SaveModel(bout);
this->SaveModel(bout, with_pbuffer);
bout.Finish('\n');
} else {
fo->Write("binf", 4);
this->SaveModel(*fo);
this->SaveModel(*fo, with_pbuffer);
}
delete fo;
}
@@ -442,14 +443,17 @@ class BoostLearner : public rabit::Serializable {
unsigned num_feature;
/* \brief number of class, if it is multi-class classification */
int num_class;
/*! \brief whether the model itself is saved with pbuffer */
int saved_with_pbuffer;
/*! \brief reserved field */
int reserved[31];
int reserved[30];
/*! \brief constructor */
ModelParam(void) {
std::memset(this, 0, sizeof(ModelParam));
base_score = 0.5f;
num_feature = 0;
num_class = 0;
std::memset(reserved, 0, sizeof(reserved));
saved_with_pbuffer = 0;
}
/*!
* \brief set parameters from outside