add with pbuffer info to model, allow xgb model to be saved in a more memory compact way
This commit is contained in:
parent
3b4697786e
commit
7f7947f31c
@ -64,7 +64,13 @@ class GBTree : public IGradBooster {
|
||||
}
|
||||
virtual void SaveModel(utils::IStream &fo, bool with_pbuffer) const {
|
||||
utils::Assert(mparam.num_trees == static_cast<int>(trees.size()), "GBTree");
|
||||
fo.Write(&mparam, sizeof(ModelParam));
|
||||
if (with_pbuffer) {
|
||||
fo.Write(&mparam, sizeof(ModelParam));
|
||||
} else {
|
||||
ModelParam p = mparam;
|
||||
p.num_pbuffer = 0;
|
||||
fo.Write(&p, sizeof(ModelParam));
|
||||
}
|
||||
for (size_t i = 0; i < trees.size(); ++i) {
|
||||
trees[i]->SaveModel(fo);
|
||||
}
|
||||
|
||||
@ -157,11 +157,9 @@ class BoostLearner : public rabit::Serializable {
|
||||
/*!
|
||||
* \brief load model from stream
|
||||
* \param fi input stream
|
||||
* \param with_pbuffer whether to load with predict buffer
|
||||
* \param calc_num_feature whether call InitTrainer with calc_num_feature
|
||||
*/
|
||||
inline void LoadModel(utils::IStream &fi,
|
||||
bool with_pbuffer = true,
|
||||
bool calc_num_feature = true) {
|
||||
utils::Check(fi.Read(&mparam, sizeof(ModelParam)) != 0,
|
||||
"BoostLearner: wrong model format");
|
||||
@ -189,15 +187,15 @@ class BoostLearner : public rabit::Serializable {
|
||||
char tmp[32];
|
||||
utils::SPrintf(tmp, sizeof(tmp), "%u", mparam.num_class);
|
||||
obj_->SetParam("num_class", tmp);
|
||||
gbm_->LoadModel(fi, with_pbuffer);
|
||||
if (!with_pbuffer || distributed_mode == 2) {
|
||||
gbm_->LoadModel(fi, mparam.saved_with_pbuffer != 0);
|
||||
if (mparam.saved_with_pbuffer == 0) {
|
||||
gbm_->ResetPredBuffer(pred_buffer_size);
|
||||
}
|
||||
}
|
||||
// rabit load model from rabit checkpoint
|
||||
virtual void Load(rabit::Stream *fi) {
|
||||
// for row split, we should not keep pbuffer
|
||||
this->LoadModel(*fi, distributed_mode != 2, false);
|
||||
this->LoadModel(*fi, false);
|
||||
}
|
||||
// rabit save model to rabit checkpoint
|
||||
virtual void Save(rabit::Stream *fo) const {
|
||||
@ -218,18 +216,20 @@ class BoostLearner : public rabit::Serializable {
|
||||
if (header == "bs64") {
|
||||
utils::Base64InStream bsin(fi);
|
||||
bsin.InitPosition();
|
||||
this->LoadModel(bsin);
|
||||
this->LoadModel(bsin, true);
|
||||
} else if (header == "binf") {
|
||||
this->LoadModel(*fi);
|
||||
this->LoadModel(*fi, true);
|
||||
} else {
|
||||
delete fi;
|
||||
fi = utils::IStream::Create(fname, "r");
|
||||
this->LoadModel(*fi);
|
||||
this->LoadModel(*fi, true);
|
||||
}
|
||||
delete fi;
|
||||
}
|
||||
inline void SaveModel(utils::IStream &fo, bool with_pbuffer = true) const {
|
||||
fo.Write(&mparam, sizeof(ModelParam));
|
||||
inline void SaveModel(utils::IStream &fo, bool with_pbuffer) const {
|
||||
ModelParam p = mparam;
|
||||
p.saved_with_pbuffer = static_cast<int>(with_pbuffer);
|
||||
fo.Write(&p, sizeof(ModelParam));
|
||||
fo.Write(name_obj_);
|
||||
fo.Write(name_gbm_);
|
||||
gbm_->SaveModel(fo, with_pbuffer);
|
||||
@ -237,17 +237,18 @@ class BoostLearner : public rabit::Serializable {
|
||||
/*!
|
||||
* \brief save model into file
|
||||
* \param fname file name
|
||||
* \param with_pbuffer whether save pbuffer together
|
||||
*/
|
||||
inline void SaveModel(const char *fname) const {
|
||||
inline void SaveModel(const char *fname, bool with_pbuffer) const {
|
||||
utils::IStream *fo = utils::IStream::Create(fname, "w");
|
||||
if (save_base64 != 0 || !strcmp(fname, "stdout")) {
|
||||
fo->Write("bs64\t", 5);
|
||||
utils::Base64OutStream bout(fo);
|
||||
this->SaveModel(bout);
|
||||
this->SaveModel(bout, with_pbuffer);
|
||||
bout.Finish('\n');
|
||||
} else {
|
||||
fo->Write("binf", 4);
|
||||
this->SaveModel(*fo);
|
||||
this->SaveModel(*fo, with_pbuffer);
|
||||
}
|
||||
delete fo;
|
||||
}
|
||||
@ -442,14 +443,17 @@ class BoostLearner : public rabit::Serializable {
|
||||
unsigned num_feature;
|
||||
/* \brief number of class, if it is multi-class classification */
|
||||
int num_class;
|
||||
/*! \brief whether the model itself is saved with pbuffer */
|
||||
int saved_with_pbuffer;
|
||||
/*! \brief reserved field */
|
||||
int reserved[31];
|
||||
int reserved[30];
|
||||
/*! \brief constructor */
|
||||
ModelParam(void) {
|
||||
std::memset(this, 0, sizeof(ModelParam));
|
||||
base_score = 0.5f;
|
||||
num_feature = 0;
|
||||
num_class = 0;
|
||||
std::memset(reserved, 0, sizeof(reserved));
|
||||
saved_with_pbuffer = 0;
|
||||
}
|
||||
/*!
|
||||
* \brief set parameters from outside
|
||||
|
||||
@ -87,6 +87,7 @@ class BoostLearnTask {
|
||||
if (!strcmp("name_pred", name)) name_pred = val;
|
||||
if (!strcmp("dsplit", name)) data_split = val;
|
||||
if (!strcmp("dump_stats", name)) dump_model_stats = atoi(val);
|
||||
if (!strcmp("save_pbuffer", name)) save_with_pbuffer = atoi(val);
|
||||
if (!strncmp("eval[", name, 5)) {
|
||||
char evname[256];
|
||||
utils::Assert(sscanf(name, "eval[%[^]]", evname) == 1, "must specify evaluation name for display");
|
||||
@ -115,6 +116,7 @@ class BoostLearnTask {
|
||||
model_dir_path = "./";
|
||||
data_split = "NONE";
|
||||
load_part = 0;
|
||||
save_with_pbuffer = 0;
|
||||
data = NULL;
|
||||
}
|
||||
~BoostLearnTask(void){
|
||||
@ -241,7 +243,7 @@ class BoostLearnTask {
|
||||
}
|
||||
inline void SaveModel(const char *fname) const {
|
||||
if (rabit::GetRank() != 0) return;
|
||||
learner.SaveModel(fname);
|
||||
learner.SaveModel(fname, save_with_pbuffer != 0);
|
||||
}
|
||||
inline void SaveModel(int i) const {
|
||||
char fname[256];
|
||||
@ -297,6 +299,8 @@ class BoostLearnTask {
|
||||
int pred_margin;
|
||||
/*! \brief whether dump statistics along with model */
|
||||
int dump_model_stats;
|
||||
/*! \brief whether save prediction buffer */
|
||||
int save_with_pbuffer;
|
||||
/*! \brief name of feature map */
|
||||
std::string name_fmap;
|
||||
/*! \brief name of dump file */
|
||||
|
||||
@ -58,13 +58,13 @@ class Booster: public learner::BoostLearner {
|
||||
}
|
||||
inline void LoadModelFromBuffer(const void *buf, size_t size) {
|
||||
utils::MemoryFixSizeBuffer fs((void*)buf, size);
|
||||
learner::BoostLearner::LoadModel(fs);
|
||||
learner::BoostLearner::LoadModel(fs, true);
|
||||
this->init_model = true;
|
||||
}
|
||||
inline const char *GetModelRaw(bst_ulong *out_len) {
|
||||
model_str.resize(0);
|
||||
utils::MemoryBufferStream fs(&model_str);
|
||||
learner::BoostLearner::SaveModel(fs);
|
||||
learner::BoostLearner::SaveModel(fs, false);
|
||||
*out_len = static_cast<bst_ulong>(model_str.length());
|
||||
if (*out_len == 0) {
|
||||
return NULL;
|
||||
@ -323,7 +323,7 @@ extern "C"{
|
||||
static_cast<Booster*>(handle)->LoadModel(fname);
|
||||
}
|
||||
void XGBoosterSaveModel(const void *handle, const char *fname) {
|
||||
static_cast<const Booster*>(handle)->SaveModel(fname);
|
||||
static_cast<const Booster*>(handle)->SaveModel(fname, false);
|
||||
}
|
||||
void XGBoosterLoadModelFromBuffer(void *handle, const void *buf, bst_ulong len) {
|
||||
static_cast<Booster*>(handle)->LoadModelFromBuffer(buf, len);
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user