diff --git a/amalgamation/xgboost-all0.cc b/amalgamation/xgboost-all0.cc index 645695dec..7cc36c16b 100644 --- a/amalgamation/xgboost-all0.cc +++ b/amalgamation/xgboost-all0.cc @@ -30,6 +30,7 @@ #include "../src/data/data.cc" #include "../src/data/simple_csr_source.cc" #include "../src/data/simple_dmatrix.cc" +#include "../src/data/sparse_page_raw_format.cc" #if DMLC_ENABLE_STD_THREAD #include "../src/data/sparse_page_source.cc" diff --git a/src/data/data.cc b/src/data/data.cc index d73c1e8bb..30da58e8e 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -4,6 +4,7 @@ */ #include #include +#include #include #include "./sparse_batch_page.h" #include "./simple_dmatrix.h" @@ -15,6 +16,10 @@ #include "./sparse_page_dmatrix.h" #endif +namespace dmlc { +DMLC_REGISTRY_ENABLE(::xgboost::data::SparsePageFormatReg); +} // namespace dmlc + namespace xgboost { // implementation of inline functions void MetaInfo::Clear() { @@ -231,3 +236,24 @@ DMatrix* DMatrix::Create(std::unique_ptr&& source, } } } // namespace xgboost + +namespace xgboost { +namespace data { +SparsePage::Format* SparsePage::Format::Create(const std::string& name) { + auto *e = ::dmlc::Registry< ::xgboost::data::SparsePageFormatReg>::Get()->Find(name); + if (e == nullptr) { + LOG(FATAL) << "Unknown format type " << name; + } + return (e->body)(); +} + +std::string SparsePage::Format::DecideFormat(const std::string& cache_prefix) { + size_t pos = cache_prefix.rfind(".fmt-"); + if (pos != std::string::npos) { + return cache_prefix.substr(pos + 5, cache_prefix.length()); + } else { + return "raw"; + } +} +} // namespace data +} // namespace xgboost diff --git a/src/data/sparse_batch_page.h b/src/data/sparse_batch_page.h index 145495e29..78534bfd6 100644 --- a/src/data/sparse_batch_page.h +++ b/src/data/sparse_batch_page.h @@ -14,6 +14,7 @@ #include #include #include +#include namespace xgboost { namespace data { @@ -22,6 +23,9 @@ namespace data { */ class SparsePage { public: + /*! \brief Format of the sparse page. */ + class Format; + /*! \brief offset of the segments */ std::vector offset; /*! \brief the data of the segments */ @@ -35,87 +39,6 @@ class SparsePage { inline size_t Size() const { return offset.size() - 1; } - /*! - * \brief load only the segments we are interested in - * \param fi the input stream of the file - * \param sorted_index_set sorted index of segments we are interested in - * \return true of the loading as successful, false if end of file was reached - */ - inline bool Load(dmlc::SeekStream *fi, - const std::vector &sorted_index_set) { - if (!fi->Read(&disk_offset_)) return false; - // setup the offset - offset.clear(); offset.push_back(0); - for (size_t i = 0; i < sorted_index_set.size(); ++i) { - bst_uint fid = sorted_index_set[i]; - CHECK_LT(fid + 1, disk_offset_.size()); - size_t size = disk_offset_[fid + 1] - disk_offset_[fid]; - offset.push_back(offset.back() + size); - } - data.resize(offset.back()); - // read in the data - size_t begin = fi->Tell(); - size_t curr_offset = 0; - for (size_t i = 0; i < sorted_index_set.size();) { - bst_uint fid = sorted_index_set[i]; - if (disk_offset_[fid] != curr_offset) { - CHECK_GT(disk_offset_[fid], curr_offset); - fi->Seek(begin + disk_offset_[fid] * sizeof(SparseBatch::Entry)); - curr_offset = disk_offset_[fid]; - } - size_t j, size_to_read = 0; - for (j = i; j < sorted_index_set.size(); ++j) { - if (disk_offset_[sorted_index_set[j]] == disk_offset_[fid] + size_to_read) { - size_to_read += offset[j + 1] - offset[j]; - } else { - break; - } - } - - if (size_to_read != 0) { - CHECK_EQ(fi->Read(dmlc::BeginPtr(data) + offset[i], - size_to_read * sizeof(SparseBatch::Entry)), - size_to_read * sizeof(SparseBatch::Entry)) - << "Invalid SparsePage file"; - curr_offset += size_to_read; - } - i = j; - } - // seek to end of record - if (curr_offset != disk_offset_.back()) { - fi->Seek(begin + disk_offset_.back() * sizeof(SparseBatch::Entry)); - } - return true; - } - /*! - * \brief load all the segments - * \param fi the input stream of the file - * \return true of the loading as successful, false if end of file was reached - */ - inline bool Load(dmlc::Stream *fi) { - if (!fi->Read(&offset)) return false; - CHECK_NE(offset.size(), 0) << "Invalid SparsePage file"; - data.resize(offset.back()); - if (data.size() != 0) { - CHECK_EQ(fi->Read(dmlc::BeginPtr(data), data.size() * sizeof(SparseBatch::Entry)), - data.size() * sizeof(SparseBatch::Entry)) - << "Invalid SparsePage file"; - } - return true; - } - /*! - * \brief save the data to fo, when a page was written - * to disk it must contain all the elements in the - * \param fo output stream - */ - inline void Save(dmlc::Stream *fo) const { - CHECK(offset.size() != 0 && offset[0] == 0); - CHECK_EQ(offset.back(), data.size()); - fo->Write(offset); - if (data.size() != 0) { - fo->Write(dmlc::BeginPtr(data), data.size() * sizeof(SparseBatch::Entry)); - } - } /*! \return estimation of memory cost of this page */ inline size_t MemCostBytes(void) const { return offset.size() * sizeof(size_t) + data.size() * sizeof(SparseBatch::Entry); @@ -126,28 +49,7 @@ class SparsePage { offset.push_back(0); data.clear(); } - /*! - * \brief load all the segments and add it to existing batch - * \param fi the input stream of the file - * \return true of the loading as successful, false if end of file was reached - */ - inline bool PushLoad(dmlc::Stream *fi) { - if (!fi->Read(&disk_offset_)) return false; - data.resize(offset.back() + disk_offset_.back()); - if (disk_offset_.back() != 0) { - CHECK_EQ(fi->Read(dmlc::BeginPtr(data) + offset.back(), - disk_offset_.back() * sizeof(SparseBatch::Entry)), - disk_offset_.back() * sizeof(SparseBatch::Entry)) - << "Invalid SparsePage file"; - } - size_t top = offset.back(); - size_t begin = offset.size(); - offset.resize(offset.size() + disk_offset_.size()); - for (size_t i = 0; i < disk_offset_.size(); ++i) { - offset[i + begin] = top + disk_offset_[i]; - } - return true; - } + /*! * \brief Push row batch into the page * \param batch the row batch @@ -223,11 +125,72 @@ class SparsePage { out.size = offset.size() - 1; return out; } - - private: - /*! \brief external memory column offset */ - std::vector disk_offset_; }; + +/*! + * \brief Format specification of SparsePage. + */ +class SparsePage::Format { + public: + /*! \brief virtual destructor */ + virtual ~Format() {} + /*! + * \brief Load all the segments into page, advance fi to end of the block. + * \param page The data to read page into. + * \param fi the input stream of the file + * \return true of the loading as successful, false if end of file was reached + */ + virtual bool Read(SparsePage* page, dmlc::SeekStream* fi) = 0; + /*! + * \brief read only the segments we are interested in, advance fi to end of the block. + * \param page The page to load the data into. + * \param fi the input stream of the file + * \param sorted_index_set sorted index of segments we are interested in + * \return true of the loading as successful, false if end of file was reached + */ + virtual bool Read(SparsePage* page, + dmlc::SeekStream* fi, + const std::vector& sorted_index_set) = 0; + /*! + * \brief save the data to fo, when a page was written. + * \param fo output stream + */ + virtual void Write(const SparsePage& page, dmlc::Stream* fo) const = 0; + /*! + * \brief Create sparse page of format. + * \return The created format functors. + */ + static Format* Create(const std::string& name); + /*! + * \brief decide the format from cache prefix. + * \return format type of the cache prefix. + */ + static std::string DecideFormat(const std::string& cache_prefix); +}; + +/*! + * \brief Registry entry for sparse page format. + */ +struct SparsePageFormatReg + : public dmlc::FunctionRegEntryBase > { +}; + +/*! + * \brief Macro to register sparse page format. + * + * \code + * // example of registering a objective + * XGBOOST_REGISTER_SPARSE_PAGE_FORMAT(raw) + * .describe("Raw binary data format.") + * .set_body([]() { + * return new RawFormat(); + * }); + * \endcode + */ +#define XGBOOST_REGISTER_SPARSE_PAGE_FORMAT(Name) \ + DMLC_REGISTRY_REGISTER(::xgboost::data::SparsePageFormatReg, SparsePageFormat, Name) + } // namespace data } // namespace xgboost #endif // XGBOOST_DATA_SPARSE_BATCH_PAGE_H_ diff --git a/src/data/sparse_page_dmatrix.cc b/src/data/sparse_page_dmatrix.cc index ba4f566f5..e25ea6e25 100644 --- a/src/data/sparse_page_dmatrix.cc +++ b/src/data/sparse_page_dmatrix.cc @@ -17,20 +17,24 @@ namespace data { SparsePageDMatrix::ColPageIter::ColPageIter(std::unique_ptr&& fi) : fi_(std::move(fi)), page_(nullptr) { - - load_all_ = false; + + std::string format; + CHECK(fi_->Read(&format)) << "Invalid page format"; + format_.reset(SparsePage::Format::Create(format)); + size_t fbegin = fi_->Tell(); + prefetcher_.Init([this](SparsePage** dptr) { if (*dptr == nullptr) { *dptr = new SparsePage(); } if (load_all_) { - return (*dptr)->Load(fi_.get()); + return format_->Read(*dptr, fi_.get()); } else { - return (*dptr)->Load(fi_.get(), index_set_); + return format_->Read(*dptr, fi_.get(), index_set_); } - }, [this] () { - fi_->Seek(0); + }, [this, fbegin] () { + fi_->Seek(fbegin); index_set_ = set_index_set_; load_all_ = set_load_all_; }); @@ -222,6 +226,11 @@ void SparsePageDMatrix::InitColAccess(const std::vector& enabled, std::string col_data_name = cache_prefix_ + ".col.page"; std::unique_ptr fo(dmlc::Stream::Create(col_data_name.c_str(), "w")); + // find format. + std::string name_format = SparsePage::Format::DecideFormat(cache_prefix_); + fo->Write(name_format); + std::unique_ptr format(SparsePage::Format::Create(name_format)); + double tstart = dmlc::GetTime(); size_t bytes_write = 0; SparsePage* pcol = nullptr; @@ -230,7 +239,7 @@ void SparsePageDMatrix::InitColAccess(const std::vector& enabled, for (size_t i = 0; i < pcol->Size(); ++i) { col_size_[i] += pcol->offset[i + 1] - pcol->offset[i]; } - pcol->Save(fo.get()); + format->Write(*pcol, fo.get()); size_t spage = pcol->MemCostBytes(); bytes_write += spage; double tdiff = dmlc::GetTime() - tstart; diff --git a/src/data/sparse_page_dmatrix.h b/src/data/sparse_page_dmatrix.h index 12e9ae165..9d0a2e344 100644 --- a/src/data/sparse_page_dmatrix.h +++ b/src/data/sparse_page_dmatrix.h @@ -92,6 +92,10 @@ class SparsePageDMatrix : public DMatrix { private: // data file pointer. std::unique_ptr fi_; + // the temp page. + SparsePage* page_; + // page format. + std::unique_ptr format_; // The index set to be loaded. std::vector index_set_; // The index set by the outsiders @@ -100,8 +104,6 @@ class SparsePageDMatrix : public DMatrix { bool set_load_all_, load_all_; // data prefetcher. dmlc::ThreadedIter prefetcher_; - // the temp page. - SparsePage* page_; // temporal space for batch ColBatch out_; // the pointer data. diff --git a/src/data/sparse_page_raw_format.cc b/src/data/sparse_page_raw_format.cc new file mode 100644 index 000000000..867ffad1c --- /dev/null +++ b/src/data/sparse_page_raw_format.cc @@ -0,0 +1,96 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file sparse_page_raw_format.cc + * Raw binary format of sparse page. + */ +#include +#include "./sparse_batch_page.h" + +namespace xgboost { +namespace data { + +class SparsePageRawFormat : public SparsePage::Format { + public: + bool Read(SparsePage* page, dmlc::SeekStream* fi) override { + if (!fi->Read(&(page->offset))) return false; + CHECK_NE(page->offset.size(), 0) << "Invalid SparsePage file"; + page->data.resize(page->offset.back()); + if (page->data.size() != 0) { + CHECK_EQ(fi->Read(dmlc::BeginPtr(page->data), + (page->data).size() * sizeof(SparseBatch::Entry)), + (page->data).size() * sizeof(SparseBatch::Entry)) + << "Invalid SparsePage file"; + } + return true; + } + + bool Read(SparsePage* page, + dmlc::SeekStream* fi, + const std::vector& sorted_index_set) override { + if (!fi->Read(&disk_offset_)) return false; + // setup the offset + page->offset.clear(); + page->offset.push_back(0); + for (size_t i = 0; i < sorted_index_set.size(); ++i) { + bst_uint fid = sorted_index_set[i]; + CHECK_LT(fid + 1, disk_offset_.size()); + size_t size = disk_offset_[fid + 1] - disk_offset_[fid]; + page->offset.push_back(page->offset.back() + size); + } + page->data.resize(page->offset.back()); + // read in the data + size_t begin = fi->Tell(); + size_t curr_offset = 0; + for (size_t i = 0; i < sorted_index_set.size();) { + bst_uint fid = sorted_index_set[i]; + if (disk_offset_[fid] != curr_offset) { + CHECK_GT(disk_offset_[fid], curr_offset); + fi->Seek(begin + disk_offset_[fid] * sizeof(SparseBatch::Entry)); + curr_offset = disk_offset_[fid]; + } + size_t j, size_to_read = 0; + for (j = i; j < sorted_index_set.size(); ++j) { + if (disk_offset_[sorted_index_set[j]] == disk_offset_[fid] + size_to_read) { + size_to_read += page->offset[j + 1] - page->offset[j]; + } else { + break; + } + } + + if (size_to_read != 0) { + CHECK_EQ(fi->Read(dmlc::BeginPtr(page->data) + page->offset[i], + size_to_read * sizeof(SparseBatch::Entry)), + size_to_read * sizeof(SparseBatch::Entry)) + << "Invalid SparsePage file"; + curr_offset += size_to_read; + } + i = j; + } + // seek to end of record + if (curr_offset != disk_offset_.back()) { + fi->Seek(begin + disk_offset_.back() * sizeof(SparseBatch::Entry)); + } + return true; + } + + void Write(const SparsePage& page, dmlc::Stream* fo) const override { + CHECK(page.offset.size() != 0 && page.offset[0] == 0); + CHECK_EQ(page.offset.back(), page.data.size()); + fo->Write(page.offset); + if (page.data.size() != 0) { + fo->Write(dmlc::BeginPtr(page.data), page.data.size() * sizeof(SparseBatch::Entry)); + } + } + + private: + /*! \brief external memory column offset */ + std::vector disk_offset_; +}; + +XGBOOST_REGISTER_SPARSE_PAGE_FORMAT(raw) +.describe("Raw binary data format.") +.set_body([]() { + return new SparsePageRawFormat(); + }); +} // namespace data +} // namespace xgboost diff --git a/src/data/sparse_page_source.cc b/src/data/sparse_page_source.cc index 4ee6f17c4..8159d3807 100644 --- a/src/data/sparse_page_source.cc +++ b/src/data/sparse_page_source.cc @@ -24,12 +24,18 @@ SparsePageSource::SparsePageSource(const std::string& cache_prefix) // read in the cache files. std::string name_row = cache_prefix + ".row.page"; fi_.reset(dmlc::SeekStream::CreateForRead(name_row.c_str())); + + std::string format; + CHECK(fi_->Read(&format)) << "Invalid page format"; + format_.reset(SparsePage::Format::Create(format)); + size_t fbegin = fi_->Tell(); + prefetcher_.Init([this] (SparsePage** dptr) { if (*dptr == nullptr) { *dptr = new SparsePage(); } - return (*dptr)->Load(fi_.get()); - }, [this] () { fi_->Seek(0); }); + return format_->Read(*dptr, fi_.get()); + }, [this, fbegin] () { fi_->Seek(fbegin); }); } SparsePageSource::~SparsePageSource() { @@ -72,6 +78,10 @@ void SparsePageSource::Create(dmlc::Parser* src, std::string name_info = cache_prefix; std::string name_row = cache_prefix + ".row.page"; std::unique_ptr fo(dmlc::Stream::Create(name_row.c_str(), "w")); + std::string name_format = SparsePage::Format::DecideFormat(cache_prefix); + fo->Write(name_format); + std::unique_ptr format(SparsePage::Format::Create(name_format)); + MetaInfo info; SparsePage page; size_t bytes_write = 0; @@ -95,7 +105,7 @@ void SparsePageSource::Create(dmlc::Parser* src, page.Push(batch); if (page.MemCostBytes() >= kPageSize) { bytes_write += page.MemCostBytes(); - page.Save(fo.get()); + format->Write(page, fo.get()); page.Clear(); double tdiff = dmlc::GetTime() - tstart; LOG(CONSOLE) << "Writing to " << name_row << " in " @@ -105,7 +115,7 @@ void SparsePageSource::Create(dmlc::Parser* src, } if (page.data.size() != 0) { - page.Save(fo.get()); + format->Write(page, fo.get()); } fo.reset(dmlc::Stream::Create(name_info.c_str(), "w")); @@ -122,6 +132,10 @@ void SparsePageSource::Create(DMatrix* src, std::string name_info = cache_prefix; std::string name_row = cache_prefix + ".row.page"; std::unique_ptr fo(dmlc::Stream::Create(name_row.c_str(), "w")); + // find format. + std::string name_format = SparsePage::Format::DecideFormat(cache_prefix); + fo->Write(name_format); + std::unique_ptr format(SparsePage::Format::Create(name_format)); SparsePage page; size_t bytes_write = 0; @@ -132,7 +146,7 @@ void SparsePageSource::Create(DMatrix* src, page.Push(iter->Value()); if (page.MemCostBytes() >= kPageSize) { bytes_write += page.MemCostBytes(); - page.Save(fo.get()); + format->Write(page, fo.get()); page.Clear(); double tdiff = dmlc::GetTime() - tstart; LOG(CONSOLE) << "Writing to " << name_row << " in " @@ -142,7 +156,7 @@ void SparsePageSource::Create(DMatrix* src, } if (page.data.size() != 0) { - page.Save(fo.get()); + format->Write(page, fo.get()); } fo.reset(dmlc::Stream::Create(name_info.c_str(), "w")); diff --git a/src/data/sparse_page_source.h b/src/data/sparse_page_source.h index cdbfd9020..79c55b4ba 100644 --- a/src/data/sparse_page_source.h +++ b/src/data/sparse_page_source.h @@ -75,6 +75,8 @@ class SparsePageSource : public DataSource { std::string cache_prefix_; /*! \brief file pointer to the row blob file. */ std::unique_ptr fi_; + /*! \brief Sparse page format file. */ + std::unique_ptr format_; /*! \brief internal prefetcher. */ dmlc::ThreadedIter prefetcher_; };