move stream to rabit part, support rabit on yarn

This commit is contained in:
tqchen
2015-03-09 14:43:46 -07:00
parent 9f7c6fe271
commit a8d5af39fd
14 changed files with 134 additions and 500 deletions

View File

@@ -1,205 +0,0 @@
#ifndef XGBOOST_UTILS_BASE64_H_
#define XGBOOST_UTILS_BASE64_H_
/*!
* \file base64.h
* \brief data stream support to input and output from/to base64 stream
* base64 is easier to store and pass as text format in mapreduce
* \author Tianqi Chen
*/
#include <cctype>
#include <cstdio>
#include "./utils.h"
#include "./io.h"
namespace xgboost {
namespace utils {
/*! \brief namespace of base64 decoding and encoding table */
namespace base64 {
const char DecodeTable[] = {
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
62, // '+'
0, 0, 0,
63, // '/'
52, 53, 54, 55, 56, 57, 58, 59, 60, 61, // '0'-'9'
0, 0, 0, 0, 0, 0, 0,
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, // 'A'-'Z'
0, 0, 0, 0, 0, 0,
26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38,
39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, // 'a'-'z'
};
static const char EncodeTable[] =
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
} // namespace base64
/*! \brief the stream that reads from base64, note we take from file pointers */
class Base64InStream: public IStream {
public:
explicit Base64InStream(FILE *fp) : fp(fp) {
num_prev = 0; tmp_ch = 0;
}
/*!
* \brief initialize the stream position to beginning of next base64 stream
* call this function before actually start read
*/
inline void InitPosition(void) {
// get a charater
do {
tmp_ch = fgetc(fp);
} while (isspace(tmp_ch));
}
/*! \brief whether current position is end of a base64 stream */
inline bool IsEOF(void) const {
return num_prev == 0 && (tmp_ch == EOF || isspace(tmp_ch));
}
virtual size_t Read(void *ptr, size_t size) {
using base64::DecodeTable;
if (size == 0) return 0;
// use tlen to record left size
size_t tlen = size;
unsigned char *cptr = static_cast<unsigned char*>(ptr);
// if anything left, load from previous buffered result
if (num_prev != 0) {
if (num_prev == 2) {
if (tlen >= 2) {
*cptr++ = buf_prev[0];
*cptr++ = buf_prev[1];
tlen -= 2;
num_prev = 0;
} else {
// assert tlen == 1
*cptr++ = buf_prev[0]; --tlen;
buf_prev[0] = buf_prev[1];
num_prev = 1;
}
} else {
// assert num_prev == 1
*cptr++ = buf_prev[0]; --tlen; num_prev = 0;
}
}
if (tlen == 0) return size;
int nvalue;
// note: everything goes with 4 bytes in Base64
// so we process 4 bytes a unit
while (tlen && tmp_ch != EOF && !isspace(tmp_ch)) {
// first byte
nvalue = DecodeTable[tmp_ch] << 18;
{
// second byte
Check((tmp_ch = fgetc(fp), tmp_ch != EOF && !isspace(tmp_ch)),
"invalid base64 format");
nvalue |= DecodeTable[tmp_ch] << 12;
*cptr++ = (nvalue >> 16) & 0xFF; --tlen;
}
{
// third byte
Check((tmp_ch = fgetc(fp), tmp_ch != EOF && !isspace(tmp_ch)),
"invalid base64 format");
// handle termination
if (tmp_ch == '=') {
Check((tmp_ch = fgetc(fp), tmp_ch == '='), "invalid base64 format");
Check((tmp_ch = fgetc(fp), tmp_ch == EOF || isspace(tmp_ch)),
"invalid base64 format");
break;
}
nvalue |= DecodeTable[tmp_ch] << 6;
if (tlen) {
*cptr++ = (nvalue >> 8) & 0xFF; --tlen;
} else {
buf_prev[num_prev++] = (nvalue >> 8) & 0xFF;
}
}
{
// fourth byte
Check((tmp_ch = fgetc(fp), tmp_ch != EOF && !isspace(tmp_ch)),
"invalid base64 format");
if (tmp_ch == '=') {
Check((tmp_ch = fgetc(fp), tmp_ch == EOF || isspace(tmp_ch)),
"invalid base64 format");
break;
}
nvalue |= DecodeTable[tmp_ch];
if (tlen) {
*cptr++ = nvalue & 0xFF; --tlen;
} else {
buf_prev[num_prev ++] = nvalue & 0xFF;
}
}
// get next char
tmp_ch = fgetc(fp);
}
if (kStrictCheck) {
Check(tlen == 0, "Base64InStream: read incomplete");
}
return size - tlen;
}
virtual void Write(const void *ptr, size_t size) {
utils::Error("Base64InStream do not support write");
}
private:
FILE *fp;
int tmp_ch;
int num_prev;
unsigned char buf_prev[2];
// whether we need to do strict check
static const bool kStrictCheck = false;
};
/*! \brief the stream that write to base64, note we take from file pointers */
class Base64OutStream: public IStream {
public:
explicit Base64OutStream(FILE *fp) : fp(fp) {
buf_top = 0;
}
virtual void Write(const void *ptr, size_t size) {
using base64::EncodeTable;
size_t tlen = size;
const unsigned char *cptr = static_cast<const unsigned char*>(ptr);
while (tlen) {
while (buf_top < 3 && tlen != 0) {
buf[++buf_top] = *cptr++; --tlen;
}
if (buf_top == 3) {
// flush 4 bytes out
fputc(EncodeTable[buf[1] >> 2], fp);
fputc(EncodeTable[((buf[1] << 4) | (buf[2] >> 4)) & 0x3F], fp);
fputc(EncodeTable[((buf[2] << 2) | (buf[3] >> 6)) & 0x3F], fp);
fputc(EncodeTable[buf[3] & 0x3F], fp);
buf_top = 0;
}
}
}
virtual size_t Read(void *ptr, size_t size) {
Error("Base64OutStream do not support read");
return 0;
}
/*!
* \brief finish writing of all current base64 stream, do some post processing
* \param endch charater to put to end of stream, if it is EOF, then nothing will be done
*/
inline void Finish(char endch = EOF) {
using base64::EncodeTable;
if (buf_top == 1) {
fputc(EncodeTable[buf[1] >> 2], fp);
fputc(EncodeTable[(buf[1] << 4) & 0x3F], fp);
fputc('=', fp);
fputc('=', fp);
}
if (buf_top == 2) {
fputc(EncodeTable[buf[1] >> 2], fp);
fputc(EncodeTable[((buf[1] << 4) | (buf[2] >> 4)) & 0x3F], fp);
fputc(EncodeTable[(buf[2] << 2) & 0x3F], fp);
fputc('=', fp);
}
buf_top = 0;
if (endch != EOF) fputc(endch, fp);
}
private:
FILE *fp;
int buf_top;
unsigned char buf[4];
};
} // namespace utils
} // namespace xgboost
#endif // XGBOOST_UTILS_BASE64_H_

