Refactor DMatrix to return batches of different page types (#4686)

* Use explicit template parameter for specifying page type.
This commit is contained in:
Rong Ou
2019-08-03 12:10:34 -07:00
committed by Jiaming Yuan
parent e930a8e54f
commit 6edddd7966
41 changed files with 477 additions and 470 deletions

View File

@@ -15,6 +15,7 @@
#include <numeric>
#include <algorithm>
#include <string>
#include <utility>
#include <vector>
#include "./base.h"
#include "../../src/common/span.h"
@@ -270,20 +271,34 @@ class SparsePage {
size_t Size() { return offset.Size() - 1; }
};
class CSCPage: public SparsePage {
public:
CSCPage() : SparsePage() {}
explicit CSCPage(SparsePage page) : SparsePage(std::move(page)) {}
};
class SortedCSCPage : public SparsePage {
public:
SortedCSCPage() : SparsePage() {}
explicit SortedCSCPage(SparsePage page) : SparsePage(std::move(page)) {}
};
template<typename T>
class BatchIteratorImpl {
public:
virtual ~BatchIteratorImpl() {}
virtual BatchIteratorImpl* Clone() = 0;
virtual SparsePage& operator*() = 0;
virtual const SparsePage& operator*() const = 0;
virtual T& operator*() = 0;
virtual const T& operator*() const = 0;
virtual void operator++() = 0;
virtual bool AtEnd() const = 0;
};
template<typename T>
class BatchIterator {
public:
using iterator_category = std::forward_iterator_tag;
explicit BatchIterator(BatchIteratorImpl* impl) { impl_.reset(impl); }
explicit BatchIterator(BatchIteratorImpl<T>* impl) { impl_.reset(impl); }
BatchIterator(const BatchIterator& other) {
if (other.impl_) {
@@ -298,12 +313,12 @@ class BatchIterator {
++(*impl_);
}
SparsePage& operator*() {
T& operator*() {
CHECK(impl_ != nullptr);
return *(*impl_);
}
const SparsePage& operator*() const {
const T& operator*() const {
CHECK(impl_ != nullptr);
return *(*impl_);
}
@@ -319,17 +334,18 @@ class BatchIterator {
}
private:
std::unique_ptr<BatchIteratorImpl> impl_;
std::unique_ptr<BatchIteratorImpl<T>> impl_;
};
template<typename T>
class BatchSet {
public:
explicit BatchSet(BatchIterator begin_iter) : begin_iter_(begin_iter) {}
BatchIterator begin() { return begin_iter_; }
BatchIterator end() { return BatchIterator(nullptr); }
explicit BatchSet(BatchIterator<T> begin_iter) : begin_iter_(begin_iter) {}
BatchIterator<T> begin() { return begin_iter_; }
BatchIterator<T> end() { return BatchIterator<T>(nullptr); }
private:
BatchIterator begin_iter_;
BatchIterator<T> begin_iter_;
};
/*!
@@ -339,7 +355,8 @@ class BatchSet {
*
* On distributed setting, usually an customized dmlc::Parser is needed instead.
*/
class DataSource : public dmlc::DataIter<SparsePage> {
template<typename T>
class DataSource : public dmlc::DataIter<T> {
public:
/*!
* \brief Meta information about the dataset
@@ -367,11 +384,10 @@ class DMatrix {
/*! \brief meta information of the dataset */
virtual const MetaInfo& Info() const = 0;
/**
* \brief Gets row batches. Use range based for loop over BatchSet to access individual batches.
* \brief Gets batches. Use range based for loop over BatchSet to access individual batches.
*/
virtual BatchSet GetRowBatches() = 0;
virtual BatchSet GetSortedColumnBatches() = 0;
virtual BatchSet GetColumnBatches() = 0;
template<typename T>
BatchSet<T> GetBatches();
// the following are column meta data, should be able to answer them fast.
/*! \return Whether the data columns single column block. */
virtual bool SingleColBlock() const = 0;
@@ -410,7 +426,7 @@ class DMatrix {
* This can be nullptr for common cases, and in-memory mode will be used.
* \return a Created DMatrix.
*/
static DMatrix* Create(std::unique_ptr<DataSource>&& source,
static DMatrix* Create(std::unique_ptr<DataSource<SparsePage>>&& source,
const std::string& cache_prefix = "");
/*!
* \brief Create a DMatrix by loading data from parser.
@@ -431,7 +447,27 @@ class DMatrix {
/*! \brief page size 32 MB */
static const size_t kPageSize = 32UL << 20UL;
protected:
virtual BatchSet<SparsePage> GetRowBatches() = 0;
virtual BatchSet<CSCPage> GetColumnBatches() = 0;
virtual BatchSet<SortedCSCPage> GetSortedColumnBatches() = 0;
};
template<>
inline BatchSet<SparsePage> DMatrix::GetBatches() {
return GetRowBatches();
}
template<>
inline BatchSet<CSCPage> DMatrix::GetBatches() {
return GetColumnBatches();
}
template<>
inline BatchSet<SortedCSCPage> DMatrix::GetBatches() {
return GetSortedColumnBatches();
}
} // namespace xgboost
namespace dmlc {