diff --git a/src/data/ellpack_page_source.cu b/src/data/ellpack_page_source.cu index 24b11ad2a..b16bde6a4 100644 --- a/src/data/ellpack_page_source.cu +++ b/src/data/ellpack_page_source.cu @@ -45,7 +45,7 @@ class EllpackPageSourceImpl : public DataSource { dh::BulkAllocator ba_; /*! \brief The EllpackInfo, with the underlying GPU memory shared by all pages. */ EllpackInfo ellpack_info_; - std::unique_ptr> source_; + std::unique_ptr> source_; std::string cache_info_; }; @@ -98,11 +98,13 @@ EllpackPageSourceImpl::EllpackPageSourceImpl(DMatrix* dmat, WriteEllpackPages(dmat, cache_info); monitor_.StopCuda("WriteEllpackPages"); - source_.reset(new SparsePageSource(cache_info_, kPageType_)); + source_.reset(new ExternalMemoryPrefetcher( + ParseCacheInfo(cache_info_, kPageType_))); } void EllpackPageSourceImpl::BeforeFirst() { - source_.reset(new SparsePageSource(cache_info_, kPageType_)); + source_.reset(new ExternalMemoryPrefetcher( + ParseCacheInfo(cache_info_, kPageType_))); source_->BeforeFirst(); } diff --git a/src/data/sparse_page_dmatrix.cc b/src/data/sparse_page_dmatrix.cc index 67712e26e..e04bb0c8c 100644 --- a/src/data/sparse_page_dmatrix.cc +++ b/src/data/sparse_page_dmatrix.cc @@ -23,58 +23,24 @@ const MetaInfo& SparsePageDMatrix::Info() const { return row_source_->info; } -template -class SparseBatchIteratorImpl : public BatchIteratorImpl { - public: - explicit SparseBatchIteratorImpl(S* source) : source_(source) { - CHECK(source_ != nullptr); - } - T& operator*() override { return source_->Value(); } - const T& operator*() const override { return source_->Value(); } - void operator++() override { at_end_ = !source_->Next(); } - bool AtEnd() const override { return at_end_; } - - private: - S* source_{nullptr}; - bool at_end_{ false }; -}; - BatchSet SparsePageDMatrix::GetRowBatches() { - auto cast = dynamic_cast*>(row_source_.get()); - CHECK(cast); - cast->BeforeFirst(); - cast->Next(); - auto begin_iter = BatchIterator( - new SparseBatchIteratorImpl, SparsePage>(cast)); - return BatchSet(begin_iter); + return row_source_->GetBatchSet(); } BatchSet SparsePageDMatrix::GetColumnBatches() { // Lazily instantiate if (!column_source_) { - SparsePageSource::CreateColumnPage(this, cache_info_, false); - column_source_.reset(new SparsePageSource(cache_info_, ".col.page")); + column_source_.reset(new CSCPageSource(this, cache_info_)); } - column_source_->BeforeFirst(); - column_source_->Next(); - auto begin_iter = BatchIterator( - new SparseBatchIteratorImpl, CSCPage>(column_source_.get())); - return BatchSet(begin_iter); + return column_source_->GetBatchSet(); } BatchSet SparsePageDMatrix::GetSortedColumnBatches() { // Lazily instantiate if (!sorted_column_source_) { - SparsePageSource::CreateColumnPage(this, cache_info_, true); - sorted_column_source_.reset( - new SparsePageSource(cache_info_, ".sorted.col.page")); + sorted_column_source_.reset(new SortedCSCPageSource(this, cache_info_)); } - sorted_column_source_->BeforeFirst(); - sorted_column_source_->Next(); - auto begin_iter = BatchIterator( - new SparseBatchIteratorImpl, SortedCSCPage>( - sorted_column_source_.get())); - return BatchSet(begin_iter); + return sorted_column_source_->GetBatchSet(); } BatchSet SparsePageDMatrix::GetEllpackBatches(const BatchParam& param) { diff --git a/src/data/sparse_page_dmatrix.h b/src/data/sparse_page_dmatrix.h index 12a40e795..a227ceccc 100644 --- a/src/data/sparse_page_dmatrix.h +++ b/src/data/sparse_page_dmatrix.h @@ -22,24 +22,15 @@ namespace data { // Used for external memory. class SparsePageDMatrix : public DMatrix { public: - explicit SparsePageDMatrix(std::unique_ptr>&& source, - std::string cache_info) - : row_source_(std::move(source)), cache_info_(std::move(cache_info)) {} - template explicit SparsePageDMatrix(AdapterT* adapter, float missing, int nthread, const std::string& cache_prefix, size_t page_size = kPageSize) : cache_info_(std::move(cache_prefix)) { - if (!data::SparsePageSource::CacheExist(cache_prefix, - ".row.page")) { - data::SparsePageSource::CreateRowPage( - adapter, missing, nthread, cache_prefix, page_size); - } - row_source_.reset( - new data::SparsePageSource(cache_prefix, ".row.page")); + row_source_.reset(new data::SparsePageSource(adapter, missing, nthread, + cache_prefix, page_size)); } - // Set number of threads but keep old value so we can reset it after + // Set number of threads but keep old value so we can reset it after ~SparsePageDMatrix() override = default; MetaInfo& Info() override; @@ -57,9 +48,9 @@ class SparsePageDMatrix : public DMatrix { BatchSet GetEllpackBatches(const BatchParam& param) override; // source data pointers. - std::unique_ptr> row_source_; - std::unique_ptr> column_source_; - std::unique_ptr> sorted_column_source_; + std::unique_ptr row_source_; + std::unique_ptr column_source_; + std::unique_ptr sorted_column_source_; std::unique_ptr ellpack_source_; // saved batch param BatchParam batch_param_; diff --git a/src/data/sparse_page_source.h b/src/data/sparse_page_source.h index 172eb14bc..63466493e 100644 --- a/src/data/sparse_page_source.h +++ b/src/data/sparse_page_source.h @@ -17,6 +17,7 @@ #include #include #include +#include #include "xgboost/base.h" #include "xgboost/data.h" @@ -24,6 +25,7 @@ #include "adapter.h" #include "sparse_page_writer.h" #include "../common/common.h" +#include namespace { @@ -49,6 +51,26 @@ GetCacheShards(const std::string& cache_info) { namespace xgboost { namespace data { +template +class SparseBatchIteratorImpl : public BatchIteratorImpl { + public: + explicit SparseBatchIteratorImpl(S* source) : source_(source) { + CHECK(source_ != nullptr); + source_->BeforeFirst(); + source_->Next(); + } + T& operator*() override { return source_->Value(); } + const T& operator*() const override { return source_->Value(); } + void operator++() override { at_end_ = !source_->Next(); } + bool AtEnd() const override { return at_end_; } + + private: + S* source_{nullptr}; + bool at_end_{ false }; +}; + + /*! \brief magic number used to identify Page */ + static const int kMagic = 0xffffab02; /*! * \brief decide the format from cache prefix. * \return pair of row format, column format type of the cache prefix. @@ -89,116 +111,149 @@ inline CacheInfo ParseCacheInfo(const std::string& cache_info, const std::string return info; } -/*! - * \brief External memory data source. - * \code - * std::unique_ptr source(new SimpleCSRSource(cache_prefix)); - * // add data to source - * DMatrix* dmat = DMatrix::Create(std::move(source)); - * \encode - */ -template -class SparsePageSource : public DataSource { - public: - /*! - * \brief Create source from cache files the cache_prefix. - * \param cache_prefix The prefix of cache we want to solve. +inline void TryDeleteCacheFile(const std::string& file) { + if (std::remove(file.c_str()) != 0) { + LOG(WARNING) << "Couldn't remove external memory cache file " << file + << "; you may want to remove it manually"; + } +} + +inline void CheckCacheFileExists(const std::string& file) { + std::ifstream f(file.c_str()); + if (f.good()) { + LOG(FATAL) << "Cache file " << file + << " exists already; Is there another DMatrix with the same " + "cache prefix? Otherwise please remove it manually."; + } +} + + /** + * \brief Given a set of cache files and page type, this object iterates over batches using prefetching for improved performance. Not thread safe. + * + * \tparam PageT Type of the page t. */ - explicit SparsePageSource(const std::string& cache_info, - const std::string& page_type) noexcept(false) + template +class ExternalMemoryPrefetcher : dmlc::DataIter { + public: + explicit ExternalMemoryPrefetcher(const CacheInfo& info) noexcept(false) : base_rowid_(0), page_(nullptr), clock_ptr_(0) { // read in the info files - std::vector cache_shards = GetCacheShards(cache_info); - CHECK_NE(cache_shards.size(), 0U); + CHECK_NE(info.name_shards.size(), 0U); { - std::string name_info = cache_shards[0]; - std::unique_ptr finfo(dmlc::Stream::Create(name_info.c_str(), "r")); + std::unique_ptr finfo( + dmlc::Stream::Create(info.name_info.c_str(), "r")); int tmagic; CHECK_EQ(finfo->Read(&tmagic, sizeof(tmagic)), sizeof(tmagic)); CHECK_EQ(tmagic, kMagic) << "invalid format, magic number mismatch"; - this->info.LoadBinary(finfo.get()); } - files_.resize(cache_shards.size()); - formats_.resize(cache_shards.size()); - prefetchers_.resize(cache_shards.size()); + files_.resize(info.name_shards.size()); + formats_.resize(info.name_shards.size()); + prefetchers_.resize(info.name_shards.size()); // read in the cache files. - for (size_t i = 0; i < cache_shards.size(); ++i) { - std::string name_row = cache_shards[i] + page_type; + for (size_t i = 0; i < info.name_shards.size(); ++i) { + std::string name_row = info.name_shards.at(i); files_[i].reset(dmlc::SeekStream::CreateForRead(name_row.c_str())); std::unique_ptr& fi = files_[i]; std::string format; CHECK(fi->Read(&format)) << "Invalid page format"; - formats_[i].reset(CreatePageFormat(format)); - std::unique_ptr>& fmt = formats_[i]; + formats_[i].reset(CreatePageFormat(format)); + std::unique_ptr>& fmt = formats_[i]; size_t fbegin = fi->Tell(); - prefetchers_[i].reset(new dmlc::ThreadedIter(4)); - prefetchers_[i]->Init([&fi, &fmt] (T** dptr) { - if (*dptr == nullptr) { - *dptr = new T(); - } - return fmt->Read(*dptr, fi.get()); - }, [&fi, fbegin] () { fi->Seek(fbegin); }); + prefetchers_[i].reset(new dmlc::ThreadedIter(4)); + prefetchers_[i]->Init( + [&fi, &fmt](PageT** dptr) { + if (*dptr == nullptr) { + *dptr = new PageT(); + } + return fmt->Read(*dptr, fi.get()); + }, + [&fi, fbegin]() { fi->Seek(fbegin); }); } } - /*! \brief destructor */ - ~SparsePageSource() override { + ~ExternalMemoryPrefetcher() override { delete page_; } // implement Next bool Next() override { + CHECK(mutex_.try_lock()) << "Multiple threads attempting to use prefetcher"; // doing clock rotation over shards. if (page_ != nullptr) { size_t n = prefetchers_.size(); prefetchers_[(clock_ptr_ + n - 1) % n]->Recycle(&page_); } + if (prefetchers_[clock_ptr_]->Next(&page_)) { page_->SetBaseRowId(base_rowid_); base_rowid_ += page_->Size(); // advance clock clock_ptr_ = (clock_ptr_ + 1) % prefetchers_.size(); + mutex_.unlock(); return true; } else { + mutex_.unlock(); return false; } } // implement BeforeFirst void BeforeFirst() override { + CHECK(mutex_.try_lock()) << "Multiple threads attempting to use prefetcher"; base_rowid_ = 0; clock_ptr_ = 0; for (auto& p : prefetchers_) { p->BeforeFirst(); } + mutex_.unlock(); } // implement Value - T& Value() { - return *page_; - } + PageT& Value() { return *page_; } - const T& Value() const override { - return *page_; - } + const PageT& Value() const override { return *page_; } + private: + std::mutex mutex_; + /*! \brief number of rows */ + size_t base_rowid_; + /*! \brief page currently on hold. */ + PageT* page_; + /*! \brief internal clock ptr */ + size_t clock_ptr_; + /*! \brief file pointer to the row blob file. */ + std::vector> files_; + /*! \brief Sparse page format file. */ + std::vector>> formats_; + /*! \brief internal prefetcher. */ + std::vector>> prefetchers_; +}; + +class SparsePageSource { + public: template - static void CreateRowPage(AdapterT* adapter, float missing, int nthread, - const std::string& cache_info, - const size_t page_size = DMatrix::kPageSize) { + SparsePageSource(AdapterT* adapter, float missing, int nthread, + const std::string& cache_info, + const size_t page_size = DMatrix::kPageSize) { const std::string page_type = ".row.page"; - auto cinfo = ParseCacheInfo(cache_info, page_type); + cache_info_ = ParseCacheInfo(cache_info, page_type); + + // Warn user if old cache files + CheckCacheFileExists(cache_info_.name_info); + for (auto file : cache_info_.name_shards) { + CheckCacheFileExists(file); + } + { - SparsePageWriter writer(cinfo.name_shards, - cinfo.format_shards, 6); + SparsePageWriter writer(cache_info_.name_shards, + cache_info_.format_shards, 6); std::shared_ptr page; writer.Alloc(&page); page->Clear(); uint64_t inferred_num_columns = 0; uint64_t inferred_num_rows = 0; - MetaInfo info; size_t bytes_write = 0; double tstart = dmlc::GetTime(); // print every 4 sec. @@ -232,7 +287,8 @@ class SparsePageSource : public DataSource { // get group for (size_t i = 0; i < batch.Size(); ++i) { const uint64_t cur_group_id = batch.Qid()[i]; - if (last_group_id == default_max || last_group_id != cur_group_id) { + if (last_group_id == default_max || + last_group_id != cur_group_id) { info.group_ptr_.push_back(group_size); } last_group_id = cur_group_id; @@ -300,61 +356,53 @@ class SparsePageSource : public DataSource { writer.PushWrite(std::move(page)); } std::unique_ptr fo( - dmlc::Stream::Create(cinfo.name_info.c_str(), "w")); + dmlc::Stream::Create(cache_info_.name_info.c_str(), "w")); int tmagic = kMagic; fo->Write(&tmagic, sizeof(tmagic)); // Either every row has query ID or none at all CHECK(qids.empty() || qids.size() == info.num_row_); info.SaveBinary(fo.get()); } - LOG(INFO) << "SparsePageSource::CreateRowPage Finished writing to " - << cinfo.name_info; - } - /*! - * \brief Create source cache by copy content from DMatrix. - * Creates transposed column page, may be sorted or not. - * \param cache_info The cache_info of cache file location. - * \param sorted Whether columns should be pre-sorted - */ - static void CreateColumnPage(DMatrix* src, - const std::string& cache_info, bool sorted) { - const std::string page_type = sorted ? ".sorted.col.page" : ".col.page"; - CreatePageFromDMatrix(src, cache_info, page_type); + LOG(INFO) << "SparsePageSource Finished writing to " + << cache_info_.name_info; + + external_prefetcher_.reset( + new ExternalMemoryPrefetcher(cache_info_)); } - /*! - * \brief Check if the cache file already exists. - * \param cache_info The cache prefix of files. - * \param page_type Type of the page. - * \return Whether cache file already exists. - */ - static bool CacheExist(const std::string& cache_info, - const std::string& page_type) { - std::vector cache_shards = GetCacheShards(cache_info); - CHECK_NE(cache_shards.size(), 0U); - { - std::string name_info = cache_shards[0]; - std::unique_ptr finfo(dmlc::Stream::Create(name_info.c_str(), "r", true)); - if (finfo == nullptr) return false; + ~SparsePageSource() { + external_prefetcher_.reset(); + TryDeleteCacheFile(cache_info_.name_info); + for (auto file : cache_info_.name_shards) { + TryDeleteCacheFile(file); } - for (const std::string& prefix : cache_shards) { - std::string name_row = prefix + page_type; - std::unique_ptr frow(dmlc::Stream::Create(name_row.c_str(), "r", true)); - if (frow == nullptr) return false; - } - return true; } - /*! \brief magic number used to identify Page */ - static const int kMagic = 0xffffab02; + BatchSet GetBatchSet() { + auto begin_iter = BatchIterator( + new SparseBatchIteratorImpl, + SparsePage>(external_prefetcher_.get())); + return BatchSet(begin_iter); + } + MetaInfo info; private: - static void CreatePageFromDMatrix(DMatrix* src, const std::string& cache_info, - const std::string& page_type, - const size_t page_size = DMatrix::kPageSize) { - auto cinfo = ParseCacheInfo(cache_info, page_type); + std::unique_ptr> external_prefetcher_; + CacheInfo cache_info_; +}; + +class CSCPageSource { + public: + CSCPageSource(DMatrix* src, const std::string& cache_info, + const size_t page_size = DMatrix::kPageSize) { + std::string page_type = ".col.page"; + cache_info_ = ParseCacheInfo(cache_info, page_type); + for (auto file : cache_info_.name_shards) { + CheckCacheFileExists(file); + } { - SparsePageWriter writer(cinfo.name_shards, cinfo.format_shards, 6); + SparsePageWriter writer(cache_info_.name_shards, + cache_info_.format_shards, 6); std::shared_ptr page; writer.Alloc(&page); page->Clear(); @@ -362,15 +410,7 @@ class SparsePageSource : public DataSource { size_t bytes_write = 0; double tstart = dmlc::GetTime(); for (auto& batch : src->GetBatches()) { - if (page_type == ".col.page") { - page->PushCSC(batch.GetTranspose(src->Info().num_col_)); - } else if (page_type == ".sorted.col.page") { - SparsePage tmp = batch.GetTranspose(src->Info().num_col_); - page->PushCSC(tmp); - page->SortRows(); - } else { - LOG(FATAL) << "Unknown page type: " << page_type; - } + page->PushCSC(batch.GetTranspose(src->Info().num_col_)); if (page->MemCostBytes() >= page_size) { bytes_write += page->MemCostBytes(); @@ -386,23 +426,94 @@ class SparsePageSource : public DataSource { if (page->data.Size() != 0) { writer.PushWrite(std::move(page)); } + LOG(INFO) << "CSCPageSource: Finished writing to " + << cache_info_.name_info; } - LOG(INFO) << "SparsePageSource: Finished writing to " << cinfo.name_info; + external_prefetcher_.reset( + new ExternalMemoryPrefetcher(cache_info_)); } - /*! \brief number of rows */ - size_t base_rowid_; - /*! \brief page currently on hold. */ - T* page_; - /*! \brief internal clock ptr */ - size_t clock_ptr_; - /*! \brief file pointer to the row blob file. */ - std::vector> files_; - /*! \brief Sparse page format file. */ - std::vector>> formats_; - /*! \brief internal prefetcher. */ - std::vector>> prefetchers_; + ~CSCPageSource() { + external_prefetcher_.reset(); + for (auto file : cache_info_.name_shards) { + TryDeleteCacheFile(file); + } + } + + BatchSet GetBatchSet() { + auto begin_iter = BatchIterator( + new SparseBatchIteratorImpl, CSCPage>( + external_prefetcher_.get())); + return BatchSet(begin_iter); + } + + private: + std::unique_ptr> external_prefetcher_; + CacheInfo cache_info_; }; + +class SortedCSCPageSource { + public: + SortedCSCPageSource(DMatrix* src, const std::string& cache_info, + const size_t page_size = DMatrix::kPageSize) { + std::string page_type = ".sorted.col.page"; + cache_info_ = ParseCacheInfo(cache_info, page_type); + for (auto file : cache_info_.name_shards) { + CheckCacheFileExists(file); + } + { + SparsePageWriter writer(cache_info_.name_shards, + cache_info_.format_shards, 6); + std::shared_ptr page; + writer.Alloc(&page); + page->Clear(); + + size_t bytes_write = 0; + double tstart = dmlc::GetTime(); + for (auto& batch : src->GetBatches()) { + SparsePage tmp = batch.GetTranspose(src->Info().num_col_); + page->PushCSC(tmp); + page->SortRows(); + + if (page->MemCostBytes() >= page_size) { + bytes_write += page->MemCostBytes(); + writer.PushWrite(std::move(page)); + writer.Alloc(&page); + page->Clear(); + double tdiff = dmlc::GetTime() - tstart; + LOG(INFO) << "Writing to " << cache_info << " in " + << ((bytes_write >> 20UL) / tdiff) << " MB/s, " + << (bytes_write >> 20UL) << " written"; + } + } + if (page->data.Size() != 0) { + writer.PushWrite(std::move(page)); + } + LOG(INFO) << "SortedCSCPageSource: Finished writing to " + << cache_info_.name_info; + } + external_prefetcher_.reset( + new ExternalMemoryPrefetcher(cache_info_)); + } + ~SortedCSCPageSource() { + external_prefetcher_.reset(); + for (auto file : cache_info_.name_shards) { + TryDeleteCacheFile(file); + } + } + + BatchSet GetBatchSet() { + auto begin_iter = BatchIterator( + new SparseBatchIteratorImpl, + SortedCSCPage>(external_prefetcher_.get())); + return BatchSet(begin_iter); + } + + private: + std::unique_ptr> external_prefetcher_; + CacheInfo cache_info_; +}; + } // namespace data } // namespace xgboost #endif // XGBOOST_DATA_SPARSE_PAGE_SOURCE_H_ diff --git a/tests/cpp/data/test_sparse_page_dmatrix.cc b/tests/cpp/data/test_sparse_page_dmatrix.cc index c64b04ce5..f1356eddb 100644 --- a/tests/cpp/data/test_sparse_page_dmatrix.cc +++ b/tests/cpp/data/test_sparse_page_dmatrix.cc @@ -1,10 +1,10 @@ // Copyright by Contributors #include -#include -#include "../../../src/data/sparse_page_dmatrix.h" -#include "../../../src/data/adapter.h" -#include "../helpers.h" #include +#include +#include "../../../src/data/adapter.h" +#include "../../../src/data/sparse_page_dmatrix.h" +#include "../helpers.h" using namespace xgboost; // NOLINT @@ -12,8 +12,8 @@ TEST(SparsePageDMatrix, MetaInfo) { dmlc::TemporaryDirectory tempdir; const std::string tmp_file = tempdir.path + "/simple.libsvm"; CreateSimpleTestData(tmp_file); - xgboost::DMatrix * dmat = xgboost::DMatrix::Load( - tmp_file + "#" + tmp_file + ".cache", false, false); + xgboost::DMatrix *dmat = xgboost::DMatrix::Load( + tmp_file + "#" + tmp_file + ".cache", false, false); std::cout << tmp_file << std::endl; EXPECT_TRUE(FileExists(tmp_file + ".cache")); @@ -44,21 +44,21 @@ TEST(SparsePageDMatrix, ColAccess) { dmlc::TemporaryDirectory tempdir; const std::string tmp_file = tempdir.path + "/simple.libsvm"; CreateSimpleTestData(tmp_file); - xgboost::DMatrix * dmat = xgboost::DMatrix::Load( - tmp_file + "#" + tmp_file + ".cache", true, false); + xgboost::DMatrix *dmat = + xgboost::DMatrix::Load(tmp_file + "#" + tmp_file + ".cache", true, false); EXPECT_EQ(dmat->GetColDensity(0), 1); EXPECT_EQ(dmat->GetColDensity(1), 0.5); // Loop over the batches and assert the data is as expected - for (auto const& col_batch : dmat->GetBatches()) { + for (auto const &col_batch : dmat->GetBatches()) { EXPECT_EQ(col_batch.Size(), dmat->Info().num_col_); EXPECT_EQ(col_batch[1][0].fvalue, 10.0f); EXPECT_EQ(col_batch[1].size(), 1); } // Loop over the batches and assert the data is as expected - for (auto const& col_batch : dmat->GetBatches()) { + for (auto const &col_batch : dmat->GetBatches()) { EXPECT_EQ(col_batch.Size(), dmat->Info().num_col_); EXPECT_EQ(col_batch[1][0].fvalue, 10.0f); EXPECT_EQ(col_batch[1].size(), 1); @@ -70,25 +70,61 @@ TEST(SparsePageDMatrix, ColAccess) { EXPECT_TRUE(FileExists(tmp_file + ".cache.sorted.col.page")); delete dmat; + + EXPECT_FALSE(FileExists(tmp_file + ".cache")); + EXPECT_FALSE(FileExists(tmp_file + ".cache.row.page")); + EXPECT_FALSE(FileExists(tmp_file + ".cache.col.page")); + EXPECT_FALSE(FileExists(tmp_file + ".cache.sorted.col.page")); } +TEST(SparsePageDMatrix, ExistingCacheFile) { + dmlc::TemporaryDirectory tmpdir; + std::string filename = tmpdir.path + "/big.libsvm"; + std::unique_ptr dmat = + xgboost::CreateSparsePageDMatrix(12, 64, filename); + EXPECT_ANY_THROW({ + std::unique_ptr dmat2 = + xgboost::CreateSparsePageDMatrix(12, 64, filename); + }); +} + +#if defined(_OPENMP) +TEST(SparsePageDMatrix, ThreadSafetyException) { + dmlc::TemporaryDirectory tmpdir; + std::string filename = tmpdir.path + "/test"; + std::unique_ptr dmat = + xgboost::CreateSparsePageDMatrix(12, 64, filename); + + bool exception = false; + int threads = 1000; +#pragma omp parallel for + for (auto i = 0; i < threads; i++) { + try { + auto iter = dmat->GetBatches().begin(); + ++iter; + } catch (...) { + exception = true; + } + } + EXPECT_TRUE(exception); +} +#endif + // Multi-batches access TEST(SparsePageDMatrix, ColAccessBatches) { dmlc::TemporaryDirectory tmpdir; std::string filename = tmpdir.path + "/big.libsvm"; // Create multiple sparse pages - std::unique_ptr dmat { - xgboost::CreateSparsePageDMatrix(1024, 1024, filename) - }; + std::unique_ptr dmat{ + xgboost::CreateSparsePageDMatrix(1024, 1024, filename)}; auto n_threads = omp_get_max_threads(); omp_set_num_threads(16); - for (auto const& page : dmat->GetBatches()) { + for (auto const &page : dmat->GetBatches()) { ASSERT_EQ(dmat->Info().num_col_, page.Size()); } omp_set_num_threads(n_threads); } - TEST(SparsePageDMatrix, Empty) { dmlc::TemporaryDirectory tempdir; const std::string tmp_file = tempdir.path + "/simple.libsvm"; @@ -96,34 +132,40 @@ TEST(SparsePageDMatrix, Empty) { std::vector feature_idx = {}; std::vector row_ptr = {}; - data::CSRAdapter csr_adapter(row_ptr.data(), feature_idx.data(), data.data(), 0, 0, 0); - data::SparsePageDMatrix dmat(&csr_adapter, - std::numeric_limits::quiet_NaN(), 1,tmp_file); - EXPECT_EQ(dmat.Info().num_nonzero_, 0); - EXPECT_EQ(dmat.Info().num_row_, 0); - EXPECT_EQ(dmat.Info().num_col_, 0); - for (auto &batch : dmat.GetBatches()) { - EXPECT_EQ(batch.Size(), 0); + { + data::CSRAdapter csr_adapter(row_ptr.data(), feature_idx.data(), + data.data(), 0, 0, 0); + data::SparsePageDMatrix dmat( + &csr_adapter, std::numeric_limits::quiet_NaN(), 1, tmp_file); + EXPECT_EQ(dmat.Info().num_nonzero_, 0); + EXPECT_EQ(dmat.Info().num_row_, 0); + EXPECT_EQ(dmat.Info().num_col_, 0); + for (auto &batch : dmat.GetBatches()) { + EXPECT_EQ(batch.Size(), 0); + } } - data::DenseAdapter dense_adapter(nullptr, 0, 0); - data::SparsePageDMatrix dmat2(&dense_adapter, - std::numeric_limits::quiet_NaN(), 1,tmp_file); - EXPECT_EQ(dmat2.Info().num_nonzero_, 0); - EXPECT_EQ(dmat2.Info().num_row_, 0); - EXPECT_EQ(dmat2.Info().num_col_, 0); - for (auto &batch : dmat2.GetBatches()) { - EXPECT_EQ(batch.Size(), 0); + { + data::DenseAdapter dense_adapter(nullptr, 0, 0); + data::SparsePageDMatrix dmat2( + &dense_adapter, std::numeric_limits::quiet_NaN(), 1, tmp_file); + EXPECT_EQ(dmat2.Info().num_nonzero_, 0); + EXPECT_EQ(dmat2.Info().num_row_, 0); + EXPECT_EQ(dmat2.Info().num_col_, 0); + for (auto &batch : dmat2.GetBatches()) { + EXPECT_EQ(batch.Size(), 0); + } } - - data::CSCAdapter csc_adapter(nullptr, nullptr, nullptr, 0, 0); - data::SparsePageDMatrix dmat3(&csc_adapter, - std::numeric_limits::quiet_NaN(), 1,tmp_file); - EXPECT_EQ(dmat3.Info().num_nonzero_, 0); - EXPECT_EQ(dmat3.Info().num_row_, 0); - EXPECT_EQ(dmat3.Info().num_col_, 0); - for (auto &batch : dmat3.GetBatches()) { - EXPECT_EQ(batch.Size(), 0); + { + data::CSCAdapter csc_adapter(nullptr, nullptr, nullptr, 0, 0); + data::SparsePageDMatrix dmat3( + &csc_adapter, std::numeric_limits::quiet_NaN(), 1, tmp_file); + EXPECT_EQ(dmat3.Info().num_nonzero_, 0); + EXPECT_EQ(dmat3.Info().num_row_, 0); + EXPECT_EQ(dmat3.Info().num_col_, 0); + for (auto &batch : dmat3.GetBatches()) { + EXPECT_EQ(batch.Size(), 0); + } } } @@ -134,12 +176,14 @@ TEST(SparsePageDMatrix, MissingData) { std::vector feature_idx = {0, 1, 0}; std::vector row_ptr = {0, 2, 3}; - data::CSRAdapter adapter(row_ptr.data(), feature_idx.data(), data.data(), 2, 3, 2); - data::SparsePageDMatrix dmat(&adapter, std::numeric_limits::quiet_NaN(), 1,tmp_file); + data::CSRAdapter adapter(row_ptr.data(), feature_idx.data(), data.data(), 2, + 3, 2); + data::SparsePageDMatrix dmat( + &adapter, std::numeric_limits::quiet_NaN(), 1, tmp_file); EXPECT_EQ(dmat.Info().num_nonzero_, 2); const std::string tmp_file2 = tempdir.path + "/simple2.libsvm"; - data::SparsePageDMatrix dmat2(&adapter, 1.0, 1,tmp_file2); + data::SparsePageDMatrix dmat2(&adapter, 1.0, 1, tmp_file2); EXPECT_EQ(dmat2.Info().num_nonzero_, 1); } @@ -150,8 +194,10 @@ TEST(SparsePageDMatrix, EmptyRow) { std::vector feature_idx = {0, 1}; std::vector row_ptr = {0, 2, 2}; - data::CSRAdapter adapter(row_ptr.data(), feature_idx.data(), data.data(), 2, 2, 2); - data::SparsePageDMatrix dmat(&adapter, std::numeric_limits::quiet_NaN(), 1,tmp_file); + data::CSRAdapter adapter(row_ptr.data(), feature_idx.data(), data.data(), 2, + 2, 2); + data::SparsePageDMatrix dmat( + &adapter, std::numeric_limits::quiet_NaN(), 1, tmp_file); EXPECT_EQ(dmat.Info().num_nonzero_, 2); EXPECT_EQ(dmat.Info().num_row_, 2); EXPECT_EQ(dmat.Info().num_col_, 2); @@ -173,9 +219,8 @@ TEST(SparsePageDMatrix, FromDense) { for (auto &batch : dmat.GetBatches()) { for (auto i = 0ull; i < batch.Size(); i++) { auto inst = batch[i]; - for(auto j = 0ull; j < inst.size(); j++) - { - EXPECT_EQ(inst[j].fvalue, data[i*n+j]); + for (auto j = 0ull; j < inst.size(); j++) { + EXPECT_EQ(inst[j].fvalue, data[i * n + j]); EXPECT_EQ(inst[j].index, j); } } @@ -215,9 +260,9 @@ TEST(SparsePageDMatrix, FromCSC) { TEST(SparsePageDMatrix, FromFile) { std::string filename = "test.libsvm"; - CreateBigTestData(filename,20); + CreateBigTestData(filename, 20); std::unique_ptr> parser( - dmlc::Parser::Create(filename.c_str(), 0, 1, "auto")); + dmlc::Parser::Create(filename.c_str(), 0, 1, "auto")); data::FileAdapter adapter(parser.get()); dmlc::TemporaryDirectory tempdir; const std::string tmp_file = tempdir.path + "/simple.libsvm";