/*! * Copyright (c) 2014-2019 by Contributors * \file sparse_page_writer.h * \author Tianqi Chen */ #ifndef XGBOOST_DATA_SPARSE_PAGE_WRITER_H_ #define XGBOOST_DATA_SPARSE_PAGE_WRITER_H_ #include #include #include #include #include #include #include #include #include #if DMLC_ENABLE_STD_THREAD #include #include #endif // DMLC_ENABLE_STD_THREAD namespace xgboost { namespace data { template struct SparsePageFormatReg; /*! * \brief Format specification of SparsePage. */ template class SparsePageFormat { public: /*! \brief virtual destructor */ virtual ~SparsePageFormat() = default; /*! * \brief Load all the segments into page, advance fi to end of the block. * \param page The data to read page into. * \param fi the input stream of the file * \return true of the loading as successful, false if end of file was reached */ virtual bool Read(T* page, dmlc::SeekStream* fi) = 0; /*! * \brief read only the segments we are interested in, advance fi to end of the block. * \param page The page to load the data into. * \param fi the input stream of the file * \param sorted_index_set sorted index of segments we are interested in * \return true of the loading as successful, false if end of file was reached */ virtual bool Read(T* page, dmlc::SeekStream* fi, const std::vector& sorted_index_set) = 0; /*! * \brief save the data to fo, when a page was written. * \param fo output stream */ virtual void Write(const T& page, dmlc::Stream* fo) = 0; }; /*! * \brief Create sparse page of format. * \return The created format functors. */ template inline SparsePageFormat* CreatePageFormat(const std::string& name) { auto *e = ::dmlc::Registry>::Get()->Find(name); if (e == nullptr) { LOG(FATAL) << "Unknown format type " << name; return nullptr; } return (e->body)(); } #if DMLC_ENABLE_STD_THREAD /*! * \brief A threaded writer to write sparse batch page to sharded files. * @tparam T Type of the page. */ template class SparsePageWriter { public: /*! * \brief constructor * \param name_shards name of shard files. * \param format_shards format of each shard. * \param extra_buffer_capacity Extra buffer capacity before block. */ explicit SparsePageWriter(const std::vector& name_shards, const std::vector& format_shards, size_t extra_buffer_capacity) : num_free_buffer_(extra_buffer_capacity + name_shards.size()), clock_ptr_(0), workers_(name_shards.size()), qworkers_(name_shards.size()) { CHECK_EQ(name_shards.size(), format_shards.size()); // start writer threads for (size_t i = 0; i < name_shards.size(); ++i) { std::string name_shard = name_shards[i]; std::string format_shard = format_shards[i]; auto* wqueue = &qworkers_[i]; workers_[i].reset(new std::thread( [this, name_shard, format_shard, wqueue]() { std::unique_ptr fo(dmlc::Stream::Create(name_shard.c_str(), "w")); std::unique_ptr> fmt(CreatePageFormat(format_shard)); fo->Write(format_shard); std::shared_ptr page; while (wqueue->Pop(&page)) { if (page == nullptr) break; fmt->Write(*page, fo.get()); qrecycle_.Push(std::move(page)); } fo.reset(nullptr); LOG(INFO) << "SparsePageWriter Finished writing to " << name_shard; })); } } /*! \brief destructor, will close the files automatically */ ~SparsePageWriter() { for (auto& queue : qworkers_) { // use nullptr to signal termination. std::shared_ptr sig(nullptr); queue.Push(std::move(sig)); } for (auto& thread : workers_) { thread->join(); } } /*! * \brief Push a write job to the writer. * This function won't block, * writing is done by another thread inside writer. * \param page The page to be written */ void PushWrite(std::shared_ptr&& page) { qworkers_[clock_ptr_].Push(std::move(page)); clock_ptr_ = (clock_ptr_ + 1) % workers_.size(); } /*! * \brief Allocate a page to store results. * This function can block when the writer is too slow and buffer pages * have not yet been recycled. * \param out_page Used to store the allocated pages. */ void Alloc(std::shared_ptr* out_page) { CHECK(*out_page == nullptr); if (num_free_buffer_ != 0) { out_page->reset(new T()); --num_free_buffer_; } else { CHECK(qrecycle_.Pop(out_page)); } } private: /*! \brief number of allocated pages */ size_t num_free_buffer_; /*! \brief clock_pointer */ size_t clock_ptr_; /*! \brief writer threads */ std::vector> workers_; /*! \brief recycler queue */ dmlc::ConcurrentBlockingQueue> qrecycle_; /*! \brief worker threads */ std::vector>> qworkers_; }; #endif // DMLC_ENABLE_STD_THREAD /*! * \brief Registry entry for sparse page format. */ template struct SparsePageFormatReg : public dmlc::FunctionRegEntryBase, std::function* ()>> { }; /*! * \brief Macro to register sparse page format. * * \code * // example of registering a objective * XGBOOST_REGISTER_SPARSE_PAGE_FORMAT(raw) * .describe("Raw binary data format.") * .set_body([]() { * return new RawFormat(); * }); * \endcode */ #define SparsePageFmt SparsePageFormat #define XGBOOST_REGISTER_SPARSE_PAGE_FORMAT(Name) \ DMLC_REGISTRY_REGISTER(SparsePageFormatReg, SparsePageFmt, Name) #define CSCPageFmt SparsePageFormat #define XGBOOST_REGISTER_CSC_PAGE_FORMAT(Name) \ DMLC_REGISTRY_REGISTER(SparsePageFormatReg, CSCPageFmt, Name) #define SortedCSCPageFmt SparsePageFormat #define XGBOOST_REGISTER_SORTED_CSC_PAGE_FORMAT(Name) \ DMLC_REGISTRY_REGISTER(SparsePageFormatReg, SortedCSCPageFmt, Name) #define EllpackPageFmt SparsePageFormat #define XGBOOST_REGISTER_ELLPACK_PAGE_FORMAT(Name) \ DMLC_REGISTRY_REGISTER(SparsePageFormatReg, EllpackPageFm, Name) } // namespace data } // namespace xgboost #endif // XGBOOST_DATA_SPARSE_PAGE_WRITER_H_