#ifndef XGBOOST_UTILS_IO_H #define XGBOOST_UTILS_IO_H #include #include #include #include "./utils.h" /*! * \file io.h * \brief general stream interface for serialization, I/O * \author Tianqi Chen */ namespace xgboost { namespace utils { /*! * \brief interface of stream I/O, used to serialize model */ class IStream { public: /*! * \brief read data from stream * \param ptr pointer to memory buffer * \param size size of block * \return usually is the size of data readed */ virtual size_t Read(void *ptr, size_t size) = 0; /*! * \brief write data to stream * \param ptr pointer to memory buffer * \param size size of block */ virtual void Write(const void *ptr, size_t size) = 0; /*! \brief virtual destructor */ virtual ~IStream(void) {} public: // helper functions to write various of data structures /*! * \brief binary serialize a vector * \param vec vector to be serialized */ template inline void Write(const std::vector &vec) { uint64_t sz = static_cast(vec.size()); this->Write(&sz, sizeof(sz)); if (sz != 0) { this->Write(&vec[0], sizeof(T) * sz); } } /*! * \brief binary load a vector * \param out_vec vector to be loaded * \return whether load is successfull */ template inline bool Read(std::vector *out_vec) { uint64_t sz; if (this->Read(&sz, sizeof(sz)) == 0) return false; out_vec->resize(sz); if (sz != 0) { if (this->Read(&(*out_vec)[0], sizeof(T) * sz) == 0) return false; } return true; } /*! * \brief binary serialize a string * \param str the string to be serialized */ inline void Write(const std::string &str) { uint64_t sz = static_cast(str.length()); this->Write(&sz, sizeof(sz)); if (sz != 0) { this->Write(&str[0], sizeof(char) * sz); } } /*! * \brief binary load a string * \param out_str string to be loaded * \return whether load is successful */ inline bool Read(std::string *out_str) { uint64_t sz; if (this->Read(&sz, sizeof(sz)) == 0) return false; out_str->resize(sz); if (sz != 0) { if (this->Read(&(*out_str)[0], sizeof(char) * sz) == 0) return false; } return true; } }; /*! \brief interface of i/o stream that support seek */ class ISeekStream: public IStream { public: /*! \brief seek to certain position of the file */ virtual void Seek(size_t pos) = 0; /*! \brief tell the position of the stream */ virtual size_t Tell(void) = 0; }; /*! \brief fixed size memory buffer */ struct MemoryFixSizeBuffer : public ISeekStream { public: MemoryFixSizeBuffer(void *p_buffer, size_t buffer_size) : p_buffer_(reinterpret_cast(p_buffer)), buffer_size_(buffer_size) { curr_ptr_ = 0; } virtual ~MemoryFixSizeBuffer(void) {} virtual size_t Read(void *ptr, size_t size) { utils::Assert(curr_ptr_ <= buffer_size_, "read can not have position excceed buffer length"); size_t nread = std::min(buffer_size_ - curr_ptr_, size); if (nread != 0) memcpy(ptr, p_buffer_ + curr_ptr_, nread); curr_ptr_ += nread; return nread; } virtual void Write(const void *ptr, size_t size) { if (size == 0) return; utils::Assert(curr_ptr_ + size <= buffer_size_, "write position exceed fixed buffer size"); memcpy(p_buffer_ + curr_ptr_, ptr, size); curr_ptr_ += size; } virtual void Seek(size_t pos) { curr_ptr_ = static_cast(pos); } virtual size_t Tell(void) { return curr_ptr_; } private: /*! \brief in memory buffer */ char *p_buffer_; /*! \brief current pointer */ size_t buffer_size_; /*! \brief current pointer */ size_t curr_ptr_; }; // class MemoryFixSizeBuffer /*! \brief a in memory buffer that can be read and write as stream interface */ struct MemoryBufferStream : public ISeekStream { public: MemoryBufferStream(std::string *p_buffer) : p_buffer_(p_buffer) { curr_ptr_ = 0; } virtual ~MemoryBufferStream(void) {} virtual size_t Read(void *ptr, size_t size) { utils::Assert(curr_ptr_ <= p_buffer_->length(), "read can not have position excceed buffer length"); size_t nread = std::min(p_buffer_->length() - curr_ptr_, size); if (nread != 0) memcpy(ptr, &(*p_buffer_)[0] + curr_ptr_, nread); curr_ptr_ += nread; return nread; } virtual void Write(const void *ptr, size_t size) { if (size == 0) return; if (curr_ptr_ + size > p_buffer_->length()) { p_buffer_->resize(curr_ptr_+size); } memcpy(&(*p_buffer_)[0] + curr_ptr_, ptr, size); curr_ptr_ += size; } virtual void Seek(size_t pos) { curr_ptr_ = static_cast(pos); } virtual size_t Tell(void) { return curr_ptr_; } private: /*! \brief in memory buffer */ std::string *p_buffer_; /*! \brief current pointer */ size_t curr_ptr_; }; // class MemoryBufferStream /*! \brief implementation of file i/o stream */ class FileStream : public ISeekStream { public: explicit FileStream(FILE *fp) : fp(fp) {} explicit FileStream(void) { this->fp = NULL; } virtual size_t Read(void *ptr, size_t size) { return std::fread(ptr, size, 1, fp); } virtual void Write(const void *ptr, size_t size) { std::fwrite(ptr, size, 1, fp); } virtual void Seek(size_t pos) { std::fseek(fp, pos, SEEK_SET); } virtual size_t Tell(void) { return std::ftell(fp); } inline void Close(void) { if (fp != NULL){ std::fclose(fp); fp = NULL; } } private: FILE *fp; }; } // namespace utils } // namespace xgboost #endif