diff --git a/src/io/io.cpp b/src/io/io.cpp index e8b9ce337..e413b2799 100644 --- a/src/io/io.cpp +++ b/src/io/io.cpp @@ -37,8 +37,7 @@ DataMatrix* LoadDataMatrix(const char *fname, bool silent, bool savebuffer) { } void SaveDataMatrix(const DataMatrix &dmat, const char *fname, bool silent) { - if (!strcmp(fname + strlen(fname) - 5, ".page")) { - + if (!strcmp(fname + strlen(fname) - 5, ".page")) { DMatrixPage::Save(fname, dmat, silent); return; } @@ -46,7 +45,9 @@ void SaveDataMatrix(const DataMatrix &dmat, const char *fname, bool silent) { const DMatrixSimple *p_dmat = static_cast(&dmat); p_dmat->SaveBinary(fname, silent); } else { - utils::Error("not implemented"); + DMatrixSimple smat; + smat.CopyFrom(dmat); + smat.SaveBinary(fname, silent); } } diff --git a/src/io/page_dmatrix-inl.hpp b/src/io/page_dmatrix-inl.hpp index 4d13a0bc5..2701fb3b3 100644 --- a/src/io/page_dmatrix-inl.hpp +++ b/src/io/page_dmatrix-inl.hpp @@ -32,7 +32,7 @@ struct RowBatchPage { const size_t dsize = row.length * sizeof(RowBatch::Entry); if (FreeBytes() < dsize+ sizeof(int)) return false; row_ptr(Size() + 1) = row_ptr(Size()) + row.length; - memcpy(data_ptr(Size()) , row.data, dsize); + memcpy(data_ptr(row_ptr(Size())) , row.data, dsize); ++ data_[0]; return true; } @@ -48,13 +48,18 @@ struct RowBatchPage { batch.data_ptr = this->data_ptr(0); batch.size = static_cast(this->Size()); std::vector &rptr = *p_rptr; - rptr.resize(this->Size()+1); + rptr.resize(this->Size() + 1); for (size_t i = 0; i < rptr.size(); ++i) { rptr[i] = static_cast(this->row_ptr(i)); } batch.ind_ptr = &rptr[0]; return batch; } + /*! \brief get i-th row from the batch */ + inline RowBatch::Inst operator[](size_t i) { + return RowBatch::Inst(data_ptr(0) + row_ptr(i), + static_cast(row_ptr(i+1) - row_ptr(i))); + } /*! * \brief clear the page, cleanup the content */ @@ -77,7 +82,7 @@ struct RowBatchPage { return data_[0]; } /*! \brief page size 64 MB */ - static const size_t kPageSize = 64 << 18; + static const size_t kPageSize = 64 << 8; private: /*! \return number of elements */ @@ -112,7 +117,6 @@ class ThreadRowPageIterator: public utils::IIterator { itr.BeforeFirst(); isend_ = false; base_rowid_ = 0; - utils::Assert(this->LoadNextPage(), "ThreadRowPageIterator"); } virtual bool Next(void) { if(!this->LoadNextPage()) return false; diff --git a/src/io/simple_dmatrix-inl.hpp b/src/io/simple_dmatrix-inl.hpp index 47be8a41a..8d7064bdd 100644 --- a/src/io/simple_dmatrix-inl.hpp +++ b/src/io/simple_dmatrix-inl.hpp @@ -44,8 +44,8 @@ class DMatrixSimple : public DataMatrix { } /*! \brief copy content data from source matrix */ inline void CopyFrom(const DataMatrix &src) { - this->info = src.info; this->Clear(); + this->info = src.info; // clone data content in thos matrix utils::IIterator *iter = src.fmat()->RowIterator(); iter->BeforeFirst(); diff --git a/src/io/simple_fmatrix-inl.hpp b/src/io/simple_fmatrix-inl.hpp index 86763a105..f099eb1a9 100644 --- a/src/io/simple_fmatrix-inl.hpp +++ b/src/io/simple_fmatrix-inl.hpp @@ -150,7 +150,7 @@ class FMatrixS : public IFMatrix{ iter_->BeforeFirst(); while (iter_->Next()) { const RowBatch &batch = iter_->Value(); - for (size_t i = 0; i < batch.size; ++i) { + for (size_t i = 0; i < batch.size; ++i) { if (pkeep == 1.0f || random::SampleBinary(pkeep)) { buffered_rowset_.push_back(static_cast(batch.base_rowid+i)); RowBatch::Inst inst = batch[i];