Refactor DMatrix to return batches of different page types (#4686)
* Use explicit template parameter for specifying page type.
This commit is contained in:
@@ -15,6 +15,7 @@
|
||||
#include <numeric>
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include "./base.h"
|
||||
#include "../../src/common/span.h"
|
||||
@@ -270,20 +271,34 @@ class SparsePage {
|
||||
size_t Size() { return offset.Size() - 1; }
|
||||
};
|
||||
|
||||
class CSCPage: public SparsePage {
|
||||
public:
|
||||
CSCPage() : SparsePage() {}
|
||||
explicit CSCPage(SparsePage page) : SparsePage(std::move(page)) {}
|
||||
};
|
||||
|
||||
class SortedCSCPage : public SparsePage {
|
||||
public:
|
||||
SortedCSCPage() : SparsePage() {}
|
||||
explicit SortedCSCPage(SparsePage page) : SparsePage(std::move(page)) {}
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
class BatchIteratorImpl {
|
||||
public:
|
||||
virtual ~BatchIteratorImpl() {}
|
||||
virtual BatchIteratorImpl* Clone() = 0;
|
||||
virtual SparsePage& operator*() = 0;
|
||||
virtual const SparsePage& operator*() const = 0;
|
||||
virtual T& operator*() = 0;
|
||||
virtual const T& operator*() const = 0;
|
||||
virtual void operator++() = 0;
|
||||
virtual bool AtEnd() const = 0;
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
class BatchIterator {
|
||||
public:
|
||||
using iterator_category = std::forward_iterator_tag;
|
||||
explicit BatchIterator(BatchIteratorImpl* impl) { impl_.reset(impl); }
|
||||
explicit BatchIterator(BatchIteratorImpl<T>* impl) { impl_.reset(impl); }
|
||||
|
||||
BatchIterator(const BatchIterator& other) {
|
||||
if (other.impl_) {
|
||||
@@ -298,12 +313,12 @@ class BatchIterator {
|
||||
++(*impl_);
|
||||
}
|
||||
|
||||
SparsePage& operator*() {
|
||||
T& operator*() {
|
||||
CHECK(impl_ != nullptr);
|
||||
return *(*impl_);
|
||||
}
|
||||
|
||||
const SparsePage& operator*() const {
|
||||
const T& operator*() const {
|
||||
CHECK(impl_ != nullptr);
|
||||
return *(*impl_);
|
||||
}
|
||||
@@ -319,17 +334,18 @@ class BatchIterator {
|
||||
}
|
||||
|
||||
private:
|
||||
std::unique_ptr<BatchIteratorImpl> impl_;
|
||||
std::unique_ptr<BatchIteratorImpl<T>> impl_;
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
class BatchSet {
|
||||
public:
|
||||
explicit BatchSet(BatchIterator begin_iter) : begin_iter_(begin_iter) {}
|
||||
BatchIterator begin() { return begin_iter_; }
|
||||
BatchIterator end() { return BatchIterator(nullptr); }
|
||||
explicit BatchSet(BatchIterator<T> begin_iter) : begin_iter_(begin_iter) {}
|
||||
BatchIterator<T> begin() { return begin_iter_; }
|
||||
BatchIterator<T> end() { return BatchIterator<T>(nullptr); }
|
||||
|
||||
private:
|
||||
BatchIterator begin_iter_;
|
||||
BatchIterator<T> begin_iter_;
|
||||
};
|
||||
|
||||
/*!
|
||||
@@ -339,7 +355,8 @@ class BatchSet {
|
||||
*
|
||||
* On distributed setting, usually an customized dmlc::Parser is needed instead.
|
||||
*/
|
||||
class DataSource : public dmlc::DataIter<SparsePage> {
|
||||
template<typename T>
|
||||
class DataSource : public dmlc::DataIter<T> {
|
||||
public:
|
||||
/*!
|
||||
* \brief Meta information about the dataset
|
||||
@@ -367,11 +384,10 @@ class DMatrix {
|
||||
/*! \brief meta information of the dataset */
|
||||
virtual const MetaInfo& Info() const = 0;
|
||||
/**
|
||||
* \brief Gets row batches. Use range based for loop over BatchSet to access individual batches.
|
||||
* \brief Gets batches. Use range based for loop over BatchSet to access individual batches.
|
||||
*/
|
||||
virtual BatchSet GetRowBatches() = 0;
|
||||
virtual BatchSet GetSortedColumnBatches() = 0;
|
||||
virtual BatchSet GetColumnBatches() = 0;
|
||||
template<typename T>
|
||||
BatchSet<T> GetBatches();
|
||||
// the following are column meta data, should be able to answer them fast.
|
||||
/*! \return Whether the data columns single column block. */
|
||||
virtual bool SingleColBlock() const = 0;
|
||||
@@ -410,7 +426,7 @@ class DMatrix {
|
||||
* This can be nullptr for common cases, and in-memory mode will be used.
|
||||
* \return a Created DMatrix.
|
||||
*/
|
||||
static DMatrix* Create(std::unique_ptr<DataSource>&& source,
|
||||
static DMatrix* Create(std::unique_ptr<DataSource<SparsePage>>&& source,
|
||||
const std::string& cache_prefix = "");
|
||||
/*!
|
||||
* \brief Create a DMatrix by loading data from parser.
|
||||
@@ -431,7 +447,27 @@ class DMatrix {
|
||||
|
||||
/*! \brief page size 32 MB */
|
||||
static const size_t kPageSize = 32UL << 20UL;
|
||||
|
||||
protected:
|
||||
virtual BatchSet<SparsePage> GetRowBatches() = 0;
|
||||
virtual BatchSet<CSCPage> GetColumnBatches() = 0;
|
||||
virtual BatchSet<SortedCSCPage> GetSortedColumnBatches() = 0;
|
||||
};
|
||||
|
||||
template<>
|
||||
inline BatchSet<SparsePage> DMatrix::GetBatches() {
|
||||
return GetRowBatches();
|
||||
}
|
||||
|
||||
template<>
|
||||
inline BatchSet<CSCPage> DMatrix::GetBatches() {
|
||||
return GetColumnBatches();
|
||||
}
|
||||
|
||||
template<>
|
||||
inline BatchSet<SortedCSCPage> DMatrix::GetBatches() {
|
||||
return GetSortedColumnBatches();
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
namespace dmlc {
|
||||
|
||||
Reference in New Issue
Block a user