diff --git a/src/io/io.cpp b/src/io/io.cpp index f56cff679..e8b9ce337 100644 --- a/src/io/io.cpp +++ b/src/io/io.cpp @@ -22,7 +22,13 @@ DataMatrix* LoadDataMatrix(const char *fname, bool silent, bool savebuffer) { dmat->LoadBinary(fs, silent, fname); fs.Close(); return dmat; - } + } + if (magic == DMatrixPage::kMagic) { + DMatrixPage *dmat = new DMatrixPage(); + dmat->Load(fs, silent, fname); + // the file pointer is hold in page matrix + return dmat; + } fs.Close(); DMatrixSimple *dmat = new DMatrixSimple(); @@ -31,6 +37,11 @@ DataMatrix* LoadDataMatrix(const char *fname, bool silent, bool savebuffer) { } void SaveDataMatrix(const DataMatrix &dmat, const char *fname, bool silent) { + if (!strcmp(fname + strlen(fname) - 5, ".page")) { + + DMatrixPage::Save(fname, dmat, silent); + return; + } if (dmat.magic == DMatrixSimple::kMagic) { const DMatrixSimple *p_dmat = static_cast(&dmat); p_dmat->SaveBinary(fname, silent); diff --git a/src/io/page_dmatrix-inl.hpp b/src/io/page_dmatrix-inl.hpp index 82a373352..4d13a0bc5 100644 --- a/src/io/page_dmatrix-inl.hpp +++ b/src/io/page_dmatrix-inl.hpp @@ -8,6 +8,8 @@ #include "../data.h" #include "../utils/iterator.h" #include "../utils/thread_buffer.h" +#include "./simple_fmatrix-inl.hpp" + namespace xgboost { namespace io { /*! \brief page structure that can be used to store a rowbatch */ @@ -102,7 +104,7 @@ class ThreadRowPageIterator: public utils::IIterator { base_rowid_ = 0; isend_ = false; } - virtual ~ThreadRowPageIterator(void) { + virtual ~ThreadRowPageIterator(void) { } virtual void Init(void) { } @@ -188,7 +190,9 @@ class ThreadRowPageIterator: public utils::IIterator { inline void FreeSpace(PagePtr &a) { delete a; } - inline void Destroy(void) {} + inline void Destroy(void) { + fi.Close(); + } inline void BeforeFirst(void) { fi.Seek(file_begin_); } @@ -199,6 +203,63 @@ class ThreadRowPageIterator: public utils::IIterator { int ptop_; utils::ThreadBuffer itr; }; + +/*! \brief data matrix using page */ +class DMatrixPage : public DataMatrix { + public: + DMatrixPage(void) : DataMatrix(kMagic) { + iter_ = new ThreadRowPageIterator(); + fmat_ = new FMatrixS(iter_); + } + // virtual destructor + virtual ~DMatrixPage(void) { + delete fmat_; + } + virtual IFMatrix *fmat(void) const { + return fmat_; + } + /*! \brief load and initialize the iterator with fi */ + inline void Load(utils::FileStream &fi, + bool silent = false, + const char *fname = NULL){ + int magic; + utils::Check(fi.Read(&magic, sizeof(magic)) != 0, "invalid input file format"); + utils::Check(magic == kMagic, "invalid format,magic number mismatch"); + this->info.LoadBinary(fi); + iter_->Load(fi); + if (!silent) { + printf("DMatrixPage: %lux%lu matrix is loaded", + info.num_row(), info.num_col()); + if (fname != NULL) { + printf(" from %s\n", fname); + } else { + printf("\n"); + } + if (info.group_ptr.size() != 0) { + printf("data contains %u groups\n", (unsigned)info.group_ptr.size()-1); + } + } + } + /*! \brief save a DataMatrix as DMatrixPage*/ + inline static void Save(const char* fname, const DataMatrix &mat, bool silent) { + utils::FileStream fs(utils::FopenCheck(fname, "wb")); + int magic = kMagic; + fs.Write(&magic, sizeof(magic)); + mat.info.SaveBinary(fs); + ThreadRowPageIterator::Save(mat.fmat()->RowIterator(), fs); + fs.Close(); + if (!silent) { + printf("DMatrixPage: %lux%lu is saved to %s\n", + mat.info.num_row(), mat.info.num_col(), fname); + } + } + /*! \brief the real fmatrix */ + FMatrixS *fmat_; + /*! \brief row iterator */ + ThreadRowPageIterator *iter_; + /*! \brief magic number used to identify DMatrix */ + static const int kMagic = 0xffffab02; +}; } // namespace io } // namespace xgboost #endif // XGBOOST_IO_PAGE_ROW_ITER_INL_HPP_ diff --git a/src/utils/io.h b/src/utils/io.h index 23fa0d468..dbfcee3f6 100644 --- a/src/utils/io.h +++ b/src/utils/io.h @@ -100,7 +100,9 @@ class ISeekStream: public IStream { /*! \brief implementation of file i/o stream */ class FileStream : public ISeekStream { public: - explicit FileStream(void) {} + explicit FileStream(void) { + this->fp = NULL; + } explicit FileStream(FILE *fp) { this->fp = fp; } @@ -117,7 +119,9 @@ class FileStream : public ISeekStream { return static_cast(ftell(fp)); } inline void Close(void) { - fclose(fp); + if (fp != NULL){ + fclose(fp); fp = NULL; + } } private: