current progress

This commit is contained in:
tqchen
2015-04-15 22:28:43 -07:00
parent e8f6f3b541
commit a514340c96
5 changed files with 381 additions and 165 deletions

View File

@@ -1,7 +1,7 @@
#ifndef XGBOOST_IO_PAGE_ROW_ITER_INL_HPP_
#define XGBOOST_IO_PAGE_ROW_ITER_INL_HPP_
#ifndef XGBOOST_IO_PAGE_DMATRIX_INL_HPP_
#define XGBOOST_IO_PAGE_DMATRIX_INL_HPP_
/*!
* \file page_row_iter-inl.hpp
* \file page_dmatrix-inl.hpp
* row iterator based on sparse page
* \author Tianqi Chen
*/
@@ -10,97 +10,11 @@
#include "../utils/iterator.h"
#include "../utils/thread_buffer.h"
#include "./simple_fmatrix-inl.hpp"
#include "./sparse_batch_page.h"
#include "./page_fmatrix-inl.hpp"
namespace xgboost {
namespace io {
/*! \brief page structure that can be used to store a rowbatch */
struct RowBatchPage {
public:
explicit RowBatchPage(size_t page_size) : kPageSize(page_size) {
data_ = new int[kPageSize];
utils::Assert(data_ != NULL, "fail to allocate row batch page");
this->Clear();
}
~RowBatchPage(void) {
if (data_ != NULL) delete [] data_;
}
/*!
* \brief Push one row into page
* \param row an instance row
* \return false or true to push into
*/
inline bool PushRow(const RowBatch::Inst &row) {
const size_t dsize = row.length * sizeof(RowBatch::Entry);
if (FreeBytes() < dsize+ sizeof(int)) return false;
row_ptr(Size() + 1) = row_ptr(Size()) + row.length;
memcpy(data_ptr(row_ptr(Size())) , row.data, dsize);
++data_[0];
return true;
}
/*!
* \brief get a row batch representation from the page
* \param p_rptr a temporal space that can be used to provide
* ind_ptr storage for RowBatch
* \return a new RowBatch object
*/
inline RowBatch GetRowBatch(std::vector<size_t> *p_rptr, size_t base_rowid) {
RowBatch batch;
batch.base_rowid = base_rowid;
batch.data_ptr = this->data_ptr(0);
batch.size = static_cast<size_t>(this->Size());
std::vector<size_t> &rptr = *p_rptr;
rptr.resize(this->Size() + 1);
for (size_t i = 0; i < rptr.size(); ++i) {
rptr[i] = static_cast<size_t>(this->row_ptr(static_cast<int>(i)));
}
batch.ind_ptr = &rptr[0];
return batch;
}
/*! \brief get i-th row from the batch */
inline RowBatch::Inst operator[](int i) {
return RowBatch::Inst(data_ptr(0) + row_ptr(i),
static_cast<bst_uint>(row_ptr(i+1) - row_ptr(i)));
}
/*!
* \brief clear the page, cleanup the content
*/
inline void Clear(void) {
memset(&data_[0], 0, sizeof(int) * kPageSize);
}
/*!
* \brief load one page form instream
* \return true if loading is successful
*/
inline bool Load(utils::IStream &fi) {
return fi.Read(&data_[0], sizeof(int) * kPageSize) != 0;
}
/*! \brief save one page into outstream */
inline void Save(utils::IStream &fo) {
fo.Write(&data_[0], sizeof(int) * kPageSize);
}
/*! \return number of elements */
inline int Size(void) const {
return data_[0];
}
protected:
/*! \return number of elements */
inline size_t FreeBytes(void) {
return (kPageSize - (Size() + 2)) * sizeof(int) -
row_ptr(Size()) * sizeof(RowBatch::Entry);
}
/*! \brief equivalent row pointer at i */
inline int& row_ptr(int i) {
return data_[kPageSize - i - 1];
}
inline RowBatch::Entry* data_ptr(int i) {
return (RowBatch::Entry*)(&data_[1]) + i;
}
// content of data
int *data_;
// page size
const size_t kPageSize;
};
/*! \brief thread buffer iterator */
class ThreadRowPageIterator: public utils::IIterator<RowBatch> {
public:
@@ -118,7 +32,10 @@ class ThreadRowPageIterator: public utils::IIterator<RowBatch> {
}
virtual bool Next(void) {
if (!itr.Next(page_)) return false;
out_ = page_->GetRowBatch(&tmp_ptr_, base_rowid_);
out_.base_rowid = base_rowid_;
out_.ind_ptr = BeginPtr(page_->offset);
out_.data_ptr = BeginPtr(page_->data);
out_.size = page_->offset.size() - 1;
base_rowid_ += out_.size;
return true;
}
@@ -127,76 +44,18 @@ class ThreadRowPageIterator: public utils::IIterator<RowBatch> {
}
/*! \brief load and initialize the iterator with fi */
inline void Load(const utils::FileStream &fi) {
itr.get_factory().SetFile(fi);
itr.get_factory().SetFile(fi, 0);
itr.Init();
this->BeforeFirst();
}
/*!
* \brief save a row iterator to output stream, in row iterator format
*/
inline static void Save(utils::IIterator<RowBatch> *iter,
utils::IStream &fo) {
RowBatchPage page(kPageSize);
iter->BeforeFirst();
while (iter->Next()) {
const RowBatch &batch = iter->Value();
for (size_t i = 0; i < batch.size; ++i) {
if (!page.PushRow(batch[i])) {
page.Save(fo);
page.Clear();
utils::Check(page.PushRow(batch[i]), "row is too big");
}
}
}
if (page.Size() != 0) page.Save(fo);
}
/*! \brief page size 64 MB */
static const size_t kPageSize = 64 << 18;
private:
// base row id
size_t base_rowid_;
// temporal ptr
std::vector<size_t> tmp_ptr_;
// output data
RowBatch out_;
// page pointer type
typedef RowBatchPage* PagePtr;
// loader factory for page
struct Factory {
public:
size_t file_begin_;
utils::FileStream fi;
Factory(void) {}
inline void SetFile(const utils::FileStream &fi) {
this->fi = fi;
file_begin_ = this->fi.Tell();
}
inline bool Init(void) {
return true;
}
inline void SetParam(const char *name, const char *val) {}
inline bool LoadNext(PagePtr &val) {
return val->Load(fi);
}
inline PagePtr Create(void) {
PagePtr a = new RowBatchPage(kPageSize);
return a;
}
inline void FreeSpace(PagePtr &a) {
delete a;
}
inline void Destroy(void) {
fi.Close();
}
inline void BeforeFirst(void) {
fi.Seek(file_begin_);
}
};
protected:
PagePtr page_;
utils::ThreadBuffer<PagePtr, Factory> itr;
SparsePage *page_;
utils::ThreadBuffer<SparsePage*, SparsePageFactory> itr;
};
/*! \brief data matrix using page */
@@ -247,8 +106,20 @@ class DMatrixPageBase : public DataMatrix {
mat.info.SaveBinary(fs);
fs.Close();
fname += ".row.blob";
utils::IIterator<RowBatch> *iter = mat.fmat()->RowIterator();
utils::FileStream fbin(utils::FopenCheck(fname.c_str(), "wb"));
ThreadRowPageIterator::Save(mat.fmat()->RowIterator(), fbin);
SparsePage page;
iter->BeforeFirst();
while (iter->Next()) {
const RowBatch &batch = iter->Value();
for (size_t i = 0; i < batch.size; ++i) {
page.Push(batch[i]);
if (page.MemCostBytes() >= kPageSize) {
page.Save(&fbin); page.Clear();
}
}
}
if (page.data.size() != 0) page.Save(&fbin);
fbin.Close();
if (!silent) {
utils::Printf("DMatrixPage: %lux%lu is saved to %s\n",
@@ -268,7 +139,7 @@ class DMatrixPageBase : public DataMatrix {
}
std::string fname_row = std::string(cache_file) + ".row.blob";
utils::FileStream fo(utils::FopenCheck(fname_row.c_str(), "wb"));
RowBatchPage page(ThreadRowPageIterator::kPageSize);
SparsePage page;
dmlc::InputSplit *in =
dmlc::InputSplit::Create(uri, rank, npart);
std::string line;
@@ -286,10 +157,9 @@ class DMatrixPageBase : public DataMatrix {
feats.push_back(e);
}
RowBatch::Inst row(BeginPtr(feats), feats.size());
if (!page.PushRow(row)) {
page.Save(fo);
page.Clear();
utils::Check(page.PushRow(row), "row is too big");
page.Push(row);
if (page.MemCostBytes() >= kPageSize) {
page.Save(&fo); page.Clear();
}
for (size_t i = 0; i < feats.size(); ++i) {
info.info.num_col = std::max(info.info.num_col,
@@ -298,8 +168,8 @@ class DMatrixPageBase : public DataMatrix {
this->info.labels.push_back(label);
info.info.num_row += 1;
}
if (page.Size() != 0) {
page.Save(fo);
if (page.data.size() != 0) {
page.Save(&fo);
}
delete in;
fo.Close();
@@ -319,7 +189,8 @@ class DMatrixPageBase : public DataMatrix {
}
/*! \brief magic number used to identify DMatrix */
static const int kMagic = TKMagic;
/*! \brief page size 64 MB */
static const size_t kPageSize = 64 << 18;
protected:
/*! \brief row iterator */
ThreadRowPageIterator *iter_;