/*! * Copyright (c) 2014-2021 by Contributors * \file sparse_page_source.h */ #ifndef XGBOOST_DATA_SPARSE_PAGE_SOURCE_H_ #define XGBOOST_DATA_SPARSE_PAGE_SOURCE_H_ #include // std::min #include #include #include #include #include #include #include #include "rabit/rabit.h" #include "xgboost/base.h" #include "xgboost/data.h" #include "adapter.h" #include "sparse_page_writer.h" #include "proxy_dmatrix.h" #include "../common/common.h" namespace xgboost { namespace data { inline void TryDeleteCacheFile(const std::string& file) { if (std::remove(file.c_str()) != 0) { LOG(WARNING) << "Couldn't remove external memory cache file " << file << "; you may want to remove it manually"; } } struct Cache { // whether the write to the cache is complete bool written; std::string name; std::string format; // offset into binary cache file. std::vector offset; Cache(bool w, std::string n, std::string fmt) : written{w}, name{std::move(n)}, format{std::move(fmt)} { offset.push_back(0); } static std::string ShardName(std::string name, std::string format) { CHECK_EQ(format.front(), '.'); return name + format; } std::string ShardName() { return ShardName(this->name, this->format); } // The write is completed. void Commit() { if (!written) { std::partial_sum(offset.begin(), offset.end(), offset.begin()); written = true; } } }; // Prevents multi-threaded call. class TryLockGuard { std::mutex& lock_; public: explicit TryLockGuard(std::mutex& lock) : lock_{lock} { // NOLINT CHECK(lock_.try_lock()) << "Multiple threads attempting to use Sparse DMatrix."; } ~TryLockGuard() { lock_.unlock(); } }; template class SparsePageSourceImpl : public BatchIteratorImpl { protected: // Prevents calling this iterator from multiple places(or threads). std::mutex single_threaded_; std::shared_ptr page_; bool at_end_ {false}; float missing_; int nthreads_; bst_feature_t n_features_; uint32_t count_{0}; uint32_t n_batches_ {0}; std::shared_ptr cache_info_; std::unique_ptr fo_; using Ring = std::vector>>; // A ring storing futures to data. Since the DMatrix iterator is forward only, so we // can pre-fetch data in a ring. std::unique_ptr ring_{new Ring}; bool ReadCache() { CHECK(!at_end_); if (!cache_info_->written) { return false; } if (fo_) { fo_.reset(); // flush the data to disk. ring_->resize(n_batches_); } // An heuristic for number of pre-fetched batches. We can make it part of BatchParam // to let user adjust number of pre-fetched batches when needed. uint32_t constexpr kPreFetch = 4; size_t n_prefetch_batches = std::min(kPreFetch, n_batches_); CHECK_GT(n_prefetch_batches, 0) << "total batches:" << n_batches_; size_t fetch_it = count_; for (size_t i = 0; i < n_prefetch_batches; ++i, ++fetch_it) { fetch_it %= n_batches_; // ring if (ring_->at(fetch_it).valid()) { continue; } auto const *self = this; // make sure it's const CHECK_LT(fetch_it, cache_info_->offset.size()); ring_->at(fetch_it) = std::async(std::launch::async, [fetch_it, self]() { std::unique_ptr> fmt{CreatePageFormat("raw")}; auto n = self->cache_info_->ShardName(); size_t offset = self->cache_info_->offset.at(fetch_it); std::unique_ptr fi{ dmlc::SeekStream::CreateForRead(n.c_str())}; fi->Seek(offset); CHECK_EQ(fi->Tell(), offset); auto page = std::make_shared(); CHECK(fmt->Read(page.get(), fi.get())); return page; }); } CHECK_EQ(std::count_if(ring_->cbegin(), ring_->cend(), [](auto const &f) { return f.valid(); }), n_prefetch_batches) << "Sparse DMatrix assumes forward iteration."; page_ = (*ring_)[count_].get(); return true; } void WriteCache() { CHECK(!cache_info_->written); std::unique_ptr> fmt{CreatePageFormat("raw")}; if (!fo_) { auto n = cache_info_->ShardName(); fo_.reset(dmlc::Stream::Create(n.c_str(), "w")); } auto bytes = fmt->Write(*page_, fo_.get()); cache_info_->offset.push_back(bytes); } virtual void Fetch() = 0; public: SparsePageSourceImpl(float missing, int nthreads, bst_feature_t n_features, uint32_t n_batches, std::shared_ptr cache) : missing_{missing}, nthreads_{nthreads}, n_features_{n_features}, n_batches_{n_batches}, cache_info_{std::move(cache)} {} SparsePageSourceImpl(SparsePageSourceImpl const &that) = delete; ~SparsePageSourceImpl() override { for (auto& fu : *ring_) { if (fu.valid()) { fu.get(); } } } uint32_t Iter() const { return count_; } const S &operator*() const override { CHECK(page_); return *page_; } std::shared_ptr Page() const override { return page_; } bool AtEnd() const override { return at_end_; } virtual void Reset() { TryLockGuard guard{single_threaded_}; at_end_ = false; count_ = 0; this->Fetch(); } }; #if defined(XGBOOST_USE_CUDA) void DevicePush(DMatrixProxy* proxy, float missing, SparsePage* page); #else inline void DevicePush(DMatrixProxy* proxy, float missing, SparsePage* page) { common::AssertGPUSupport(); } #endif class SparsePageSource : public SparsePageSourceImpl { DataIterProxy iter_; DMatrixProxy* proxy_; size_t base_row_id_ {0}; void Fetch() final { page_ = std::make_shared(); if (!this->ReadCache()) { bool type_error { false }; CHECK(proxy_); HostAdapterDispatch(proxy_, [&](auto const &adapter_batch) { page_->Push(adapter_batch, this->missing_, this->nthreads_); }, &type_error); if (type_error) { DevicePush(proxy_, missing_, page_.get()); } page_->SetBaseRowId(base_row_id_); base_row_id_ += page_->Size(); n_batches_++; this->WriteCache(); } } public: SparsePageSource( DataIterProxy iter, DMatrixProxy *proxy, float missing, int nthreads, bst_feature_t n_features, uint32_t n_batches, std::shared_ptr cache) : SparsePageSourceImpl(missing, nthreads, n_features, n_batches, cache), iter_{iter}, proxy_{proxy} { if (!cache_info_->written) { iter_.Reset(); CHECK_EQ(iter_.Next(), 1) << "Must have at least 1 batch."; } this->Fetch(); } SparsePageSource& operator++() final { TryLockGuard guard{single_threaded_}; count_++; if (cache_info_->written) { at_end_ = (count_ == n_batches_); } else { at_end_ = !iter_.Next(); } if (at_end_) { cache_info_->Commit(); if (n_batches_ != 0) { CHECK_EQ(count_, n_batches_); } CHECK_GE(count_, 1); proxy_ = nullptr; } else { this->Fetch(); } return *this; } void Reset() override { if (proxy_) { TryLockGuard guard{single_threaded_}; iter_.Reset(); } SparsePageSourceImpl::Reset(); TryLockGuard guard{single_threaded_}; base_row_id_ = 0; } }; // A mixin for advancing the iterator. template class PageSourceIncMixIn : public SparsePageSourceImpl { protected: std::shared_ptr source_; public: using SparsePageSourceImpl::SparsePageSourceImpl; PageSourceIncMixIn& operator++() final { TryLockGuard guard{this->single_threaded_}; ++(*source_); ++this->count_; this->at_end_ = source_->AtEnd(); if (this->at_end_) { this->cache_info_->Commit(); if (this->n_batches_ != 0) { CHECK_EQ(this->count_, this->n_batches_); } CHECK_GE(this->count_, 1); } else { this->Fetch(); } CHECK_EQ(source_->Iter(), this->count_); return *this; } }; class CSCPageSource : public PageSourceIncMixIn { protected: void Fetch() final { if (!this->ReadCache()) { auto const &csr = source_->Page(); this->page_.reset(new CSCPage{}); // we might be able to optimize this by merging transpose and pushcsc this->page_->PushCSC(csr->GetTranspose(n_features_)); page_->SetBaseRowId(csr->base_rowid); this->WriteCache(); } } public: CSCPageSource( float missing, int nthreads, bst_feature_t n_features, uint32_t n_batches, std::shared_ptr cache, std::shared_ptr source) : PageSourceIncMixIn(missing, nthreads, n_features, n_batches, cache) { this->source_ = source; this->Fetch(); } }; class SortedCSCPageSource : public PageSourceIncMixIn { protected: void Fetch() final { if (!this->ReadCache()) { auto const &csr = this->source_->Page(); this->page_.reset(new SortedCSCPage{}); // we might be able to optimize this by merging transpose and pushcsc this->page_->PushCSC(csr->GetTranspose(n_features_)); CHECK_EQ(this->page_->Size(), n_features_); CHECK_EQ(this->page_->data.Size(), csr->data.Size()); this->page_->SortRows(); page_->SetBaseRowId(csr->base_rowid); this->WriteCache(); } } public: SortedCSCPageSource(float missing, int nthreads, bst_feature_t n_features, uint32_t n_batches, std::shared_ptr cache, std::shared_ptr source) : PageSourceIncMixIn(missing, nthreads, n_features, n_batches, cache) { this->source_ = source; this->Fetch(); } }; } // namespace data } // namespace xgboost #endif // XGBOOST_DATA_SPARSE_PAGE_SOURCE_H_