#ifndef XGBOOST_IO_PAGE_ROW_ITER_INL_HPP_ #define XGBOOST_IO_PAGE_ROW_ITER_INL_HPP_ /*! * \file page_row_iter-inl.hpp * row iterator based on sparse page * \author Tianqi Chen */ #include #include "../data.h" #include "../utils/iterator.h" #include "../utils/thread_buffer.h" #include "./simple_fmatrix-inl.hpp" namespace xgboost { namespace io { /*! \brief page structure that can be used to store a rowbatch */ struct RowBatchPage { public: explicit RowBatchPage(size_t page_size) : kPageSize(page_size) { data_ = new int[kPageSize]; utils::Assert(data_ != NULL, "fail to allocate row batch page"); this->Clear(); } ~RowBatchPage(void) { if (data_ != NULL) delete [] data_; } /*! * \brief Push one row into page * \param row an instance row * \return false or true to push into */ inline bool PushRow(const RowBatch::Inst &row) { const size_t dsize = row.length * sizeof(RowBatch::Entry); if (FreeBytes() < dsize+ sizeof(int)) return false; row_ptr(Size() + 1) = row_ptr(Size()) + row.length; memcpy(data_ptr(row_ptr(Size())) , row.data, dsize); ++data_[0]; return true; } /*! * \brief get a row batch representation from the page * \param p_rptr a temporal space that can be used to provide * ind_ptr storage for RowBatch * \return a new RowBatch object */ inline RowBatch GetRowBatch(std::vector *p_rptr, size_t base_rowid) { RowBatch batch; batch.base_rowid = base_rowid; batch.data_ptr = this->data_ptr(0); batch.size = static_cast(this->Size()); std::vector &rptr = *p_rptr; rptr.resize(this->Size() + 1); for (size_t i = 0; i < rptr.size(); ++i) { rptr[i] = static_cast(this->row_ptr(static_cast(i))); } batch.ind_ptr = &rptr[0]; return batch; } /*! \brief get i-th row from the batch */ inline RowBatch::Inst operator[](int i) { return RowBatch::Inst(data_ptr(0) + row_ptr(i), static_cast(row_ptr(i+1) - row_ptr(i))); } /*! * \brief clear the page, cleanup the content */ inline void Clear(void) { memset(&data_[0], 0, sizeof(int) * kPageSize); } /*! * \brief load one page form instream * \return true if loading is successful */ inline bool Load(utils::IStream &fi) { return fi.Read(&data_[0], sizeof(int) * kPageSize) != 0; } /*! \brief save one page into outstream */ inline void Save(utils::IStream &fo) { fo.Write(&data_[0], sizeof(int) * kPageSize); } /*! \return number of elements */ inline int Size(void) const { return data_[0]; } protected: /*! \return number of elements */ inline size_t FreeBytes(void) { return (kPageSize - (Size() + 2)) * sizeof(int) - row_ptr(Size()) * sizeof(RowBatch::Entry); } /*! \brief equivalent row pointer at i */ inline int& row_ptr(int i) { return data_[kPageSize - i - 1]; } inline RowBatch::Entry* data_ptr(int i) { return (RowBatch::Entry*)(&data_[1]) + i; } // content of data int *data_; // page size const size_t kPageSize; }; /*! \brief thread buffer iterator */ class ThreadRowPageIterator: public utils::IIterator { public: ThreadRowPageIterator(void) { itr.SetParam("buffer_size", "2"); page_ = NULL; base_rowid_ = 0; } virtual ~ThreadRowPageIterator(void) {} virtual void Init(void) { } virtual void BeforeFirst(void) { itr.BeforeFirst(); base_rowid_ = 0; } virtual bool Next(void) { if (!itr.Next(page_)) return false; out_ = page_->GetRowBatch(&tmp_ptr_, base_rowid_); base_rowid_ += out_.size; return true; } virtual const RowBatch &Value(void) const { return out_; } /*! \brief load and initialize the iterator with fi */ inline void Load(const utils::FileStream &fi) { itr.get_factory().SetFile(fi); itr.Init(); this->BeforeFirst(); } /*! * \brief save a row iterator to output stream, in row iterator format */ inline static void Save(utils::IIterator *iter, utils::IStream &fo) { RowBatchPage page(kPageSize); iter->BeforeFirst(); while (iter->Next()) { const RowBatch &batch = iter->Value(); for (size_t i = 0; i < batch.size; ++i) { if (!page.PushRow(batch[i])) { page.Save(fo); page.Clear(); utils::Check(page.PushRow(batch[i]), "row is too big"); } } } if (page.Size() != 0) page.Save(fo); } /*! \brief page size 64 MB */ static const size_t kPageSize = 64 << 18; private: // base row id size_t base_rowid_; // temporal ptr std::vector tmp_ptr_; // output data RowBatch out_; // page pointer type typedef RowBatchPage* PagePtr; // loader factory for page struct Factory { public: long file_begin_; utils::FileStream fi; Factory(void) {} inline void SetFile(const utils::FileStream &fi) { this->fi = fi; file_begin_ = this->fi.Tell(); } inline bool Init(void) { return true; } inline void SetParam(const char *name, const char *val) {} inline bool LoadNext(PagePtr &val) { return val->Load(fi); } inline PagePtr Create(void) { PagePtr a = new RowBatchPage(kPageSize); return a; } inline void FreeSpace(PagePtr &a) { delete a; } inline void Destroy(void) { fi.Close(); } inline void BeforeFirst(void) { fi.Seek(file_begin_); } }; protected: PagePtr page_; utils::ThreadBuffer itr; }; /*! \brief data matrix using page */ template class DMatrixPageBase : public DataMatrix { public: DMatrixPageBase(void) : DataMatrix(kMagic) { iter_ = new ThreadRowPageIterator(); } // virtual destructor virtual ~DMatrixPageBase(void) { // do not delete row iterator, since it is owned by fmat // to be cleaned up in a more clear way } /*! \brief load and initialize the iterator with fi */ inline void Load(utils::FileStream &fi, bool silent = false, const char *fname = NULL, bool skip_magic_check = false) { int tmagic; utils::Check(fi.Read(&tmagic, sizeof(tmagic)) != 0, "invalid input file format"); if (!skip_magic_check) { utils::Check(tmagic == magic, "invalid format,magic number mismatch"); } this->info.LoadBinary(fi); iter_->Load(fi); if (!silent) { utils::Printf("DMatrixPage: %lux%lu matrix is loaded", static_cast(info.num_row()), static_cast(info.num_col())); if (fname != NULL) { utils::Printf(" from %s\n", fname); } else { utils::Printf("\n"); } if (info.group_ptr.size() != 0) { utils::Printf("data contains %u groups\n", (unsigned)info.group_ptr.size() - 1); } } } /*! \brief save a DataMatrix as DMatrixPage*/ inline static void Save(const char* fname, const DataMatrix &mat, bool silent) { utils::FileStream fs(utils::FopenCheck(fname, "wb")); int magic = kMagic; fs.Write(&magic, sizeof(magic)); mat.info.SaveBinary(fs); ThreadRowPageIterator::Save(mat.fmat()->RowIterator(), fs); fs.Close(); if (!silent) { utils::Printf("DMatrixPage: %lux%lu is saved to %s\n", static_cast(mat.info.num_row()), static_cast(mat.info.num_col()), fname); } } /*! \brief magic number used to identify DMatrix */ static const int kMagic = TKMagic; protected: /*! \brief row iterator */ ThreadRowPageIterator *iter_; }; class DMatrixPage : public DMatrixPageBase<0xffffab02> { public: DMatrixPage(void) { fmat_ = new FMatrixS(iter_); } virtual ~DMatrixPage(void) { delete fmat_; } virtual IFMatrix *fmat(void) const { return fmat_; } /*! \brief the real fmatrix */ IFMatrix *fmat_; }; } // namespace io } // namespace xgboost #endif // XGBOOST_IO_PAGE_ROW_ITER_INL_HPP_