diff --git a/src/learner/learner-inl.hpp b/src/learner/learner-inl.hpp index 201100a6f..cb02ee075 100644 --- a/src/learner/learner-inl.hpp +++ b/src/learner/learner-inl.hpp @@ -12,6 +12,8 @@ #include // rabit library for synchronization #include +#include "../utils/io.h" +#include "../utils/base64.h" #include "./objective.h" #include "./evaluation.h" #include "../gbm/gbm.h" @@ -36,6 +38,7 @@ class BoostLearner : public rabit::ISerializable { pred_buffer_size = 0; seed_per_iteration = 0; seed = 0; + save_base64 = 0; } virtual ~BoostLearner(void) { if (obj_ != NULL) delete obj_; @@ -68,9 +71,6 @@ class BoostLearner : public rabit::ISerializable { utils::SPrintf(str_temp, sizeof(str_temp), "%lu", static_cast(buffer_size)); this->SetParam("num_pbuffer", str_temp); - if (!silent) { - utils::Printf("buffer_size=%ld\n", static_cast(buffer_size)); - } this->pred_buffer_size = buffer_size; } /*! @@ -108,6 +108,7 @@ class BoostLearner : public rabit::ISerializable { this->seed = seed; random::Seed(atoi(val)); } if (!strcmp("seed_per_iter", name)) seed_per_iteration = atoi(val); + if (!strcmp("save_base64", name)) save_base64 = atoi(val); if (!strcmp(name, "num_class")) this->SetParam("num_output_group", val); if (!strcmp(name, "nthread")) { omp_set_num_threads(atoi(val)); @@ -191,9 +192,29 @@ class BoostLearner : public rabit::ISerializable { * \param fname file name */ inline void LoadModel(const char *fname) { - utils::FileStream fi(utils::FopenCheck(fname, "rb")); + FILE *fp = utils::FopenCheck(fname, "rb"); + std::string header; header.resize(4); + utils::FileStream fi(fp); + // check header for different binary encode + // can be base64 or binary + if (fi.Read(&header[0], 4) != 0) { + // base64 format + if (header == "bs64") { + utils::Base64InStream bsin(fp); + bsin.InitPosition(); + this->LoadModel(bsin); + fclose(fp); + return; + } + if (header == "binf") { + this->LoadModel(fi); + fclose(fp); + return; + } + } + fi.Seek(0); this->LoadModel(fi); - fi.Close(); + fclose(fp); } inline void SaveModel(utils::IStream &fo, bool with_pbuffer = true) const { fo.Write(&mparam, sizeof(ModelParam)); @@ -206,9 +227,24 @@ class BoostLearner : public rabit::ISerializable { * \param fname file name */ inline void SaveModel(const char *fname) const { - utils::FileStream fo(utils::FopenCheck(fname, "wb")); - this->SaveModel(fo); - fo.Close(); + FILE *fp; + if (!strcmp(fname, "stdout")) { + fp = stdout; + } else { + fp = utils::FopenCheck(fname, "wb"); + } + utils::FileStream fo(fp); + std::string header; + if (save_base64 != 0|| fp == stdout) { + fo.Write("bs64\t", 5); + utils::Base64OutStream bout(fp); + this->SaveModel(bout); + bout.Finish('\n'); + } else { + fo.Write("binf", 4); + this->SaveModel(fo); + } + if (fp != stdout) fclose(fp); } /*! * \brief check if data matrix is ready to be used by training, @@ -383,6 +419,8 @@ class BoostLearner : public rabit::ISerializable { // this is important for restart from existing iterations // default set to no, but will auto switch on in distributed mode int seed_per_iteration; + // save model in base64 encoding + int save_base64; // silent during training int silent; // distributed learning mode, if any, 0:none, 1:col, 2:row diff --git a/src/utils/base64.h b/src/utils/base64.h new file mode 100644 index 000000000..36699199f --- /dev/null +++ b/src/utils/base64.h @@ -0,0 +1,205 @@ +#ifndef XGBOOST_UTILS_BASE64_H_ +#define XGBOOST_UTILS_BASE64_H_ +/*! + * \file base64.h + * \brief data stream support to input and output from/to base64 stream + * base64 is easier to store and pass as text format in mapreduce + * \author Tianqi Chen + */ +#include +#include +#include "./utils.h" +#include "./io.h" + +namespace xgboost { +namespace utils { +/*! \brief namespace of base64 decoding and encoding table */ +namespace base64 { +const char DecodeTable[] = { + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 62, // '+' + 0, 0, 0, + 63, // '/' + 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, // '0'-'9' + 0, 0, 0, 0, 0, 0, 0, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, + 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, // 'A'-'Z' + 0, 0, 0, 0, 0, 0, + 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, + 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, // 'a'-'z' +}; +static const char EncodeTable[] = + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; +} // namespace base64 +/*! \brief the stream that reads from base64, note we take from file pointers */ +class Base64InStream: public IStream { + public: + explicit Base64InStream(FILE *fp) : fp(fp) { + num_prev = 0; tmp_ch = 0; + } + /*! + * \brief initialize the stream position to beginning of next base64 stream + * call this function before actually start read + */ + inline void InitPosition(void) { + // get a charater + do { + tmp_ch = fgetc(fp); + } while (isspace(tmp_ch)); + } + /*! \brief whether current position is end of a base64 stream */ + inline bool IsEOF(void) const { + return num_prev == 0 && (tmp_ch == EOF || isspace(tmp_ch)); + } + virtual size_t Read(void *ptr, size_t size) { + using base64::DecodeTable; + if (size == 0) return 0; + // use tlen to record left size + size_t tlen = size; + unsigned char *cptr = static_cast(ptr); + // if anything left, load from previous buffered result + if (num_prev != 0) { + if (num_prev == 2) { + if (tlen >= 2) { + *cptr++ = buf_prev[0]; + *cptr++ = buf_prev[1]; + tlen -= 2; + num_prev = 0; + } else { + // assert tlen == 1 + *cptr++ = buf_prev[0]; --tlen; + buf_prev[0] = buf_prev[1]; + num_prev = 1; + } + } else { + // assert num_prev == 1 + *cptr++ = buf_prev[0]; --tlen; num_prev = 0; + } + } + if (tlen == 0) return size; + int nvalue; + // note: everything goes with 4 bytes in Base64 + // so we process 4 bytes a unit + while (tlen && tmp_ch != EOF && !isspace(tmp_ch)) { + // first byte + nvalue = DecodeTable[tmp_ch] << 18; + { + // second byte + Check((tmp_ch = fgetc(fp), tmp_ch != EOF && !isspace(tmp_ch)), + "invalid base64 format"); + nvalue |= DecodeTable[tmp_ch] << 12; + *cptr++ = (nvalue >> 16) & 0xFF; --tlen; + } + { + // third byte + Check((tmp_ch = fgetc(fp), tmp_ch != EOF && !isspace(tmp_ch)), + "invalid base64 format"); + // handle termination + if (tmp_ch == '=') { + Check((tmp_ch = fgetc(fp), tmp_ch == '='), "invalid base64 format"); + Check((tmp_ch = fgetc(fp), tmp_ch == EOF || isspace(tmp_ch)), + "invalid base64 format"); + break; + } + nvalue |= DecodeTable[tmp_ch] << 6; + if (tlen) { + *cptr++ = (nvalue >> 8) & 0xFF; --tlen; + } else { + buf_prev[num_prev++] = (nvalue >> 8) & 0xFF; + } + } + { + // fourth byte + Check((tmp_ch = fgetc(fp), tmp_ch != EOF && !isspace(tmp_ch)), + "invalid base64 format"); + if (tmp_ch == '=') { + Check((tmp_ch = fgetc(fp), tmp_ch == EOF || isspace(tmp_ch)), + "invalid base64 format"); + break; + } + nvalue |= DecodeTable[tmp_ch]; + if (tlen) { + *cptr++ = nvalue & 0xFF; --tlen; + } else { + buf_prev[num_prev ++] = nvalue & 0xFF; + } + } + // get next char + tmp_ch = fgetc(fp); + } + if (kStrictCheck) { + Check(tlen == 0, "Base64InStream: read incomplete"); + } + return size - tlen; + } + virtual void Write(const void *ptr, size_t size) { + utils::Error("Base64InStream do not support write"); + } + + private: + FILE *fp; + unsigned char tmp_ch; + int num_prev; + unsigned char buf_prev[2]; + // whether we need to do strict check + static const bool kStrictCheck = false; +}; +/*! \brief the stream that write to base64, note we take from file pointers */ +class Base64OutStream: public IStream { + public: + explicit Base64OutStream(FILE *fp) : fp(fp) { + buf_top = 0; + } + virtual void Write(const void *ptr, size_t size) { + using base64::EncodeTable; + size_t tlen = size; + const unsigned char *cptr = static_cast(ptr); + while (tlen) { + while (buf_top < 3 && tlen != 0) { + buf[++buf_top] = *cptr++; --tlen; + } + if (buf_top == 3) { + // flush 4 bytes out + fputc(EncodeTable[buf[1] >> 2], fp); + fputc(EncodeTable[((buf[1] << 4) | (buf[2] >> 4)) & 0x3F], fp); + fputc(EncodeTable[((buf[2] << 2) | (buf[3] >> 6)) & 0x3F], fp); + fputc(EncodeTable[buf[3] & 0x3F], fp); + buf_top = 0; + } + } + } + virtual size_t Read(void *ptr, size_t size) { + Error("Base64OutStream do not support read"); + return 0; + } + /*! + * \brief finish writing of all current base64 stream, do some post processing + * \param endch charater to put to end of stream, if it is EOF, then nothing will be done + */ + inline void Finish(char endch = EOF) { + using base64::EncodeTable; + if (buf_top == 1) { + fputc(EncodeTable[buf[1] >> 2], fp); + fputc(EncodeTable[(buf[1] << 4) & 0x3F], fp); + fputc('=', fp); + fputc('=', fp); + } + if (buf_top == 2) { + fputc(EncodeTable[buf[1] >> 2], fp); + fputc(EncodeTable[((buf[1] << 4) | (buf[2] >> 4)) & 0x3F], fp); + fputc(EncodeTable[(buf[2] << 2) & 0x3F], fp); + fputc('=', fp); + } + buf_top = 0; + if (endch != EOF) fputc(endch, fp); + } + + private: + FILE *fp; + int buf_top; + unsigned char buf[4]; +}; +} // namespace utils +} // namespace xgboost +#endif // XGBOOST_UTILS_BASE64_H_ diff --git a/src/utils/thread_buffer.h b/src/utils/thread_buffer.h index ace50c4b8..ed36e1b43 100644 --- a/src/utils/thread_buffer.h +++ b/src/utils/thread_buffer.h @@ -1,5 +1,5 @@ -#ifndef XGBOOST_UTILS_THREAD_BUFFER_H -#define XGBOOST_UTILS_THREAD_BUFFER_H +#ifndef XGBOOST_UTILS_THREAD_BUFFER_H_ +#define XGBOOST_UTILS_THREAD_BUFFER_H_ /*! * \file thread_buffer.h * \brief multi-thread buffer, iterator, can be used to create parallel pipeline diff --git a/src/xgboost_main.cpp b/src/xgboost_main.cpp index a3f838131..7816fbfd2 100644 --- a/src/xgboost_main.cpp +++ b/src/xgboost_main.cpp @@ -31,6 +31,11 @@ class BoostLearnTask { this->SetParam(name, val); } } + // do not save anything when save to stdout + if (model_out == "stdout") { + this->SetParam("silent", "1"); + save_period = 0; + } // whether need data rank bool need_data_rank = strchr(train_path.c_str(), '%') != NULL; // if need data rank in loading, initialize rabit engine before load data @@ -41,7 +46,7 @@ class BoostLearnTask { if (!need_data_rank) rabit::Init(argc, argv); if (rabit::IsDistributed()) { std::string pname = rabit::GetProcessorName(); - printf("start %s:%d\n", pname.c_str(), rabit::GetRank()); + fprintf(stderr, "start %s:%d\n", pname.c_str(), rabit::GetRank()); } if (rabit::IsDistributed()) { this->SetParam("data_split", "col"); @@ -158,9 +163,7 @@ class BoostLearnTask { } inline void InitLearner(void) { if (model_in != "NULL") { - utils::FileStream fi(utils::FopenCheck(model_in.c_str(), "rb")); - learner.LoadModel(fi); - fi.Close(); + learner.LoadModel(model_in.c_str()); } else { utils::Assert(task == "train", "model_in not specified"); learner.InitModel(); @@ -215,9 +218,7 @@ class BoostLearnTask { } inline void SaveModel(const char *fname) const { if (rabit::GetRank() != 0) return; - utils::FileStream fo(utils::FopenCheck(fname, "wb")); - learner.SaveModel(fo); - fo.Close(); + learner.SaveModel(fname); } inline void SaveModel(int i) const { char fname[256];