[MEM] Add rowset struct to save memory with billion level rows

This commit is contained in:
tqchen
2016-01-19 16:40:07 -08:00
parent 2230f1273f
commit 88447ca32e
9 changed files with 101 additions and 30 deletions

View File

@@ -183,6 +183,41 @@ class DataSource : public dmlc::DataIter<RowBatch> {
MetaInfo info;
};
/*!
* \brief A vector-like structure to represent set of rows.
* But saves the memory when all rows are in the set (common case in xgb)
*/
struct RowSet {
public:
/*! \return i-th row index */
inline bst_uint operator[](size_t i) const;
/*! \return the size of the set. */
inline size_t size() const;
/*! \brief push the index back to the set */
inline void push_back(bst_uint i);
/*! \brief clear the set */
inline void clear();
/*!
* \brief save rowset to file.
* \param fo The file to be saved.
*/
inline void Save(dmlc::Stream* fo) const;
/*!
* \brief Load rowset from file.
* \param fi The file to be loaded.
* \return if read is successful.
*/
inline bool Load(dmlc::Stream* fi);
/*! \brief constructor */
RowSet() : size_(0) {}
private:
/*! \brief The internal data structure of size */
uint64_t size_;
/*! \brief The internal data structure of row set if not all*/
std::vector<bst_uint> rows_;
};
/*!
* \brief Internal data structured used by XGBoost during training.
* There are two ways to create a customized DMatrix that reads in user defined-format.
@@ -235,7 +270,7 @@ class DMatrix {
/*! \brief get column density */
virtual float GetColDensity(size_t cidx) const = 0;
/*! \return reference of buffered rowset, in column access */
virtual const std::vector<bst_uint>& buffered_rowset() const = 0;
virtual const RowSet& buffered_rowset() const = 0;
/*! \brief virtual destructor */
virtual ~DMatrix() {}
/*!
@@ -290,9 +325,48 @@ class DMatrix {
LearnerImpl* cache_learner_ptr_;
};
// implementation of inline functions
inline bst_uint RowSet::operator[](size_t i) const {
return rows_.size() == 0 ? i : rows_[i];
}
inline size_t RowSet::size() const {
return size_;
}
inline void RowSet::clear() {
rows_.clear(); size_ = 0;
}
inline void RowSet::push_back(bst_uint i) {
if (rows_.size() == 0) {
if (i == size_) {
++size_; return;
} else {
rows_.resize(size_);
for (size_t i = 0; i < size_; ++i) {
rows_[i] = static_cast<bst_uint>(i);
}
}
}
rows_.push_back(i);
++size_;
}
inline void RowSet::Save(dmlc::Stream* fo) const {
fo->Write(rows_);
fo->Write(&size_, sizeof(size_));
}
inline bool RowSet::Load(dmlc::Stream* fi) {
if (!fi->Read(&rows_)) return false;
if (rows_.size() != 0) return true;
return fi->Read(&size_, sizeof(size_)) == sizeof(size_);
}
} // namespace xgboost
namespace dmlc {
DMLC_DECLARE_TRAITS(is_pod, xgboost::SparseBatch::Entry, true);
DMLC_DECLARE_TRAITS(has_saveload, xgboost::RowSet, true);
}
#endif // XGBOOST_DATA_H_