add base64 model format

This commit is contained in:
tqchen
2014-12-24 02:33:50 -08:00
parent c8396ca24e
commit 6d7ef172ef
4 changed files with 261 additions and 17 deletions

View File

@@ -12,6 +12,8 @@
#include <limits>
// rabit library for synchronization
#include <rabit.h>
#include "../utils/io.h"
#include "../utils/base64.h"
#include "./objective.h"
#include "./evaluation.h"
#include "../gbm/gbm.h"
@@ -36,6 +38,7 @@ class BoostLearner : public rabit::ISerializable {
pred_buffer_size = 0;
seed_per_iteration = 0;
seed = 0;
save_base64 = 0;
}
virtual ~BoostLearner(void) {
if (obj_ != NULL) delete obj_;
@@ -68,9 +71,6 @@ class BoostLearner : public rabit::ISerializable {
utils::SPrintf(str_temp, sizeof(str_temp), "%lu",
static_cast<unsigned long>(buffer_size));
this->SetParam("num_pbuffer", str_temp);
if (!silent) {
utils::Printf("buffer_size=%ld\n", static_cast<long>(buffer_size));
}
this->pred_buffer_size = buffer_size;
}
/*!
@@ -108,6 +108,7 @@ class BoostLearner : public rabit::ISerializable {
this->seed = seed; random::Seed(atoi(val));
}
if (!strcmp("seed_per_iter", name)) seed_per_iteration = atoi(val);
if (!strcmp("save_base64", name)) save_base64 = atoi(val);
if (!strcmp(name, "num_class")) this->SetParam("num_output_group", val);
if (!strcmp(name, "nthread")) {
omp_set_num_threads(atoi(val));
@@ -191,9 +192,29 @@ class BoostLearner : public rabit::ISerializable {
* \param fname file name
*/
inline void LoadModel(const char *fname) {
utils::FileStream fi(utils::FopenCheck(fname, "rb"));
FILE *fp = utils::FopenCheck(fname, "rb");
std::string header; header.resize(4);
utils::FileStream fi(fp);
// check header for different binary encode
// can be base64 or binary
if (fi.Read(&header[0], 4) != 0) {
// base64 format
if (header == "bs64") {
utils::Base64InStream bsin(fp);
bsin.InitPosition();
this->LoadModel(bsin);
fclose(fp);
return;
}
if (header == "binf") {
this->LoadModel(fi);
fclose(fp);
return;
}
}
fi.Seek(0);
this->LoadModel(fi);
fi.Close();
fclose(fp);
}
inline void SaveModel(utils::IStream &fo, bool with_pbuffer = true) const {
fo.Write(&mparam, sizeof(ModelParam));
@@ -206,9 +227,24 @@ class BoostLearner : public rabit::ISerializable {
* \param fname file name
*/
inline void SaveModel(const char *fname) const {
utils::FileStream fo(utils::FopenCheck(fname, "wb"));
this->SaveModel(fo);
fo.Close();
FILE *fp;
if (!strcmp(fname, "stdout")) {
fp = stdout;
} else {
fp = utils::FopenCheck(fname, "wb");
}
utils::FileStream fo(fp);
std::string header;
if (save_base64 != 0|| fp == stdout) {
fo.Write("bs64\t", 5);
utils::Base64OutStream bout(fp);
this->SaveModel(bout);
bout.Finish('\n');
} else {
fo.Write("binf", 4);
this->SaveModel(fo);
}
if (fp != stdout) fclose(fp);
}
/*!
* \brief check if data matrix is ready to be used by training,
@@ -383,6 +419,8 @@ class BoostLearner : public rabit::ISerializable {
// this is important for restart from existing iterations
// default set to no, but will auto switch on in distributed mode
int seed_per_iteration;
// save model in base64 encoding
int save_base64;
// silent during training
int silent;
// distributed learning mode, if any, 0:none, 1:col, 2:row