add with pbuffer info to model, allow xgb model to be saved in a more memory compact way

This commit is contained in:
tqchen 2015-05-06 15:43:15 -07:00
parent 3b4697786e
commit 7f7947f31c
4 changed files with 34 additions and 20 deletions

View File

@ -64,7 +64,13 @@ class GBTree : public IGradBooster {
} }
virtual void SaveModel(utils::IStream &fo, bool with_pbuffer) const { virtual void SaveModel(utils::IStream &fo, bool with_pbuffer) const {
utils::Assert(mparam.num_trees == static_cast<int>(trees.size()), "GBTree"); utils::Assert(mparam.num_trees == static_cast<int>(trees.size()), "GBTree");
fo.Write(&mparam, sizeof(ModelParam)); if (with_pbuffer) {
fo.Write(&mparam, sizeof(ModelParam));
} else {
ModelParam p = mparam;
p.num_pbuffer = 0;
fo.Write(&p, sizeof(ModelParam));
}
for (size_t i = 0; i < trees.size(); ++i) { for (size_t i = 0; i < trees.size(); ++i) {
trees[i]->SaveModel(fo); trees[i]->SaveModel(fo);
} }

View File

@ -157,11 +157,9 @@ class BoostLearner : public rabit::Serializable {
/*! /*!
* \brief load model from stream * \brief load model from stream
* \param fi input stream * \param fi input stream
* \param with_pbuffer whether to load with predict buffer
* \param calc_num_feature whether call InitTrainer with calc_num_feature * \param calc_num_feature whether call InitTrainer with calc_num_feature
*/ */
inline void LoadModel(utils::IStream &fi, inline void LoadModel(utils::IStream &fi,
bool with_pbuffer = true,
bool calc_num_feature = true) { bool calc_num_feature = true) {
utils::Check(fi.Read(&mparam, sizeof(ModelParam)) != 0, utils::Check(fi.Read(&mparam, sizeof(ModelParam)) != 0,
"BoostLearner: wrong model format"); "BoostLearner: wrong model format");
@ -189,15 +187,15 @@ class BoostLearner : public rabit::Serializable {
char tmp[32]; char tmp[32];
utils::SPrintf(tmp, sizeof(tmp), "%u", mparam.num_class); utils::SPrintf(tmp, sizeof(tmp), "%u", mparam.num_class);
obj_->SetParam("num_class", tmp); obj_->SetParam("num_class", tmp);
gbm_->LoadModel(fi, with_pbuffer); gbm_->LoadModel(fi, mparam.saved_with_pbuffer != 0);
if (!with_pbuffer || distributed_mode == 2) { if (mparam.saved_with_pbuffer == 0) {
gbm_->ResetPredBuffer(pred_buffer_size); gbm_->ResetPredBuffer(pred_buffer_size);
} }
} }
// rabit load model from rabit checkpoint // rabit load model from rabit checkpoint
virtual void Load(rabit::Stream *fi) { virtual void Load(rabit::Stream *fi) {
// for row split, we should not keep pbuffer // for row split, we should not keep pbuffer
this->LoadModel(*fi, distributed_mode != 2, false); this->LoadModel(*fi, false);
} }
// rabit save model to rabit checkpoint // rabit save model to rabit checkpoint
virtual void Save(rabit::Stream *fo) const { virtual void Save(rabit::Stream *fo) const {
@ -218,18 +216,20 @@ class BoostLearner : public rabit::Serializable {
if (header == "bs64") { if (header == "bs64") {
utils::Base64InStream bsin(fi); utils::Base64InStream bsin(fi);
bsin.InitPosition(); bsin.InitPosition();
this->LoadModel(bsin); this->LoadModel(bsin, true);
} else if (header == "binf") { } else if (header == "binf") {
this->LoadModel(*fi); this->LoadModel(*fi, true);
} else { } else {
delete fi; delete fi;
fi = utils::IStream::Create(fname, "r"); fi = utils::IStream::Create(fname, "r");
this->LoadModel(*fi); this->LoadModel(*fi, true);
} }
delete fi; delete fi;
} }
inline void SaveModel(utils::IStream &fo, bool with_pbuffer = true) const { inline void SaveModel(utils::IStream &fo, bool with_pbuffer) const {
fo.Write(&mparam, sizeof(ModelParam)); ModelParam p = mparam;
p.saved_with_pbuffer = static_cast<int>(with_pbuffer);
fo.Write(&p, sizeof(ModelParam));
fo.Write(name_obj_); fo.Write(name_obj_);
fo.Write(name_gbm_); fo.Write(name_gbm_);
gbm_->SaveModel(fo, with_pbuffer); gbm_->SaveModel(fo, with_pbuffer);
@ -237,17 +237,18 @@ class BoostLearner : public rabit::Serializable {
/*! /*!
* \brief save model into file * \brief save model into file
* \param fname file name * \param fname file name
* \param with_pbuffer whether save pbuffer together
*/ */
inline void SaveModel(const char *fname) const { inline void SaveModel(const char *fname, bool with_pbuffer) const {
utils::IStream *fo = utils::IStream::Create(fname, "w"); utils::IStream *fo = utils::IStream::Create(fname, "w");
if (save_base64 != 0 || !strcmp(fname, "stdout")) { if (save_base64 != 0 || !strcmp(fname, "stdout")) {
fo->Write("bs64\t", 5); fo->Write("bs64\t", 5);
utils::Base64OutStream bout(fo); utils::Base64OutStream bout(fo);
this->SaveModel(bout); this->SaveModel(bout, with_pbuffer);
bout.Finish('\n'); bout.Finish('\n');
} else { } else {
fo->Write("binf", 4); fo->Write("binf", 4);
this->SaveModel(*fo); this->SaveModel(*fo, with_pbuffer);
} }
delete fo; delete fo;
} }
@ -442,14 +443,17 @@ class BoostLearner : public rabit::Serializable {
unsigned num_feature; unsigned num_feature;
/* \brief number of class, if it is multi-class classification */ /* \brief number of class, if it is multi-class classification */
int num_class; int num_class;
/*! \brief whether the model itself is saved with pbuffer */
int saved_with_pbuffer;
/*! \brief reserved field */ /*! \brief reserved field */
int reserved[31]; int reserved[30];
/*! \brief constructor */ /*! \brief constructor */
ModelParam(void) { ModelParam(void) {
std::memset(this, 0, sizeof(ModelParam));
base_score = 0.5f; base_score = 0.5f;
num_feature = 0; num_feature = 0;
num_class = 0; num_class = 0;
std::memset(reserved, 0, sizeof(reserved)); saved_with_pbuffer = 0;
} }
/*! /*!
* \brief set parameters from outside * \brief set parameters from outside

View File

@ -87,6 +87,7 @@ class BoostLearnTask {
if (!strcmp("name_pred", name)) name_pred = val; if (!strcmp("name_pred", name)) name_pred = val;
if (!strcmp("dsplit", name)) data_split = val; if (!strcmp("dsplit", name)) data_split = val;
if (!strcmp("dump_stats", name)) dump_model_stats = atoi(val); if (!strcmp("dump_stats", name)) dump_model_stats = atoi(val);
if (!strcmp("save_pbuffer", name)) save_with_pbuffer = atoi(val);
if (!strncmp("eval[", name, 5)) { if (!strncmp("eval[", name, 5)) {
char evname[256]; char evname[256];
utils::Assert(sscanf(name, "eval[%[^]]", evname) == 1, "must specify evaluation name for display"); utils::Assert(sscanf(name, "eval[%[^]]", evname) == 1, "must specify evaluation name for display");
@ -115,6 +116,7 @@ class BoostLearnTask {
model_dir_path = "./"; model_dir_path = "./";
data_split = "NONE"; data_split = "NONE";
load_part = 0; load_part = 0;
save_with_pbuffer = 0;
data = NULL; data = NULL;
} }
~BoostLearnTask(void){ ~BoostLearnTask(void){
@ -241,7 +243,7 @@ class BoostLearnTask {
} }
inline void SaveModel(const char *fname) const { inline void SaveModel(const char *fname) const {
if (rabit::GetRank() != 0) return; if (rabit::GetRank() != 0) return;
learner.SaveModel(fname); learner.SaveModel(fname, save_with_pbuffer != 0);
} }
inline void SaveModel(int i) const { inline void SaveModel(int i) const {
char fname[256]; char fname[256];
@ -297,6 +299,8 @@ class BoostLearnTask {
int pred_margin; int pred_margin;
/*! \brief whether dump statistics along with model */ /*! \brief whether dump statistics along with model */
int dump_model_stats; int dump_model_stats;
/*! \brief whether save prediction buffer */
int save_with_pbuffer;
/*! \brief name of feature map */ /*! \brief name of feature map */
std::string name_fmap; std::string name_fmap;
/*! \brief name of dump file */ /*! \brief name of dump file */

View File

@ -58,13 +58,13 @@ class Booster: public learner::BoostLearner {
} }
inline void LoadModelFromBuffer(const void *buf, size_t size) { inline void LoadModelFromBuffer(const void *buf, size_t size) {
utils::MemoryFixSizeBuffer fs((void*)buf, size); utils::MemoryFixSizeBuffer fs((void*)buf, size);
learner::BoostLearner::LoadModel(fs); learner::BoostLearner::LoadModel(fs, true);
this->init_model = true; this->init_model = true;
} }
inline const char *GetModelRaw(bst_ulong *out_len) { inline const char *GetModelRaw(bst_ulong *out_len) {
model_str.resize(0); model_str.resize(0);
utils::MemoryBufferStream fs(&model_str); utils::MemoryBufferStream fs(&model_str);
learner::BoostLearner::SaveModel(fs); learner::BoostLearner::SaveModel(fs, false);
*out_len = static_cast<bst_ulong>(model_str.length()); *out_len = static_cast<bst_ulong>(model_str.length());
if (*out_len == 0) { if (*out_len == 0) {
return NULL; return NULL;
@ -323,7 +323,7 @@ extern "C"{
static_cast<Booster*>(handle)->LoadModel(fname); static_cast<Booster*>(handle)->LoadModel(fname);
} }
void XGBoosterSaveModel(const void *handle, const char *fname) { void XGBoosterSaveModel(const void *handle, const char *fname) {
static_cast<const Booster*>(handle)->SaveModel(fname); static_cast<const Booster*>(handle)->SaveModel(fname, false);
} }
void XGBoosterLoadModelFromBuffer(void *handle, const void *buf, bst_ulong len) { void XGBoosterLoadModelFromBuffer(void *handle, const void *buf, bst_ulong len) {
static_cast<Booster*>(handle)->LoadModelFromBuffer(buf, len); static_cast<Booster*>(handle)->LoadModelFromBuffer(buf, len);