Dmatrix refactor stage 1 (#3301)

* Use sparse page as singular CSR matrix representation

* Simplify dmatrix methods

* Reduce statefullness of batch iterators

* BREAKING CHANGE: Remove prob_buffer_row parameter. Users are instead recommended to sample their dataset as a preprocessing step before using XGBoost.
This commit is contained in:
Rory Mitchell
2018-06-07 10:25:58 +12:00
committed by GitHub
parent 286dccb8e8
commit a96039141a
47 changed files with 650 additions and 1036 deletions

View File

@@ -6,7 +6,7 @@
#include <xgboost/logging.h>
#include <dmlc/registry.h>
#include <cstring>
#include "./sparse_batch_page.h"
#include "./sparse_page_writer.h"
#include "./simple_dmatrix.h"
#include "./simple_csr_source.h"
#include "../common/common.h"
@@ -278,8 +278,7 @@ DMatrix* DMatrix::Create(std::unique_ptr<DataSource>&& source,
} // namespace xgboost
namespace xgboost {
namespace data {
SparsePage::Format* SparsePage::Format::Create(const std::string& name) {
data::SparsePageFormat* data::SparsePageFormat::Create(const std::string& name) {
auto *e = ::dmlc::Registry< ::xgboost::data::SparsePageFormatReg>::Get()->Find(name);
if (e == nullptr) {
LOG(FATAL) << "Unknown format type " << name;
@@ -288,7 +287,7 @@ SparsePage::Format* SparsePage::Format::Create(const std::string& name) {
}
std::pair<std::string, std::string>
SparsePage::Format::DecideFormat(const std::string& cache_prefix) {
data::SparsePageFormat::DecideFormat(const std::string& cache_prefix) {
size_t pos = cache_prefix.rfind(".fmt-");
if (pos != std::string::npos) {
@@ -305,6 +304,7 @@ SparsePage::Format::DecideFormat(const std::string& cache_prefix) {
}
}
namespace data {
// List of files that will be force linked in static links.
DMLC_REGISTRY_LINK_TAG(sparse_page_raw_format);
} // namespace data

View File

@@ -10,24 +10,18 @@ namespace xgboost {
namespace data {
void SimpleCSRSource::Clear() {
row_data_.clear();
row_ptr_.resize(1);
row_ptr_[0] = 0;
page_.Clear();
this->info.Clear();
}
void SimpleCSRSource::CopyFrom(DMatrix* src) {
this->Clear();
this->info = src->Info();
dmlc::DataIter<RowBatch>* iter = src->RowIterator();
auto iter = src->RowIterator();
iter->BeforeFirst();
while (iter->Next()) {
const RowBatch &batch = iter->Value();
for (size_t i = 0; i < batch.size; ++i) {
RowBatch::Inst inst = batch[i];
row_data_.insert(row_data_.end(), inst.data, inst.data + inst.length);
row_ptr_.push_back(row_ptr_.back() + inst.length);
}
const auto &batch = iter->Value();
page_.Push(batch);
}
}
@@ -53,16 +47,16 @@ void SimpleCSRSource::CopyFrom(dmlc::Parser<uint32_t>* parser) {
for (size_t i = batch.offset[0]; i < batch.offset[batch.size]; ++i) {
uint32_t index = batch.index[i];
bst_float fvalue = batch.value == nullptr ? 1.0f : batch.value[i];
row_data_.emplace_back(index, fvalue);
page_.data.emplace_back(index, fvalue);
this->info.num_col_ = std::max(this->info.num_col_,
static_cast<uint64_t>(index + 1));
}
size_t top = row_ptr_.size();
size_t top = page_.offset.size();
for (size_t i = 0; i < batch.size; ++i) {
row_ptr_.push_back(row_ptr_[top - 1] + batch.offset[i + 1] - batch.offset[0]);
page_.offset.push_back(page_.offset[top - 1] + batch.offset[i + 1] - batch.offset[0]);
}
}
this->info.num_nonzero_ = static_cast<uint64_t>(row_data_.size());
this->info.num_nonzero_ = static_cast<uint64_t>(page_.data.size());
}
void SimpleCSRSource::LoadBinary(dmlc::Stream* fi) {
@@ -70,16 +64,16 @@ void SimpleCSRSource::LoadBinary(dmlc::Stream* fi) {
CHECK(fi->Read(&tmagic, sizeof(tmagic)) == sizeof(tmagic)) << "invalid input file format";
CHECK_EQ(tmagic, kMagic) << "invalid format, magic number mismatch";
info.LoadBinary(fi);
fi->Read(&row_ptr_);
fi->Read(&row_data_);
fi->Read(&page_.offset);
fi->Read(&page_.data);
}
void SimpleCSRSource::SaveBinary(dmlc::Stream* fo) const {
int tmagic = kMagic;
fo->Write(&tmagic, sizeof(tmagic));
info.SaveBinary(fo);
fo->Write(row_ptr_);
fo->Write(row_data_);
fo->Write(page_.offset);
fo->Write(page_.data);
}
void SimpleCSRSource::BeforeFirst() {
@@ -89,15 +83,11 @@ void SimpleCSRSource::BeforeFirst() {
bool SimpleCSRSource::Next() {
if (!at_first_) return false;
at_first_ = false;
batch_.size = row_ptr_.size() - 1;
batch_.base_rowid = 0;
batch_.ind_ptr = dmlc::BeginPtr(row_ptr_);
batch_.data_ptr = dmlc::BeginPtr(row_data_);
return true;
}
const RowBatch& SimpleCSRSource::Value() const {
return batch_;
const SparsePage& SimpleCSRSource::Value() const {
return page_;
}
} // namespace data

View File

@@ -29,13 +29,9 @@ class SimpleCSRSource : public DataSource {
public:
// public data members
// MetaInfo info; // inheritated from DataSource
/*! \brief row pointer of CSR sparse storage */
std::vector<size_t> row_ptr_;
/*! \brief data in the CSR sparse storage */
std::vector<RowBatch::Entry> row_data_;
// functions
SparsePage page_;
/*! \brief default constructor */
SimpleCSRSource() : row_ptr_(1, 0) {}
SimpleCSRSource() = default;
/*! \brief destructor */
~SimpleCSRSource() override = default;
/*! \brief clear the data structure */
@@ -66,15 +62,13 @@ class SimpleCSRSource : public DataSource {
// implement BeforeFirst
void BeforeFirst() override;
// implement Value
const RowBatch &Value() const override;
const SparsePage &Value() const override;
/*! \brief magic number used to identify SimpleCSRSource */
static const int kMagic = 0xffffab01;
private:
/*! \brief internal variable, used to support iterator interface */
bool at_first_{true};
/*! \brief */
RowBatch batch_;
};
} // namespace data
} // namespace xgboost

View File

@@ -16,107 +16,50 @@ namespace xgboost {
namespace data {
bool SimpleDMatrix::ColBatchIter::Next() {
if (data_ptr_ >= cpages_.size()) return false;
data_ptr_ += 1;
SparsePage* pcol = cpages_[data_ptr_ - 1].get();
batch_.size = col_index_.size();
col_data_.resize(col_index_.size(), SparseBatch::Inst(nullptr, 0));
for (size_t i = 0; i < col_data_.size(); ++i) {
const bst_uint ridx = col_index_[i];
col_data_[i] = SparseBatch::Inst
(dmlc::BeginPtr(pcol->data) + pcol->offset[ridx],
static_cast<bst_uint>(pcol->offset[ridx + 1] - pcol->offset[ridx]));
}
batch_.col_index = dmlc::BeginPtr(col_index_);
batch_.col_data = dmlc::BeginPtr(col_data_);
if (data_ >= 1) return false;
data_ += 1;
return true;
}
dmlc::DataIter<ColBatch>* SimpleDMatrix::ColIterator() {
size_t ncol = this->Info().num_col_;
col_iter_.col_index_.resize(ncol);
for (size_t i = 0; i < ncol; ++i) {
col_iter_.col_index_[i] = static_cast<bst_uint>(i);
}
dmlc::DataIter<SparsePage>* SimpleDMatrix::ColIterator() {
col_iter_.BeforeFirst();
return &col_iter_;
}
dmlc::DataIter<ColBatch>* SimpleDMatrix::ColIterator(const std::vector<bst_uint>&fset) {
size_t ncol = this->Info().num_col_;
col_iter_.col_index_.resize(0);
for (auto fidx : fset) {
if (fidx < ncol) col_iter_.col_index_.push_back(fidx);
}
col_iter_.BeforeFirst();
return &col_iter_;
}
void SimpleDMatrix::InitColAccess(const std::vector<bool> &enabled,
float pkeep,
size_t max_row_perbatch, bool sorted) {
void SimpleDMatrix::InitColAccess(
size_t max_row_perbatch, bool sorted) {
if (this->HaveColAccess(sorted)) return;
col_iter_.sorted_ = sorted;
col_iter_.cpages_.clear();
if (Info().num_row_ < max_row_perbatch) {
std::unique_ptr<SparsePage> page(new SparsePage());
this->MakeOneBatch(enabled, pkeep, page.get(), sorted);
col_iter_.cpages_.push_back(std::move(page));
} else {
this->MakeManyBatch(enabled, pkeep, max_row_perbatch, sorted);
}
// setup col-size
col_size_.resize(Info().num_col_);
std::fill(col_size_.begin(), col_size_.end(), 0);
for (auto & cpage : col_iter_.cpages_) {
SparsePage *pcol = cpage.get();
for (size_t j = 0; j < pcol->Size(); ++j) {
col_size_[j] += pcol->offset[j + 1] - pcol->offset[j];
}
}
col_iter_.column_page_.reset(new SparsePage());
this->MakeOneBatch(col_iter_.column_page_.get(), sorted);
}
// internal function to make one batch from row iter.
void SimpleDMatrix::MakeOneBatch(const std::vector<bool>& enabled, float pkeep,
SparsePage* pcol, bool sorted) {
void SimpleDMatrix::MakeOneBatch(SparsePage* pcol, bool sorted) {
// clear rowset
buffered_rowset_.Clear();
// bit map
const int nthread = omp_get_max_threads();
std::vector<bool> bmap;
pcol->Clear();
common::ParallelGroupBuilder<SparseBatch::Entry>
common::ParallelGroupBuilder<Entry>
builder(&pcol->offset, &pcol->data);
builder.InitBudget(Info().num_col_, nthread);
// start working
dmlc::DataIter<RowBatch>* iter = this->RowIterator();
auto iter = this->RowIterator();
iter->BeforeFirst();
while (iter->Next()) {
const RowBatch& batch = iter->Value();
bmap.resize(bmap.size() + batch.size, true);
std::bernoulli_distribution coin_flip(pkeep);
auto& rnd = common::GlobalRandom();
long batch_size = static_cast<long>(batch.size); // NOLINT(*)
const auto& batch = iter->Value();
long batch_size = static_cast<long>(batch.Size()); // NOLINT(*)
for (long i = 0; i < batch_size; ++i) { // NOLINT(*)
auto ridx = static_cast<bst_uint>(batch.base_rowid + i);
if (pkeep == 1.0f || coin_flip(rnd)) {
buffered_rowset_.PushBack(ridx);
} else {
bmap[i] = false;
}
buffered_rowset_.PushBack(ridx);
}
#pragma omp parallel for schedule(static)
for (long i = 0; i < batch_size; ++i) { // NOLINT(*)
int tid = omp_get_thread_num();
auto ridx = static_cast<bst_uint>(batch.base_rowid + i);
if (bmap[ridx]) {
RowBatch::Inst inst = batch[i];
for (bst_uint j = 0; j < inst.length; ++j) {
if (enabled[inst[j].index]) {
builder.AddBudget(inst[j].index, tid);
}
}
auto inst = batch[i];
for (bst_uint j = 0; j < inst.length; ++j) {
builder.AddBudget(inst[j].index, tid);
}
}
}
@@ -124,20 +67,16 @@ void SimpleDMatrix::MakeOneBatch(const std::vector<bool>& enabled, float pkeep,
iter->BeforeFirst();
while (iter->Next()) {
const RowBatch& batch = iter->Value();
auto batch = iter->Value();
#pragma omp parallel for schedule(static)
for (long i = 0; i < static_cast<long>(batch.size); ++i) { // NOLINT(*)
for (long i = 0; i < static_cast<long>(batch.Size()); ++i) { // NOLINT(*)
int tid = omp_get_thread_num();
auto ridx = static_cast<bst_uint>(batch.base_rowid + i);
if (bmap[ridx]) {
RowBatch::Inst inst = batch[i];
for (bst_uint j = 0; j < inst.length; ++j) {
if (enabled[inst[j].index]) {
builder.Push(inst[j].index,
SparseBatch::Entry(static_cast<bst_uint>(batch.base_rowid+i),
inst[j].fvalue), tid);
}
}
auto inst = batch[i];
for (bst_uint j = 0; j < inst.length; ++j) {
builder.Push(
inst[j].index,
Entry(static_cast<bst_uint>(batch.base_rowid + i), inst[j].fvalue),
tid);
}
}
}
@@ -152,102 +91,14 @@ void SimpleDMatrix::MakeOneBatch(const std::vector<bool>& enabled, float pkeep,
if (pcol->offset[i] < pcol->offset[i + 1]) {
std::sort(dmlc::BeginPtr(pcol->data) + pcol->offset[i],
dmlc::BeginPtr(pcol->data) + pcol->offset[i + 1],
SparseBatch::Entry::CmpValue);
}
}
}
}
void SimpleDMatrix::MakeManyBatch(const std::vector<bool>& enabled,
float pkeep,
size_t max_row_perbatch, bool sorted) {
size_t btop = 0;
std::bernoulli_distribution coin_flip(pkeep);
auto& rnd = common::GlobalRandom();
buffered_rowset_.Clear();
// internal temp cache
SparsePage tmp; tmp.Clear();
// start working
dmlc::DataIter<RowBatch>* iter = this->RowIterator();
iter->BeforeFirst();
while (iter->Next()) {
const RowBatch &batch = iter->Value();
for (size_t i = 0; i < batch.size; ++i) {
auto ridx = static_cast<bst_uint>(batch.base_rowid + i);
if (pkeep == 1.0f || coin_flip(rnd)) {
buffered_rowset_.PushBack(ridx);
tmp.Push(batch[i]);
}
if (tmp.Size() >= max_row_perbatch) {
std::unique_ptr<SparsePage> page(new SparsePage());
this->MakeColPage(tmp.GetRowBatch(0), btop, enabled, page.get(), sorted);
col_iter_.cpages_.push_back(std::move(page));
btop = buffered_rowset_.Size();
tmp.Clear();
}
}
}
if (tmp.Size() != 0) {
std::unique_ptr<SparsePage> page(new SparsePage());
this->MakeColPage(tmp.GetRowBatch(0), btop, enabled, page.get(), sorted);
col_iter_.cpages_.push_back(std::move(page));
}
}
// make column page from subset of rowbatchs
void SimpleDMatrix::MakeColPage(const RowBatch& batch,
size_t buffer_begin,
const std::vector<bool>& enabled,
SparsePage* pcol, bool sorted) {
const int nthread = std::min(omp_get_max_threads(), std::max(omp_get_num_procs() / 2 - 2, 1));
pcol->Clear();
common::ParallelGroupBuilder<SparseBatch::Entry>
builder(&pcol->offset, &pcol->data);
builder.InitBudget(Info().num_col_, nthread);
bst_omp_uint ndata = static_cast<bst_uint>(batch.size);
#pragma omp parallel for schedule(static) num_threads(nthread)
for (bst_omp_uint i = 0; i < ndata; ++i) {
int tid = omp_get_thread_num();
RowBatch::Inst inst = batch[i];
for (bst_uint j = 0; j < inst.length; ++j) {
const SparseBatch::Entry &e = inst[j];
if (enabled[e.index]) {
builder.AddBudget(e.index, tid);
}
}
}
builder.InitStorage();
#pragma omp parallel for schedule(static) num_threads(nthread)
for (bst_omp_uint i = 0; i < ndata; ++i) {
int tid = omp_get_thread_num();
RowBatch::Inst inst = batch[i];
for (bst_uint j = 0; j < inst.length; ++j) {
const SparseBatch::Entry &e = inst[j];
builder.Push(
e.index,
SparseBatch::Entry(buffered_rowset_[i + buffer_begin], e.fvalue),
tid);
}
}
CHECK_EQ(pcol->Size(), Info().num_col_);
// sort columns
if (sorted) {
auto ncol = static_cast<bst_omp_uint>(pcol->Size());
#pragma omp parallel for schedule(dynamic, 1) num_threads(nthread)
for (bst_omp_uint i = 0; i < ncol; ++i) {
if (pcol->offset[i] < pcol->offset[i + 1]) {
std::sort(dmlc::BeginPtr(pcol->data) + pcol->offset[i],
dmlc::BeginPtr(pcol->data) + pcol->offset[i + 1],
SparseBatch::Entry::CmpValue);
Entry::CmpValue);
}
}
}
}
bool SimpleDMatrix::SingleColBlock() const {
return col_iter_.cpages_.size() <= 1;
return true;
}
} // namespace data
} // namespace xgboost

View File

@@ -12,7 +12,6 @@
#include <vector>
#include <algorithm>
#include <cstring>
#include "./sparse_batch_page.h"
namespace xgboost {
namespace data {
@@ -30,14 +29,14 @@ class SimpleDMatrix : public DMatrix {
return source_->info;
}
dmlc::DataIter<RowBatch>* RowIterator() override {
dmlc::DataIter<RowBatch>* iter = source_.get();
dmlc::DataIter<SparsePage>* RowIterator() override {
auto iter = source_.get();
iter->BeforeFirst();
return iter;
}
bool HaveColAccess(bool sorted) const override {
return col_size_.size() != 0 && col_iter_.sorted_ == sorted;
return col_iter_.sorted_ == sorted && col_iter_.column_page_!= nullptr;
}
const RowSet& BufferedRowset() const override {
@@ -45,50 +44,42 @@ class SimpleDMatrix : public DMatrix {
}
size_t GetColSize(size_t cidx) const override {
return col_size_[cidx];
auto& batch = *col_iter_.column_page_;
return batch[cidx].length;
}
float GetColDensity(size_t cidx) const override {
size_t nmiss = buffered_rowset_.Size() - col_size_[cidx];
size_t nmiss = buffered_rowset_.Size() - GetColSize(cidx);
return 1.0f - (static_cast<float>(nmiss)) / buffered_rowset_.Size();
}
dmlc::DataIter<ColBatch>* ColIterator() override;
dmlc::DataIter<SparsePage>* ColIterator() override;
dmlc::DataIter<ColBatch>* ColIterator(const std::vector<bst_uint>& fset) override;
void InitColAccess(const std::vector<bool>& enabled,
float subsample,
size_t max_row_perbatch, bool sorted) override;
void InitColAccess(
size_t max_row_perbatch, bool sorted) override;
bool SingleColBlock() const override;
private:
// in-memory column batch iterator.
struct ColBatchIter: dmlc::DataIter<ColBatch> {
struct ColBatchIter: dmlc::DataIter<SparsePage> {
public:
ColBatchIter() = default;
void BeforeFirst() override {
data_ptr_ = 0;
data_ = 0;
}
const ColBatch &Value() const override {
return batch_;
const SparsePage &Value() const override {
return *column_page_;
}
bool Next() override;
private:
// allow SimpleDMatrix to access it.
friend class SimpleDMatrix;
// data content
std::vector<bst_uint> col_index_;
// column content
std::vector<ColBatch::Inst> col_data_;
// column sparse pages
std::vector<std::unique_ptr<SparsePage> > cpages_;
// column sparse page
std::unique_ptr<SparsePage> column_page_;
// data pointer
size_t data_ptr_{0};
// temporal space for batch
ColBatch batch_;
size_t data_{0};
// Is column sorted?
bool sorted_{false};
};
@@ -99,22 +90,10 @@ class SimpleDMatrix : public DMatrix {
ColBatchIter col_iter_;
// list of row index that are buffered.
RowSet buffered_rowset_;
/*! \brief sizeof column data */
std::vector<size_t> col_size_;
// internal function to make one batch from row iter.
void MakeOneBatch(const std::vector<bool>& enabled,
float pkeep,
SparsePage *pcol, bool sorted);
void MakeManyBatch(const std::vector<bool>& enabled,
float pkeep,
size_t max_row_perbatch, bool sorted);
void MakeColPage(const RowBatch& batch,
size_t buffer_begin,
const std::vector<bool>& enabled,
SparsePage* pcol, bool sorted);
void MakeOneBatch(
SparsePage *pcol, bool sorted);
};
} // namespace data
} // namespace xgboost

View File

@@ -1,255 +0,0 @@
/*!
* Copyright (c) 2014 by Contributors
* \file sparse_batch_page.h
* content holder of sparse batch that can be saved to disk
* the representation can be effectively
* use in external memory computation
* \author Tianqi Chen
*/
#ifndef XGBOOST_DATA_SPARSE_BATCH_PAGE_H_
#define XGBOOST_DATA_SPARSE_BATCH_PAGE_H_
#include <xgboost/data.h>
#include <dmlc/io.h>
#include <vector>
#include <algorithm>
#include <cstring>
#include <string>
#include <utility>
#include <memory>
#include <functional>
#if DMLC_ENABLE_STD_THREAD
#include <dmlc/concurrency.h>
#include <thread>
#endif
namespace xgboost {
namespace data {
/*!
* \brief in-memory storage unit of sparse batch
*/
class SparsePage {
public:
/*! \brief Format of the sparse page. */
class Format;
/*! \brief Writer to write the sparse page to files. */
class Writer;
/*! \brief minimum index of all index, used as hint for compression. */
bst_uint min_index;
/*! \brief offset of the segments */
std::vector<size_t> offset;
/*! \brief the data of the segments */
std::vector<SparseBatch::Entry> data;
/*! \brief constructor */
SparsePage() {
this->Clear();
}
/*! \return number of instance in the page */
inline size_t Size() const {
return offset.size() - 1;
}
/*! \return estimation of memory cost of this page */
inline size_t MemCostBytes() const {
return offset.size() * sizeof(size_t) + data.size() * sizeof(SparseBatch::Entry);
}
/*! \brief clear the page */
inline void Clear() {
min_index = 0;
offset.clear();
offset.push_back(0);
data.clear();
}
/*!
* \brief Push row batch into the page
* \param batch the row batch
*/
inline void Push(const RowBatch &batch) {
data.resize(offset.back() + batch.ind_ptr[batch.size]);
std::memcpy(dmlc::BeginPtr(data) + offset.back(),
batch.data_ptr + batch.ind_ptr[0],
sizeof(SparseBatch::Entry) * batch.ind_ptr[batch.size]);
size_t top = offset.back();
size_t begin = offset.size();
offset.resize(offset.size() + batch.size);
for (size_t i = 0; i < batch.size; ++i) {
offset[i + begin] = top + batch.ind_ptr[i + 1] - batch.ind_ptr[0];
}
}
/*!
* \brief Push row block into the page.
* \param batch the row batch.
*/
inline void Push(const dmlc::RowBlock<uint32_t>& batch) {
data.reserve(data.size() + batch.offset[batch.size] - batch.offset[0]);
offset.reserve(offset.size() + batch.size);
CHECK(batch.index != nullptr);
for (size_t i = 0; i < batch.size; ++i) {
offset.push_back(offset.back() + batch.offset[i + 1] - batch.offset[i]);
}
for (size_t i = batch.offset[0]; i < batch.offset[batch.size]; ++i) {
uint32_t index = batch.index[i];
bst_float fvalue = batch.value == nullptr ? 1.0f : batch.value[i];
data.emplace_back(index, fvalue);
}
CHECK_EQ(offset.back(), data.size());
}
/*!
* \brief Push a sparse page
* \param batch the row page
*/
inline void Push(const SparsePage &batch) {
size_t top = offset.back();
data.resize(top + batch.data.size());
std::memcpy(dmlc::BeginPtr(data) + top,
dmlc::BeginPtr(batch.data),
sizeof(SparseBatch::Entry) * batch.data.size());
size_t begin = offset.size();
offset.resize(begin + batch.Size());
for (size_t i = 0; i < batch.Size(); ++i) {
offset[i + begin] = top + batch.offset[i + 1];
}
}
/*!
* \brief Push one instance into page
* \param row an instance row
*/
inline void Push(const SparseBatch::Inst &inst) {
offset.push_back(offset.back() + inst.length);
size_t begin = data.size();
data.resize(begin + inst.length);
if (inst.length != 0) {
std::memcpy(dmlc::BeginPtr(data) + begin, inst.data,
sizeof(SparseBatch::Entry) * inst.length);
}
}
/*!
* \param base_rowid base_rowid of the data
* \return row batch representation of the page
*/
inline RowBatch GetRowBatch(size_t base_rowid) const {
RowBatch out;
out.base_rowid = base_rowid;
out.ind_ptr = dmlc::BeginPtr(offset);
out.data_ptr = dmlc::BeginPtr(data);
out.size = offset.size() - 1;
return out;
}
};
/*!
* \brief Format specification of SparsePage.
*/
class SparsePage::Format {
public:
/*! \brief virtual destructor */
virtual ~Format() = default;
/*!
* \brief Load all the segments into page, advance fi to end of the block.
* \param page The data to read page into.
* \param fi the input stream of the file
* \return true of the loading as successful, false if end of file was reached
*/
virtual bool Read(SparsePage* page, dmlc::SeekStream* fi) = 0;
/*!
* \brief read only the segments we are interested in, advance fi to end of the block.
* \param page The page to load the data into.
* \param fi the input stream of the file
* \param sorted_index_set sorted index of segments we are interested in
* \return true of the loading as successful, false if end of file was reached
*/
virtual bool Read(SparsePage* page,
dmlc::SeekStream* fi,
const std::vector<bst_uint>& sorted_index_set) = 0;
/*!
* \brief save the data to fo, when a page was written.
* \param fo output stream
*/
virtual void Write(const SparsePage& page, dmlc::Stream* fo) = 0;
/*!
* \brief Create sparse page of format.
* \return The created format functors.
*/
static Format* Create(const std::string& name);
/*!
* \brief decide the format from cache prefix.
* \return pair of row format, column format type of the cache prefix.
*/
static std::pair<std::string, std::string> DecideFormat(const std::string& cache_prefix);
};
#if DMLC_ENABLE_STD_THREAD
/*!
* \brief A threaded writer to write sparse batch page to sharded files.
*/
class SparsePage::Writer {
public:
/*!
* \brief constructor
* \param name_shards name of shard files.
* \param format_shards format of each shard.
* \param extra_buffer_capacity Extra buffer capacity before block.
*/
explicit Writer(
const std::vector<std::string>& name_shards,
const std::vector<std::string>& format_shards,
size_t extra_buffer_capacity);
/*! \brief destructor, will close the files automatically */
~Writer();
/*!
* \brief Push a write job to the writer.
* This function won't block,
* writing is done by another thread inside writer.
* \param page The page to be written
*/
void PushWrite(std::shared_ptr<SparsePage>&& page);
/*!
* \brief Allocate a page to store results.
* This function can block when the writer is too slow and buffer pages
* have not yet been recycled.
* \param out_page Used to store the allocated pages.
*/
void Alloc(std::shared_ptr<SparsePage>* out_page);
private:
/*! \brief number of allocated pages */
size_t num_free_buffer_;
/*! \brief clock_pointer */
size_t clock_ptr_;
/*! \brief writer threads */
std::vector<std::unique_ptr<std::thread> > workers_;
/*! \brief recycler queue */
dmlc::ConcurrentBlockingQueue<std::shared_ptr<SparsePage> > qrecycle_;
/*! \brief worker threads */
std::vector<dmlc::ConcurrentBlockingQueue<std::shared_ptr<SparsePage> > > qworkers_;
};
#endif // DMLC_ENABLE_STD_THREAD
/*!
* \brief Registry entry for sparse page format.
*/
struct SparsePageFormatReg
: public dmlc::FunctionRegEntryBase<SparsePageFormatReg,
std::function<SparsePage::Format* ()> > {
};
/*!
* \brief Macro to register sparse page format.
*
* \code
* // example of registering a objective
* XGBOOST_REGISTER_SPARSE_PAGE_FORMAT(raw)
* .describe("Raw binary data format.")
* .set_body([]() {
* return new RawFormat();
* });
* \endcode
*/
#define XGBOOST_REGISTER_SPARSE_PAGE_FORMAT(Name) \
DMLC_REGISTRY_REGISTER(::xgboost::data::SparsePageFormatReg, SparsePageFormat, Name)
} // namespace data
} // namespace xgboost
#endif // XGBOOST_DATA_SPARSE_BATCH_PAGE_H_

View File

@@ -29,8 +29,8 @@ SparsePageDMatrix::ColPageIter::ColPageIter(
dmlc::SeekStream* fi = files_[i].get();
std::string format;
CHECK(fi->Read(&format)) << "Invalid page format";
formats_[i].reset(SparsePage::Format::Create(format));
SparsePage::Format* fmt = formats_[i].get();
formats_[i].reset(SparsePageFormat::Create(format));
SparsePageFormat* fmt = formats_[i].get();
size_t fbegin = fi->Tell();
prefetchers_[i].reset(new dmlc::ThreadedIter<SparsePage>(4));
prefetchers_[i]->Init([this, fi, fmt] (SparsePage** dptr) {
@@ -61,15 +61,6 @@ bool SparsePageDMatrix::ColPageIter::Next() {
prefetchers_[(clock_ptr_ + n - 1) % n]->Recycle(&page_);
}
if (prefetchers_[clock_ptr_]->Next(&page_)) {
out_.col_index = dmlc::BeginPtr(index_set_);
col_data_.resize(page_->offset.size() - 1, SparseBatch::Inst(nullptr, 0));
for (size_t i = 0; i < col_data_.size(); ++i) {
col_data_[i] = SparseBatch::Inst
(dmlc::BeginPtr(page_->data) + page_->offset[i],
static_cast<bst_uint>(page_->offset[i + 1] - page_->offset[i]));
}
out_.col_data = dmlc::BeginPtr(col_data_);
out_.size = col_data_.size();
// advance clock
clock_ptr_ = (clock_ptr_ + 1) % prefetchers_.size();
return true;
@@ -85,40 +76,22 @@ void SparsePageDMatrix::ColPageIter::BeforeFirst() {
}
}
void SparsePageDMatrix::ColPageIter::Init(const std::vector<bst_uint>& index_set,
bool load_all) {
void SparsePageDMatrix::ColPageIter::Init(
const std::vector<bst_uint>& index_set) {
set_index_set_ = index_set;
set_load_all_ = load_all;
set_load_all_ = true;
std::sort(set_index_set_.begin(), set_index_set_.end());
this->BeforeFirst();
}
dmlc::DataIter<ColBatch>* SparsePageDMatrix::ColIterator() {
dmlc::DataIter<SparsePage>* SparsePageDMatrix::ColIterator() {
CHECK(col_iter_ != nullptr);
std::vector<bst_uint> col_index;
size_t ncol = this->Info().num_col_;
for (size_t i = 0; i < ncol; ++i) {
col_index.push_back(static_cast<bst_uint>(i));
}
col_iter_->Init(col_index, true);
std::iota(col_index.begin(), col_index.end(), bst_uint(0));
col_iter_->Init(col_index);
return col_iter_.get();
}
dmlc::DataIter<ColBatch>* SparsePageDMatrix::
ColIterator(const std::vector<bst_uint>& fset) {
CHECK(col_iter_ != nullptr);
std::vector<bst_uint> col_index;
size_t ncol = this->Info().num_col_;
for (auto fidx : fset) {
if (fidx < ncol) {
col_index.push_back(fidx);
}
}
col_iter_->Init(col_index, false);
return col_iter_.get();
}
bool SparsePageDMatrix::TryInitColData(bool sorted) {
// load meta data.
std::vector<std::string> cache_shards = common::Split(cache_info_, ':');
@@ -145,9 +118,8 @@ bool SparsePageDMatrix::TryInitColData(bool sorted) {
return true;
}
void SparsePageDMatrix::InitColAccess(const std::vector<bool>& enabled,
float pkeep,
size_t max_row_perbatch, bool sorted) {
void SparsePageDMatrix::InitColAccess(
size_t max_row_perbatch, bool sorted) {
if (HaveColAccess(sorted)) return;
if (TryInitColData(sorted)) return;
const MetaInfo& info = this->Info();
@@ -157,11 +129,9 @@ void SparsePageDMatrix::InitColAccess(const std::vector<bool>& enabled,
buffered_rowset_.Clear();
col_size_.resize(info.num_col_);
std::fill(col_size_.begin(), col_size_.end(), 0);
dmlc::DataIter<RowBatch>* iter = this->RowIterator();
std::bernoulli_distribution coin_flip(pkeep);
auto iter = this->RowIterator();
size_t batch_ptr = 0, batch_top = 0;
SparsePage tmp;
auto& rnd = common::GlobalRandom();
// function to create the page.
auto make_col_batch = [&] (
@@ -169,9 +139,9 @@ void SparsePageDMatrix::InitColAccess(const std::vector<bool>& enabled,
size_t begin,
SparsePage *pcol) {
pcol->Clear();
pcol->min_index = buffered_rowset_[begin];
pcol->base_rowid = buffered_rowset_[begin];
const int nthread = std::max(omp_get_max_threads(), std::max(omp_get_num_procs() / 2 - 1, 1));
common::ParallelGroupBuilder<SparseBatch::Entry>
common::ParallelGroupBuilder<Entry>
builder(&pcol->offset, &pcol->data);
builder.InitBudget(info.num_col_, nthread);
bst_omp_uint ndata = static_cast<bst_uint>(prow.Size());
@@ -179,10 +149,8 @@ void SparsePageDMatrix::InitColAccess(const std::vector<bool>& enabled,
for (bst_omp_uint i = 0; i < ndata; ++i) {
int tid = omp_get_thread_num();
for (size_t j = prow.offset[i]; j < prow.offset[i+1]; ++j) {
const SparseBatch::Entry &e = prow.data[j];
if (enabled[e.index]) {
builder.AddBudget(e.index, tid);
}
const auto e = prow.data[j];
builder.AddBudget(e.index, tid);
}
}
builder.InitStorage();
@@ -190,9 +158,9 @@ void SparsePageDMatrix::InitColAccess(const std::vector<bool>& enabled,
for (bst_omp_uint i = 0; i < ndata; ++i) {
int tid = omp_get_thread_num();
for (size_t j = prow.offset[i]; j < prow.offset[i+1]; ++j) {
const SparseBatch::Entry &e = prow.data[j];
const Entry &e = prow.data[j];
builder.Push(e.index,
SparseBatch::Entry(buffered_rowset_[i + begin], e.fvalue),
Entry(buffered_rowset_[i + begin], e.fvalue),
tid);
}
}
@@ -205,7 +173,7 @@ void SparsePageDMatrix::InitColAccess(const std::vector<bool>& enabled,
if (pcol->offset[i] < pcol->offset[i + 1]) {
std::sort(dmlc::BeginPtr(pcol->data) + pcol->offset[i],
dmlc::BeginPtr(pcol->data) + pcol->offset[i + 1],
SparseBatch::Entry::CmpValue);
Entry::CmpValue);
}
}
}
@@ -217,14 +185,12 @@ void SparsePageDMatrix::InitColAccess(const std::vector<bool>& enabled,
while (true) {
if (batch_ptr != batch_top) {
const RowBatch& batch = iter->Value();
CHECK_EQ(batch_top, batch.size);
auto batch = iter->Value();
CHECK_EQ(batch_top, batch.Size());
for (size_t i = batch_ptr; i < batch_top; ++i) {
auto ridx = static_cast<bst_uint>(batch.base_rowid + i);
if (pkeep == 1.0f || coin_flip(rnd)) {
buffered_rowset_.PushBack(ridx);
tmp.Push(batch[i]);
}
buffered_rowset_.PushBack(ridx);
tmp.Push(batch[i]);
if (tmp.Size() >= max_row_perbatch ||
tmp.MemCostBytes() >= kPageSize) {
@@ -237,7 +203,7 @@ void SparsePageDMatrix::InitColAccess(const std::vector<bool>& enabled,
}
if (!iter->Next()) break;
batch_ptr = 0;
batch_top = iter->Value().size;
batch_top = iter->Value().Size();
}
if (tmp.Size() != 0) {
@@ -252,11 +218,11 @@ void SparsePageDMatrix::InitColAccess(const std::vector<bool>& enabled,
std::vector<std::string> name_shards, format_shards;
for (const std::string& prefix : cache_shards) {
name_shards.push_back(prefix + ".col.page");
format_shards.push_back(SparsePage::Format::DecideFormat(prefix).second);
format_shards.push_back(SparsePageFormat::DecideFormat(prefix).second);
}
{
SparsePage::Writer writer(name_shards, format_shards, 6);
SparsePageWriter writer(name_shards, format_shards, 6);
std::shared_ptr<SparsePage> page;
writer.Alloc(&page); page->Clear();

View File

@@ -14,8 +14,8 @@
#include <vector>
#include <algorithm>
#include <string>
#include "./sparse_batch_page.h"
#include "../common/common.h"
#include "./sparse_page_writer.h"
namespace xgboost {
namespace data {
@@ -35,8 +35,8 @@ class SparsePageDMatrix : public DMatrix {
return source_->info;
}
dmlc::DataIter<RowBatch>* RowIterator() override {
dmlc::DataIter<RowBatch>* iter = source_.get();
dmlc::DataIter<SparsePage>* RowIterator() override {
auto iter = source_.get();
iter->BeforeFirst();
return iter;
}
@@ -62,13 +62,10 @@ class SparsePageDMatrix : public DMatrix {
return false;
}
dmlc::DataIter<ColBatch>* ColIterator() override;
dmlc::DataIter<SparsePage>* ColIterator() override;
dmlc::DataIter<ColBatch>* ColIterator(const std::vector<bst_uint>& fset) override;
void InitColAccess(const std::vector<bool>& enabled,
float subsample,
size_t max_row_perbatch, bool sorted) override;
void InitColAccess(
size_t max_row_perbatch, bool sorted) override;
/*! \brief page size 256 MB */
static const size_t kPageSize = 256UL << 20UL;
@@ -77,17 +74,17 @@ class SparsePageDMatrix : public DMatrix {
private:
// declare the column batch iter.
class ColPageIter : public dmlc::DataIter<ColBatch> {
class ColPageIter : public dmlc::DataIter<SparsePage> {
public:
explicit ColPageIter(std::vector<std::unique_ptr<dmlc::SeekStream> >&& files);
~ColPageIter() override;
void BeforeFirst() override;
const ColBatch &Value() const override {
return out_;
const SparsePage &Value() const override {
return *page_;
}
bool Next() override;
// initialize the column iterator with the specified index set.
void Init(const std::vector<bst_uint>& index_set, bool load_all);
void Init(const std::vector<bst_uint>& index_set);
// If the column features are sorted
bool sorted;
@@ -99,7 +96,7 @@ class SparsePageDMatrix : public DMatrix {
// data file pointer.
std::vector<std::unique_ptr<dmlc::SeekStream> > files_;
// page format.
std::vector<std::unique_ptr<SparsePage::Format> > formats_;
std::vector<std::unique_ptr<SparsePageFormat> > formats_;
/*! \brief internal prefetcher. */
std::vector<std::unique_ptr<dmlc::ThreadedIter<SparsePage> > > prefetchers_;
// The index set to be loaded.
@@ -108,10 +105,6 @@ class SparsePageDMatrix : public DMatrix {
std::vector<bst_uint> set_index_set_;
// whether to load data dataset.
bool set_load_all_, load_all_;
// temporal space for batch
ColBatch out_;
// the pointer data.
std::vector<SparseBatch::Inst> col_data_;
};
/*!
* \brief Try to initialize column data.

View File

@@ -5,14 +5,14 @@
*/
#include <xgboost/data.h>
#include <dmlc/registry.h>
#include "./sparse_batch_page.h"
#include "./sparse_page_writer.h"
namespace xgboost {
namespace data {
DMLC_REGISTRY_FILE_TAG(sparse_page_raw_format);
class SparsePageRawFormat : public SparsePage::Format {
class SparsePageRawFormat : public SparsePageFormat {
public:
bool Read(SparsePage* page, dmlc::SeekStream* fi) override {
if (!fi->Read(&(page->offset))) return false;
@@ -20,8 +20,8 @@ class SparsePageRawFormat : public SparsePage::Format {
page->data.resize(page->offset.back());
if (page->data.size() != 0) {
CHECK_EQ(fi->Read(dmlc::BeginPtr(page->data),
(page->data).size() * sizeof(SparseBatch::Entry)),
(page->data).size() * sizeof(SparseBatch::Entry))
(page->data).size() * sizeof(Entry)),
(page->data).size() * sizeof(Entry))
<< "Invalid SparsePage file";
}
return true;
@@ -47,7 +47,7 @@ class SparsePageRawFormat : public SparsePage::Format {
bst_uint fid = sorted_index_set[i];
if (disk_offset_[fid] != curr_offset) {
CHECK_GT(disk_offset_[fid], curr_offset);
fi->Seek(begin + disk_offset_[fid] * sizeof(SparseBatch::Entry));
fi->Seek(begin + disk_offset_[fid] * sizeof(Entry));
curr_offset = disk_offset_[fid];
}
size_t j, size_to_read = 0;
@@ -61,8 +61,8 @@ class SparsePageRawFormat : public SparsePage::Format {
if (size_to_read != 0) {
CHECK_EQ(fi->Read(dmlc::BeginPtr(page->data) + page->offset[i],
size_to_read * sizeof(SparseBatch::Entry)),
size_to_read * sizeof(SparseBatch::Entry))
size_to_read * sizeof(Entry)),
size_to_read * sizeof(Entry))
<< "Invalid SparsePage file";
curr_offset += size_to_read;
}
@@ -70,7 +70,7 @@ class SparsePageRawFormat : public SparsePage::Format {
}
// seek to end of record
if (curr_offset != disk_offset_.back()) {
fi->Seek(begin + disk_offset_.back() * sizeof(SparseBatch::Entry));
fi->Seek(begin + disk_offset_.back() * sizeof(Entry));
}
return true;
}
@@ -80,7 +80,7 @@ class SparsePageRawFormat : public SparsePage::Format {
CHECK_EQ(page.offset.back(), page.data.size());
fo->Write(page.offset);
if (page.data.size() != 0) {
fo->Write(dmlc::BeginPtr(page.data), page.data.size() * sizeof(SparseBatch::Entry));
fo->Write(dmlc::BeginPtr(page.data), page.data.size() * sizeof(Entry));
}
}

View File

@@ -37,8 +37,8 @@ SparsePageSource::SparsePageSource(const std::string& cache_info)
dmlc::SeekStream* fi = files_[i].get();
std::string format;
CHECK(fi->Read(&format)) << "Invalid page format";
formats_[i].reset(SparsePage::Format::Create(format));
SparsePage::Format* fmt = formats_[i].get();
formats_[i].reset(SparsePageFormat::Create(format));
SparsePageFormat* fmt = formats_[i].get();
size_t fbegin = fi->Tell();
prefetchers_[i].reset(new dmlc::ThreadedIter<SparsePage>(4));
prefetchers_[i]->Init([fi, fmt] (SparsePage** dptr) {
@@ -61,8 +61,8 @@ bool SparsePageSource::Next() {
prefetchers_[(clock_ptr_ + n - 1) % n]->Recycle(&page_);
}
if (prefetchers_[clock_ptr_]->Next(&page_)) {
batch_ = page_->GetRowBatch(base_rowid_);
base_rowid_ += batch_.size;
page_->base_rowid = base_rowid_;
base_rowid_ += page_->Size();
// advance clock
clock_ptr_ = (clock_ptr_ + 1) % prefetchers_.size();
return true;
@@ -79,8 +79,8 @@ void SparsePageSource::BeforeFirst() {
}
}
const RowBatch& SparsePageSource::Value() const {
return batch_;
const SparsePage& SparsePageSource::Value() const {
return *page_;
}
bool SparsePageSource::CacheExist(const std::string& cache_info) {
@@ -108,10 +108,10 @@ void SparsePageSource::Create(dmlc::Parser<uint32_t>* src,
std::vector<std::string> name_shards, format_shards;
for (const std::string& prefix : cache_shards) {
name_shards.push_back(prefix + ".row.page");
format_shards.push_back(SparsePage::Format::DecideFormat(prefix).first);
format_shards.push_back(SparsePageFormat::DecideFormat(prefix).first);
}
{
SparsePage::Writer writer(name_shards, format_shards, 6);
SparsePageWriter writer(name_shards, format_shards, 6);
std::shared_ptr<SparsePage> page;
writer.Alloc(&page); page->Clear();
@@ -176,17 +176,17 @@ void SparsePageSource::Create(DMatrix* src,
std::vector<std::string> name_shards, format_shards;
for (const std::string& prefix : cache_shards) {
name_shards.push_back(prefix + ".row.page");
format_shards.push_back(SparsePage::Format::DecideFormat(prefix).first);
format_shards.push_back(SparsePageFormat::DecideFormat(prefix).first);
}
{
SparsePage::Writer writer(name_shards, format_shards, 6);
SparsePageWriter writer(name_shards, format_shards, 6);
std::shared_ptr<SparsePage> page;
writer.Alloc(&page); page->Clear();
MetaInfo info = src->Info();
size_t bytes_write = 0;
double tstart = dmlc::GetTime();
dmlc::DataIter<RowBatch>* iter = src->RowIterator();
auto iter = src->RowIterator();
while (iter->Next()) {
page->Push(iter->Value());

View File

@@ -13,7 +13,7 @@
#include <vector>
#include <algorithm>
#include <string>
#include "./sparse_batch_page.h"
#include "sparse_page_writer.h"
namespace xgboost {
namespace data {
@@ -39,7 +39,7 @@ class SparsePageSource : public DataSource {
// implement BeforeFirst
void BeforeFirst() override;
// implement Value
const RowBatch& Value() const override;
const SparsePage& Value() const override;
/*!
* \brief Create source by taking data from parser.
* \param src source parser.
@@ -67,8 +67,6 @@ class SparsePageSource : public DataSource {
private:
/*! \brief number of rows */
size_t base_rowid_;
/*! \brief temp data. */
RowBatch batch_;
/*! \brief page currently on hold. */
SparsePage *page_;
/*! \brief internal clock ptr */
@@ -76,7 +74,7 @@ class SparsePageSource : public DataSource {
/*! \brief file pointer to the row blob file. */
std::vector<std::unique_ptr<dmlc::SeekStream> > files_;
/*! \brief Sparse page format file. */
std::vector<std::unique_ptr<SparsePage::Format> > formats_;
std::vector<std::unique_ptr<SparsePageFormat> > formats_;
/*! \brief internal prefetcher. */
std::vector<std::unique_ptr<dmlc::ThreadedIter<SparsePage> > > prefetchers_;
};

View File

@@ -5,13 +5,13 @@
*/
#include <xgboost/base.h>
#include <xgboost/logging.h>
#include "./sparse_batch_page.h"
#include "./sparse_page_writer.h"
#if DMLC_ENABLE_STD_THREAD
namespace xgboost {
namespace data {
SparsePage::Writer::Writer(
SparsePageWriter::SparsePageWriter(
const std::vector<std::string>& name_shards,
const std::vector<std::string>& format_shards,
size_t extra_buffer_capacity)
@@ -29,8 +29,8 @@ SparsePage::Writer::Writer(
[this, name_shard, format_shard, wqueue] () {
std::unique_ptr<dmlc::Stream> fo(
dmlc::Stream::Create(name_shard.c_str(), "w"));
std::unique_ptr<SparsePage::Format> fmt(
SparsePage::Format::Create(format_shard));
std::unique_ptr<SparsePageFormat> fmt(
SparsePageFormat::Create(format_shard));
fo->Write(format_shard);
std::shared_ptr<SparsePage> page;
while (wqueue->Pop(&page)) {
@@ -44,7 +44,7 @@ SparsePage::Writer::Writer(
}
}
SparsePage::Writer::~Writer() {
SparsePageWriter::~SparsePageWriter() {
for (auto& queue : qworkers_) {
// use nullptr to signal termination.
std::shared_ptr<SparsePage> sig(nullptr);
@@ -55,12 +55,12 @@ SparsePage::Writer::~Writer() {
}
}
void SparsePage::Writer::PushWrite(std::shared_ptr<SparsePage>&& page) {
void SparsePageWriter::PushWrite(std::shared_ptr<SparsePage>&& page) {
qworkers_[clock_ptr_].Push(std::move(page));
clock_ptr_ = (clock_ptr_ + 1) % workers_.size();
}
void SparsePage::Writer::Alloc(std::shared_ptr<SparsePage>* out_page) {
void SparsePageWriter::Alloc(std::shared_ptr<SparsePage>* out_page) {
CHECK(*out_page == nullptr);
if (num_free_buffer_ != 0) {
out_page->reset(new SparsePage());

View File

@@ -0,0 +1,139 @@
/*!
* Copyright (c) 2014 by Contributors
* \file sparse_page_writer.h
* \author Tianqi Chen
*/
#ifndef XGBOOST_DATA_SPARSE_PAGE_WRITER_H_
#define XGBOOST_DATA_SPARSE_PAGE_WRITER_H_
#include <xgboost/data.h>
#include <dmlc/io.h>
#include <vector>
#include <algorithm>
#include <cstring>
#include <string>
#include <utility>
#include <memory>
#include <functional>
#if DMLC_ENABLE_STD_THREAD
#include <dmlc/concurrency.h>
#include <thread>
#endif
namespace xgboost {
namespace data {
/*!
* \brief Format specification of SparsePage.
*/
class SparsePageFormat {
public:
/*! \brief virtual destructor */
virtual ~SparsePageFormat() = default;
/*!
* \brief Load all the segments into page, advance fi to end of the block.
* \param page The data to read page into.
* \param fi the input stream of the file
* \return true of the loading as successful, false if end of file was reached
*/
virtual bool Read(SparsePage* page, dmlc::SeekStream* fi) = 0;
/*!
* \brief read only the segments we are interested in, advance fi to end of the block.
* \param page The page to load the data into.
* \param fi the input stream of the file
* \param sorted_index_set sorted index of segments we are interested in
* \return true of the loading as successful, false if end of file was reached
*/
virtual bool Read(SparsePage* page,
dmlc::SeekStream* fi,
const std::vector<bst_uint>& sorted_index_set) = 0;
/*!
* \brief save the data to fo, when a page was written.
* \param fo output stream
*/
virtual void Write(const SparsePage& page, dmlc::Stream* fo) = 0;
/*!
* \brief Create sparse page of format.
* \return The created format functors.
*/
static SparsePageFormat* Create(const std::string& name);
/*!
* \brief decide the format from cache prefix.
* \return pair of row format, column format type of the cache prefix.
*/
static std::pair<std::string, std::string> DecideFormat(const std::string& cache_prefix);
};
#if DMLC_ENABLE_STD_THREAD
/*!
* \brief A threaded writer to write sparse batch page to sharded files.
*/
class SparsePageWriter {
public:
/*!
* \brief constructor
* \param name_shards name of shard files.
* \param format_shards format of each shard.
* \param extra_buffer_capacity Extra buffer capacity before block.
*/
explicit SparsePageWriter(
const std::vector<std::string>& name_shards,
const std::vector<std::string>& format_shards,
size_t extra_buffer_capacity);
/*! \brief destructor, will close the files automatically */
~SparsePageWriter();
/*!
* \brief Push a write job to the writer.
* This function won't block,
* writing is done by another thread inside writer.
* \param page The page to be written
*/
void PushWrite(std::shared_ptr<SparsePage>&& page);
/*!
* \brief Allocate a page to store results.
* This function can block when the writer is too slow and buffer pages
* have not yet been recycled.
* \param out_page Used to store the allocated pages.
*/
void Alloc(std::shared_ptr<SparsePage>* out_page);
private:
/*! \brief number of allocated pages */
size_t num_free_buffer_;
/*! \brief clock_pointer */
size_t clock_ptr_;
/*! \brief writer threads */
std::vector<std::unique_ptr<std::thread> > workers_;
/*! \brief recycler queue */
dmlc::ConcurrentBlockingQueue<std::shared_ptr<SparsePage> > qrecycle_;
/*! \brief worker threads */
std::vector<dmlc::ConcurrentBlockingQueue<std::shared_ptr<SparsePage> > > qworkers_;
};
#endif // DMLC_ENABLE_STD_THREAD
/*!
* \brief Registry entry for sparse page format.
*/
struct SparsePageFormatReg
: public dmlc::FunctionRegEntryBase<SparsePageFormatReg,
std::function<SparsePageFormat* ()> > {
};
/*!
* \brief Macro to register sparse page format.
*
* \code
* // example of registering a objective
* XGBOOST_REGISTER_SPARSE_PAGE_FORMAT(raw)
* .describe("Raw binary data format.")
* .set_body([]() {
* return new RawFormat();
* });
* \endcode
*/
#define XGBOOST_REGISTER_SPARSE_PAGE_FORMAT(Name) \
DMLC_REGISTRY_REGISTER(::xgboost::data::SparsePageFormatReg, SparsePageFormat, Name)
} // namespace data
} // namespace xgboost
#endif // XGBOOST_DATA_SPARSE_PAGE_WRITER_H_