add base64 model format

This commit is contained in:
tqchen 2014-12-24 02:33:50 -08:00
parent c8396ca24e
commit 6d7ef172ef
4 changed files with 261 additions and 17 deletions

View File

@ -12,6 +12,8 @@
#include <limits>
// rabit library for synchronization
#include <rabit.h>
#include "../utils/io.h"
#include "../utils/base64.h"
#include "./objective.h"
#include "./evaluation.h"
#include "../gbm/gbm.h"
@ -36,6 +38,7 @@ class BoostLearner : public rabit::ISerializable {
pred_buffer_size = 0;
seed_per_iteration = 0;
seed = 0;
save_base64 = 0;
}
virtual ~BoostLearner(void) {
if (obj_ != NULL) delete obj_;
@ -68,9 +71,6 @@ class BoostLearner : public rabit::ISerializable {
utils::SPrintf(str_temp, sizeof(str_temp), "%lu",
static_cast<unsigned long>(buffer_size));
this->SetParam("num_pbuffer", str_temp);
if (!silent) {
utils::Printf("buffer_size=%ld\n", static_cast<long>(buffer_size));
}
this->pred_buffer_size = buffer_size;
}
/*!
@ -108,6 +108,7 @@ class BoostLearner : public rabit::ISerializable {
this->seed = seed; random::Seed(atoi(val));
}
if (!strcmp("seed_per_iter", name)) seed_per_iteration = atoi(val);
if (!strcmp("save_base64", name)) save_base64 = atoi(val);
if (!strcmp(name, "num_class")) this->SetParam("num_output_group", val);
if (!strcmp(name, "nthread")) {
omp_set_num_threads(atoi(val));
@ -191,9 +192,29 @@ class BoostLearner : public rabit::ISerializable {
* \param fname file name
*/
inline void LoadModel(const char *fname) {
utils::FileStream fi(utils::FopenCheck(fname, "rb"));
FILE *fp = utils::FopenCheck(fname, "rb");
std::string header; header.resize(4);
utils::FileStream fi(fp);
// check header for different binary encode
// can be base64 or binary
if (fi.Read(&header[0], 4) != 0) {
// base64 format
if (header == "bs64") {
utils::Base64InStream bsin(fp);
bsin.InitPosition();
this->LoadModel(bsin);
fclose(fp);
return;
}
if (header == "binf") {
this->LoadModel(fi);
fclose(fp);
return;
}
}
fi.Seek(0);
this->LoadModel(fi);
fi.Close();
fclose(fp);
}
inline void SaveModel(utils::IStream &fo, bool with_pbuffer = true) const {
fo.Write(&mparam, sizeof(ModelParam));
@ -206,9 +227,24 @@ class BoostLearner : public rabit::ISerializable {
* \param fname file name
*/
inline void SaveModel(const char *fname) const {
utils::FileStream fo(utils::FopenCheck(fname, "wb"));
this->SaveModel(fo);
fo.Close();
FILE *fp;
if (!strcmp(fname, "stdout")) {
fp = stdout;
} else {
fp = utils::FopenCheck(fname, "wb");
}
utils::FileStream fo(fp);
std::string header;
if (save_base64 != 0|| fp == stdout) {
fo.Write("bs64\t", 5);
utils::Base64OutStream bout(fp);
this->SaveModel(bout);
bout.Finish('\n');
} else {
fo.Write("binf", 4);
this->SaveModel(fo);
}
if (fp != stdout) fclose(fp);
}
/*!
* \brief check if data matrix is ready to be used by training,
@ -383,6 +419,8 @@ class BoostLearner : public rabit::ISerializable {
// this is important for restart from existing iterations
// default set to no, but will auto switch on in distributed mode
int seed_per_iteration;
// save model in base64 encoding
int save_base64;
// silent during training
int silent;
// distributed learning mode, if any, 0:none, 1:col, 2:row

205
src/utils/base64.h Normal file
View File

@ -0,0 +1,205 @@
#ifndef XGBOOST_UTILS_BASE64_H_
#define XGBOOST_UTILS_BASE64_H_
/*!
* \file base64.h
* \brief data stream support to input and output from/to base64 stream
* base64 is easier to store and pass as text format in mapreduce
* \author Tianqi Chen
*/
#include <cctype>
#include <cstdio>
#include "./utils.h"
#include "./io.h"
namespace xgboost {
namespace utils {
/*! \brief namespace of base64 decoding and encoding table */
namespace base64 {
const char DecodeTable[] = {
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
62, // '+'
0, 0, 0,
63, // '/'
52, 53, 54, 55, 56, 57, 58, 59, 60, 61, // '0'-'9'
0, 0, 0, 0, 0, 0, 0,
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, // 'A'-'Z'
0, 0, 0, 0, 0, 0,
26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38,
39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, // 'a'-'z'
};
static const char EncodeTable[] =
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
} // namespace base64
/*! \brief the stream that reads from base64, note we take from file pointers */
class Base64InStream: public IStream {
public:
explicit Base64InStream(FILE *fp) : fp(fp) {
num_prev = 0; tmp_ch = 0;
}
/*!
* \brief initialize the stream position to beginning of next base64 stream
* call this function before actually start read
*/
inline void InitPosition(void) {
// get a charater
do {
tmp_ch = fgetc(fp);
} while (isspace(tmp_ch));
}
/*! \brief whether current position is end of a base64 stream */
inline bool IsEOF(void) const {
return num_prev == 0 && (tmp_ch == EOF || isspace(tmp_ch));
}
virtual size_t Read(void *ptr, size_t size) {
using base64::DecodeTable;
if (size == 0) return 0;
// use tlen to record left size
size_t tlen = size;
unsigned char *cptr = static_cast<unsigned char*>(ptr);
// if anything left, load from previous buffered result
if (num_prev != 0) {
if (num_prev == 2) {
if (tlen >= 2) {
*cptr++ = buf_prev[0];
*cptr++ = buf_prev[1];
tlen -= 2;
num_prev = 0;
} else {
// assert tlen == 1
*cptr++ = buf_prev[0]; --tlen;
buf_prev[0] = buf_prev[1];
num_prev = 1;
}
} else {
// assert num_prev == 1
*cptr++ = buf_prev[0]; --tlen; num_prev = 0;
}
}
if (tlen == 0) return size;
int nvalue;
// note: everything goes with 4 bytes in Base64
// so we process 4 bytes a unit
while (tlen && tmp_ch != EOF && !isspace(tmp_ch)) {
// first byte
nvalue = DecodeTable[tmp_ch] << 18;
{
// second byte
Check((tmp_ch = fgetc(fp), tmp_ch != EOF && !isspace(tmp_ch)),
"invalid base64 format");
nvalue |= DecodeTable[tmp_ch] << 12;
*cptr++ = (nvalue >> 16) & 0xFF; --tlen;
}
{
// third byte
Check((tmp_ch = fgetc(fp), tmp_ch != EOF && !isspace(tmp_ch)),
"invalid base64 format");
// handle termination
if (tmp_ch == '=') {
Check((tmp_ch = fgetc(fp), tmp_ch == '='), "invalid base64 format");
Check((tmp_ch = fgetc(fp), tmp_ch == EOF || isspace(tmp_ch)),
"invalid base64 format");
break;
}
nvalue |= DecodeTable[tmp_ch] << 6;
if (tlen) {
*cptr++ = (nvalue >> 8) & 0xFF; --tlen;
} else {
buf_prev[num_prev++] = (nvalue >> 8) & 0xFF;
}
}
{
// fourth byte
Check((tmp_ch = fgetc(fp), tmp_ch != EOF && !isspace(tmp_ch)),
"invalid base64 format");
if (tmp_ch == '=') {
Check((tmp_ch = fgetc(fp), tmp_ch == EOF || isspace(tmp_ch)),
"invalid base64 format");
break;
}
nvalue |= DecodeTable[tmp_ch];
if (tlen) {
*cptr++ = nvalue & 0xFF; --tlen;
} else {
buf_prev[num_prev ++] = nvalue & 0xFF;
}
}
// get next char
tmp_ch = fgetc(fp);
}
if (kStrictCheck) {
Check(tlen == 0, "Base64InStream: read incomplete");
}
return size - tlen;
}
virtual void Write(const void *ptr, size_t size) {
utils::Error("Base64InStream do not support write");
}
private:
FILE *fp;
unsigned char tmp_ch;
int num_prev;
unsigned char buf_prev[2];
// whether we need to do strict check
static const bool kStrictCheck = false;
};
/*! \brief the stream that write to base64, note we take from file pointers */
class Base64OutStream: public IStream {
public:
explicit Base64OutStream(FILE *fp) : fp(fp) {
buf_top = 0;
}
virtual void Write(const void *ptr, size_t size) {
using base64::EncodeTable;
size_t tlen = size;
const unsigned char *cptr = static_cast<const unsigned char*>(ptr);
while (tlen) {
while (buf_top < 3 && tlen != 0) {
buf[++buf_top] = *cptr++; --tlen;
}
if (buf_top == 3) {
// flush 4 bytes out
fputc(EncodeTable[buf[1] >> 2], fp);
fputc(EncodeTable[((buf[1] << 4) | (buf[2] >> 4)) & 0x3F], fp);
fputc(EncodeTable[((buf[2] << 2) | (buf[3] >> 6)) & 0x3F], fp);
fputc(EncodeTable[buf[3] & 0x3F], fp);
buf_top = 0;
}
}
}
virtual size_t Read(void *ptr, size_t size) {
Error("Base64OutStream do not support read");
return 0;
}
/*!
* \brief finish writing of all current base64 stream, do some post processing
* \param endch charater to put to end of stream, if it is EOF, then nothing will be done
*/
inline void Finish(char endch = EOF) {
using base64::EncodeTable;
if (buf_top == 1) {
fputc(EncodeTable[buf[1] >> 2], fp);
fputc(EncodeTable[(buf[1] << 4) & 0x3F], fp);
fputc('=', fp);
fputc('=', fp);
}
if (buf_top == 2) {
fputc(EncodeTable[buf[1] >> 2], fp);
fputc(EncodeTable[((buf[1] << 4) | (buf[2] >> 4)) & 0x3F], fp);
fputc(EncodeTable[(buf[2] << 2) & 0x3F], fp);
fputc('=', fp);
}
buf_top = 0;
if (endch != EOF) fputc(endch, fp);
}
private:
FILE *fp;
int buf_top;
unsigned char buf[4];
};
} // namespace utils
} // namespace xgboost
#endif // XGBOOST_UTILS_BASE64_H_

View File

@ -1,5 +1,5 @@
#ifndef XGBOOST_UTILS_THREAD_BUFFER_H
#define XGBOOST_UTILS_THREAD_BUFFER_H
#ifndef XGBOOST_UTILS_THREAD_BUFFER_H_
#define XGBOOST_UTILS_THREAD_BUFFER_H_
/*!
* \file thread_buffer.h
* \brief multi-thread buffer, iterator, can be used to create parallel pipeline

View File

@ -31,6 +31,11 @@ class BoostLearnTask {
this->SetParam(name, val);
}
}
// do not save anything when save to stdout
if (model_out == "stdout") {
this->SetParam("silent", "1");
save_period = 0;
}
// whether need data rank
bool need_data_rank = strchr(train_path.c_str(), '%') != NULL;
// if need data rank in loading, initialize rabit engine before load data
@ -41,7 +46,7 @@ class BoostLearnTask {
if (!need_data_rank) rabit::Init(argc, argv);
if (rabit::IsDistributed()) {
std::string pname = rabit::GetProcessorName();
printf("start %s:%d\n", pname.c_str(), rabit::GetRank());
fprintf(stderr, "start %s:%d\n", pname.c_str(), rabit::GetRank());
}
if (rabit::IsDistributed()) {
this->SetParam("data_split", "col");
@ -158,9 +163,7 @@ class BoostLearnTask {
}
inline void InitLearner(void) {
if (model_in != "NULL") {
utils::FileStream fi(utils::FopenCheck(model_in.c_str(), "rb"));
learner.LoadModel(fi);
fi.Close();
learner.LoadModel(model_in.c_str());
} else {
utils::Assert(task == "train", "model_in not specified");
learner.InitModel();
@ -215,9 +218,7 @@ class BoostLearnTask {
}
inline void SaveModel(const char *fname) const {
if (rabit::GetRank() != 0) return;
utils::FileStream fo(utils::FopenCheck(fname, "wb"));
learner.SaveModel(fo);
fo.Close();
learner.SaveModel(fname);
}
inline void SaveModel(int i) const {
char fname[256];