Refactor SparsePageSource, delete cache files after use (#5321)
* Refactor sparse page source * Delete temporary cache files * Log fatal if cache exists * Log fatal if multiple threads used with prefetcher
This commit is contained in:
parent
b2b2c4e231
commit
bc96ceb8b2
@ -45,7 +45,7 @@ class EllpackPageSourceImpl : public DataSource<EllpackPage> {
|
||||
dh::BulkAllocator ba_;
|
||||
/*! \brief The EllpackInfo, with the underlying GPU memory shared by all pages. */
|
||||
EllpackInfo ellpack_info_;
|
||||
std::unique_ptr<SparsePageSource<EllpackPage>> source_;
|
||||
std::unique_ptr<ExternalMemoryPrefetcher<EllpackPage>> source_;
|
||||
std::string cache_info_;
|
||||
};
|
||||
|
||||
@ -98,11 +98,13 @@ EllpackPageSourceImpl::EllpackPageSourceImpl(DMatrix* dmat,
|
||||
WriteEllpackPages(dmat, cache_info);
|
||||
monitor_.StopCuda("WriteEllpackPages");
|
||||
|
||||
source_.reset(new SparsePageSource<EllpackPage>(cache_info_, kPageType_));
|
||||
source_.reset(new ExternalMemoryPrefetcher<EllpackPage>(
|
||||
ParseCacheInfo(cache_info_, kPageType_)));
|
||||
}
|
||||
|
||||
void EllpackPageSourceImpl::BeforeFirst() {
|
||||
source_.reset(new SparsePageSource<EllpackPage>(cache_info_, kPageType_));
|
||||
source_.reset(new ExternalMemoryPrefetcher<EllpackPage>(
|
||||
ParseCacheInfo(cache_info_, kPageType_)));
|
||||
source_->BeforeFirst();
|
||||
}
|
||||
|
||||
|
||||
@ -23,58 +23,24 @@ const MetaInfo& SparsePageDMatrix::Info() const {
|
||||
return row_source_->info;
|
||||
}
|
||||
|
||||
template<typename S, typename T>
|
||||
class SparseBatchIteratorImpl : public BatchIteratorImpl<T> {
|
||||
public:
|
||||
explicit SparseBatchIteratorImpl(S* source) : source_(source) {
|
||||
CHECK(source_ != nullptr);
|
||||
}
|
||||
T& operator*() override { return source_->Value(); }
|
||||
const T& operator*() const override { return source_->Value(); }
|
||||
void operator++() override { at_end_ = !source_->Next(); }
|
||||
bool AtEnd() const override { return at_end_; }
|
||||
|
||||
private:
|
||||
S* source_{nullptr};
|
||||
bool at_end_{ false };
|
||||
};
|
||||
|
||||
BatchSet<SparsePage> SparsePageDMatrix::GetRowBatches() {
|
||||
auto cast = dynamic_cast<SparsePageSource<SparsePage>*>(row_source_.get());
|
||||
CHECK(cast);
|
||||
cast->BeforeFirst();
|
||||
cast->Next();
|
||||
auto begin_iter = BatchIterator<SparsePage>(
|
||||
new SparseBatchIteratorImpl<SparsePageSource<SparsePage>, SparsePage>(cast));
|
||||
return BatchSet<SparsePage>(begin_iter);
|
||||
return row_source_->GetBatchSet();
|
||||
}
|
||||
|
||||
BatchSet<CSCPage> SparsePageDMatrix::GetColumnBatches() {
|
||||
// Lazily instantiate
|
||||
if (!column_source_) {
|
||||
SparsePageSource<SparsePage>::CreateColumnPage(this, cache_info_, false);
|
||||
column_source_.reset(new SparsePageSource<CSCPage>(cache_info_, ".col.page"));
|
||||
column_source_.reset(new CSCPageSource(this, cache_info_));
|
||||
}
|
||||
column_source_->BeforeFirst();
|
||||
column_source_->Next();
|
||||
auto begin_iter = BatchIterator<CSCPage>(
|
||||
new SparseBatchIteratorImpl<SparsePageSource<CSCPage>, CSCPage>(column_source_.get()));
|
||||
return BatchSet<CSCPage>(begin_iter);
|
||||
return column_source_->GetBatchSet();
|
||||
}
|
||||
|
||||
BatchSet<SortedCSCPage> SparsePageDMatrix::GetSortedColumnBatches() {
|
||||
// Lazily instantiate
|
||||
if (!sorted_column_source_) {
|
||||
SparsePageSource<SparsePage>::CreateColumnPage(this, cache_info_, true);
|
||||
sorted_column_source_.reset(
|
||||
new SparsePageSource<SortedCSCPage>(cache_info_, ".sorted.col.page"));
|
||||
sorted_column_source_.reset(new SortedCSCPageSource(this, cache_info_));
|
||||
}
|
||||
sorted_column_source_->BeforeFirst();
|
||||
sorted_column_source_->Next();
|
||||
auto begin_iter = BatchIterator<SortedCSCPage>(
|
||||
new SparseBatchIteratorImpl<SparsePageSource<SortedCSCPage>, SortedCSCPage>(
|
||||
sorted_column_source_.get()));
|
||||
return BatchSet<SortedCSCPage>(begin_iter);
|
||||
return sorted_column_source_->GetBatchSet();
|
||||
}
|
||||
|
||||
BatchSet<EllpackPage> SparsePageDMatrix::GetEllpackBatches(const BatchParam& param) {
|
||||
|
||||
@ -22,24 +22,15 @@ namespace data {
|
||||
// Used for external memory.
|
||||
class SparsePageDMatrix : public DMatrix {
|
||||
public:
|
||||
explicit SparsePageDMatrix(std::unique_ptr<DataSource<SparsePage>>&& source,
|
||||
std::string cache_info)
|
||||
: row_source_(std::move(source)), cache_info_(std::move(cache_info)) {}
|
||||
|
||||
template <typename AdapterT>
|
||||
explicit SparsePageDMatrix(AdapterT* adapter, float missing, int nthread,
|
||||
const std::string& cache_prefix,
|
||||
size_t page_size = kPageSize)
|
||||
: cache_info_(std::move(cache_prefix)) {
|
||||
if (!data::SparsePageSource<SparsePage>::CacheExist(cache_prefix,
|
||||
".row.page")) {
|
||||
data::SparsePageSource<SparsePage>::CreateRowPage(
|
||||
adapter, missing, nthread, cache_prefix, page_size);
|
||||
}
|
||||
row_source_.reset(
|
||||
new data::SparsePageSource<SparsePage>(cache_prefix, ".row.page"));
|
||||
row_source_.reset(new data::SparsePageSource(adapter, missing, nthread,
|
||||
cache_prefix, page_size));
|
||||
}
|
||||
// Set number of threads but keep old value so we can reset it after
|
||||
// Set number of threads but keep old value so we can reset it after
|
||||
~SparsePageDMatrix() override = default;
|
||||
|
||||
MetaInfo& Info() override;
|
||||
@ -57,9 +48,9 @@ class SparsePageDMatrix : public DMatrix {
|
||||
BatchSet<EllpackPage> GetEllpackBatches(const BatchParam& param) override;
|
||||
|
||||
// source data pointers.
|
||||
std::unique_ptr<DataSource<SparsePage>> row_source_;
|
||||
std::unique_ptr<SparsePageSource<CSCPage>> column_source_;
|
||||
std::unique_ptr<SparsePageSource<SortedCSCPage>> sorted_column_source_;
|
||||
std::unique_ptr<SparsePageSource> row_source_;
|
||||
std::unique_ptr<CSCPageSource> column_source_;
|
||||
std::unique_ptr<SortedCSCPageSource> sorted_column_source_;
|
||||
std::unique_ptr<EllpackPageSource> ellpack_source_;
|
||||
// saved batch param
|
||||
BatchParam batch_param_;
|
||||
|
||||
@ -17,6 +17,7 @@
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <fstream>
|
||||
|
||||
#include "xgboost/base.h"
|
||||
#include "xgboost/data.h"
|
||||
@ -24,6 +25,7 @@
|
||||
#include "adapter.h"
|
||||
#include "sparse_page_writer.h"
|
||||
#include "../common/common.h"
|
||||
#include <xgboost/data.h>
|
||||
|
||||
namespace {
|
||||
|
||||
@ -49,6 +51,26 @@ GetCacheShards(const std::string& cache_info) {
|
||||
namespace xgboost {
|
||||
namespace data {
|
||||
|
||||
template<typename S, typename T>
|
||||
class SparseBatchIteratorImpl : public BatchIteratorImpl<T> {
|
||||
public:
|
||||
explicit SparseBatchIteratorImpl(S* source) : source_(source) {
|
||||
CHECK(source_ != nullptr);
|
||||
source_->BeforeFirst();
|
||||
source_->Next();
|
||||
}
|
||||
T& operator*() override { return source_->Value(); }
|
||||
const T& operator*() const override { return source_->Value(); }
|
||||
void operator++() override { at_end_ = !source_->Next(); }
|
||||
bool AtEnd() const override { return at_end_; }
|
||||
|
||||
private:
|
||||
S* source_{nullptr};
|
||||
bool at_end_{ false };
|
||||
};
|
||||
|
||||
/*! \brief magic number used to identify Page */
|
||||
static const int kMagic = 0xffffab02;
|
||||
/*!
|
||||
* \brief decide the format from cache prefix.
|
||||
* \return pair of row format, column format type of the cache prefix.
|
||||
@ -89,116 +111,149 @@ inline CacheInfo ParseCacheInfo(const std::string& cache_info, const std::string
|
||||
return info;
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief External memory data source.
|
||||
* \code
|
||||
* std::unique_ptr<DataSource> source(new SimpleCSRSource(cache_prefix));
|
||||
* // add data to source
|
||||
* DMatrix* dmat = DMatrix::Create(std::move(source));
|
||||
* \encode
|
||||
*/
|
||||
template<typename T>
|
||||
class SparsePageSource : public DataSource<T> {
|
||||
public:
|
||||
/*!
|
||||
* \brief Create source from cache files the cache_prefix.
|
||||
* \param cache_prefix The prefix of cache we want to solve.
|
||||
inline void TryDeleteCacheFile(const std::string& file) {
|
||||
if (std::remove(file.c_str()) != 0) {
|
||||
LOG(WARNING) << "Couldn't remove external memory cache file " << file
|
||||
<< "; you may want to remove it manually";
|
||||
}
|
||||
}
|
||||
|
||||
inline void CheckCacheFileExists(const std::string& file) {
|
||||
std::ifstream f(file.c_str());
|
||||
if (f.good()) {
|
||||
LOG(FATAL) << "Cache file " << file
|
||||
<< " exists already; Is there another DMatrix with the same "
|
||||
"cache prefix? Otherwise please remove it manually.";
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Given a set of cache files and page type, this object iterates over batches using prefetching for improved performance. Not thread safe.
|
||||
*
|
||||
* \tparam PageT Type of the page t.
|
||||
*/
|
||||
explicit SparsePageSource(const std::string& cache_info,
|
||||
const std::string& page_type) noexcept(false)
|
||||
template <typename PageT>
|
||||
class ExternalMemoryPrefetcher : dmlc::DataIter<PageT> {
|
||||
public:
|
||||
explicit ExternalMemoryPrefetcher(const CacheInfo& info) noexcept(false)
|
||||
: base_rowid_(0), page_(nullptr), clock_ptr_(0) {
|
||||
// read in the info files
|
||||
std::vector<std::string> cache_shards = GetCacheShards(cache_info);
|
||||
CHECK_NE(cache_shards.size(), 0U);
|
||||
CHECK_NE(info.name_shards.size(), 0U);
|
||||
{
|
||||
std::string name_info = cache_shards[0];
|
||||
std::unique_ptr<dmlc::Stream> finfo(dmlc::Stream::Create(name_info.c_str(), "r"));
|
||||
std::unique_ptr<dmlc::Stream> finfo(
|
||||
dmlc::Stream::Create(info.name_info.c_str(), "r"));
|
||||
int tmagic;
|
||||
CHECK_EQ(finfo->Read(&tmagic, sizeof(tmagic)), sizeof(tmagic));
|
||||
CHECK_EQ(tmagic, kMagic) << "invalid format, magic number mismatch";
|
||||
this->info.LoadBinary(finfo.get());
|
||||
}
|
||||
files_.resize(cache_shards.size());
|
||||
formats_.resize(cache_shards.size());
|
||||
prefetchers_.resize(cache_shards.size());
|
||||
files_.resize(info.name_shards.size());
|
||||
formats_.resize(info.name_shards.size());
|
||||
prefetchers_.resize(info.name_shards.size());
|
||||
|
||||
// read in the cache files.
|
||||
for (size_t i = 0; i < cache_shards.size(); ++i) {
|
||||
std::string name_row = cache_shards[i] + page_type;
|
||||
for (size_t i = 0; i < info.name_shards.size(); ++i) {
|
||||
std::string name_row = info.name_shards.at(i);
|
||||
files_[i].reset(dmlc::SeekStream::CreateForRead(name_row.c_str()));
|
||||
std::unique_ptr<dmlc::SeekStream>& fi = files_[i];
|
||||
std::string format;
|
||||
CHECK(fi->Read(&format)) << "Invalid page format";
|
||||
formats_[i].reset(CreatePageFormat<T>(format));
|
||||
std::unique_ptr<SparsePageFormat<T>>& fmt = formats_[i];
|
||||
formats_[i].reset(CreatePageFormat<PageT>(format));
|
||||
std::unique_ptr<SparsePageFormat<PageT>>& fmt = formats_[i];
|
||||
size_t fbegin = fi->Tell();
|
||||
prefetchers_[i].reset(new dmlc::ThreadedIter<T>(4));
|
||||
prefetchers_[i]->Init([&fi, &fmt] (T** dptr) {
|
||||
if (*dptr == nullptr) {
|
||||
*dptr = new T();
|
||||
}
|
||||
return fmt->Read(*dptr, fi.get());
|
||||
}, [&fi, fbegin] () { fi->Seek(fbegin); });
|
||||
prefetchers_[i].reset(new dmlc::ThreadedIter<PageT>(4));
|
||||
prefetchers_[i]->Init(
|
||||
[&fi, &fmt](PageT** dptr) {
|
||||
if (*dptr == nullptr) {
|
||||
*dptr = new PageT();
|
||||
}
|
||||
return fmt->Read(*dptr, fi.get());
|
||||
},
|
||||
[&fi, fbegin]() { fi->Seek(fbegin); });
|
||||
}
|
||||
}
|
||||
|
||||
/*! \brief destructor */
|
||||
~SparsePageSource() override {
|
||||
~ExternalMemoryPrefetcher() override {
|
||||
delete page_;
|
||||
}
|
||||
|
||||
// implement Next
|
||||
bool Next() override {
|
||||
CHECK(mutex_.try_lock()) << "Multiple threads attempting to use prefetcher";
|
||||
// doing clock rotation over shards.
|
||||
if (page_ != nullptr) {
|
||||
size_t n = prefetchers_.size();
|
||||
prefetchers_[(clock_ptr_ + n - 1) % n]->Recycle(&page_);
|
||||
}
|
||||
|
||||
if (prefetchers_[clock_ptr_]->Next(&page_)) {
|
||||
page_->SetBaseRowId(base_rowid_);
|
||||
base_rowid_ += page_->Size();
|
||||
// advance clock
|
||||
clock_ptr_ = (clock_ptr_ + 1) % prefetchers_.size();
|
||||
mutex_.unlock();
|
||||
return true;
|
||||
} else {
|
||||
mutex_.unlock();
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// implement BeforeFirst
|
||||
void BeforeFirst() override {
|
||||
CHECK(mutex_.try_lock()) << "Multiple threads attempting to use prefetcher";
|
||||
base_rowid_ = 0;
|
||||
clock_ptr_ = 0;
|
||||
for (auto& p : prefetchers_) {
|
||||
p->BeforeFirst();
|
||||
}
|
||||
mutex_.unlock();
|
||||
}
|
||||
|
||||
// implement Value
|
||||
T& Value() {
|
||||
return *page_;
|
||||
}
|
||||
PageT& Value() { return *page_; }
|
||||
|
||||
const T& Value() const override {
|
||||
return *page_;
|
||||
}
|
||||
const PageT& Value() const override { return *page_; }
|
||||
|
||||
private:
|
||||
std::mutex mutex_;
|
||||
/*! \brief number of rows */
|
||||
size_t base_rowid_;
|
||||
/*! \brief page currently on hold. */
|
||||
PageT* page_;
|
||||
/*! \brief internal clock ptr */
|
||||
size_t clock_ptr_;
|
||||
/*! \brief file pointer to the row blob file. */
|
||||
std::vector<std::unique_ptr<dmlc::SeekStream>> files_;
|
||||
/*! \brief Sparse page format file. */
|
||||
std::vector<std::unique_ptr<SparsePageFormat<PageT>>> formats_;
|
||||
/*! \brief internal prefetcher. */
|
||||
std::vector<std::unique_ptr<dmlc::ThreadedIter<PageT>>> prefetchers_;
|
||||
};
|
||||
|
||||
class SparsePageSource {
|
||||
public:
|
||||
template <typename AdapterT>
|
||||
static void CreateRowPage(AdapterT* adapter, float missing, int nthread,
|
||||
const std::string& cache_info,
|
||||
const size_t page_size = DMatrix::kPageSize) {
|
||||
SparsePageSource(AdapterT* adapter, float missing, int nthread,
|
||||
const std::string& cache_info,
|
||||
const size_t page_size = DMatrix::kPageSize) {
|
||||
const std::string page_type = ".row.page";
|
||||
auto cinfo = ParseCacheInfo(cache_info, page_type);
|
||||
cache_info_ = ParseCacheInfo(cache_info, page_type);
|
||||
|
||||
// Warn user if old cache files
|
||||
CheckCacheFileExists(cache_info_.name_info);
|
||||
for (auto file : cache_info_.name_shards) {
|
||||
CheckCacheFileExists(file);
|
||||
}
|
||||
|
||||
{
|
||||
SparsePageWriter<SparsePage> writer(cinfo.name_shards,
|
||||
cinfo.format_shards, 6);
|
||||
SparsePageWriter<SparsePage> writer(cache_info_.name_shards,
|
||||
cache_info_.format_shards, 6);
|
||||
std::shared_ptr<SparsePage> page;
|
||||
writer.Alloc(&page);
|
||||
page->Clear();
|
||||
|
||||
uint64_t inferred_num_columns = 0;
|
||||
uint64_t inferred_num_rows = 0;
|
||||
MetaInfo info;
|
||||
size_t bytes_write = 0;
|
||||
double tstart = dmlc::GetTime();
|
||||
// print every 4 sec.
|
||||
@ -232,7 +287,8 @@ class SparsePageSource : public DataSource<T> {
|
||||
// get group
|
||||
for (size_t i = 0; i < batch.Size(); ++i) {
|
||||
const uint64_t cur_group_id = batch.Qid()[i];
|
||||
if (last_group_id == default_max || last_group_id != cur_group_id) {
|
||||
if (last_group_id == default_max ||
|
||||
last_group_id != cur_group_id) {
|
||||
info.group_ptr_.push_back(group_size);
|
||||
}
|
||||
last_group_id = cur_group_id;
|
||||
@ -300,61 +356,53 @@ class SparsePageSource : public DataSource<T> {
|
||||
writer.PushWrite(std::move(page));
|
||||
}
|
||||
std::unique_ptr<dmlc::Stream> fo(
|
||||
dmlc::Stream::Create(cinfo.name_info.c_str(), "w"));
|
||||
dmlc::Stream::Create(cache_info_.name_info.c_str(), "w"));
|
||||
int tmagic = kMagic;
|
||||
fo->Write(&tmagic, sizeof(tmagic));
|
||||
// Either every row has query ID or none at all
|
||||
CHECK(qids.empty() || qids.size() == info.num_row_);
|
||||
info.SaveBinary(fo.get());
|
||||
}
|
||||
LOG(INFO) << "SparsePageSource::CreateRowPage Finished writing to "
|
||||
<< cinfo.name_info;
|
||||
}
|
||||
/*!
|
||||
* \brief Create source cache by copy content from DMatrix.
|
||||
* Creates transposed column page, may be sorted or not.
|
||||
* \param cache_info The cache_info of cache file location.
|
||||
* \param sorted Whether columns should be pre-sorted
|
||||
*/
|
||||
static void CreateColumnPage(DMatrix* src,
|
||||
const std::string& cache_info, bool sorted) {
|
||||
const std::string page_type = sorted ? ".sorted.col.page" : ".col.page";
|
||||
CreatePageFromDMatrix(src, cache_info, page_type);
|
||||
LOG(INFO) << "SparsePageSource Finished writing to "
|
||||
<< cache_info_.name_info;
|
||||
|
||||
external_prefetcher_.reset(
|
||||
new ExternalMemoryPrefetcher<SparsePage>(cache_info_));
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief Check if the cache file already exists.
|
||||
* \param cache_info The cache prefix of files.
|
||||
* \param page_type Type of the page.
|
||||
* \return Whether cache file already exists.
|
||||
*/
|
||||
static bool CacheExist(const std::string& cache_info,
|
||||
const std::string& page_type) {
|
||||
std::vector<std::string> cache_shards = GetCacheShards(cache_info);
|
||||
CHECK_NE(cache_shards.size(), 0U);
|
||||
{
|
||||
std::string name_info = cache_shards[0];
|
||||
std::unique_ptr<dmlc::Stream> finfo(dmlc::Stream::Create(name_info.c_str(), "r", true));
|
||||
if (finfo == nullptr) return false;
|
||||
~SparsePageSource() {
|
||||
external_prefetcher_.reset();
|
||||
TryDeleteCacheFile(cache_info_.name_info);
|
||||
for (auto file : cache_info_.name_shards) {
|
||||
TryDeleteCacheFile(file);
|
||||
}
|
||||
for (const std::string& prefix : cache_shards) {
|
||||
std::string name_row = prefix + page_type;
|
||||
std::unique_ptr<dmlc::Stream> frow(dmlc::Stream::Create(name_row.c_str(), "r", true));
|
||||
if (frow == nullptr) return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/*! \brief magic number used to identify Page */
|
||||
static const int kMagic = 0xffffab02;
|
||||
BatchSet<SparsePage> GetBatchSet() {
|
||||
auto begin_iter = BatchIterator<SparsePage>(
|
||||
new SparseBatchIteratorImpl<ExternalMemoryPrefetcher<SparsePage>,
|
||||
SparsePage>(external_prefetcher_.get()));
|
||||
return BatchSet<SparsePage>(begin_iter);
|
||||
}
|
||||
MetaInfo info;
|
||||
|
||||
private:
|
||||
static void CreatePageFromDMatrix(DMatrix* src, const std::string& cache_info,
|
||||
const std::string& page_type,
|
||||
const size_t page_size = DMatrix::kPageSize) {
|
||||
auto cinfo = ParseCacheInfo(cache_info, page_type);
|
||||
std::unique_ptr<ExternalMemoryPrefetcher<SparsePage>> external_prefetcher_;
|
||||
CacheInfo cache_info_;
|
||||
};
|
||||
|
||||
class CSCPageSource {
|
||||
public:
|
||||
CSCPageSource(DMatrix* src, const std::string& cache_info,
|
||||
const size_t page_size = DMatrix::kPageSize) {
|
||||
std::string page_type = ".col.page";
|
||||
cache_info_ = ParseCacheInfo(cache_info, page_type);
|
||||
for (auto file : cache_info_.name_shards) {
|
||||
CheckCacheFileExists(file);
|
||||
}
|
||||
{
|
||||
SparsePageWriter<SparsePage> writer(cinfo.name_shards, cinfo.format_shards, 6);
|
||||
SparsePageWriter<SparsePage> writer(cache_info_.name_shards,
|
||||
cache_info_.format_shards, 6);
|
||||
std::shared_ptr<SparsePage> page;
|
||||
writer.Alloc(&page);
|
||||
page->Clear();
|
||||
@ -362,15 +410,7 @@ class SparsePageSource : public DataSource<T> {
|
||||
size_t bytes_write = 0;
|
||||
double tstart = dmlc::GetTime();
|
||||
for (auto& batch : src->GetBatches<SparsePage>()) {
|
||||
if (page_type == ".col.page") {
|
||||
page->PushCSC(batch.GetTranspose(src->Info().num_col_));
|
||||
} else if (page_type == ".sorted.col.page") {
|
||||
SparsePage tmp = batch.GetTranspose(src->Info().num_col_);
|
||||
page->PushCSC(tmp);
|
||||
page->SortRows();
|
||||
} else {
|
||||
LOG(FATAL) << "Unknown page type: " << page_type;
|
||||
}
|
||||
page->PushCSC(batch.GetTranspose(src->Info().num_col_));
|
||||
|
||||
if (page->MemCostBytes() >= page_size) {
|
||||
bytes_write += page->MemCostBytes();
|
||||
@ -386,23 +426,94 @@ class SparsePageSource : public DataSource<T> {
|
||||
if (page->data.Size() != 0) {
|
||||
writer.PushWrite(std::move(page));
|
||||
}
|
||||
LOG(INFO) << "CSCPageSource: Finished writing to "
|
||||
<< cache_info_.name_info;
|
||||
}
|
||||
LOG(INFO) << "SparsePageSource: Finished writing to " << cinfo.name_info;
|
||||
external_prefetcher_.reset(
|
||||
new ExternalMemoryPrefetcher<CSCPage>(cache_info_));
|
||||
}
|
||||
|
||||
/*! \brief number of rows */
|
||||
size_t base_rowid_;
|
||||
/*! \brief page currently on hold. */
|
||||
T* page_;
|
||||
/*! \brief internal clock ptr */
|
||||
size_t clock_ptr_;
|
||||
/*! \brief file pointer to the row blob file. */
|
||||
std::vector<std::unique_ptr<dmlc::SeekStream>> files_;
|
||||
/*! \brief Sparse page format file. */
|
||||
std::vector<std::unique_ptr<SparsePageFormat<T>>> formats_;
|
||||
/*! \brief internal prefetcher. */
|
||||
std::vector<std::unique_ptr<dmlc::ThreadedIter<T>>> prefetchers_;
|
||||
~CSCPageSource() {
|
||||
external_prefetcher_.reset();
|
||||
for (auto file : cache_info_.name_shards) {
|
||||
TryDeleteCacheFile(file);
|
||||
}
|
||||
}
|
||||
|
||||
BatchSet<CSCPage> GetBatchSet() {
|
||||
auto begin_iter = BatchIterator<CSCPage>(
|
||||
new SparseBatchIteratorImpl<ExternalMemoryPrefetcher<CSCPage>, CSCPage>(
|
||||
external_prefetcher_.get()));
|
||||
return BatchSet<CSCPage>(begin_iter);
|
||||
}
|
||||
|
||||
private:
|
||||
std::unique_ptr<ExternalMemoryPrefetcher<CSCPage>> external_prefetcher_;
|
||||
CacheInfo cache_info_;
|
||||
};
|
||||
|
||||
class SortedCSCPageSource {
|
||||
public:
|
||||
SortedCSCPageSource(DMatrix* src, const std::string& cache_info,
|
||||
const size_t page_size = DMatrix::kPageSize) {
|
||||
std::string page_type = ".sorted.col.page";
|
||||
cache_info_ = ParseCacheInfo(cache_info, page_type);
|
||||
for (auto file : cache_info_.name_shards) {
|
||||
CheckCacheFileExists(file);
|
||||
}
|
||||
{
|
||||
SparsePageWriter<SparsePage> writer(cache_info_.name_shards,
|
||||
cache_info_.format_shards, 6);
|
||||
std::shared_ptr<SparsePage> page;
|
||||
writer.Alloc(&page);
|
||||
page->Clear();
|
||||
|
||||
size_t bytes_write = 0;
|
||||
double tstart = dmlc::GetTime();
|
||||
for (auto& batch : src->GetBatches<SparsePage>()) {
|
||||
SparsePage tmp = batch.GetTranspose(src->Info().num_col_);
|
||||
page->PushCSC(tmp);
|
||||
page->SortRows();
|
||||
|
||||
if (page->MemCostBytes() >= page_size) {
|
||||
bytes_write += page->MemCostBytes();
|
||||
writer.PushWrite(std::move(page));
|
||||
writer.Alloc(&page);
|
||||
page->Clear();
|
||||
double tdiff = dmlc::GetTime() - tstart;
|
||||
LOG(INFO) << "Writing to " << cache_info << " in "
|
||||
<< ((bytes_write >> 20UL) / tdiff) << " MB/s, "
|
||||
<< (bytes_write >> 20UL) << " written";
|
||||
}
|
||||
}
|
||||
if (page->data.Size() != 0) {
|
||||
writer.PushWrite(std::move(page));
|
||||
}
|
||||
LOG(INFO) << "SortedCSCPageSource: Finished writing to "
|
||||
<< cache_info_.name_info;
|
||||
}
|
||||
external_prefetcher_.reset(
|
||||
new ExternalMemoryPrefetcher<SortedCSCPage>(cache_info_));
|
||||
}
|
||||
~SortedCSCPageSource() {
|
||||
external_prefetcher_.reset();
|
||||
for (auto file : cache_info_.name_shards) {
|
||||
TryDeleteCacheFile(file);
|
||||
}
|
||||
}
|
||||
|
||||
BatchSet<SortedCSCPage> GetBatchSet() {
|
||||
auto begin_iter = BatchIterator<SortedCSCPage>(
|
||||
new SparseBatchIteratorImpl<ExternalMemoryPrefetcher<SortedCSCPage>,
|
||||
SortedCSCPage>(external_prefetcher_.get()));
|
||||
return BatchSet<SortedCSCPage>(begin_iter);
|
||||
}
|
||||
|
||||
private:
|
||||
std::unique_ptr<ExternalMemoryPrefetcher<SortedCSCPage>> external_prefetcher_;
|
||||
CacheInfo cache_info_;
|
||||
};
|
||||
|
||||
} // namespace data
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_DATA_SPARSE_PAGE_SOURCE_H_
|
||||
|
||||
@ -1,10 +1,10 @@
|
||||
// Copyright by Contributors
|
||||
#include <dmlc/filesystem.h>
|
||||
#include <xgboost/data.h>
|
||||
#include "../../../src/data/sparse_page_dmatrix.h"
|
||||
#include "../../../src/data/adapter.h"
|
||||
#include "../helpers.h"
|
||||
#include <gtest/gtest.h>
|
||||
#include <xgboost/data.h>
|
||||
#include "../../../src/data/adapter.h"
|
||||
#include "../../../src/data/sparse_page_dmatrix.h"
|
||||
#include "../helpers.h"
|
||||
|
||||
using namespace xgboost; // NOLINT
|
||||
|
||||
@ -12,8 +12,8 @@ TEST(SparsePageDMatrix, MetaInfo) {
|
||||
dmlc::TemporaryDirectory tempdir;
|
||||
const std::string tmp_file = tempdir.path + "/simple.libsvm";
|
||||
CreateSimpleTestData(tmp_file);
|
||||
xgboost::DMatrix * dmat = xgboost::DMatrix::Load(
|
||||
tmp_file + "#" + tmp_file + ".cache", false, false);
|
||||
xgboost::DMatrix *dmat = xgboost::DMatrix::Load(
|
||||
tmp_file + "#" + tmp_file + ".cache", false, false);
|
||||
std::cout << tmp_file << std::endl;
|
||||
EXPECT_TRUE(FileExists(tmp_file + ".cache"));
|
||||
|
||||
@ -44,21 +44,21 @@ TEST(SparsePageDMatrix, ColAccess) {
|
||||
dmlc::TemporaryDirectory tempdir;
|
||||
const std::string tmp_file = tempdir.path + "/simple.libsvm";
|
||||
CreateSimpleTestData(tmp_file);
|
||||
xgboost::DMatrix * dmat = xgboost::DMatrix::Load(
|
||||
tmp_file + "#" + tmp_file + ".cache", true, false);
|
||||
xgboost::DMatrix *dmat =
|
||||
xgboost::DMatrix::Load(tmp_file + "#" + tmp_file + ".cache", true, false);
|
||||
|
||||
EXPECT_EQ(dmat->GetColDensity(0), 1);
|
||||
EXPECT_EQ(dmat->GetColDensity(1), 0.5);
|
||||
|
||||
// Loop over the batches and assert the data is as expected
|
||||
for (auto const& col_batch : dmat->GetBatches<xgboost::SortedCSCPage>()) {
|
||||
for (auto const &col_batch : dmat->GetBatches<xgboost::SortedCSCPage>()) {
|
||||
EXPECT_EQ(col_batch.Size(), dmat->Info().num_col_);
|
||||
EXPECT_EQ(col_batch[1][0].fvalue, 10.0f);
|
||||
EXPECT_EQ(col_batch[1].size(), 1);
|
||||
}
|
||||
|
||||
// Loop over the batches and assert the data is as expected
|
||||
for (auto const& col_batch : dmat->GetBatches<xgboost::CSCPage>()) {
|
||||
for (auto const &col_batch : dmat->GetBatches<xgboost::CSCPage>()) {
|
||||
EXPECT_EQ(col_batch.Size(), dmat->Info().num_col_);
|
||||
EXPECT_EQ(col_batch[1][0].fvalue, 10.0f);
|
||||
EXPECT_EQ(col_batch[1].size(), 1);
|
||||
@ -70,25 +70,61 @@ TEST(SparsePageDMatrix, ColAccess) {
|
||||
EXPECT_TRUE(FileExists(tmp_file + ".cache.sorted.col.page"));
|
||||
|
||||
delete dmat;
|
||||
|
||||
EXPECT_FALSE(FileExists(tmp_file + ".cache"));
|
||||
EXPECT_FALSE(FileExists(tmp_file + ".cache.row.page"));
|
||||
EXPECT_FALSE(FileExists(tmp_file + ".cache.col.page"));
|
||||
EXPECT_FALSE(FileExists(tmp_file + ".cache.sorted.col.page"));
|
||||
}
|
||||
|
||||
TEST(SparsePageDMatrix, ExistingCacheFile) {
|
||||
dmlc::TemporaryDirectory tmpdir;
|
||||
std::string filename = tmpdir.path + "/big.libsvm";
|
||||
std::unique_ptr<xgboost::DMatrix> dmat =
|
||||
xgboost::CreateSparsePageDMatrix(12, 64, filename);
|
||||
EXPECT_ANY_THROW({
|
||||
std::unique_ptr<xgboost::DMatrix> dmat2 =
|
||||
xgboost::CreateSparsePageDMatrix(12, 64, filename);
|
||||
});
|
||||
}
|
||||
|
||||
#if defined(_OPENMP)
|
||||
TEST(SparsePageDMatrix, ThreadSafetyException) {
|
||||
dmlc::TemporaryDirectory tmpdir;
|
||||
std::string filename = tmpdir.path + "/test";
|
||||
std::unique_ptr<xgboost::DMatrix> dmat =
|
||||
xgboost::CreateSparsePageDMatrix(12, 64, filename);
|
||||
|
||||
bool exception = false;
|
||||
int threads = 1000;
|
||||
#pragma omp parallel for
|
||||
for (auto i = 0; i < threads; i++) {
|
||||
try {
|
||||
auto iter = dmat->GetBatches<SparsePage>().begin();
|
||||
++iter;
|
||||
} catch (...) {
|
||||
exception = true;
|
||||
}
|
||||
}
|
||||
EXPECT_TRUE(exception);
|
||||
}
|
||||
#endif
|
||||
|
||||
// Multi-batches access
|
||||
TEST(SparsePageDMatrix, ColAccessBatches) {
|
||||
dmlc::TemporaryDirectory tmpdir;
|
||||
std::string filename = tmpdir.path + "/big.libsvm";
|
||||
// Create multiple sparse pages
|
||||
std::unique_ptr<xgboost::DMatrix> dmat {
|
||||
xgboost::CreateSparsePageDMatrix(1024, 1024, filename)
|
||||
};
|
||||
std::unique_ptr<xgboost::DMatrix> dmat{
|
||||
xgboost::CreateSparsePageDMatrix(1024, 1024, filename)};
|
||||
auto n_threads = omp_get_max_threads();
|
||||
omp_set_num_threads(16);
|
||||
for (auto const& page : dmat->GetBatches<xgboost::CSCPage>()) {
|
||||
for (auto const &page : dmat->GetBatches<xgboost::CSCPage>()) {
|
||||
ASSERT_EQ(dmat->Info().num_col_, page.Size());
|
||||
}
|
||||
omp_set_num_threads(n_threads);
|
||||
}
|
||||
|
||||
|
||||
TEST(SparsePageDMatrix, Empty) {
|
||||
dmlc::TemporaryDirectory tempdir;
|
||||
const std::string tmp_file = tempdir.path + "/simple.libsvm";
|
||||
@ -96,34 +132,40 @@ TEST(SparsePageDMatrix, Empty) {
|
||||
std::vector<unsigned> feature_idx = {};
|
||||
std::vector<size_t> row_ptr = {};
|
||||
|
||||
data::CSRAdapter csr_adapter(row_ptr.data(), feature_idx.data(), data.data(), 0, 0, 0);
|
||||
data::SparsePageDMatrix dmat(&csr_adapter,
|
||||
std::numeric_limits<float>::quiet_NaN(), 1,tmp_file);
|
||||
EXPECT_EQ(dmat.Info().num_nonzero_, 0);
|
||||
EXPECT_EQ(dmat.Info().num_row_, 0);
|
||||
EXPECT_EQ(dmat.Info().num_col_, 0);
|
||||
for (auto &batch : dmat.GetBatches<SparsePage>()) {
|
||||
EXPECT_EQ(batch.Size(), 0);
|
||||
{
|
||||
data::CSRAdapter csr_adapter(row_ptr.data(), feature_idx.data(),
|
||||
data.data(), 0, 0, 0);
|
||||
data::SparsePageDMatrix dmat(
|
||||
&csr_adapter, std::numeric_limits<float>::quiet_NaN(), 1, tmp_file);
|
||||
EXPECT_EQ(dmat.Info().num_nonzero_, 0);
|
||||
EXPECT_EQ(dmat.Info().num_row_, 0);
|
||||
EXPECT_EQ(dmat.Info().num_col_, 0);
|
||||
for (auto &batch : dmat.GetBatches<SparsePage>()) {
|
||||
EXPECT_EQ(batch.Size(), 0);
|
||||
}
|
||||
}
|
||||
|
||||
data::DenseAdapter dense_adapter(nullptr, 0, 0);
|
||||
data::SparsePageDMatrix dmat2(&dense_adapter,
|
||||
std::numeric_limits<float>::quiet_NaN(), 1,tmp_file);
|
||||
EXPECT_EQ(dmat2.Info().num_nonzero_, 0);
|
||||
EXPECT_EQ(dmat2.Info().num_row_, 0);
|
||||
EXPECT_EQ(dmat2.Info().num_col_, 0);
|
||||
for (auto &batch : dmat2.GetBatches<SparsePage>()) {
|
||||
EXPECT_EQ(batch.Size(), 0);
|
||||
{
|
||||
data::DenseAdapter dense_adapter(nullptr, 0, 0);
|
||||
data::SparsePageDMatrix dmat2(
|
||||
&dense_adapter, std::numeric_limits<float>::quiet_NaN(), 1, tmp_file);
|
||||
EXPECT_EQ(dmat2.Info().num_nonzero_, 0);
|
||||
EXPECT_EQ(dmat2.Info().num_row_, 0);
|
||||
EXPECT_EQ(dmat2.Info().num_col_, 0);
|
||||
for (auto &batch : dmat2.GetBatches<SparsePage>()) {
|
||||
EXPECT_EQ(batch.Size(), 0);
|
||||
}
|
||||
}
|
||||
|
||||
data::CSCAdapter csc_adapter(nullptr, nullptr, nullptr, 0, 0);
|
||||
data::SparsePageDMatrix dmat3(&csc_adapter,
|
||||
std::numeric_limits<float>::quiet_NaN(), 1,tmp_file);
|
||||
EXPECT_EQ(dmat3.Info().num_nonzero_, 0);
|
||||
EXPECT_EQ(dmat3.Info().num_row_, 0);
|
||||
EXPECT_EQ(dmat3.Info().num_col_, 0);
|
||||
for (auto &batch : dmat3.GetBatches<SparsePage>()) {
|
||||
EXPECT_EQ(batch.Size(), 0);
|
||||
{
|
||||
data::CSCAdapter csc_adapter(nullptr, nullptr, nullptr, 0, 0);
|
||||
data::SparsePageDMatrix dmat3(
|
||||
&csc_adapter, std::numeric_limits<float>::quiet_NaN(), 1, tmp_file);
|
||||
EXPECT_EQ(dmat3.Info().num_nonzero_, 0);
|
||||
EXPECT_EQ(dmat3.Info().num_row_, 0);
|
||||
EXPECT_EQ(dmat3.Info().num_col_, 0);
|
||||
for (auto &batch : dmat3.GetBatches<SparsePage>()) {
|
||||
EXPECT_EQ(batch.Size(), 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -134,12 +176,14 @@ TEST(SparsePageDMatrix, MissingData) {
|
||||
std::vector<unsigned> feature_idx = {0, 1, 0};
|
||||
std::vector<size_t> row_ptr = {0, 2, 3};
|
||||
|
||||
data::CSRAdapter adapter(row_ptr.data(), feature_idx.data(), data.data(), 2, 3, 2);
|
||||
data::SparsePageDMatrix dmat(&adapter, std::numeric_limits<float>::quiet_NaN(), 1,tmp_file);
|
||||
data::CSRAdapter adapter(row_ptr.data(), feature_idx.data(), data.data(), 2,
|
||||
3, 2);
|
||||
data::SparsePageDMatrix dmat(
|
||||
&adapter, std::numeric_limits<float>::quiet_NaN(), 1, tmp_file);
|
||||
EXPECT_EQ(dmat.Info().num_nonzero_, 2);
|
||||
|
||||
const std::string tmp_file2 = tempdir.path + "/simple2.libsvm";
|
||||
data::SparsePageDMatrix dmat2(&adapter, 1.0, 1,tmp_file2);
|
||||
data::SparsePageDMatrix dmat2(&adapter, 1.0, 1, tmp_file2);
|
||||
EXPECT_EQ(dmat2.Info().num_nonzero_, 1);
|
||||
}
|
||||
|
||||
@ -150,8 +194,10 @@ TEST(SparsePageDMatrix, EmptyRow) {
|
||||
std::vector<unsigned> feature_idx = {0, 1};
|
||||
std::vector<size_t> row_ptr = {0, 2, 2};
|
||||
|
||||
data::CSRAdapter adapter(row_ptr.data(), feature_idx.data(), data.data(), 2, 2, 2);
|
||||
data::SparsePageDMatrix dmat(&adapter, std::numeric_limits<float>::quiet_NaN(), 1,tmp_file);
|
||||
data::CSRAdapter adapter(row_ptr.data(), feature_idx.data(), data.data(), 2,
|
||||
2, 2);
|
||||
data::SparsePageDMatrix dmat(
|
||||
&adapter, std::numeric_limits<float>::quiet_NaN(), 1, tmp_file);
|
||||
EXPECT_EQ(dmat.Info().num_nonzero_, 2);
|
||||
EXPECT_EQ(dmat.Info().num_row_, 2);
|
||||
EXPECT_EQ(dmat.Info().num_col_, 2);
|
||||
@ -173,9 +219,8 @@ TEST(SparsePageDMatrix, FromDense) {
|
||||
for (auto &batch : dmat.GetBatches<SparsePage>()) {
|
||||
for (auto i = 0ull; i < batch.Size(); i++) {
|
||||
auto inst = batch[i];
|
||||
for(auto j = 0ull; j < inst.size(); j++)
|
||||
{
|
||||
EXPECT_EQ(inst[j].fvalue, data[i*n+j]);
|
||||
for (auto j = 0ull; j < inst.size(); j++) {
|
||||
EXPECT_EQ(inst[j].fvalue, data[i * n + j]);
|
||||
EXPECT_EQ(inst[j].index, j);
|
||||
}
|
||||
}
|
||||
@ -215,9 +260,9 @@ TEST(SparsePageDMatrix, FromCSC) {
|
||||
|
||||
TEST(SparsePageDMatrix, FromFile) {
|
||||
std::string filename = "test.libsvm";
|
||||
CreateBigTestData(filename,20);
|
||||
CreateBigTestData(filename, 20);
|
||||
std::unique_ptr<dmlc::Parser<uint32_t>> parser(
|
||||
dmlc::Parser<uint32_t>::Create(filename.c_str(), 0, 1, "auto"));
|
||||
dmlc::Parser<uint32_t>::Create(filename.c_str(), 0, 1, "auto"));
|
||||
data::FileAdapter adapter(parser.get());
|
||||
dmlc::TemporaryDirectory tempdir;
|
||||
const std::string tmp_file = tempdir.path + "/simple.libsvm";
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user