Refactor SparsePageSource, delete cache files after use (#5321)
* Refactor sparse page source * Delete temporary cache files * Log fatal if cache exists * Log fatal if multiple threads used with prefetcher
This commit is contained in:
parent
b2b2c4e231
commit
bc96ceb8b2
@ -45,7 +45,7 @@ class EllpackPageSourceImpl : public DataSource<EllpackPage> {
|
|||||||
dh::BulkAllocator ba_;
|
dh::BulkAllocator ba_;
|
||||||
/*! \brief The EllpackInfo, with the underlying GPU memory shared by all pages. */
|
/*! \brief The EllpackInfo, with the underlying GPU memory shared by all pages. */
|
||||||
EllpackInfo ellpack_info_;
|
EllpackInfo ellpack_info_;
|
||||||
std::unique_ptr<SparsePageSource<EllpackPage>> source_;
|
std::unique_ptr<ExternalMemoryPrefetcher<EllpackPage>> source_;
|
||||||
std::string cache_info_;
|
std::string cache_info_;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -98,11 +98,13 @@ EllpackPageSourceImpl::EllpackPageSourceImpl(DMatrix* dmat,
|
|||||||
WriteEllpackPages(dmat, cache_info);
|
WriteEllpackPages(dmat, cache_info);
|
||||||
monitor_.StopCuda("WriteEllpackPages");
|
monitor_.StopCuda("WriteEllpackPages");
|
||||||
|
|
||||||
source_.reset(new SparsePageSource<EllpackPage>(cache_info_, kPageType_));
|
source_.reset(new ExternalMemoryPrefetcher<EllpackPage>(
|
||||||
|
ParseCacheInfo(cache_info_, kPageType_)));
|
||||||
}
|
}
|
||||||
|
|
||||||
void EllpackPageSourceImpl::BeforeFirst() {
|
void EllpackPageSourceImpl::BeforeFirst() {
|
||||||
source_.reset(new SparsePageSource<EllpackPage>(cache_info_, kPageType_));
|
source_.reset(new ExternalMemoryPrefetcher<EllpackPage>(
|
||||||
|
ParseCacheInfo(cache_info_, kPageType_)));
|
||||||
source_->BeforeFirst();
|
source_->BeforeFirst();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -23,58 +23,24 @@ const MetaInfo& SparsePageDMatrix::Info() const {
|
|||||||
return row_source_->info;
|
return row_source_->info;
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename S, typename T>
|
|
||||||
class SparseBatchIteratorImpl : public BatchIteratorImpl<T> {
|
|
||||||
public:
|
|
||||||
explicit SparseBatchIteratorImpl(S* source) : source_(source) {
|
|
||||||
CHECK(source_ != nullptr);
|
|
||||||
}
|
|
||||||
T& operator*() override { return source_->Value(); }
|
|
||||||
const T& operator*() const override { return source_->Value(); }
|
|
||||||
void operator++() override { at_end_ = !source_->Next(); }
|
|
||||||
bool AtEnd() const override { return at_end_; }
|
|
||||||
|
|
||||||
private:
|
|
||||||
S* source_{nullptr};
|
|
||||||
bool at_end_{ false };
|
|
||||||
};
|
|
||||||
|
|
||||||
BatchSet<SparsePage> SparsePageDMatrix::GetRowBatches() {
|
BatchSet<SparsePage> SparsePageDMatrix::GetRowBatches() {
|
||||||
auto cast = dynamic_cast<SparsePageSource<SparsePage>*>(row_source_.get());
|
return row_source_->GetBatchSet();
|
||||||
CHECK(cast);
|
|
||||||
cast->BeforeFirst();
|
|
||||||
cast->Next();
|
|
||||||
auto begin_iter = BatchIterator<SparsePage>(
|
|
||||||
new SparseBatchIteratorImpl<SparsePageSource<SparsePage>, SparsePage>(cast));
|
|
||||||
return BatchSet<SparsePage>(begin_iter);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
BatchSet<CSCPage> SparsePageDMatrix::GetColumnBatches() {
|
BatchSet<CSCPage> SparsePageDMatrix::GetColumnBatches() {
|
||||||
// Lazily instantiate
|
// Lazily instantiate
|
||||||
if (!column_source_) {
|
if (!column_source_) {
|
||||||
SparsePageSource<SparsePage>::CreateColumnPage(this, cache_info_, false);
|
column_source_.reset(new CSCPageSource(this, cache_info_));
|
||||||
column_source_.reset(new SparsePageSource<CSCPage>(cache_info_, ".col.page"));
|
|
||||||
}
|
}
|
||||||
column_source_->BeforeFirst();
|
return column_source_->GetBatchSet();
|
||||||
column_source_->Next();
|
|
||||||
auto begin_iter = BatchIterator<CSCPage>(
|
|
||||||
new SparseBatchIteratorImpl<SparsePageSource<CSCPage>, CSCPage>(column_source_.get()));
|
|
||||||
return BatchSet<CSCPage>(begin_iter);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
BatchSet<SortedCSCPage> SparsePageDMatrix::GetSortedColumnBatches() {
|
BatchSet<SortedCSCPage> SparsePageDMatrix::GetSortedColumnBatches() {
|
||||||
// Lazily instantiate
|
// Lazily instantiate
|
||||||
if (!sorted_column_source_) {
|
if (!sorted_column_source_) {
|
||||||
SparsePageSource<SparsePage>::CreateColumnPage(this, cache_info_, true);
|
sorted_column_source_.reset(new SortedCSCPageSource(this, cache_info_));
|
||||||
sorted_column_source_.reset(
|
|
||||||
new SparsePageSource<SortedCSCPage>(cache_info_, ".sorted.col.page"));
|
|
||||||
}
|
}
|
||||||
sorted_column_source_->BeforeFirst();
|
return sorted_column_source_->GetBatchSet();
|
||||||
sorted_column_source_->Next();
|
|
||||||
auto begin_iter = BatchIterator<SortedCSCPage>(
|
|
||||||
new SparseBatchIteratorImpl<SparsePageSource<SortedCSCPage>, SortedCSCPage>(
|
|
||||||
sorted_column_source_.get()));
|
|
||||||
return BatchSet<SortedCSCPage>(begin_iter);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
BatchSet<EllpackPage> SparsePageDMatrix::GetEllpackBatches(const BatchParam& param) {
|
BatchSet<EllpackPage> SparsePageDMatrix::GetEllpackBatches(const BatchParam& param) {
|
||||||
|
|||||||
@ -22,22 +22,13 @@ namespace data {
|
|||||||
// Used for external memory.
|
// Used for external memory.
|
||||||
class SparsePageDMatrix : public DMatrix {
|
class SparsePageDMatrix : public DMatrix {
|
||||||
public:
|
public:
|
||||||
explicit SparsePageDMatrix(std::unique_ptr<DataSource<SparsePage>>&& source,
|
|
||||||
std::string cache_info)
|
|
||||||
: row_source_(std::move(source)), cache_info_(std::move(cache_info)) {}
|
|
||||||
|
|
||||||
template <typename AdapterT>
|
template <typename AdapterT>
|
||||||
explicit SparsePageDMatrix(AdapterT* adapter, float missing, int nthread,
|
explicit SparsePageDMatrix(AdapterT* adapter, float missing, int nthread,
|
||||||
const std::string& cache_prefix,
|
const std::string& cache_prefix,
|
||||||
size_t page_size = kPageSize)
|
size_t page_size = kPageSize)
|
||||||
: cache_info_(std::move(cache_prefix)) {
|
: cache_info_(std::move(cache_prefix)) {
|
||||||
if (!data::SparsePageSource<SparsePage>::CacheExist(cache_prefix,
|
row_source_.reset(new data::SparsePageSource(adapter, missing, nthread,
|
||||||
".row.page")) {
|
cache_prefix, page_size));
|
||||||
data::SparsePageSource<SparsePage>::CreateRowPage(
|
|
||||||
adapter, missing, nthread, cache_prefix, page_size);
|
|
||||||
}
|
|
||||||
row_source_.reset(
|
|
||||||
new data::SparsePageSource<SparsePage>(cache_prefix, ".row.page"));
|
|
||||||
}
|
}
|
||||||
// Set number of threads but keep old value so we can reset it after
|
// Set number of threads but keep old value so we can reset it after
|
||||||
~SparsePageDMatrix() override = default;
|
~SparsePageDMatrix() override = default;
|
||||||
@ -57,9 +48,9 @@ class SparsePageDMatrix : public DMatrix {
|
|||||||
BatchSet<EllpackPage> GetEllpackBatches(const BatchParam& param) override;
|
BatchSet<EllpackPage> GetEllpackBatches(const BatchParam& param) override;
|
||||||
|
|
||||||
// source data pointers.
|
// source data pointers.
|
||||||
std::unique_ptr<DataSource<SparsePage>> row_source_;
|
std::unique_ptr<SparsePageSource> row_source_;
|
||||||
std::unique_ptr<SparsePageSource<CSCPage>> column_source_;
|
std::unique_ptr<CSCPageSource> column_source_;
|
||||||
std::unique_ptr<SparsePageSource<SortedCSCPage>> sorted_column_source_;
|
std::unique_ptr<SortedCSCPageSource> sorted_column_source_;
|
||||||
std::unique_ptr<EllpackPageSource> ellpack_source_;
|
std::unique_ptr<EllpackPageSource> ellpack_source_;
|
||||||
// saved batch param
|
// saved batch param
|
||||||
BatchParam batch_param_;
|
BatchParam batch_param_;
|
||||||
|
|||||||
@ -17,6 +17,7 @@
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include <fstream>
|
||||||
|
|
||||||
#include "xgboost/base.h"
|
#include "xgboost/base.h"
|
||||||
#include "xgboost/data.h"
|
#include "xgboost/data.h"
|
||||||
@ -24,6 +25,7 @@
|
|||||||
#include "adapter.h"
|
#include "adapter.h"
|
||||||
#include "sparse_page_writer.h"
|
#include "sparse_page_writer.h"
|
||||||
#include "../common/common.h"
|
#include "../common/common.h"
|
||||||
|
#include <xgboost/data.h>
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
@ -49,6 +51,26 @@ GetCacheShards(const std::string& cache_info) {
|
|||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace data {
|
namespace data {
|
||||||
|
|
||||||
|
template<typename S, typename T>
|
||||||
|
class SparseBatchIteratorImpl : public BatchIteratorImpl<T> {
|
||||||
|
public:
|
||||||
|
explicit SparseBatchIteratorImpl(S* source) : source_(source) {
|
||||||
|
CHECK(source_ != nullptr);
|
||||||
|
source_->BeforeFirst();
|
||||||
|
source_->Next();
|
||||||
|
}
|
||||||
|
T& operator*() override { return source_->Value(); }
|
||||||
|
const T& operator*() const override { return source_->Value(); }
|
||||||
|
void operator++() override { at_end_ = !source_->Next(); }
|
||||||
|
bool AtEnd() const override { return at_end_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
S* source_{nullptr};
|
||||||
|
bool at_end_{ false };
|
||||||
|
};
|
||||||
|
|
||||||
|
/*! \brief magic number used to identify Page */
|
||||||
|
static const int kMagic = 0xffffab02;
|
||||||
/*!
|
/*!
|
||||||
* \brief decide the format from cache prefix.
|
* \brief decide the format from cache prefix.
|
||||||
* \return pair of row format, column format type of the cache prefix.
|
* \return pair of row format, column format type of the cache prefix.
|
||||||
@ -89,116 +111,149 @@ inline CacheInfo ParseCacheInfo(const std::string& cache_info, const std::string
|
|||||||
return info;
|
return info;
|
||||||
}
|
}
|
||||||
|
|
||||||
/*!
|
inline void TryDeleteCacheFile(const std::string& file) {
|
||||||
* \brief External memory data source.
|
if (std::remove(file.c_str()) != 0) {
|
||||||
* \code
|
LOG(WARNING) << "Couldn't remove external memory cache file " << file
|
||||||
* std::unique_ptr<DataSource> source(new SimpleCSRSource(cache_prefix));
|
<< "; you may want to remove it manually";
|
||||||
* // add data to source
|
}
|
||||||
* DMatrix* dmat = DMatrix::Create(std::move(source));
|
}
|
||||||
* \encode
|
|
||||||
|
inline void CheckCacheFileExists(const std::string& file) {
|
||||||
|
std::ifstream f(file.c_str());
|
||||||
|
if (f.good()) {
|
||||||
|
LOG(FATAL) << "Cache file " << file
|
||||||
|
<< " exists already; Is there another DMatrix with the same "
|
||||||
|
"cache prefix? Otherwise please remove it manually.";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* \brief Given a set of cache files and page type, this object iterates over batches using prefetching for improved performance. Not thread safe.
|
||||||
|
*
|
||||||
|
* \tparam PageT Type of the page t.
|
||||||
*/
|
*/
|
||||||
template<typename T>
|
template <typename PageT>
|
||||||
class SparsePageSource : public DataSource<T> {
|
class ExternalMemoryPrefetcher : dmlc::DataIter<PageT> {
|
||||||
public:
|
public:
|
||||||
/*!
|
explicit ExternalMemoryPrefetcher(const CacheInfo& info) noexcept(false)
|
||||||
* \brief Create source from cache files the cache_prefix.
|
|
||||||
* \param cache_prefix The prefix of cache we want to solve.
|
|
||||||
*/
|
|
||||||
explicit SparsePageSource(const std::string& cache_info,
|
|
||||||
const std::string& page_type) noexcept(false)
|
|
||||||
: base_rowid_(0), page_(nullptr), clock_ptr_(0) {
|
: base_rowid_(0), page_(nullptr), clock_ptr_(0) {
|
||||||
// read in the info files
|
// read in the info files
|
||||||
std::vector<std::string> cache_shards = GetCacheShards(cache_info);
|
CHECK_NE(info.name_shards.size(), 0U);
|
||||||
CHECK_NE(cache_shards.size(), 0U);
|
|
||||||
{
|
{
|
||||||
std::string name_info = cache_shards[0];
|
std::unique_ptr<dmlc::Stream> finfo(
|
||||||
std::unique_ptr<dmlc::Stream> finfo(dmlc::Stream::Create(name_info.c_str(), "r"));
|
dmlc::Stream::Create(info.name_info.c_str(), "r"));
|
||||||
int tmagic;
|
int tmagic;
|
||||||
CHECK_EQ(finfo->Read(&tmagic, sizeof(tmagic)), sizeof(tmagic));
|
CHECK_EQ(finfo->Read(&tmagic, sizeof(tmagic)), sizeof(tmagic));
|
||||||
CHECK_EQ(tmagic, kMagic) << "invalid format, magic number mismatch";
|
CHECK_EQ(tmagic, kMagic) << "invalid format, magic number mismatch";
|
||||||
this->info.LoadBinary(finfo.get());
|
|
||||||
}
|
}
|
||||||
files_.resize(cache_shards.size());
|
files_.resize(info.name_shards.size());
|
||||||
formats_.resize(cache_shards.size());
|
formats_.resize(info.name_shards.size());
|
||||||
prefetchers_.resize(cache_shards.size());
|
prefetchers_.resize(info.name_shards.size());
|
||||||
|
|
||||||
// read in the cache files.
|
// read in the cache files.
|
||||||
for (size_t i = 0; i < cache_shards.size(); ++i) {
|
for (size_t i = 0; i < info.name_shards.size(); ++i) {
|
||||||
std::string name_row = cache_shards[i] + page_type;
|
std::string name_row = info.name_shards.at(i);
|
||||||
files_[i].reset(dmlc::SeekStream::CreateForRead(name_row.c_str()));
|
files_[i].reset(dmlc::SeekStream::CreateForRead(name_row.c_str()));
|
||||||
std::unique_ptr<dmlc::SeekStream>& fi = files_[i];
|
std::unique_ptr<dmlc::SeekStream>& fi = files_[i];
|
||||||
std::string format;
|
std::string format;
|
||||||
CHECK(fi->Read(&format)) << "Invalid page format";
|
CHECK(fi->Read(&format)) << "Invalid page format";
|
||||||
formats_[i].reset(CreatePageFormat<T>(format));
|
formats_[i].reset(CreatePageFormat<PageT>(format));
|
||||||
std::unique_ptr<SparsePageFormat<T>>& fmt = formats_[i];
|
std::unique_ptr<SparsePageFormat<PageT>>& fmt = formats_[i];
|
||||||
size_t fbegin = fi->Tell();
|
size_t fbegin = fi->Tell();
|
||||||
prefetchers_[i].reset(new dmlc::ThreadedIter<T>(4));
|
prefetchers_[i].reset(new dmlc::ThreadedIter<PageT>(4));
|
||||||
prefetchers_[i]->Init([&fi, &fmt] (T** dptr) {
|
prefetchers_[i]->Init(
|
||||||
|
[&fi, &fmt](PageT** dptr) {
|
||||||
if (*dptr == nullptr) {
|
if (*dptr == nullptr) {
|
||||||
*dptr = new T();
|
*dptr = new PageT();
|
||||||
}
|
}
|
||||||
return fmt->Read(*dptr, fi.get());
|
return fmt->Read(*dptr, fi.get());
|
||||||
}, [&fi, fbegin] () { fi->Seek(fbegin); });
|
},
|
||||||
|
[&fi, fbegin]() { fi->Seek(fbegin); });
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/*! \brief destructor */
|
/*! \brief destructor */
|
||||||
~SparsePageSource() override {
|
~ExternalMemoryPrefetcher() override {
|
||||||
delete page_;
|
delete page_;
|
||||||
}
|
}
|
||||||
|
|
||||||
// implement Next
|
// implement Next
|
||||||
bool Next() override {
|
bool Next() override {
|
||||||
|
CHECK(mutex_.try_lock()) << "Multiple threads attempting to use prefetcher";
|
||||||
// doing clock rotation over shards.
|
// doing clock rotation over shards.
|
||||||
if (page_ != nullptr) {
|
if (page_ != nullptr) {
|
||||||
size_t n = prefetchers_.size();
|
size_t n = prefetchers_.size();
|
||||||
prefetchers_[(clock_ptr_ + n - 1) % n]->Recycle(&page_);
|
prefetchers_[(clock_ptr_ + n - 1) % n]->Recycle(&page_);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (prefetchers_[clock_ptr_]->Next(&page_)) {
|
if (prefetchers_[clock_ptr_]->Next(&page_)) {
|
||||||
page_->SetBaseRowId(base_rowid_);
|
page_->SetBaseRowId(base_rowid_);
|
||||||
base_rowid_ += page_->Size();
|
base_rowid_ += page_->Size();
|
||||||
// advance clock
|
// advance clock
|
||||||
clock_ptr_ = (clock_ptr_ + 1) % prefetchers_.size();
|
clock_ptr_ = (clock_ptr_ + 1) % prefetchers_.size();
|
||||||
|
mutex_.unlock();
|
||||||
return true;
|
return true;
|
||||||
} else {
|
} else {
|
||||||
|
mutex_.unlock();
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// implement BeforeFirst
|
// implement BeforeFirst
|
||||||
void BeforeFirst() override {
|
void BeforeFirst() override {
|
||||||
|
CHECK(mutex_.try_lock()) << "Multiple threads attempting to use prefetcher";
|
||||||
base_rowid_ = 0;
|
base_rowid_ = 0;
|
||||||
clock_ptr_ = 0;
|
clock_ptr_ = 0;
|
||||||
for (auto& p : prefetchers_) {
|
for (auto& p : prefetchers_) {
|
||||||
p->BeforeFirst();
|
p->BeforeFirst();
|
||||||
}
|
}
|
||||||
|
mutex_.unlock();
|
||||||
}
|
}
|
||||||
|
|
||||||
// implement Value
|
// implement Value
|
||||||
T& Value() {
|
PageT& Value() { return *page_; }
|
||||||
return *page_;
|
|
||||||
}
|
|
||||||
|
|
||||||
const T& Value() const override {
|
const PageT& Value() const override { return *page_; }
|
||||||
return *page_;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::mutex mutex_;
|
||||||
|
/*! \brief number of rows */
|
||||||
|
size_t base_rowid_;
|
||||||
|
/*! \brief page currently on hold. */
|
||||||
|
PageT* page_;
|
||||||
|
/*! \brief internal clock ptr */
|
||||||
|
size_t clock_ptr_;
|
||||||
|
/*! \brief file pointer to the row blob file. */
|
||||||
|
std::vector<std::unique_ptr<dmlc::SeekStream>> files_;
|
||||||
|
/*! \brief Sparse page format file. */
|
||||||
|
std::vector<std::unique_ptr<SparsePageFormat<PageT>>> formats_;
|
||||||
|
/*! \brief internal prefetcher. */
|
||||||
|
std::vector<std::unique_ptr<dmlc::ThreadedIter<PageT>>> prefetchers_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class SparsePageSource {
|
||||||
|
public:
|
||||||
template <typename AdapterT>
|
template <typename AdapterT>
|
||||||
static void CreateRowPage(AdapterT* adapter, float missing, int nthread,
|
SparsePageSource(AdapterT* adapter, float missing, int nthread,
|
||||||
const std::string& cache_info,
|
const std::string& cache_info,
|
||||||
const size_t page_size = DMatrix::kPageSize) {
|
const size_t page_size = DMatrix::kPageSize) {
|
||||||
const std::string page_type = ".row.page";
|
const std::string page_type = ".row.page";
|
||||||
auto cinfo = ParseCacheInfo(cache_info, page_type);
|
cache_info_ = ParseCacheInfo(cache_info, page_type);
|
||||||
|
|
||||||
|
// Warn user if old cache files
|
||||||
|
CheckCacheFileExists(cache_info_.name_info);
|
||||||
|
for (auto file : cache_info_.name_shards) {
|
||||||
|
CheckCacheFileExists(file);
|
||||||
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
SparsePageWriter<SparsePage> writer(cinfo.name_shards,
|
SparsePageWriter<SparsePage> writer(cache_info_.name_shards,
|
||||||
cinfo.format_shards, 6);
|
cache_info_.format_shards, 6);
|
||||||
std::shared_ptr<SparsePage> page;
|
std::shared_ptr<SparsePage> page;
|
||||||
writer.Alloc(&page);
|
writer.Alloc(&page);
|
||||||
page->Clear();
|
page->Clear();
|
||||||
|
|
||||||
uint64_t inferred_num_columns = 0;
|
uint64_t inferred_num_columns = 0;
|
||||||
uint64_t inferred_num_rows = 0;
|
uint64_t inferred_num_rows = 0;
|
||||||
MetaInfo info;
|
|
||||||
size_t bytes_write = 0;
|
size_t bytes_write = 0;
|
||||||
double tstart = dmlc::GetTime();
|
double tstart = dmlc::GetTime();
|
||||||
// print every 4 sec.
|
// print every 4 sec.
|
||||||
@ -232,7 +287,8 @@ class SparsePageSource : public DataSource<T> {
|
|||||||
// get group
|
// get group
|
||||||
for (size_t i = 0; i < batch.Size(); ++i) {
|
for (size_t i = 0; i < batch.Size(); ++i) {
|
||||||
const uint64_t cur_group_id = batch.Qid()[i];
|
const uint64_t cur_group_id = batch.Qid()[i];
|
||||||
if (last_group_id == default_max || last_group_id != cur_group_id) {
|
if (last_group_id == default_max ||
|
||||||
|
last_group_id != cur_group_id) {
|
||||||
info.group_ptr_.push_back(group_size);
|
info.group_ptr_.push_back(group_size);
|
||||||
}
|
}
|
||||||
last_group_id = cur_group_id;
|
last_group_id = cur_group_id;
|
||||||
@ -300,61 +356,53 @@ class SparsePageSource : public DataSource<T> {
|
|||||||
writer.PushWrite(std::move(page));
|
writer.PushWrite(std::move(page));
|
||||||
}
|
}
|
||||||
std::unique_ptr<dmlc::Stream> fo(
|
std::unique_ptr<dmlc::Stream> fo(
|
||||||
dmlc::Stream::Create(cinfo.name_info.c_str(), "w"));
|
dmlc::Stream::Create(cache_info_.name_info.c_str(), "w"));
|
||||||
int tmagic = kMagic;
|
int tmagic = kMagic;
|
||||||
fo->Write(&tmagic, sizeof(tmagic));
|
fo->Write(&tmagic, sizeof(tmagic));
|
||||||
// Either every row has query ID or none at all
|
// Either every row has query ID or none at all
|
||||||
CHECK(qids.empty() || qids.size() == info.num_row_);
|
CHECK(qids.empty() || qids.size() == info.num_row_);
|
||||||
info.SaveBinary(fo.get());
|
info.SaveBinary(fo.get());
|
||||||
}
|
}
|
||||||
LOG(INFO) << "SparsePageSource::CreateRowPage Finished writing to "
|
LOG(INFO) << "SparsePageSource Finished writing to "
|
||||||
<< cinfo.name_info;
|
<< cache_info_.name_info;
|
||||||
}
|
|
||||||
/*!
|
external_prefetcher_.reset(
|
||||||
* \brief Create source cache by copy content from DMatrix.
|
new ExternalMemoryPrefetcher<SparsePage>(cache_info_));
|
||||||
* Creates transposed column page, may be sorted or not.
|
|
||||||
* \param cache_info The cache_info of cache file location.
|
|
||||||
* \param sorted Whether columns should be pre-sorted
|
|
||||||
*/
|
|
||||||
static void CreateColumnPage(DMatrix* src,
|
|
||||||
const std::string& cache_info, bool sorted) {
|
|
||||||
const std::string page_type = sorted ? ".sorted.col.page" : ".col.page";
|
|
||||||
CreatePageFromDMatrix(src, cache_info, page_type);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/*!
|
~SparsePageSource() {
|
||||||
* \brief Check if the cache file already exists.
|
external_prefetcher_.reset();
|
||||||
* \param cache_info The cache prefix of files.
|
TryDeleteCacheFile(cache_info_.name_info);
|
||||||
* \param page_type Type of the page.
|
for (auto file : cache_info_.name_shards) {
|
||||||
* \return Whether cache file already exists.
|
TryDeleteCacheFile(file);
|
||||||
*/
|
|
||||||
static bool CacheExist(const std::string& cache_info,
|
|
||||||
const std::string& page_type) {
|
|
||||||
std::vector<std::string> cache_shards = GetCacheShards(cache_info);
|
|
||||||
CHECK_NE(cache_shards.size(), 0U);
|
|
||||||
{
|
|
||||||
std::string name_info = cache_shards[0];
|
|
||||||
std::unique_ptr<dmlc::Stream> finfo(dmlc::Stream::Create(name_info.c_str(), "r", true));
|
|
||||||
if (finfo == nullptr) return false;
|
|
||||||
}
|
}
|
||||||
for (const std::string& prefix : cache_shards) {
|
|
||||||
std::string name_row = prefix + page_type;
|
|
||||||
std::unique_ptr<dmlc::Stream> frow(dmlc::Stream::Create(name_row.c_str(), "r", true));
|
|
||||||
if (frow == nullptr) return false;
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/*! \brief magic number used to identify Page */
|
BatchSet<SparsePage> GetBatchSet() {
|
||||||
static const int kMagic = 0xffffab02;
|
auto begin_iter = BatchIterator<SparsePage>(
|
||||||
|
new SparseBatchIteratorImpl<ExternalMemoryPrefetcher<SparsePage>,
|
||||||
|
SparsePage>(external_prefetcher_.get()));
|
||||||
|
return BatchSet<SparsePage>(begin_iter);
|
||||||
|
}
|
||||||
|
MetaInfo info;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
static void CreatePageFromDMatrix(DMatrix* src, const std::string& cache_info,
|
std::unique_ptr<ExternalMemoryPrefetcher<SparsePage>> external_prefetcher_;
|
||||||
const std::string& page_type,
|
CacheInfo cache_info_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class CSCPageSource {
|
||||||
|
public:
|
||||||
|
CSCPageSource(DMatrix* src, const std::string& cache_info,
|
||||||
const size_t page_size = DMatrix::kPageSize) {
|
const size_t page_size = DMatrix::kPageSize) {
|
||||||
auto cinfo = ParseCacheInfo(cache_info, page_type);
|
std::string page_type = ".col.page";
|
||||||
|
cache_info_ = ParseCacheInfo(cache_info, page_type);
|
||||||
|
for (auto file : cache_info_.name_shards) {
|
||||||
|
CheckCacheFileExists(file);
|
||||||
|
}
|
||||||
{
|
{
|
||||||
SparsePageWriter<SparsePage> writer(cinfo.name_shards, cinfo.format_shards, 6);
|
SparsePageWriter<SparsePage> writer(cache_info_.name_shards,
|
||||||
|
cache_info_.format_shards, 6);
|
||||||
std::shared_ptr<SparsePage> page;
|
std::shared_ptr<SparsePage> page;
|
||||||
writer.Alloc(&page);
|
writer.Alloc(&page);
|
||||||
page->Clear();
|
page->Clear();
|
||||||
@ -362,15 +410,7 @@ class SparsePageSource : public DataSource<T> {
|
|||||||
size_t bytes_write = 0;
|
size_t bytes_write = 0;
|
||||||
double tstart = dmlc::GetTime();
|
double tstart = dmlc::GetTime();
|
||||||
for (auto& batch : src->GetBatches<SparsePage>()) {
|
for (auto& batch : src->GetBatches<SparsePage>()) {
|
||||||
if (page_type == ".col.page") {
|
|
||||||
page->PushCSC(batch.GetTranspose(src->Info().num_col_));
|
page->PushCSC(batch.GetTranspose(src->Info().num_col_));
|
||||||
} else if (page_type == ".sorted.col.page") {
|
|
||||||
SparsePage tmp = batch.GetTranspose(src->Info().num_col_);
|
|
||||||
page->PushCSC(tmp);
|
|
||||||
page->SortRows();
|
|
||||||
} else {
|
|
||||||
LOG(FATAL) << "Unknown page type: " << page_type;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (page->MemCostBytes() >= page_size) {
|
if (page->MemCostBytes() >= page_size) {
|
||||||
bytes_write += page->MemCostBytes();
|
bytes_write += page->MemCostBytes();
|
||||||
@ -386,23 +426,94 @@ class SparsePageSource : public DataSource<T> {
|
|||||||
if (page->data.Size() != 0) {
|
if (page->data.Size() != 0) {
|
||||||
writer.PushWrite(std::move(page));
|
writer.PushWrite(std::move(page));
|
||||||
}
|
}
|
||||||
|
LOG(INFO) << "CSCPageSource: Finished writing to "
|
||||||
|
<< cache_info_.name_info;
|
||||||
}
|
}
|
||||||
LOG(INFO) << "SparsePageSource: Finished writing to " << cinfo.name_info;
|
external_prefetcher_.reset(
|
||||||
|
new ExternalMemoryPrefetcher<CSCPage>(cache_info_));
|
||||||
}
|
}
|
||||||
|
|
||||||
/*! \brief number of rows */
|
~CSCPageSource() {
|
||||||
size_t base_rowid_;
|
external_prefetcher_.reset();
|
||||||
/*! \brief page currently on hold. */
|
for (auto file : cache_info_.name_shards) {
|
||||||
T* page_;
|
TryDeleteCacheFile(file);
|
||||||
/*! \brief internal clock ptr */
|
}
|
||||||
size_t clock_ptr_;
|
}
|
||||||
/*! \brief file pointer to the row blob file. */
|
|
||||||
std::vector<std::unique_ptr<dmlc::SeekStream>> files_;
|
BatchSet<CSCPage> GetBatchSet() {
|
||||||
/*! \brief Sparse page format file. */
|
auto begin_iter = BatchIterator<CSCPage>(
|
||||||
std::vector<std::unique_ptr<SparsePageFormat<T>>> formats_;
|
new SparseBatchIteratorImpl<ExternalMemoryPrefetcher<CSCPage>, CSCPage>(
|
||||||
/*! \brief internal prefetcher. */
|
external_prefetcher_.get()));
|
||||||
std::vector<std::unique_ptr<dmlc::ThreadedIter<T>>> prefetchers_;
|
return BatchSet<CSCPage>(begin_iter);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::unique_ptr<ExternalMemoryPrefetcher<CSCPage>> external_prefetcher_;
|
||||||
|
CacheInfo cache_info_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class SortedCSCPageSource {
|
||||||
|
public:
|
||||||
|
SortedCSCPageSource(DMatrix* src, const std::string& cache_info,
|
||||||
|
const size_t page_size = DMatrix::kPageSize) {
|
||||||
|
std::string page_type = ".sorted.col.page";
|
||||||
|
cache_info_ = ParseCacheInfo(cache_info, page_type);
|
||||||
|
for (auto file : cache_info_.name_shards) {
|
||||||
|
CheckCacheFileExists(file);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
SparsePageWriter<SparsePage> writer(cache_info_.name_shards,
|
||||||
|
cache_info_.format_shards, 6);
|
||||||
|
std::shared_ptr<SparsePage> page;
|
||||||
|
writer.Alloc(&page);
|
||||||
|
page->Clear();
|
||||||
|
|
||||||
|
size_t bytes_write = 0;
|
||||||
|
double tstart = dmlc::GetTime();
|
||||||
|
for (auto& batch : src->GetBatches<SparsePage>()) {
|
||||||
|
SparsePage tmp = batch.GetTranspose(src->Info().num_col_);
|
||||||
|
page->PushCSC(tmp);
|
||||||
|
page->SortRows();
|
||||||
|
|
||||||
|
if (page->MemCostBytes() >= page_size) {
|
||||||
|
bytes_write += page->MemCostBytes();
|
||||||
|
writer.PushWrite(std::move(page));
|
||||||
|
writer.Alloc(&page);
|
||||||
|
page->Clear();
|
||||||
|
double tdiff = dmlc::GetTime() - tstart;
|
||||||
|
LOG(INFO) << "Writing to " << cache_info << " in "
|
||||||
|
<< ((bytes_write >> 20UL) / tdiff) << " MB/s, "
|
||||||
|
<< (bytes_write >> 20UL) << " written";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (page->data.Size() != 0) {
|
||||||
|
writer.PushWrite(std::move(page));
|
||||||
|
}
|
||||||
|
LOG(INFO) << "SortedCSCPageSource: Finished writing to "
|
||||||
|
<< cache_info_.name_info;
|
||||||
|
}
|
||||||
|
external_prefetcher_.reset(
|
||||||
|
new ExternalMemoryPrefetcher<SortedCSCPage>(cache_info_));
|
||||||
|
}
|
||||||
|
~SortedCSCPageSource() {
|
||||||
|
external_prefetcher_.reset();
|
||||||
|
for (auto file : cache_info_.name_shards) {
|
||||||
|
TryDeleteCacheFile(file);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
BatchSet<SortedCSCPage> GetBatchSet() {
|
||||||
|
auto begin_iter = BatchIterator<SortedCSCPage>(
|
||||||
|
new SparseBatchIteratorImpl<ExternalMemoryPrefetcher<SortedCSCPage>,
|
||||||
|
SortedCSCPage>(external_prefetcher_.get()));
|
||||||
|
return BatchSet<SortedCSCPage>(begin_iter);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::unique_ptr<ExternalMemoryPrefetcher<SortedCSCPage>> external_prefetcher_;
|
||||||
|
CacheInfo cache_info_;
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace data
|
} // namespace data
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
#endif // XGBOOST_DATA_SPARSE_PAGE_SOURCE_H_
|
#endif // XGBOOST_DATA_SPARSE_PAGE_SOURCE_H_
|
||||||
|
|||||||
@ -1,10 +1,10 @@
|
|||||||
// Copyright by Contributors
|
// Copyright by Contributors
|
||||||
#include <dmlc/filesystem.h>
|
#include <dmlc/filesystem.h>
|
||||||
#include <xgboost/data.h>
|
|
||||||
#include "../../../src/data/sparse_page_dmatrix.h"
|
|
||||||
#include "../../../src/data/adapter.h"
|
|
||||||
#include "../helpers.h"
|
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
|
#include <xgboost/data.h>
|
||||||
|
#include "../../../src/data/adapter.h"
|
||||||
|
#include "../../../src/data/sparse_page_dmatrix.h"
|
||||||
|
#include "../helpers.h"
|
||||||
|
|
||||||
using namespace xgboost; // NOLINT
|
using namespace xgboost; // NOLINT
|
||||||
|
|
||||||
@ -44,8 +44,8 @@ TEST(SparsePageDMatrix, ColAccess) {
|
|||||||
dmlc::TemporaryDirectory tempdir;
|
dmlc::TemporaryDirectory tempdir;
|
||||||
const std::string tmp_file = tempdir.path + "/simple.libsvm";
|
const std::string tmp_file = tempdir.path + "/simple.libsvm";
|
||||||
CreateSimpleTestData(tmp_file);
|
CreateSimpleTestData(tmp_file);
|
||||||
xgboost::DMatrix * dmat = xgboost::DMatrix::Load(
|
xgboost::DMatrix *dmat =
|
||||||
tmp_file + "#" + tmp_file + ".cache", true, false);
|
xgboost::DMatrix::Load(tmp_file + "#" + tmp_file + ".cache", true, false);
|
||||||
|
|
||||||
EXPECT_EQ(dmat->GetColDensity(0), 1);
|
EXPECT_EQ(dmat->GetColDensity(0), 1);
|
||||||
EXPECT_EQ(dmat->GetColDensity(1), 0.5);
|
EXPECT_EQ(dmat->GetColDensity(1), 0.5);
|
||||||
@ -70,16 +70,53 @@ TEST(SparsePageDMatrix, ColAccess) {
|
|||||||
EXPECT_TRUE(FileExists(tmp_file + ".cache.sorted.col.page"));
|
EXPECT_TRUE(FileExists(tmp_file + ".cache.sorted.col.page"));
|
||||||
|
|
||||||
delete dmat;
|
delete dmat;
|
||||||
|
|
||||||
|
EXPECT_FALSE(FileExists(tmp_file + ".cache"));
|
||||||
|
EXPECT_FALSE(FileExists(tmp_file + ".cache.row.page"));
|
||||||
|
EXPECT_FALSE(FileExists(tmp_file + ".cache.col.page"));
|
||||||
|
EXPECT_FALSE(FileExists(tmp_file + ".cache.sorted.col.page"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(SparsePageDMatrix, ExistingCacheFile) {
|
||||||
|
dmlc::TemporaryDirectory tmpdir;
|
||||||
|
std::string filename = tmpdir.path + "/big.libsvm";
|
||||||
|
std::unique_ptr<xgboost::DMatrix> dmat =
|
||||||
|
xgboost::CreateSparsePageDMatrix(12, 64, filename);
|
||||||
|
EXPECT_ANY_THROW({
|
||||||
|
std::unique_ptr<xgboost::DMatrix> dmat2 =
|
||||||
|
xgboost::CreateSparsePageDMatrix(12, 64, filename);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
#if defined(_OPENMP)
|
||||||
|
TEST(SparsePageDMatrix, ThreadSafetyException) {
|
||||||
|
dmlc::TemporaryDirectory tmpdir;
|
||||||
|
std::string filename = tmpdir.path + "/test";
|
||||||
|
std::unique_ptr<xgboost::DMatrix> dmat =
|
||||||
|
xgboost::CreateSparsePageDMatrix(12, 64, filename);
|
||||||
|
|
||||||
|
bool exception = false;
|
||||||
|
int threads = 1000;
|
||||||
|
#pragma omp parallel for
|
||||||
|
for (auto i = 0; i < threads; i++) {
|
||||||
|
try {
|
||||||
|
auto iter = dmat->GetBatches<SparsePage>().begin();
|
||||||
|
++iter;
|
||||||
|
} catch (...) {
|
||||||
|
exception = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
EXPECT_TRUE(exception);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
// Multi-batches access
|
// Multi-batches access
|
||||||
TEST(SparsePageDMatrix, ColAccessBatches) {
|
TEST(SparsePageDMatrix, ColAccessBatches) {
|
||||||
dmlc::TemporaryDirectory tmpdir;
|
dmlc::TemporaryDirectory tmpdir;
|
||||||
std::string filename = tmpdir.path + "/big.libsvm";
|
std::string filename = tmpdir.path + "/big.libsvm";
|
||||||
// Create multiple sparse pages
|
// Create multiple sparse pages
|
||||||
std::unique_ptr<xgboost::DMatrix> dmat{
|
std::unique_ptr<xgboost::DMatrix> dmat{
|
||||||
xgboost::CreateSparsePageDMatrix(1024, 1024, filename)
|
xgboost::CreateSparsePageDMatrix(1024, 1024, filename)};
|
||||||
};
|
|
||||||
auto n_threads = omp_get_max_threads();
|
auto n_threads = omp_get_max_threads();
|
||||||
omp_set_num_threads(16);
|
omp_set_num_threads(16);
|
||||||
for (auto const &page : dmat->GetBatches<xgboost::CSCPage>()) {
|
for (auto const &page : dmat->GetBatches<xgboost::CSCPage>()) {
|
||||||
@ -88,7 +125,6 @@ TEST(SparsePageDMatrix, ColAccessBatches) {
|
|||||||
omp_set_num_threads(n_threads);
|
omp_set_num_threads(n_threads);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
TEST(SparsePageDMatrix, Empty) {
|
TEST(SparsePageDMatrix, Empty) {
|
||||||
dmlc::TemporaryDirectory tempdir;
|
dmlc::TemporaryDirectory tempdir;
|
||||||
const std::string tmp_file = tempdir.path + "/simple.libsvm";
|
const std::string tmp_file = tempdir.path + "/simple.libsvm";
|
||||||
@ -96,29 +132,34 @@ TEST(SparsePageDMatrix, Empty) {
|
|||||||
std::vector<unsigned> feature_idx = {};
|
std::vector<unsigned> feature_idx = {};
|
||||||
std::vector<size_t> row_ptr = {};
|
std::vector<size_t> row_ptr = {};
|
||||||
|
|
||||||
data::CSRAdapter csr_adapter(row_ptr.data(), feature_idx.data(), data.data(), 0, 0, 0);
|
{
|
||||||
data::SparsePageDMatrix dmat(&csr_adapter,
|
data::CSRAdapter csr_adapter(row_ptr.data(), feature_idx.data(),
|
||||||
std::numeric_limits<float>::quiet_NaN(), 1,tmp_file);
|
data.data(), 0, 0, 0);
|
||||||
|
data::SparsePageDMatrix dmat(
|
||||||
|
&csr_adapter, std::numeric_limits<float>::quiet_NaN(), 1, tmp_file);
|
||||||
EXPECT_EQ(dmat.Info().num_nonzero_, 0);
|
EXPECT_EQ(dmat.Info().num_nonzero_, 0);
|
||||||
EXPECT_EQ(dmat.Info().num_row_, 0);
|
EXPECT_EQ(dmat.Info().num_row_, 0);
|
||||||
EXPECT_EQ(dmat.Info().num_col_, 0);
|
EXPECT_EQ(dmat.Info().num_col_, 0);
|
||||||
for (auto &batch : dmat.GetBatches<SparsePage>()) {
|
for (auto &batch : dmat.GetBatches<SparsePage>()) {
|
||||||
EXPECT_EQ(batch.Size(), 0);
|
EXPECT_EQ(batch.Size(), 0);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
data::DenseAdapter dense_adapter(nullptr, 0, 0);
|
data::DenseAdapter dense_adapter(nullptr, 0, 0);
|
||||||
data::SparsePageDMatrix dmat2(&dense_adapter,
|
data::SparsePageDMatrix dmat2(
|
||||||
std::numeric_limits<float>::quiet_NaN(), 1,tmp_file);
|
&dense_adapter, std::numeric_limits<float>::quiet_NaN(), 1, tmp_file);
|
||||||
EXPECT_EQ(dmat2.Info().num_nonzero_, 0);
|
EXPECT_EQ(dmat2.Info().num_nonzero_, 0);
|
||||||
EXPECT_EQ(dmat2.Info().num_row_, 0);
|
EXPECT_EQ(dmat2.Info().num_row_, 0);
|
||||||
EXPECT_EQ(dmat2.Info().num_col_, 0);
|
EXPECT_EQ(dmat2.Info().num_col_, 0);
|
||||||
for (auto &batch : dmat2.GetBatches<SparsePage>()) {
|
for (auto &batch : dmat2.GetBatches<SparsePage>()) {
|
||||||
EXPECT_EQ(batch.Size(), 0);
|
EXPECT_EQ(batch.Size(), 0);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
{
|
||||||
data::CSCAdapter csc_adapter(nullptr, nullptr, nullptr, 0, 0);
|
data::CSCAdapter csc_adapter(nullptr, nullptr, nullptr, 0, 0);
|
||||||
data::SparsePageDMatrix dmat3(&csc_adapter,
|
data::SparsePageDMatrix dmat3(
|
||||||
std::numeric_limits<float>::quiet_NaN(), 1,tmp_file);
|
&csc_adapter, std::numeric_limits<float>::quiet_NaN(), 1, tmp_file);
|
||||||
EXPECT_EQ(dmat3.Info().num_nonzero_, 0);
|
EXPECT_EQ(dmat3.Info().num_nonzero_, 0);
|
||||||
EXPECT_EQ(dmat3.Info().num_row_, 0);
|
EXPECT_EQ(dmat3.Info().num_row_, 0);
|
||||||
EXPECT_EQ(dmat3.Info().num_col_, 0);
|
EXPECT_EQ(dmat3.Info().num_col_, 0);
|
||||||
@ -126,6 +167,7 @@ TEST(SparsePageDMatrix, Empty) {
|
|||||||
EXPECT_EQ(batch.Size(), 0);
|
EXPECT_EQ(batch.Size(), 0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
TEST(SparsePageDMatrix, MissingData) {
|
TEST(SparsePageDMatrix, MissingData) {
|
||||||
dmlc::TemporaryDirectory tempdir;
|
dmlc::TemporaryDirectory tempdir;
|
||||||
@ -134,8 +176,10 @@ TEST(SparsePageDMatrix, MissingData) {
|
|||||||
std::vector<unsigned> feature_idx = {0, 1, 0};
|
std::vector<unsigned> feature_idx = {0, 1, 0};
|
||||||
std::vector<size_t> row_ptr = {0, 2, 3};
|
std::vector<size_t> row_ptr = {0, 2, 3};
|
||||||
|
|
||||||
data::CSRAdapter adapter(row_ptr.data(), feature_idx.data(), data.data(), 2, 3, 2);
|
data::CSRAdapter adapter(row_ptr.data(), feature_idx.data(), data.data(), 2,
|
||||||
data::SparsePageDMatrix dmat(&adapter, std::numeric_limits<float>::quiet_NaN(), 1,tmp_file);
|
3, 2);
|
||||||
|
data::SparsePageDMatrix dmat(
|
||||||
|
&adapter, std::numeric_limits<float>::quiet_NaN(), 1, tmp_file);
|
||||||
EXPECT_EQ(dmat.Info().num_nonzero_, 2);
|
EXPECT_EQ(dmat.Info().num_nonzero_, 2);
|
||||||
|
|
||||||
const std::string tmp_file2 = tempdir.path + "/simple2.libsvm";
|
const std::string tmp_file2 = tempdir.path + "/simple2.libsvm";
|
||||||
@ -150,8 +194,10 @@ TEST(SparsePageDMatrix, EmptyRow) {
|
|||||||
std::vector<unsigned> feature_idx = {0, 1};
|
std::vector<unsigned> feature_idx = {0, 1};
|
||||||
std::vector<size_t> row_ptr = {0, 2, 2};
|
std::vector<size_t> row_ptr = {0, 2, 2};
|
||||||
|
|
||||||
data::CSRAdapter adapter(row_ptr.data(), feature_idx.data(), data.data(), 2, 2, 2);
|
data::CSRAdapter adapter(row_ptr.data(), feature_idx.data(), data.data(), 2,
|
||||||
data::SparsePageDMatrix dmat(&adapter, std::numeric_limits<float>::quiet_NaN(), 1,tmp_file);
|
2, 2);
|
||||||
|
data::SparsePageDMatrix dmat(
|
||||||
|
&adapter, std::numeric_limits<float>::quiet_NaN(), 1, tmp_file);
|
||||||
EXPECT_EQ(dmat.Info().num_nonzero_, 2);
|
EXPECT_EQ(dmat.Info().num_nonzero_, 2);
|
||||||
EXPECT_EQ(dmat.Info().num_row_, 2);
|
EXPECT_EQ(dmat.Info().num_row_, 2);
|
||||||
EXPECT_EQ(dmat.Info().num_col_, 2);
|
EXPECT_EQ(dmat.Info().num_col_, 2);
|
||||||
@ -173,8 +219,7 @@ TEST(SparsePageDMatrix, FromDense) {
|
|||||||
for (auto &batch : dmat.GetBatches<SparsePage>()) {
|
for (auto &batch : dmat.GetBatches<SparsePage>()) {
|
||||||
for (auto i = 0ull; i < batch.Size(); i++) {
|
for (auto i = 0ull; i < batch.Size(); i++) {
|
||||||
auto inst = batch[i];
|
auto inst = batch[i];
|
||||||
for(auto j = 0ull; j < inst.size(); j++)
|
for (auto j = 0ull; j < inst.size(); j++) {
|
||||||
{
|
|
||||||
EXPECT_EQ(inst[j].fvalue, data[i * n + j]);
|
EXPECT_EQ(inst[j].fvalue, data[i * n + j]);
|
||||||
EXPECT_EQ(inst[j].index, j);
|
EXPECT_EQ(inst[j].index, j);
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user