Refactor SparsePageSource, delete cache files after use (#5321)

* Refactor sparse page source

* Delete temporary cache files

* Log fatal if cache exists

* Log fatal if multiple threads used with prefetcher
This commit is contained in:
Rory Mitchell 2020-02-19 16:43:41 +13:00 committed by GitHub
parent b2b2c4e231
commit bc96ceb8b2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 336 additions and 221 deletions

View File

@ -45,7 +45,7 @@ class EllpackPageSourceImpl : public DataSource<EllpackPage> {
dh::BulkAllocator ba_;
/*! \brief The EllpackInfo, with the underlying GPU memory shared by all pages. */
EllpackInfo ellpack_info_;
std::unique_ptr<SparsePageSource<EllpackPage>> source_;
std::unique_ptr<ExternalMemoryPrefetcher<EllpackPage>> source_;
std::string cache_info_;
};
@ -98,11 +98,13 @@ EllpackPageSourceImpl::EllpackPageSourceImpl(DMatrix* dmat,
WriteEllpackPages(dmat, cache_info);
monitor_.StopCuda("WriteEllpackPages");
source_.reset(new SparsePageSource<EllpackPage>(cache_info_, kPageType_));
source_.reset(new ExternalMemoryPrefetcher<EllpackPage>(
ParseCacheInfo(cache_info_, kPageType_)));
}
void EllpackPageSourceImpl::BeforeFirst() {
source_.reset(new SparsePageSource<EllpackPage>(cache_info_, kPageType_));
source_.reset(new ExternalMemoryPrefetcher<EllpackPage>(
ParseCacheInfo(cache_info_, kPageType_)));
source_->BeforeFirst();
}

View File

@ -23,58 +23,24 @@ const MetaInfo& SparsePageDMatrix::Info() const {
return row_source_->info;
}
template<typename S, typename T>
class SparseBatchIteratorImpl : public BatchIteratorImpl<T> {
public:
explicit SparseBatchIteratorImpl(S* source) : source_(source) {
CHECK(source_ != nullptr);
}
T& operator*() override { return source_->Value(); }
const T& operator*() const override { return source_->Value(); }
void operator++() override { at_end_ = !source_->Next(); }
bool AtEnd() const override { return at_end_; }
private:
S* source_{nullptr};
bool at_end_{ false };
};
BatchSet<SparsePage> SparsePageDMatrix::GetRowBatches() {
auto cast = dynamic_cast<SparsePageSource<SparsePage>*>(row_source_.get());
CHECK(cast);
cast->BeforeFirst();
cast->Next();
auto begin_iter = BatchIterator<SparsePage>(
new SparseBatchIteratorImpl<SparsePageSource<SparsePage>, SparsePage>(cast));
return BatchSet<SparsePage>(begin_iter);
return row_source_->GetBatchSet();
}
BatchSet<CSCPage> SparsePageDMatrix::GetColumnBatches() {
// Lazily instantiate
if (!column_source_) {
SparsePageSource<SparsePage>::CreateColumnPage(this, cache_info_, false);
column_source_.reset(new SparsePageSource<CSCPage>(cache_info_, ".col.page"));
column_source_.reset(new CSCPageSource(this, cache_info_));
}
column_source_->BeforeFirst();
column_source_->Next();
auto begin_iter = BatchIterator<CSCPage>(
new SparseBatchIteratorImpl<SparsePageSource<CSCPage>, CSCPage>(column_source_.get()));
return BatchSet<CSCPage>(begin_iter);
return column_source_->GetBatchSet();
}
BatchSet<SortedCSCPage> SparsePageDMatrix::GetSortedColumnBatches() {
// Lazily instantiate
if (!sorted_column_source_) {
SparsePageSource<SparsePage>::CreateColumnPage(this, cache_info_, true);
sorted_column_source_.reset(
new SparsePageSource<SortedCSCPage>(cache_info_, ".sorted.col.page"));
sorted_column_source_.reset(new SortedCSCPageSource(this, cache_info_));
}
sorted_column_source_->BeforeFirst();
sorted_column_source_->Next();
auto begin_iter = BatchIterator<SortedCSCPage>(
new SparseBatchIteratorImpl<SparsePageSource<SortedCSCPage>, SortedCSCPage>(
sorted_column_source_.get()));
return BatchSet<SortedCSCPage>(begin_iter);
return sorted_column_source_->GetBatchSet();
}
BatchSet<EllpackPage> SparsePageDMatrix::GetEllpackBatches(const BatchParam& param) {

View File

@ -22,22 +22,13 @@ namespace data {
// Used for external memory.
class SparsePageDMatrix : public DMatrix {
public:
explicit SparsePageDMatrix(std::unique_ptr<DataSource<SparsePage>>&& source,
std::string cache_info)
: row_source_(std::move(source)), cache_info_(std::move(cache_info)) {}
template <typename AdapterT>
explicit SparsePageDMatrix(AdapterT* adapter, float missing, int nthread,
const std::string& cache_prefix,
size_t page_size = kPageSize)
: cache_info_(std::move(cache_prefix)) {
if (!data::SparsePageSource<SparsePage>::CacheExist(cache_prefix,
".row.page")) {
data::SparsePageSource<SparsePage>::CreateRowPage(
adapter, missing, nthread, cache_prefix, page_size);
}
row_source_.reset(
new data::SparsePageSource<SparsePage>(cache_prefix, ".row.page"));
row_source_.reset(new data::SparsePageSource(adapter, missing, nthread,
cache_prefix, page_size));
}
// Set number of threads but keep old value so we can reset it after
~SparsePageDMatrix() override = default;
@ -57,9 +48,9 @@ class SparsePageDMatrix : public DMatrix {
BatchSet<EllpackPage> GetEllpackBatches(const BatchParam& param) override;
// source data pointers.
std::unique_ptr<DataSource<SparsePage>> row_source_;
std::unique_ptr<SparsePageSource<CSCPage>> column_source_;
std::unique_ptr<SparsePageSource<SortedCSCPage>> sorted_column_source_;
std::unique_ptr<SparsePageSource> row_source_;
std::unique_ptr<CSCPageSource> column_source_;
std::unique_ptr<SortedCSCPageSource> sorted_column_source_;
std::unique_ptr<EllpackPageSource> ellpack_source_;
// saved batch param
BatchParam batch_param_;

View File

@ -17,6 +17,7 @@
#include <string>
#include <utility>
#include <vector>
#include <fstream>
#include "xgboost/base.h"
#include "xgboost/data.h"
@ -24,6 +25,7 @@
#include "adapter.h"
#include "sparse_page_writer.h"
#include "../common/common.h"
#include <xgboost/data.h>
namespace {
@ -49,6 +51,26 @@ GetCacheShards(const std::string& cache_info) {
namespace xgboost {
namespace data {
template<typename S, typename T>
class SparseBatchIteratorImpl : public BatchIteratorImpl<T> {
public:
explicit SparseBatchIteratorImpl(S* source) : source_(source) {
CHECK(source_ != nullptr);
source_->BeforeFirst();
source_->Next();
}
T& operator*() override { return source_->Value(); }
const T& operator*() const override { return source_->Value(); }
void operator++() override { at_end_ = !source_->Next(); }
bool AtEnd() const override { return at_end_; }
private:
S* source_{nullptr};
bool at_end_{ false };
};
/*! \brief magic number used to identify Page */
static const int kMagic = 0xffffab02;
/*!
* \brief decide the format from cache prefix.
* \return pair of row format, column format type of the cache prefix.
@ -89,116 +111,149 @@ inline CacheInfo ParseCacheInfo(const std::string& cache_info, const std::string
return info;
}
/*!
* \brief External memory data source.
* \code
* std::unique_ptr<DataSource> source(new SimpleCSRSource(cache_prefix));
* // add data to source
* DMatrix* dmat = DMatrix::Create(std::move(source));
* \encode
inline void TryDeleteCacheFile(const std::string& file) {
if (std::remove(file.c_str()) != 0) {
LOG(WARNING) << "Couldn't remove external memory cache file " << file
<< "; you may want to remove it manually";
}
}
inline void CheckCacheFileExists(const std::string& file) {
std::ifstream f(file.c_str());
if (f.good()) {
LOG(FATAL) << "Cache file " << file
<< " exists already; Is there another DMatrix with the same "
"cache prefix? Otherwise please remove it manually.";
}
}
/**
* \brief Given a set of cache files and page type, this object iterates over batches using prefetching for improved performance. Not thread safe.
*
* \tparam PageT Type of the page t.
*/
template<typename T>
class SparsePageSource : public DataSource<T> {
template <typename PageT>
class ExternalMemoryPrefetcher : dmlc::DataIter<PageT> {
public:
/*!
* \brief Create source from cache files the cache_prefix.
* \param cache_prefix The prefix of cache we want to solve.
*/
explicit SparsePageSource(const std::string& cache_info,
const std::string& page_type) noexcept(false)
explicit ExternalMemoryPrefetcher(const CacheInfo& info) noexcept(false)
: base_rowid_(0), page_(nullptr), clock_ptr_(0) {
// read in the info files
std::vector<std::string> cache_shards = GetCacheShards(cache_info);
CHECK_NE(cache_shards.size(), 0U);
CHECK_NE(info.name_shards.size(), 0U);
{
std::string name_info = cache_shards[0];
std::unique_ptr<dmlc::Stream> finfo(dmlc::Stream::Create(name_info.c_str(), "r"));
std::unique_ptr<dmlc::Stream> finfo(
dmlc::Stream::Create(info.name_info.c_str(), "r"));
int tmagic;
CHECK_EQ(finfo->Read(&tmagic, sizeof(tmagic)), sizeof(tmagic));
CHECK_EQ(tmagic, kMagic) << "invalid format, magic number mismatch";
this->info.LoadBinary(finfo.get());
}
files_.resize(cache_shards.size());
formats_.resize(cache_shards.size());
prefetchers_.resize(cache_shards.size());
files_.resize(info.name_shards.size());
formats_.resize(info.name_shards.size());
prefetchers_.resize(info.name_shards.size());
// read in the cache files.
for (size_t i = 0; i < cache_shards.size(); ++i) {
std::string name_row = cache_shards[i] + page_type;
for (size_t i = 0; i < info.name_shards.size(); ++i) {
std::string name_row = info.name_shards.at(i);
files_[i].reset(dmlc::SeekStream::CreateForRead(name_row.c_str()));
std::unique_ptr<dmlc::SeekStream>& fi = files_[i];
std::string format;
CHECK(fi->Read(&format)) << "Invalid page format";
formats_[i].reset(CreatePageFormat<T>(format));
std::unique_ptr<SparsePageFormat<T>>& fmt = formats_[i];
formats_[i].reset(CreatePageFormat<PageT>(format));
std::unique_ptr<SparsePageFormat<PageT>>& fmt = formats_[i];
size_t fbegin = fi->Tell();
prefetchers_[i].reset(new dmlc::ThreadedIter<T>(4));
prefetchers_[i]->Init([&fi, &fmt] (T** dptr) {
prefetchers_[i].reset(new dmlc::ThreadedIter<PageT>(4));
prefetchers_[i]->Init(
[&fi, &fmt](PageT** dptr) {
if (*dptr == nullptr) {
*dptr = new T();
*dptr = new PageT();
}
return fmt->Read(*dptr, fi.get());
}, [&fi, fbegin] () { fi->Seek(fbegin); });
},
[&fi, fbegin]() { fi->Seek(fbegin); });
}
}
/*! \brief destructor */
~SparsePageSource() override {
~ExternalMemoryPrefetcher() override {
delete page_;
}
// implement Next
bool Next() override {
CHECK(mutex_.try_lock()) << "Multiple threads attempting to use prefetcher";
// doing clock rotation over shards.
if (page_ != nullptr) {
size_t n = prefetchers_.size();
prefetchers_[(clock_ptr_ + n - 1) % n]->Recycle(&page_);
}
if (prefetchers_[clock_ptr_]->Next(&page_)) {
page_->SetBaseRowId(base_rowid_);
base_rowid_ += page_->Size();
// advance clock
clock_ptr_ = (clock_ptr_ + 1) % prefetchers_.size();
mutex_.unlock();
return true;
} else {
mutex_.unlock();
return false;
}
}
// implement BeforeFirst
void BeforeFirst() override {
CHECK(mutex_.try_lock()) << "Multiple threads attempting to use prefetcher";
base_rowid_ = 0;
clock_ptr_ = 0;
for (auto& p : prefetchers_) {
p->BeforeFirst();
}
mutex_.unlock();
}
// implement Value
T& Value() {
return *page_;
}
PageT& Value() { return *page_; }
const T& Value() const override {
return *page_;
}
const PageT& Value() const override { return *page_; }
private:
std::mutex mutex_;
/*! \brief number of rows */
size_t base_rowid_;
/*! \brief page currently on hold. */
PageT* page_;
/*! \brief internal clock ptr */
size_t clock_ptr_;
/*! \brief file pointer to the row blob file. */
std::vector<std::unique_ptr<dmlc::SeekStream>> files_;
/*! \brief Sparse page format file. */
std::vector<std::unique_ptr<SparsePageFormat<PageT>>> formats_;
/*! \brief internal prefetcher. */
std::vector<std::unique_ptr<dmlc::ThreadedIter<PageT>>> prefetchers_;
};
class SparsePageSource {
public:
template <typename AdapterT>
static void CreateRowPage(AdapterT* adapter, float missing, int nthread,
SparsePageSource(AdapterT* adapter, float missing, int nthread,
const std::string& cache_info,
const size_t page_size = DMatrix::kPageSize) {
const std::string page_type = ".row.page";
auto cinfo = ParseCacheInfo(cache_info, page_type);
cache_info_ = ParseCacheInfo(cache_info, page_type);
// Warn user if old cache files
CheckCacheFileExists(cache_info_.name_info);
for (auto file : cache_info_.name_shards) {
CheckCacheFileExists(file);
}
{
SparsePageWriter<SparsePage> writer(cinfo.name_shards,
cinfo.format_shards, 6);
SparsePageWriter<SparsePage> writer(cache_info_.name_shards,
cache_info_.format_shards, 6);
std::shared_ptr<SparsePage> page;
writer.Alloc(&page);
page->Clear();
uint64_t inferred_num_columns = 0;
uint64_t inferred_num_rows = 0;
MetaInfo info;
size_t bytes_write = 0;
double tstart = dmlc::GetTime();
// print every 4 sec.
@ -232,7 +287,8 @@ class SparsePageSource : public DataSource<T> {
// get group
for (size_t i = 0; i < batch.Size(); ++i) {
const uint64_t cur_group_id = batch.Qid()[i];
if (last_group_id == default_max || last_group_id != cur_group_id) {
if (last_group_id == default_max ||
last_group_id != cur_group_id) {
info.group_ptr_.push_back(group_size);
}
last_group_id = cur_group_id;
@ -300,61 +356,53 @@ class SparsePageSource : public DataSource<T> {
writer.PushWrite(std::move(page));
}
std::unique_ptr<dmlc::Stream> fo(
dmlc::Stream::Create(cinfo.name_info.c_str(), "w"));
dmlc::Stream::Create(cache_info_.name_info.c_str(), "w"));
int tmagic = kMagic;
fo->Write(&tmagic, sizeof(tmagic));
// Either every row has query ID or none at all
CHECK(qids.empty() || qids.size() == info.num_row_);
info.SaveBinary(fo.get());
}
LOG(INFO) << "SparsePageSource::CreateRowPage Finished writing to "
<< cinfo.name_info;
}
/*!
* \brief Create source cache by copy content from DMatrix.
* Creates transposed column page, may be sorted or not.
* \param cache_info The cache_info of cache file location.
* \param sorted Whether columns should be pre-sorted
*/
static void CreateColumnPage(DMatrix* src,
const std::string& cache_info, bool sorted) {
const std::string page_type = sorted ? ".sorted.col.page" : ".col.page";
CreatePageFromDMatrix(src, cache_info, page_type);
LOG(INFO) << "SparsePageSource Finished writing to "
<< cache_info_.name_info;
external_prefetcher_.reset(
new ExternalMemoryPrefetcher<SparsePage>(cache_info_));
}
/*!
* \brief Check if the cache file already exists.
* \param cache_info The cache prefix of files.
* \param page_type Type of the page.
* \return Whether cache file already exists.
*/
static bool CacheExist(const std::string& cache_info,
const std::string& page_type) {
std::vector<std::string> cache_shards = GetCacheShards(cache_info);
CHECK_NE(cache_shards.size(), 0U);
{
std::string name_info = cache_shards[0];
std::unique_ptr<dmlc::Stream> finfo(dmlc::Stream::Create(name_info.c_str(), "r", true));
if (finfo == nullptr) return false;
~SparsePageSource() {
external_prefetcher_.reset();
TryDeleteCacheFile(cache_info_.name_info);
for (auto file : cache_info_.name_shards) {
TryDeleteCacheFile(file);
}
for (const std::string& prefix : cache_shards) {
std::string name_row = prefix + page_type;
std::unique_ptr<dmlc::Stream> frow(dmlc::Stream::Create(name_row.c_str(), "r", true));
if (frow == nullptr) return false;
}
return true;
}
/*! \brief magic number used to identify Page */
static const int kMagic = 0xffffab02;
BatchSet<SparsePage> GetBatchSet() {
auto begin_iter = BatchIterator<SparsePage>(
new SparseBatchIteratorImpl<ExternalMemoryPrefetcher<SparsePage>,
SparsePage>(external_prefetcher_.get()));
return BatchSet<SparsePage>(begin_iter);
}
MetaInfo info;
private:
static void CreatePageFromDMatrix(DMatrix* src, const std::string& cache_info,
const std::string& page_type,
std::unique_ptr<ExternalMemoryPrefetcher<SparsePage>> external_prefetcher_;
CacheInfo cache_info_;
};
class CSCPageSource {
public:
CSCPageSource(DMatrix* src, const std::string& cache_info,
const size_t page_size = DMatrix::kPageSize) {
auto cinfo = ParseCacheInfo(cache_info, page_type);
std::string page_type = ".col.page";
cache_info_ = ParseCacheInfo(cache_info, page_type);
for (auto file : cache_info_.name_shards) {
CheckCacheFileExists(file);
}
{
SparsePageWriter<SparsePage> writer(cinfo.name_shards, cinfo.format_shards, 6);
SparsePageWriter<SparsePage> writer(cache_info_.name_shards,
cache_info_.format_shards, 6);
std::shared_ptr<SparsePage> page;
writer.Alloc(&page);
page->Clear();
@ -362,15 +410,7 @@ class SparsePageSource : public DataSource<T> {
size_t bytes_write = 0;
double tstart = dmlc::GetTime();
for (auto& batch : src->GetBatches<SparsePage>()) {
if (page_type == ".col.page") {
page->PushCSC(batch.GetTranspose(src->Info().num_col_));
} else if (page_type == ".sorted.col.page") {
SparsePage tmp = batch.GetTranspose(src->Info().num_col_);
page->PushCSC(tmp);
page->SortRows();
} else {
LOG(FATAL) << "Unknown page type: " << page_type;
}
if (page->MemCostBytes() >= page_size) {
bytes_write += page->MemCostBytes();
@ -386,23 +426,94 @@ class SparsePageSource : public DataSource<T> {
if (page->data.Size() != 0) {
writer.PushWrite(std::move(page));
}
LOG(INFO) << "CSCPageSource: Finished writing to "
<< cache_info_.name_info;
}
LOG(INFO) << "SparsePageSource: Finished writing to " << cinfo.name_info;
external_prefetcher_.reset(
new ExternalMemoryPrefetcher<CSCPage>(cache_info_));
}
/*! \brief number of rows */
size_t base_rowid_;
/*! \brief page currently on hold. */
T* page_;
/*! \brief internal clock ptr */
size_t clock_ptr_;
/*! \brief file pointer to the row blob file. */
std::vector<std::unique_ptr<dmlc::SeekStream>> files_;
/*! \brief Sparse page format file. */
std::vector<std::unique_ptr<SparsePageFormat<T>>> formats_;
/*! \brief internal prefetcher. */
std::vector<std::unique_ptr<dmlc::ThreadedIter<T>>> prefetchers_;
~CSCPageSource() {
external_prefetcher_.reset();
for (auto file : cache_info_.name_shards) {
TryDeleteCacheFile(file);
}
}
BatchSet<CSCPage> GetBatchSet() {
auto begin_iter = BatchIterator<CSCPage>(
new SparseBatchIteratorImpl<ExternalMemoryPrefetcher<CSCPage>, CSCPage>(
external_prefetcher_.get()));
return BatchSet<CSCPage>(begin_iter);
}
private:
std::unique_ptr<ExternalMemoryPrefetcher<CSCPage>> external_prefetcher_;
CacheInfo cache_info_;
};
class SortedCSCPageSource {
public:
SortedCSCPageSource(DMatrix* src, const std::string& cache_info,
const size_t page_size = DMatrix::kPageSize) {
std::string page_type = ".sorted.col.page";
cache_info_ = ParseCacheInfo(cache_info, page_type);
for (auto file : cache_info_.name_shards) {
CheckCacheFileExists(file);
}
{
SparsePageWriter<SparsePage> writer(cache_info_.name_shards,
cache_info_.format_shards, 6);
std::shared_ptr<SparsePage> page;
writer.Alloc(&page);
page->Clear();
size_t bytes_write = 0;
double tstart = dmlc::GetTime();
for (auto& batch : src->GetBatches<SparsePage>()) {
SparsePage tmp = batch.GetTranspose(src->Info().num_col_);
page->PushCSC(tmp);
page->SortRows();
if (page->MemCostBytes() >= page_size) {
bytes_write += page->MemCostBytes();
writer.PushWrite(std::move(page));
writer.Alloc(&page);
page->Clear();
double tdiff = dmlc::GetTime() - tstart;
LOG(INFO) << "Writing to " << cache_info << " in "
<< ((bytes_write >> 20UL) / tdiff) << " MB/s, "
<< (bytes_write >> 20UL) << " written";
}
}
if (page->data.Size() != 0) {
writer.PushWrite(std::move(page));
}
LOG(INFO) << "SortedCSCPageSource: Finished writing to "
<< cache_info_.name_info;
}
external_prefetcher_.reset(
new ExternalMemoryPrefetcher<SortedCSCPage>(cache_info_));
}
~SortedCSCPageSource() {
external_prefetcher_.reset();
for (auto file : cache_info_.name_shards) {
TryDeleteCacheFile(file);
}
}
BatchSet<SortedCSCPage> GetBatchSet() {
auto begin_iter = BatchIterator<SortedCSCPage>(
new SparseBatchIteratorImpl<ExternalMemoryPrefetcher<SortedCSCPage>,
SortedCSCPage>(external_prefetcher_.get()));
return BatchSet<SortedCSCPage>(begin_iter);
}
private:
std::unique_ptr<ExternalMemoryPrefetcher<SortedCSCPage>> external_prefetcher_;
CacheInfo cache_info_;
};
} // namespace data
} // namespace xgboost
#endif // XGBOOST_DATA_SPARSE_PAGE_SOURCE_H_

View File

@ -1,10 +1,10 @@
// Copyright by Contributors
#include <dmlc/filesystem.h>
#include <xgboost/data.h>
#include "../../../src/data/sparse_page_dmatrix.h"
#include "../../../src/data/adapter.h"
#include "../helpers.h"
#include <gtest/gtest.h>
#include <xgboost/data.h>
#include "../../../src/data/adapter.h"
#include "../../../src/data/sparse_page_dmatrix.h"
#include "../helpers.h"
using namespace xgboost; // NOLINT
@ -12,7 +12,7 @@ TEST(SparsePageDMatrix, MetaInfo) {
dmlc::TemporaryDirectory tempdir;
const std::string tmp_file = tempdir.path + "/simple.libsvm";
CreateSimpleTestData(tmp_file);
xgboost::DMatrix * dmat = xgboost::DMatrix::Load(
xgboost::DMatrix *dmat = xgboost::DMatrix::Load(
tmp_file + "#" + tmp_file + ".cache", false, false);
std::cout << tmp_file << std::endl;
EXPECT_TRUE(FileExists(tmp_file + ".cache"));
@ -44,21 +44,21 @@ TEST(SparsePageDMatrix, ColAccess) {
dmlc::TemporaryDirectory tempdir;
const std::string tmp_file = tempdir.path + "/simple.libsvm";
CreateSimpleTestData(tmp_file);
xgboost::DMatrix * dmat = xgboost::DMatrix::Load(
tmp_file + "#" + tmp_file + ".cache", true, false);
xgboost::DMatrix *dmat =
xgboost::DMatrix::Load(tmp_file + "#" + tmp_file + ".cache", true, false);
EXPECT_EQ(dmat->GetColDensity(0), 1);
EXPECT_EQ(dmat->GetColDensity(1), 0.5);
// Loop over the batches and assert the data is as expected
for (auto const& col_batch : dmat->GetBatches<xgboost::SortedCSCPage>()) {
for (auto const &col_batch : dmat->GetBatches<xgboost::SortedCSCPage>()) {
EXPECT_EQ(col_batch.Size(), dmat->Info().num_col_);
EXPECT_EQ(col_batch[1][0].fvalue, 10.0f);
EXPECT_EQ(col_batch[1].size(), 1);
}
// Loop over the batches and assert the data is as expected
for (auto const& col_batch : dmat->GetBatches<xgboost::CSCPage>()) {
for (auto const &col_batch : dmat->GetBatches<xgboost::CSCPage>()) {
EXPECT_EQ(col_batch.Size(), dmat->Info().num_col_);
EXPECT_EQ(col_batch[1][0].fvalue, 10.0f);
EXPECT_EQ(col_batch[1].size(), 1);
@ -70,25 +70,61 @@ TEST(SparsePageDMatrix, ColAccess) {
EXPECT_TRUE(FileExists(tmp_file + ".cache.sorted.col.page"));
delete dmat;
EXPECT_FALSE(FileExists(tmp_file + ".cache"));
EXPECT_FALSE(FileExists(tmp_file + ".cache.row.page"));
EXPECT_FALSE(FileExists(tmp_file + ".cache.col.page"));
EXPECT_FALSE(FileExists(tmp_file + ".cache.sorted.col.page"));
}
TEST(SparsePageDMatrix, ExistingCacheFile) {
dmlc::TemporaryDirectory tmpdir;
std::string filename = tmpdir.path + "/big.libsvm";
std::unique_ptr<xgboost::DMatrix> dmat =
xgboost::CreateSparsePageDMatrix(12, 64, filename);
EXPECT_ANY_THROW({
std::unique_ptr<xgboost::DMatrix> dmat2 =
xgboost::CreateSparsePageDMatrix(12, 64, filename);
});
}
#if defined(_OPENMP)
TEST(SparsePageDMatrix, ThreadSafetyException) {
dmlc::TemporaryDirectory tmpdir;
std::string filename = tmpdir.path + "/test";
std::unique_ptr<xgboost::DMatrix> dmat =
xgboost::CreateSparsePageDMatrix(12, 64, filename);
bool exception = false;
int threads = 1000;
#pragma omp parallel for
for (auto i = 0; i < threads; i++) {
try {
auto iter = dmat->GetBatches<SparsePage>().begin();
++iter;
} catch (...) {
exception = true;
}
}
EXPECT_TRUE(exception);
}
#endif
// Multi-batches access
TEST(SparsePageDMatrix, ColAccessBatches) {
dmlc::TemporaryDirectory tmpdir;
std::string filename = tmpdir.path + "/big.libsvm";
// Create multiple sparse pages
std::unique_ptr<xgboost::DMatrix> dmat {
xgboost::CreateSparsePageDMatrix(1024, 1024, filename)
};
std::unique_ptr<xgboost::DMatrix> dmat{
xgboost::CreateSparsePageDMatrix(1024, 1024, filename)};
auto n_threads = omp_get_max_threads();
omp_set_num_threads(16);
for (auto const& page : dmat->GetBatches<xgboost::CSCPage>()) {
for (auto const &page : dmat->GetBatches<xgboost::CSCPage>()) {
ASSERT_EQ(dmat->Info().num_col_, page.Size());
}
omp_set_num_threads(n_threads);
}
TEST(SparsePageDMatrix, Empty) {
dmlc::TemporaryDirectory tempdir;
const std::string tmp_file = tempdir.path + "/simple.libsvm";
@ -96,35 +132,41 @@ TEST(SparsePageDMatrix, Empty) {
std::vector<unsigned> feature_idx = {};
std::vector<size_t> row_ptr = {};
data::CSRAdapter csr_adapter(row_ptr.data(), feature_idx.data(), data.data(), 0, 0, 0);
data::SparsePageDMatrix dmat(&csr_adapter,
std::numeric_limits<float>::quiet_NaN(), 1,tmp_file);
{
data::CSRAdapter csr_adapter(row_ptr.data(), feature_idx.data(),
data.data(), 0, 0, 0);
data::SparsePageDMatrix dmat(
&csr_adapter, std::numeric_limits<float>::quiet_NaN(), 1, tmp_file);
EXPECT_EQ(dmat.Info().num_nonzero_, 0);
EXPECT_EQ(dmat.Info().num_row_, 0);
EXPECT_EQ(dmat.Info().num_col_, 0);
for (auto &batch : dmat.GetBatches<SparsePage>()) {
EXPECT_EQ(batch.Size(), 0);
}
}
{
data::DenseAdapter dense_adapter(nullptr, 0, 0);
data::SparsePageDMatrix dmat2(&dense_adapter,
std::numeric_limits<float>::quiet_NaN(), 1,tmp_file);
data::SparsePageDMatrix dmat2(
&dense_adapter, std::numeric_limits<float>::quiet_NaN(), 1, tmp_file);
EXPECT_EQ(dmat2.Info().num_nonzero_, 0);
EXPECT_EQ(dmat2.Info().num_row_, 0);
EXPECT_EQ(dmat2.Info().num_col_, 0);
for (auto &batch : dmat2.GetBatches<SparsePage>()) {
EXPECT_EQ(batch.Size(), 0);
}
}
{
data::CSCAdapter csc_adapter(nullptr, nullptr, nullptr, 0, 0);
data::SparsePageDMatrix dmat3(&csc_adapter,
std::numeric_limits<float>::quiet_NaN(), 1,tmp_file);
data::SparsePageDMatrix dmat3(
&csc_adapter, std::numeric_limits<float>::quiet_NaN(), 1, tmp_file);
EXPECT_EQ(dmat3.Info().num_nonzero_, 0);
EXPECT_EQ(dmat3.Info().num_row_, 0);
EXPECT_EQ(dmat3.Info().num_col_, 0);
for (auto &batch : dmat3.GetBatches<SparsePage>()) {
EXPECT_EQ(batch.Size(), 0);
}
}
}
TEST(SparsePageDMatrix, MissingData) {
@ -134,12 +176,14 @@ TEST(SparsePageDMatrix, MissingData) {
std::vector<unsigned> feature_idx = {0, 1, 0};
std::vector<size_t> row_ptr = {0, 2, 3};
data::CSRAdapter adapter(row_ptr.data(), feature_idx.data(), data.data(), 2, 3, 2);
data::SparsePageDMatrix dmat(&adapter, std::numeric_limits<float>::quiet_NaN(), 1,tmp_file);
data::CSRAdapter adapter(row_ptr.data(), feature_idx.data(), data.data(), 2,
3, 2);
data::SparsePageDMatrix dmat(
&adapter, std::numeric_limits<float>::quiet_NaN(), 1, tmp_file);
EXPECT_EQ(dmat.Info().num_nonzero_, 2);
const std::string tmp_file2 = tempdir.path + "/simple2.libsvm";
data::SparsePageDMatrix dmat2(&adapter, 1.0, 1,tmp_file2);
data::SparsePageDMatrix dmat2(&adapter, 1.0, 1, tmp_file2);
EXPECT_EQ(dmat2.Info().num_nonzero_, 1);
}
@ -150,8 +194,10 @@ TEST(SparsePageDMatrix, EmptyRow) {
std::vector<unsigned> feature_idx = {0, 1};
std::vector<size_t> row_ptr = {0, 2, 2};
data::CSRAdapter adapter(row_ptr.data(), feature_idx.data(), data.data(), 2, 2, 2);
data::SparsePageDMatrix dmat(&adapter, std::numeric_limits<float>::quiet_NaN(), 1,tmp_file);
data::CSRAdapter adapter(row_ptr.data(), feature_idx.data(), data.data(), 2,
2, 2);
data::SparsePageDMatrix dmat(
&adapter, std::numeric_limits<float>::quiet_NaN(), 1, tmp_file);
EXPECT_EQ(dmat.Info().num_nonzero_, 2);
EXPECT_EQ(dmat.Info().num_row_, 2);
EXPECT_EQ(dmat.Info().num_col_, 2);
@ -173,9 +219,8 @@ TEST(SparsePageDMatrix, FromDense) {
for (auto &batch : dmat.GetBatches<SparsePage>()) {
for (auto i = 0ull; i < batch.Size(); i++) {
auto inst = batch[i];
for(auto j = 0ull; j < inst.size(); j++)
{
EXPECT_EQ(inst[j].fvalue, data[i*n+j]);
for (auto j = 0ull; j < inst.size(); j++) {
EXPECT_EQ(inst[j].fvalue, data[i * n + j]);
EXPECT_EQ(inst[j].index, j);
}
}
@ -215,7 +260,7 @@ TEST(SparsePageDMatrix, FromCSC) {
TEST(SparsePageDMatrix, FromFile) {
std::string filename = "test.libsvm";
CreateBigTestData(filename,20);
CreateBigTestData(filename, 20);
std::unique_ptr<dmlc::Parser<uint32_t>> parser(
dmlc::Parser<uint32_t>::Create(filename.c_str(), 0, 1, "auto"));
data::FileAdapter adapter(parser.get());