still buggy

This commit is contained in:
tqchen 2014-09-02 17:18:17 -07:00
parent a89e3063e6
commit 226d26d40c
2 changed files with 19 additions and 11 deletions

View File

@ -132,6 +132,7 @@ class CSCMatrixManager {
"invalid column buffer format"); "invalid column buffer format");
p_page->col_data.push_back(ColBatch::Inst(p_data, len)); p_page->col_data.push_back(ColBatch::Inst(p_data, len));
p_page->col_index.push_back(cidx); p_page->col_index.push_back(cidx);
return true;
} }
// the following are in memory auxiliary data structure // the following are in memory auxiliary data structure
/*! \brief top of reader position */ /*! \brief top of reader position */
@ -159,6 +160,7 @@ class ThreadColPageIterator : public utils::IIterator<ColBatch> {
float page_ratio, bool silent) { float page_ratio, bool silent) {
itr_.SetParam("buffer_size", "2"); itr_.SetParam("buffer_size", "2");
itr_.get_factory().Setup(fi, page_ratio); itr_.get_factory().Setup(fi, page_ratio);
itr_.Init();
if (!silent) { if (!silent) {
utils::Printf("ThreadColPageIterator: finish initialzing, %u columns\n", utils::Printf("ThreadColPageIterator: finish initialzing, %u columns\n",
static_cast<unsigned>(col_ptr().size() - 1)); static_cast<unsigned>(col_ptr().size() - 1));
@ -239,8 +241,11 @@ class FMatrixPage : public IFMatrix {
} }
virtual void InitColAccess(float pkeep = 1.0f) { virtual void InitColAccess(float pkeep = 1.0f) {
if (this->HaveColAccess()) return; if (this->HaveColAccess()) return;
this->InitColData(pkeep, fname_cbuffer_.c_str(), if (!this->LoadColData()) {
64 << 20, 5); this->InitColData(pkeep, fname_cbuffer_.c_str(),
64 << 20, 5);
utils::Check(this->LoadColData(), "fail to read in column data");
}
} }
/*! /*!
* \brief get the row iterator associated with FMatrix * \brief get the row iterator associated with FMatrix

View File

@ -6,6 +6,7 @@
* \author Tianqi Chen * \author Tianqi Chen
*/ */
#include <vector> #include <vector>
#include <utility>
#include <algorithm> #include <algorithm>
#include "./io.h" #include "./io.h"
#include "./utils.h" #include "./utils.h"
@ -166,8 +167,8 @@ struct SparseCSRFileBuilder {
buffer_rptr.resize(rptr.size()); buffer_rptr.resize(rptr.size());
buffer_temp.reserve(buffer_size); buffer_temp.reserve(buffer_size);
buffer_data.resize(buffer_size); buffer_data.resize(buffer_size);
saved_offset.clear(); saved_offset = rptr;
saved_offset.resize(rptr.size() - 1, 0); saved_offset.resize(rptr.size() - 1);
this->ClearBuffer(); this->ClearBuffer();
} }
/*! \brief step 4: push element into buffer */ /*! \brief step 4: push element into buffer */
@ -176,6 +177,7 @@ struct SparseCSRFileBuilder {
this->WriteBuffer(); this->WriteBuffer();
this->ClearBuffer(); this->ClearBuffer();
} }
buffer_rptr[row_id + 1] += 1;
buffer_temp.push_back(std::make_pair(row_id, col_id)); buffer_temp.push_back(std::make_pair(row_id, col_id));
} }
/*! \brief finalize the construction */ /*! \brief finalize the construction */
@ -190,14 +192,14 @@ struct SparseCSRFileBuilder {
inline void SortRows(Comparator comp, size_t step) { inline void SortRows(Comparator comp, size_t step) {
for (size_t i = 0; i < rptr.size() - 1; i += step) { for (size_t i = 0; i < rptr.size() - 1; i += step) {
bst_omp_uint begin = static_cast<bst_omp_uint>(i); bst_omp_uint begin = static_cast<bst_omp_uint>(i);
bst_omp_uint end = static_cast<bst_omp_uint>(std::min(rptr.size(), i + step)); bst_omp_uint end = static_cast<bst_omp_uint>(std::min(rptr.size() - 1, i + step));
if (rptr[end] != rptr[begin]) { if (rptr[end] != rptr[begin]) {
fo->Seek(begin_data + rptr[begin] * sizeof(IndexType)); fo->Seek(begin_data + rptr[begin] * sizeof(IndexType));
buffer_data.resize(rptr[end] - rptr[begin]); buffer_data.resize(rptr[end] - rptr[begin]);
fo->Read(BeginPtr(buffer_data), (rptr[end] - rptr[begin]) * sizeof(IndexType)); fo->Read(BeginPtr(buffer_data), (rptr[end] - rptr[begin]) * sizeof(IndexType));
// do parallel sorting // do parallel sorting
#pragma omp parallel for schedule(static) #pragma omp parallel for schedule(static)
for (bst_omp_uint j = begin; j < end; ++j){ for (bst_omp_uint j = begin; j < end; ++j) {
std::sort(&buffer_data[0] + rptr[j] - rptr[begin], std::sort(&buffer_data[0] + rptr[j] - rptr[begin],
&buffer_data[0] + rptr[j+1] - rptr[begin], &buffer_data[0] + rptr[j+1] - rptr[begin],
comp); comp);
@ -206,6 +208,7 @@ struct SparseCSRFileBuilder {
fo->Write(BeginPtr(buffer_data), (rptr[end] - rptr[begin]) * sizeof(IndexType)); fo->Write(BeginPtr(buffer_data), (rptr[end] - rptr[begin]) * sizeof(IndexType));
} }
} }
printf("CSV::begin_dat=%lu\n", begin_data);
} }
protected: protected:
inline void WriteBuffer(void) { inline void WriteBuffer(void) {
@ -220,11 +223,11 @@ struct SparseCSRFileBuilder {
buffer_data[rp++] = buffer_temp[i].second; buffer_data[rp++] = buffer_temp[i].second;
} }
// write out // write out
for (size_t i = 0; i < buffer_rptr.size(); ++i) { for (size_t i = 0; i < buffer_rptr.size() - 1; ++i) {
size_t nelem = buffer_rptr[i+1] - buffer_rptr[i]; size_t nelem = buffer_rptr[i+1] - buffer_rptr[i];
if (nelem != 0) { if (nelem != 0) {
utils::Assert(saved_offset[i] < rptr[i+1], "data exceed bound"); utils::Assert(saved_offset[i] + nelem <= rptr[i+1], "data exceed bound");
fo->Seek((rptr[i] + saved_offset[i]) * sizeof(IndexType) + begin_data); fo->Seek(saved_offset[i] * sizeof(IndexType) + begin_data);
fo->Write(&buffer_data[0] + buffer_rptr[i], nelem * sizeof(IndexType)); fo->Write(&buffer_data[0] + buffer_rptr[i], nelem * sizeof(IndexType));
saved_offset[i] += nelem; saved_offset[i] += nelem;
} }