From 4b9aeea89c4c6fba24a8e0d487df65babb67392f Mon Sep 17 00:00:00 2001 From: tqchen Date: Tue, 2 Sep 2014 13:14:54 -0700 Subject: [PATCH] finish the fmatrix --- src/io/page_dmatrix-inl.hpp | 18 +++--- src/io/page_fmatrix-inl.hpp | 108 +++++++++++++++++++++++------------- src/utils/io.h | 1 - src/utils/matrix_csr.h | 34 ++++++++++-- 4 files changed, 106 insertions(+), 55 deletions(-) diff --git a/src/io/page_dmatrix-inl.hpp b/src/io/page_dmatrix-inl.hpp index 76767d942..83c745599 100644 --- a/src/io/page_dmatrix-inl.hpp +++ b/src/io/page_dmatrix-inl.hpp @@ -223,16 +223,16 @@ class DMatrixPage : public DataMatrix { this->info.LoadBinary(fi); iter_->Load(fi); if (!silent) { - printf("DMatrixPage: %lux%lu matrix is loaded", - static_cast(info.num_row()), - static_cast(info.num_col())); + utils::Printf("DMatrixPage: %lux%lu matrix is loaded", + static_cast(info.num_row()), + static_cast(info.num_col())); if (fname != NULL) { - printf(" from %s\n", fname); + utils::Printf(" from %s\n", fname); } else { - printf("\n"); + utils::Printf("\n"); } if (info.group_ptr.size() != 0) { - printf("data contains %u groups\n", (unsigned)info.group_ptr.size()-1); + utils::Printf("data contains %u groups\n", (unsigned)info.group_ptr.size()-1); } } } @@ -245,9 +245,9 @@ class DMatrixPage : public DataMatrix { ThreadRowPageIterator::Save(mat.fmat()->RowIterator(), fs); fs.Close(); if (!silent) { - printf("DMatrixPage: %lux%lu is saved to %s\n", - static_cast(mat.info.num_row()), - static_cast(mat.info.num_col()), fname); + utils::Printf("DMatrixPage: %lux%lu is saved to %s\n", + static_cast(mat.info.num_row()), + static_cast(mat.info.num_col()), fname); } } /*! \brief the real fmatrix */ diff --git a/src/io/page_fmatrix-inl.hpp b/src/io/page_fmatrix-inl.hpp index b2ce76faf..7e9903be4 100644 --- a/src/io/page_fmatrix-inl.hpp +++ b/src/io/page_fmatrix-inl.hpp @@ -7,10 +7,11 @@ */ #include "../data.h" #include "../utils/iterator.h" +#include "../utils/io.h" +#include "../utils/matrix_csr.h" #include "../utils/thread_buffer.h" namespace xgboost { namespace io { - class CSCMatrixManager { public: /*! \brief in memory page */ @@ -56,6 +57,10 @@ class CSCMatrixManager { }; /*! \brief define type of page pointer */ typedef Page *PagePtr; + // constructor + CSCMatrixManager(void) { + fi_ = NULL; + } /*! \brief get column pointer */ inline const std::vector &col_ptr(void) const { return col_ptr_; @@ -89,7 +94,8 @@ class CSCMatrixManager { } inline void Setup(utils::ISeekStream *fi, double page_ratio) { fi_ = fi; - fi_->Read(&begin_meta_ , sizeof(size_t)); + fi_->Read(&begin_meta_ , sizeof(begin_meta_)); + begin_data_ = static_cast(fi->Tell()); fi_->Seek(begin_meta_); fi_->Read(&col_ptr_); size_t psmax = 0; @@ -121,7 +127,7 @@ class CSCMatrixManager { size_t len = col_ptr_[cidx+1] - col_ptr_[cidx]; if (p_page->NumFreeEntry() < len) return false; ColBatch::Entry *p_data = p_page->AllocEntry(len); - fi_->Seek(col_ptr_[cidx] * sizeof(ColBatch::Entry) + sizeof(size_t)); + fi_->Seek(col_ptr_[cidx] * sizeof(ColBatch::Entry) + begin_data_); utils::Check(fi_->Read(p_data, sizeof(ColBatch::Entry) * len) != 0, "invalid column buffer format"); p_page->col_data.push_back(ColBatch::Inst(p_data, len)); @@ -137,6 +143,8 @@ class CSCMatrixManager { /*! \brief column index to be after calling before first */ std::vector col_todo_; // the following are input content + /*! \brief beginning position of data content */ + size_t begin_data_; /*! \brief size of data content */ size_t begin_meta_; /*! \brief input stream */ @@ -147,36 +155,25 @@ class CSCMatrixManager { class ThreadColPageIterator : public utils::IIterator { public: - ThreadColPageIterator(void) { + explicit ThreadColPageIterator(utils::ISeekStream *fi, + float page_ratio, bool silent) { itr_.SetParam("buffer_size", "2"); - page_ = NULL; - fi_ = NULL; - silent = 0; + itr_.get_factory().Setup(fi, page_ratio); + if (!silent) { + utils::Printf("ThreadColPageIterator: finish initialzing, %u columns\n", + static_cast(col_ptr().size() - 1)); + } } virtual ~ThreadColPageIterator(void) { - if (fi_ != NULL) { - fi_->Close(); delete fi_; - } - } - virtual void Init(void) { - fi_ = new utils::FileStream(utils::FopenCheck(col_pagefile_.c_str(), "rb")); - itr_.get_factory().Setup(fi_, col_pageratio_); - if (silent == 0) { - printf("ThreadColPageIterator: finish initialzing from %s, %u columns\n", - col_pagefile_.c_str(), static_cast(col_ptr().size() - 1)); - } - } - virtual void SetParam(const char *name, const char *val) { - if (!strcmp("col_pageratio", val)) col_pageratio_ = atof(val); - if (!strcmp("col_pagefile", val)) col_pagefile_ = val; - if (!strcmp("silent", val)) silent = atoi(val); } virtual void BeforeFirst(void) { itr_.BeforeFirst(); } virtual bool Next(void) { - if(!itr_.Next(page_)) return false; - out_ = page_->GetBatch(); + // page to be loaded + CSCMatrixManager::PagePtr page; + if(!itr_.Next(page)) return false; + out_ = page->GetBatch(); return true; } virtual const ColBatch &Value(void) const{ @@ -190,18 +187,8 @@ class ThreadColPageIterator : public utils::IIterator { } private: - // shutup - int silent; - // input file - utils::FileStream *fi_; - // size of page - float col_pageratio_; - // name of file - std::string col_pagefile_; // output data ColBatch out_; - // page to be loaded - CSCMatrixManager::PagePtr page_; // internal iterator utils::ThreadBuffer itr_; }; @@ -212,14 +199,18 @@ class ThreadColPageIterator : public utils::IIterator { class FMatrixPage : public IFMatrix { public: /*! \brief constructor */ - FMatrixPage(utils::IIterator *iter) { + FMatrixPage(utils::IIterator *iter, std::string fname_buffer) { this->row_iter_ = iter; this->col_iter_ = NULL; + this->fi_ = NULL; } // destructor virtual ~FMatrixPage(void) { if (row_iter_ != NULL) delete row_iter_; if (col_iter_ != NULL) delete col_iter_; + if (fi_ != NULL) { + fi_->Close(); delete fi_; + } } /*! \return whether column access is enabled */ virtual bool HaveColAccess(void) const { @@ -275,18 +266,44 @@ class FMatrixPage : public IFMatrix { } protected: + /*! + * \brief try load column data from file + */ + inline bool LoadColData(void) { + FILE *fp = fopen64(fname_cbuffer_.c_str(), "rb"); + if (fp == NULL) return false; + fi_ = new utils::FileStream(fp); + static_cast(fi_)->Read(&buffered_rowset_); + col_iter_ = new ThreadColPageIterator(fi_, 2.0f, false); + return true; + } /*! * \brief intialize column data * \param pkeep probability to keep a row */ inline void InitColData(float pkeep) { - buffered_rowset_.clear(); + buffered_rowset_.clear(); + utils::FileStream fo(utils::FopenCheck(fname_cbuffer_.c_str(), "wb+")); + // use 64M buffer + utils::SparseCSRFileBuilder builder(&fo, 64<<20); + // start working row_iter_->BeforeFirst(); while (row_iter_->Next()) { const RowBatch &batch = row_iter_->Value(); - + for (size_t i = 0; i < batch.size; ++i) { + if (pkeep == 1.0f || random::SampleBinary(pkeep)) { + buffered_rowset_.push_back(static_cast(batch.base_rowid+i)); + RowBatch::Inst inst = batch[i]; + for (bst_uint j = 0; j < inst.length; ++j) { + builder.AddBudget(inst[j].index); + } + } + } } + // write buffered rowset + static_cast(&fo)->Write(buffered_rowset_); + builder.InitStorage(); row_iter_->BeforeFirst(); size_t ktop = 0; while (row_iter_->Next()) { @@ -295,11 +312,18 @@ class FMatrixPage : public IFMatrix { if (ktop < buffered_rowset_.size() && buffered_rowset_[ktop] == batch.base_rowid + i) { ++ktop; - // TODO1 + RowBatch::Inst inst = batch[i]; + for (bst_uint j = 0; j < inst.length; ++j) { + builder.PushElem(inst[j].index, + ColBatch::Entry((bst_uint)(batch.base_rowid+i), + inst[j].fvalue)); + } } } } - // sort columns + builder.Finalize(); + builder.SortRows(ColBatch::Entry::CmpValue, 5); + fo.Close(); } private: @@ -307,6 +331,10 @@ class FMatrixPage : public IFMatrix { utils::IIterator *row_iter_; // column iterator ThreadColPageIterator *col_iter_; + // file pointer to data + utils::FileStream *fi_; + // file name of column buffer + std::string fname_cbuffer_; /*! \brief list of row index that are buffered */ std::vector buffered_rowset_; }; diff --git a/src/utils/io.h b/src/utils/io.h index d98b3e4dc..37f489955 100644 --- a/src/utils/io.h +++ b/src/utils/io.h @@ -125,7 +125,6 @@ class FileStream : public ISeekStream { private: FILE *fp; }; - } // namespace utils } // namespace xgboost #endif diff --git a/src/utils/matrix_csr.h b/src/utils/matrix_csr.h index b2768b2ea..e4c410511 100644 --- a/src/utils/matrix_csr.h +++ b/src/utils/matrix_csr.h @@ -9,6 +9,7 @@ #include #include "./io.h" #include "./utils.h" +#include "./omp.h" namespace xgboost { namespace utils { @@ -155,9 +156,9 @@ struct SparseCSRFileBuilder { for (size_t i = 1; i < rptr.size(); i++) { nelem += rptr[i]; rptr[i] = nelem; - } - SizeType begin_meta = sizeof(SizeType) + nelem * sizeof(IndexType); - fo->Seek(0); + } + begin_data = static_cast(fo->Tell()) + sizeof(SizeType); + SizeType begin_meta = begin_data + nelem * sizeof(IndexType); fo->Write(&begin_meta, sizeof(begin_meta)); fo->Seek(begin_meta); fo->Write(rptr); @@ -184,7 +185,28 @@ struct SparseCSRFileBuilder { utils::Assert(saved_offset[i] == rptr[i+1], "some block not write out"); } } - + /*! \brief content must be in wb+ */ + template + inline void SortRows(Comparator comp, size_t step) { + for (size_t i = 0; i < rptr.size() - 1; i += step) { + bst_omp_uint begin = static_cast(i); + bst_omp_uint end = static_cast(std::min(rptr.size(), i + step)); + if (rptr[end] != rptr[begin]) { + fo->Seek(begin_data + rptr[begin] * sizeof(IndexType)); + buffer_data.resize(rptr[end] - rptr[begin]); + fo->Read(BeginPtr(buffer_data), (rptr[end] - rptr[begin]) * sizeof(IndexType)); + // do parallel sorting + #pragma omp parallel for schedule(static) + for (bst_omp_uint j = begin; j < end; ++j){ + std::sort(&buffer_data[0] + rptr[j] - rptr[begin], + &buffer_data[0] + rptr[j+1] - rptr[begin], + comp); + } + fo->Seek(begin_data + rptr[begin] * sizeof(IndexType)); + fo->Write(BeginPtr(buffer_data), (rptr[end] - rptr[begin]) * sizeof(IndexType)); + } + } + } protected: inline void WriteBuffer(void) { SizeType start = 0; @@ -202,7 +224,7 @@ struct SparseCSRFileBuilder { size_t nelem = buffer_rptr[i+1] - buffer_rptr[i]; if (nelem != 0) { utils::Assert(saved_offset[i] < rptr[i+1], "data exceed bound"); - fo->Seek((rptr[i] + saved_offset[i]) * sizeof(IndexType) + sizeof(SizeType)); + fo->Seek((rptr[i] + saved_offset[i]) * sizeof(IndexType) + begin_data); fo->Write(&buffer_data[0] + buffer_rptr[i], nelem * sizeof(IndexType)); saved_offset[i] += nelem; } @@ -219,6 +241,8 @@ struct SparseCSRFileBuilder { std::vector rptr; /*! \brief saved top space of each item */ std::vector saved_offset; + /*! \brief beginning position of data */ + size_t begin_data; // ----- the following are buffer space /*! \brief maximum size of content buffer*/ size_t buffer_size;