View File

@@ -5,6 +5,7 @@
#include <string>
#include <cstring>
#include "./utils.h"
#include "../sync/sync.h"
/*!
* \file io.h
* \brief general stream interface for serialization, I/O
@@ -12,168 +13,13 @@
*/
namespace xgboost {
namespace utils {
/*!
* \brief interface of stream I/O, used to serialize model
*/
class IStream {
public:
/*!
* \brief read data from stream
* \param ptr pointer to memory buffer
* \param size size of block
* \return usually is the size of data readed
*/
virtual size_t Read(void *ptr, size_t size) = 0;
/*!
* \brief write data to stream
* \param ptr pointer to memory buffer
* \param size size of block
*/
virtual void Write(const void *ptr, size_t size) = 0;
/*! \brief virtual destructor */
virtual ~IStream(void) {}
public:
// helper functions to write various of data structures
/*!
* \brief binary serialize a vector
* \param vec vector to be serialized
*/
template<typename T>
inline void Write(const std::vector<T> &vec) {
uint64_t sz = static_cast<uint64_t>(vec.size());
this->Write(&sz, sizeof(sz));
if (sz != 0) {
this->Write(&vec[0], sizeof(T) * sz);
}
}
/*!
* \brief binary load a vector
* \param out_vec vector to be loaded
* \return whether load is successfull
*/
template<typename T>
inline bool Read(std::vector<T> *out_vec) {
uint64_t sz;
if (this->Read(&sz, sizeof(sz)) == 0) return false;
out_vec->resize(sz);
if (sz != 0) {
if (this->Read(&(*out_vec)[0], sizeof(T) * sz) == 0) return false;
}
return true;
}
/*!
* \brief binary serialize a string
* \param str the string to be serialized
*/
inline void Write(const std::string &str) {
uint64_t sz = static_cast<uint64_t>(str.length());
this->Write(&sz, sizeof(sz));
if (sz != 0) {
this->Write(&str[0], sizeof(char) * sz);
}
}
/*!
* \brief binary load a string
* \param out_str string to be loaded
* \return whether load is successful
*/
inline bool Read(std::string *out_str) {
uint64_t sz;
if (this->Read(&sz, sizeof(sz)) == 0) return false;
out_str->resize(sz);
if (sz != 0) {
if (this->Read(&(*out_str)[0], sizeof(char) * sz) == 0) return false;
}
return true;
}
};
/*! \brief interface of i/o stream that support seek */
class ISeekStream: public IStream {
public:
/*! \brief seek to certain position of the file */
virtual void Seek(size_t pos) = 0;
/*! \brief tell the position of the stream */
virtual size_t Tell(void) = 0;
};
/*! \brief fixed size memory buffer */
struct MemoryFixSizeBuffer : public ISeekStream {
public:
MemoryFixSizeBuffer(void *p_buffer, size_t buffer_size)
: p_buffer_(reinterpret_cast<char*>(p_buffer)), buffer_size_(buffer_size) {
curr_ptr_ = 0;
}
virtual ~MemoryFixSizeBuffer(void) {}
virtual size_t Read(void *ptr, size_t size) {
utils::Assert(curr_ptr_ + size <= buffer_size_,
"read can not have position excceed buffer length");
size_t nread = std::min(buffer_size_ - curr_ptr_, size);
if (nread != 0) std::memcpy(ptr, p_buffer_ + curr_ptr_, nread);
curr_ptr_ += nread;
return nread;
}
virtual void Write(const void *ptr, size_t size) {
if (size == 0) return;
utils::Assert(curr_ptr_ + size <= buffer_size_,
"write position exceed fixed buffer size");
std::memcpy(p_buffer_ + curr_ptr_, ptr, size);
curr_ptr_ += size;
}
virtual void Seek(size_t pos) {
curr_ptr_ = static_cast<size_t>(pos);
}
virtual size_t Tell(void) {
return curr_ptr_;
}
private:
/*! \brief in memory buffer */
char *p_buffer_;
/*! \brief current pointer */
size_t buffer_size_;
/*! \brief current pointer */
size_t curr_ptr_;
}; // class MemoryFixSizeBuffer
/*! \brief a in memory buffer that can be read and write as stream interface */
struct MemoryBufferStream : public ISeekStream {
public:
MemoryBufferStream(std::string *p_buffer)
: p_buffer_(p_buffer) {
curr_ptr_ = 0;
}
virtual ~MemoryBufferStream(void) {}
virtual size_t Read(void *ptr, size_t size) {
utils::Assert(curr_ptr_ <= p_buffer_->length(),
"read can not have position excceed buffer length");
size_t nread = std::min(p_buffer_->length() - curr_ptr_, size);
if (nread != 0) std::memcpy(ptr, &(*p_buffer_)[0] + curr_ptr_, nread);
curr_ptr_ += nread;
return nread;
}
virtual void Write(const void *ptr, size_t size) {
if (size == 0) return;
if (curr_ptr_ + size > p_buffer_->length()) {
p_buffer_->resize(curr_ptr_+size);
}
std::memcpy(&(*p_buffer_)[0] + curr_ptr_, ptr, size);
curr_ptr_ += size;
}
virtual void Seek(size_t pos) {
curr_ptr_ = static_cast<size_t>(pos);
}
virtual size_t Tell(void) {
return curr_ptr_;
}
private:
/*! \brief in memory buffer */
std::string *p_buffer_;
/*! \brief current pointer */
size_t curr_ptr_;
}; // class MemoryBufferStream
// reuse the definitions of streams
typedef rabit::IStream IStream;
typedef rabit::utils::ISeekStream ISeekStream;
typedef rabit::utils::MemoryFixSizeBuffer MemoryFixSizeBuffer;
typedef rabit::utils::MemoryBufferStream MemoryBufferStream;
typedef rabit::io::Base64InStream Base64InStream;
typedef rabit::io::Base64OutStream Base64OutStream;
/*! \brief implementation of file i/o stream */
class FileStream : public ISeekStream {
@@ -194,6 +40,9 @@ class FileStream : public ISeekStream {
virtual size_t Tell(void) {
return std::ftell(fp);
}
virtual bool AtEnd(void) const {
return std::feof(fp);
}
inline void Close(void) {
if (fp != NULL){
std::fclose(fp); fp = NULL;