add base64 model format
This commit is contained in:
@@ -12,6 +12,8 @@
|
||||
#include <limits>
|
||||
// rabit library for synchronization
|
||||
#include <rabit.h>
|
||||
#include "../utils/io.h"
|
||||
#include "../utils/base64.h"
|
||||
#include "./objective.h"
|
||||
#include "./evaluation.h"
|
||||
#include "../gbm/gbm.h"
|
||||
@@ -36,6 +38,7 @@ class BoostLearner : public rabit::ISerializable {
|
||||
pred_buffer_size = 0;
|
||||
seed_per_iteration = 0;
|
||||
seed = 0;
|
||||
save_base64 = 0;
|
||||
}
|
||||
virtual ~BoostLearner(void) {
|
||||
if (obj_ != NULL) delete obj_;
|
||||
@@ -68,9 +71,6 @@ class BoostLearner : public rabit::ISerializable {
|
||||
utils::SPrintf(str_temp, sizeof(str_temp), "%lu",
|
||||
static_cast<unsigned long>(buffer_size));
|
||||
this->SetParam("num_pbuffer", str_temp);
|
||||
if (!silent) {
|
||||
utils::Printf("buffer_size=%ld\n", static_cast<long>(buffer_size));
|
||||
}
|
||||
this->pred_buffer_size = buffer_size;
|
||||
}
|
||||
/*!
|
||||
@@ -108,6 +108,7 @@ class BoostLearner : public rabit::ISerializable {
|
||||
this->seed = seed; random::Seed(atoi(val));
|
||||
}
|
||||
if (!strcmp("seed_per_iter", name)) seed_per_iteration = atoi(val);
|
||||
if (!strcmp("save_base64", name)) save_base64 = atoi(val);
|
||||
if (!strcmp(name, "num_class")) this->SetParam("num_output_group", val);
|
||||
if (!strcmp(name, "nthread")) {
|
||||
omp_set_num_threads(atoi(val));
|
||||
@@ -191,9 +192,29 @@ class BoostLearner : public rabit::ISerializable {
|
||||
* \param fname file name
|
||||
*/
|
||||
inline void LoadModel(const char *fname) {
|
||||
utils::FileStream fi(utils::FopenCheck(fname, "rb"));
|
||||
FILE *fp = utils::FopenCheck(fname, "rb");
|
||||
std::string header; header.resize(4);
|
||||
utils::FileStream fi(fp);
|
||||
// check header for different binary encode
|
||||
// can be base64 or binary
|
||||
if (fi.Read(&header[0], 4) != 0) {
|
||||
// base64 format
|
||||
if (header == "bs64") {
|
||||
utils::Base64InStream bsin(fp);
|
||||
bsin.InitPosition();
|
||||
this->LoadModel(bsin);
|
||||
fclose(fp);
|
||||
return;
|
||||
}
|
||||
if (header == "binf") {
|
||||
this->LoadModel(fi);
|
||||
fclose(fp);
|
||||
return;
|
||||
}
|
||||
}
|
||||
fi.Seek(0);
|
||||
this->LoadModel(fi);
|
||||
fi.Close();
|
||||
fclose(fp);
|
||||
}
|
||||
inline void SaveModel(utils::IStream &fo, bool with_pbuffer = true) const {
|
||||
fo.Write(&mparam, sizeof(ModelParam));
|
||||
@@ -206,9 +227,24 @@ class BoostLearner : public rabit::ISerializable {
|
||||
* \param fname file name
|
||||
*/
|
||||
inline void SaveModel(const char *fname) const {
|
||||
utils::FileStream fo(utils::FopenCheck(fname, "wb"));
|
||||
this->SaveModel(fo);
|
||||
fo.Close();
|
||||
FILE *fp;
|
||||
if (!strcmp(fname, "stdout")) {
|
||||
fp = stdout;
|
||||
} else {
|
||||
fp = utils::FopenCheck(fname, "wb");
|
||||
}
|
||||
utils::FileStream fo(fp);
|
||||
std::string header;
|
||||
if (save_base64 != 0|| fp == stdout) {
|
||||
fo.Write("bs64\t", 5);
|
||||
utils::Base64OutStream bout(fp);
|
||||
this->SaveModel(bout);
|
||||
bout.Finish('\n');
|
||||
} else {
|
||||
fo.Write("binf", 4);
|
||||
this->SaveModel(fo);
|
||||
}
|
||||
if (fp != stdout) fclose(fp);
|
||||
}
|
||||
/*!
|
||||
* \brief check if data matrix is ready to be used by training,
|
||||
@@ -383,6 +419,8 @@ class BoostLearner : public rabit::ISerializable {
|
||||
// this is important for restart from existing iterations
|
||||
// default set to no, but will auto switch on in distributed mode
|
||||
int seed_per_iteration;
|
||||
// save model in base64 encoding
|
||||
int save_base64;
|
||||
// silent during training
|
||||
int silent;
|
||||
// distributed learning mode, if any, 0:none, 1:col, 2:row
|
||||
|
||||
Reference in New Issue
Block a user