fast loader

This commit is contained in:
tqchen 2015-04-17 23:02:30 -07:00
parent 6d9cb3a2fa
commit 5dfab4ba70
5 changed files with 92 additions and 42 deletions

View File

@ -13,6 +13,7 @@
#include "../utils/omp.h" #include "../utils/omp.h"
#include "../utils/utils.h" #include "../utils/utils.h"
#include "../sync/sync.h" #include "../sync/sync.h"
#include "../utils/thread_buffer.h"
#include "./sparse_batch_page.h" #include "./sparse_batch_page.h"
namespace xgboost { namespace xgboost {
@ -29,13 +30,19 @@ struct LibSVMPage : public SparsePage {
/*! /*!
* \brief libsvm parser that parses the input lines * \brief libsvm parser that parses the input lines
* and returns rows in input data * and returns rows in input data
* factry that was used by threadbuffer template
*/ */
class LibSVMParser : public utils::IIterator<LibSVMPage> { class LibSVMPageFactory {
public: public:
explicit LibSVMParser(dmlc::InputSplit *source, explicit LibSVMPageFactory()
int nthread) : bytes_read_(0), at_head_(true) {
: bytes_read_(0), at_head_(true), }
data_ptr_(0), data_end_(0), source_(source) { inline bool Init(void) {
return true;
}
inline void Setup(dmlc::InputSplit *source,
int nthread) {
source_ = source;
int maxthread; int maxthread;
#pragma omp parallel #pragma omp parallel
{ {
@ -44,34 +51,28 @@ class LibSVMParser : public utils::IIterator<LibSVMPage> {
maxthread = std::max(maxthread / 2, 1); maxthread = std::max(maxthread / 2, 1);
nthread_ = std::min(maxthread, nthread); nthread_ = std::min(maxthread, nthread);
} }
virtual ~LibSVMParser() { inline void SetParam(const char *name, const char *val) {}
inline bool LoadNext(std::vector<LibSVMPage> *data) {
return FillData(data);
}
inline void FreeSpace(std::vector<LibSVMPage> *a) {
delete a;
}
inline std::vector<LibSVMPage> *Create(void) {
return new std::vector<LibSVMPage>();
}
inline void BeforeFirst(void) {
utils::Assert(at_head_, "cannot call beforefirst");
}
inline void Destroy(void) {
delete source_; delete source_;
} }
virtual void BeforeFirst(void) {
utils::Assert(at_head_, "cannot call BeforeFirst");
}
virtual const LibSVMPage &Value(void) const {
return data_[data_ptr_ - 1];
}
virtual bool Next(void) {
while (true) {
while (data_ptr_ < data_end_) {
data_ptr_ += 1;
if (data_[data_ptr_ - 1].Size() != 0) {
return true;
}
}
if (!FillData()) break;
data_ptr_ = 0; data_end_ = data_.size();
}
return false;
}
inline size_t bytes_read(void) const { inline size_t bytes_read(void) const {
return bytes_read_; return bytes_read_;
} }
protected: protected:
inline bool FillData() { inline bool FillData(std::vector<LibSVMPage> *data) {
dmlc::InputSplit::Blob chunk; dmlc::InputSplit::Blob chunk;
if (!source_->NextChunk(&chunk)) return false; if (!source_->NextChunk(&chunk)) return false;
int nthread; int nthread;
@ -80,7 +81,7 @@ class LibSVMParser : public utils::IIterator<LibSVMPage> {
nthread = omp_get_num_threads(); nthread = omp_get_num_threads();
} }
// reserve space for data // reserve space for data
data_.resize(nthread); data->resize(nthread);
bytes_read_ += chunk.size; bytes_read_ += chunk.size;
utils::Assert(chunk.size != 0, "LibSVMParser.FileData"); utils::Assert(chunk.size != 0, "LibSVMParser.FileData");
char *head = reinterpret_cast<char*>(chunk.dptr); char *head = reinterpret_cast<char*>(chunk.dptr);
@ -98,9 +99,8 @@ class LibSVMParser : public utils::IIterator<LibSVMPage> {
} else { } else {
pend = BackFindEndLine(head + send, head); pend = BackFindEndLine(head + send, head);
} }
ParseBlock(pbegin, pend, &data_[tid]); ParseBlock(pbegin, pend, &(*data)[tid]);
} }
data_ptr_ = 0;
return true; return true;
} }
/*! /*!
@ -156,13 +156,54 @@ class LibSVMParser : public utils::IIterator<LibSVMPage> {
size_t bytes_read_; size_t bytes_read_;
// at beginning, at end of stream // at beginning, at end of stream
bool at_head_; bool at_head_;
// pointer to begin and end of data
size_t data_ptr_, data_end_;
// source split that provides the data // source split that provides the data
dmlc::InputSplit *source_; dmlc::InputSplit *source_;
// internal data
std::vector<LibSVMPage> data_;
}; };
class LibSVMParser : public utils::IIterator<LibSVMPage> {
public:
explicit LibSVMParser(dmlc::InputSplit *source,
int nthread)
: at_end_(false), data_ptr_(0), data_(NULL) {
itr.SetParam("buffer_size", "2");
itr.get_factory().Setup(source, nthread);
itr.Init();
}
virtual void BeforeFirst(void) {
itr.BeforeFirst();
}
virtual bool Next(void) {
if (at_end_) return false;
while (true) {
if (data_ == NULL || data_ptr_ >= data_->size()) {
if (!itr.Next(data_)) {
at_end_ = true; return false;
} else {
data_ptr_ = 0;
}
}
while (data_ptr_ < data_->size()) {
data_ptr_ += 1;
if ((*data_)[data_ptr_ - 1].Size() != 0) {
return true;
}
}
}
return true;
}
virtual const LibSVMPage &Value(void) const {
return (*data_)[data_ptr_ - 1];
}
inline size_t bytes_read(void) const {
return itr.get_factory().bytes_read();
}
private:
bool at_end_;
size_t data_ptr_;
std::vector<LibSVMPage> *data_;
utils::ThreadBuffer<std::vector<LibSVMPage>*, LibSVMPageFactory> itr;
};
} // namespace io } // namespace io
} // namespace xgboost } // namespace xgboost
#endif // XGBOOST_IO_LIBSVM_PARSER_H_ #endif // XGBOOST_IO_LIBSVM_PARSER_H_

View File

@ -149,7 +149,7 @@ class DMatrixPageBase : public DataMatrix {
size_t bytes_write = 0; size_t bytes_write = 0;
double tstart = rabit::utils::GetTime(); double tstart = rabit::utils::GetTime();
LibSVMParser parser( LibSVMParser parser(
dmlc::InputSplit::Create(uri, rank, npart, "text"), 4); dmlc::InputSplit::Create(uri, rank, npart, "text"), 16);
info.Clear(); info.Clear();
while (parser.Next()) { while (parser.Next()) {
const LibSVMPage &batch = parser.Value(); const LibSVMPage &batch = parser.Value();
@ -159,7 +159,7 @@ class DMatrixPageBase : public DataMatrix {
std::memcpy(BeginPtr(info.labels) + nlabel, std::memcpy(BeginPtr(info.labels) + nlabel,
BeginPtr(batch.label), BeginPtr(batch.label),
batch.label.size() * sizeof(float)); batch.label.size() * sizeof(float));
} }
page.Push(batch); page.Push(batch);
for (size_t i = 0; i < batch.data.size(); ++i) { for (size_t i = 0; i < batch.data.size(); ++i) {
info.info.num_col = std::max(info.info.num_col, info.info.num_col = std::max(info.info.num_col,
@ -171,7 +171,7 @@ class DMatrixPageBase : public DataMatrix {
page.Clear(); page.Clear();
double tdiff = rabit::utils::GetTime() - tstart; double tdiff = rabit::utils::GetTime() - tstart;
if (!silent) { if (!silent) {
utils::Printf("Writting to %s in %g MB/s, %g MB written\n", utils::Printf("Writting to %s in %g MB/s, %lu MB written\n",
cache_file, (bytes_write >> 20UL) / tdiff, cache_file, (bytes_write >> 20UL) / tdiff,
(bytes_write >> 20UL)); (bytes_write >> 20UL));
} }

View File

@ -173,6 +173,8 @@ class FMatrixPage : public IFMatrix {
std::fill(col_size_.begin(), col_size_.end(), 0); std::fill(col_size_.begin(), col_size_.end(), 0);
utils::FileStream fo; utils::FileStream fo;
fo = utils::FileStream(utils::FopenCheck(col_data_name_.c_str(), "wb")); fo = utils::FileStream(utils::FopenCheck(col_data_name_.c_str(), "wb"));
size_t bytes_write = 0;
double tstart = rabit::utils::GetTime();
// start working // start working
iter_->BeforeFirst(); iter_->BeforeFirst();
while (iter_->Next()) { while (iter_->Next()) {
@ -183,10 +185,17 @@ class FMatrixPage : public IFMatrix {
buffered_rowset_.push_back(ridx); buffered_rowset_.push_back(ridx);
prow.Push(batch[i]); prow.Push(batch[i]);
if (prow.MemCostBytes() >= kPageSize) { if (prow.MemCostBytes() >= kPageSize) {
bytes_write += prow.MemCostBytes();
this->PushColPage(prow, BeginPtr(buffered_rowset_) + btop, this->PushColPage(prow, BeginPtr(buffered_rowset_) + btop,
enabled, &pcol, &fo); enabled, &pcol, &fo);
btop += prow.Size(); btop += prow.Size();
prow.Clear(); prow.Clear();
double tdiff = rabit::utils::GetTime() - tstart;
utils::Printf("Writting to %s in %g MB/s, %lu MB written\n",
col_data_name_.c_str(),
(bytes_write >> 20UL) / tdiff,
(bytes_write >> 20UL));
} }
} }
} }

View File

@ -93,7 +93,7 @@ class DMatrixSimple : public DataMatrix {
npart = rabit::GetWorldSize(); npart = rabit::GetWorldSize();
} }
LibSVMParser parser( LibSVMParser parser(
dmlc::InputSplit::Create(uri, rank, npart, "text"), 4); dmlc::InputSplit::Create(uri, rank, npart, "text"), 16);
this->Clear(); this->Clear();
while (parser.Next()) { while (parser.Next()) {
const LibSVMPage &batch = parser.Value(); const LibSVMPage &batch = parser.Value();

View File

@ -142,7 +142,7 @@ class SparsePage {
* \param batch the row batch * \param batch the row batch
*/ */
inline void Push(const RowBatch &batch) { inline void Push(const RowBatch &batch) {
data.resize(offset.back() + batch.size); data.resize(offset.back() + batch.ind_ptr[batch.size]);
std::memcpy(BeginPtr(data) + offset.back(), std::memcpy(BeginPtr(data) + offset.back(),
batch.data_ptr + batch.ind_ptr[0], batch.data_ptr + batch.ind_ptr[0],
sizeof(SparseBatch::Entry) * batch.ind_ptr[batch.size]); sizeof(SparseBatch::Entry) * batch.ind_ptr[batch.size]);
@ -158,13 +158,13 @@ class SparsePage {
* \param batch the row page * \param batch the row page
*/ */
inline void Push(const SparsePage &batch) { inline void Push(const SparsePage &batch) {
data.resize(offset.back() + batch.Size()); size_t top = offset.back();
std::memcpy(BeginPtr(data) + offset.back(), data.resize(top + batch.data.size());
std::memcpy(BeginPtr(data) + top,
BeginPtr(batch.data), BeginPtr(batch.data),
sizeof(SparseBatch::Entry) * batch.data.size()); sizeof(SparseBatch::Entry) * batch.data.size());
size_t top = offset.back();
size_t begin = offset.size(); size_t begin = offset.size();
offset.resize(offset.size() + batch.Size()); offset.resize(begin + batch.Size());
for (size_t i = 0; i < batch.Size(); ++i) { for (size_t i = 0; i < batch.Size(); ++i) {
offset[i + begin] = top + batch.offset[i + 1]; offset[i + begin] = top + batch.offset[i + 1];
} }