208 lines
5.6 KiB
C++
208 lines
5.6 KiB
C++
#ifndef XGBOOST_UTILS_IO_H
|
|
#define XGBOOST_UTILS_IO_H
|
|
#include <cstdio>
|
|
#include <vector>
|
|
#include <string>
|
|
#include "./utils.h"
|
|
/*!
|
|
* \file io.h
|
|
* \brief general stream interface for serialization, I/O
|
|
* \author Tianqi Chen
|
|
*/
|
|
namespace xgboost {
|
|
namespace utils {
|
|
/*!
|
|
* \brief interface of stream I/O, used to serialize model
|
|
*/
|
|
class IStream {
|
|
public:
|
|
/*!
|
|
* \brief read data from stream
|
|
* \param ptr pointer to memory buffer
|
|
* \param size size of block
|
|
* \return usually is the size of data readed
|
|
*/
|
|
virtual size_t Read(void *ptr, size_t size) = 0;
|
|
/*!
|
|
* \brief write data to stream
|
|
* \param ptr pointer to memory buffer
|
|
* \param size size of block
|
|
*/
|
|
virtual void Write(const void *ptr, size_t size) = 0;
|
|
/*! \brief virtual destructor */
|
|
virtual ~IStream(void) {}
|
|
|
|
public:
|
|
// helper functions to write various of data structures
|
|
/*!
|
|
* \brief binary serialize a vector
|
|
* \param vec vector to be serialized
|
|
*/
|
|
template<typename T>
|
|
inline void Write(const std::vector<T> &vec) {
|
|
uint64_t sz = static_cast<uint64_t>(vec.size());
|
|
this->Write(&sz, sizeof(sz));
|
|
if (sz != 0) {
|
|
this->Write(&vec[0], sizeof(T) * sz);
|
|
}
|
|
}
|
|
/*!
|
|
* \brief binary load a vector
|
|
* \param out_vec vector to be loaded
|
|
* \return whether load is successfull
|
|
*/
|
|
template<typename T>
|
|
inline bool Read(std::vector<T> *out_vec) {
|
|
uint64_t sz;
|
|
if (this->Read(&sz, sizeof(sz)) == 0) return false;
|
|
out_vec->resize(sz);
|
|
if (sz != 0) {
|
|
if (this->Read(&(*out_vec)[0], sizeof(T) * sz) == 0) return false;
|
|
}
|
|
return true;
|
|
}
|
|
/*!
|
|
* \brief binary serialize a string
|
|
* \param str the string to be serialized
|
|
*/
|
|
inline void Write(const std::string &str) {
|
|
uint64_t sz = static_cast<uint64_t>(str.length());
|
|
this->Write(&sz, sizeof(sz));
|
|
if (sz != 0) {
|
|
this->Write(&str[0], sizeof(char) * sz);
|
|
}
|
|
}
|
|
/*!
|
|
* \brief binary load a string
|
|
* \param out_str string to be loaded
|
|
* \return whether load is successful
|
|
*/
|
|
inline bool Read(std::string *out_str) {
|
|
uint64_t sz;
|
|
if (this->Read(&sz, sizeof(sz)) == 0) return false;
|
|
out_str->resize(sz);
|
|
if (sz != 0) {
|
|
if (this->Read(&(*out_str)[0], sizeof(char) * sz) == 0) return false;
|
|
}
|
|
return true;
|
|
}
|
|
};
|
|
|
|
/*! \brief interface of i/o stream that support seek */
|
|
class ISeekStream: public IStream {
|
|
public:
|
|
/*! \brief seek to certain position of the file */
|
|
virtual void Seek(size_t pos) = 0;
|
|
/*! \brief tell the position of the stream */
|
|
virtual size_t Tell(void) = 0;
|
|
};
|
|
|
|
/*! \brief fixed size memory buffer */
|
|
struct MemoryFixSizeBuffer : public ISeekStream {
|
|
public:
|
|
MemoryFixSizeBuffer(void *p_buffer, size_t buffer_size)
|
|
: p_buffer_(reinterpret_cast<char*>(p_buffer)), buffer_size_(buffer_size) {
|
|
curr_ptr_ = 0;
|
|
}
|
|
virtual ~MemoryFixSizeBuffer(void) {}
|
|
virtual size_t Read(void *ptr, size_t size) {
|
|
utils::Assert(curr_ptr_ <= buffer_size_,
|
|
"read can not have position excceed buffer length");
|
|
size_t nread = std::min(buffer_size_ - curr_ptr_, size);
|
|
if (nread != 0) memcpy(ptr, p_buffer_ + curr_ptr_, nread);
|
|
curr_ptr_ += nread;
|
|
return nread;
|
|
}
|
|
virtual void Write(const void *ptr, size_t size) {
|
|
if (size == 0) return;
|
|
utils::Assert(curr_ptr_ + size <= buffer_size_,
|
|
"write position exceed fixed buffer size");
|
|
memcpy(p_buffer_ + curr_ptr_, ptr, size);
|
|
curr_ptr_ += size;
|
|
}
|
|
virtual void Seek(size_t pos) {
|
|
curr_ptr_ = static_cast<size_t>(pos);
|
|
}
|
|
virtual size_t Tell(void) {
|
|
return curr_ptr_;
|
|
}
|
|
|
|
private:
|
|
/*! \brief in memory buffer */
|
|
char *p_buffer_;
|
|
/*! \brief current pointer */
|
|
size_t buffer_size_;
|
|
/*! \brief current pointer */
|
|
size_t curr_ptr_;
|
|
}; // class MemoryFixSizeBuffer
|
|
|
|
/*! \brief a in memory buffer that can be read and write as stream interface */
|
|
struct MemoryBufferStream : public ISeekStream {
|
|
public:
|
|
MemoryBufferStream(std::string *p_buffer)
|
|
: p_buffer_(p_buffer) {
|
|
curr_ptr_ = 0;
|
|
}
|
|
virtual ~MemoryBufferStream(void) {}
|
|
virtual size_t Read(void *ptr, size_t size) {
|
|
utils::Assert(curr_ptr_ <= p_buffer_->length(),
|
|
"read can not have position excceed buffer length");
|
|
size_t nread = std::min(p_buffer_->length() - curr_ptr_, size);
|
|
if (nread != 0) memcpy(ptr, &(*p_buffer_)[0] + curr_ptr_, nread);
|
|
curr_ptr_ += nread;
|
|
return nread;
|
|
}
|
|
virtual void Write(const void *ptr, size_t size) {
|
|
if (size == 0) return;
|
|
if (curr_ptr_ + size > p_buffer_->length()) {
|
|
p_buffer_->resize(curr_ptr_+size);
|
|
}
|
|
memcpy(&(*p_buffer_)[0] + curr_ptr_, ptr, size);
|
|
curr_ptr_ += size;
|
|
}
|
|
virtual void Seek(size_t pos) {
|
|
curr_ptr_ = static_cast<size_t>(pos);
|
|
}
|
|
virtual size_t Tell(void) {
|
|
return curr_ptr_;
|
|
}
|
|
|
|
private:
|
|
/*! \brief in memory buffer */
|
|
std::string *p_buffer_;
|
|
/*! \brief current pointer */
|
|
size_t curr_ptr_;
|
|
}; // class MemoryBufferStream
|
|
|
|
/*! \brief implementation of file i/o stream */
|
|
class FileStream : public ISeekStream {
|
|
public:
|
|
explicit FileStream(FILE *fp) : fp(fp) {}
|
|
explicit FileStream(void) {
|
|
this->fp = NULL;
|
|
}
|
|
virtual size_t Read(void *ptr, size_t size) {
|
|
return std::fread(ptr, size, 1, fp);
|
|
}
|
|
virtual void Write(const void *ptr, size_t size) {
|
|
std::fwrite(ptr, size, 1, fp);
|
|
}
|
|
virtual void Seek(size_t pos) {
|
|
std::fseek(fp, pos, SEEK_SET);
|
|
}
|
|
virtual size_t Tell(void) {
|
|
return std::ftell(fp);
|
|
}
|
|
inline void Close(void) {
|
|
if (fp != NULL){
|
|
std::fclose(fp); fp = NULL;
|
|
}
|
|
}
|
|
|
|
private:
|
|
FILE *fp;
|
|
};
|
|
} // namespace utils
|
|
} // namespace xgboost
|
|
#endif
|