Dmatrix refactor stage 2 (#3395)

* DMatrix refactor 2

* Remove buffered rowset usage where possible

* Transition to c++11 style iterators for row access

* Transition column iterators to C++ 11
This commit is contained in:
Rory Mitchell
2018-10-01 01:29:03 +13:00
committed by GitHub
parent b50bc2c1d4
commit 70d208d68c
36 changed files with 459 additions and 846 deletions

View File

@@ -12,10 +12,12 @@
#include <cstring>
#include <memory>
#include <numeric>
#include <algorithm>
#include <string>
#include <vector>
#include "./base.h"
#include "../../src/common/span.h"
#include "../../src/common/group_data.h"
#include "../../src/common/host_device_vector.h"
@@ -191,6 +193,49 @@ class SparsePage {
data.HostVector().clear();
}
SparsePage GetTranspose(int num_columns) const {
SparsePage transpose;
common::ParallelGroupBuilder<Entry> builder(&transpose.offset.HostVector(),
&transpose.data.HostVector());
const int nthread = omp_get_max_threads();
builder.InitBudget(num_columns, nthread);
long batch_size = static_cast<long>(this->Size()); // NOLINT(*)
#pragma omp parallel for schedule(static)
for (long i = 0; i < batch_size; ++i) { // NOLINT(*)
int tid = omp_get_thread_num();
auto inst = (*this)[i];
for (bst_uint j = 0; j < inst.size(); ++j) {
builder.AddBudget(inst[j].index, tid);
}
}
builder.InitStorage();
#pragma omp parallel for schedule(static)
for (long i = 0; i < batch_size; ++i) { // NOLINT(*)
int tid = omp_get_thread_num();
auto inst = (*this)[i];
for (bst_uint j = 0; j < inst.size(); ++j) {
builder.Push(
inst[j].index,
Entry(static_cast<bst_uint>(this->base_rowid + i), inst[j].fvalue),
tid);
}
}
return transpose;
}
void SortRows() {
auto ncol = static_cast<bst_omp_uint>(this->Size());
#pragma omp parallel for schedule(dynamic, 1)
for (bst_omp_uint i = 0; i < ncol; ++i) {
if (this->offset.HostVector()[i] < this->offset.HostVector()[i + 1]) {
std::sort(
this->data.HostVector().begin() + this->offset.HostVector()[i],
this->data.HostVector().begin() + this->offset.HostVector()[i + 1],
Entry::CmpValue);
}
}
}
/*!
* \brief Push row block into the page.
* \param batch the row batch.
@@ -251,6 +296,62 @@ class SparsePage {
size_t Size() { return offset.Size() - 1; }
};
class BatchIteratorImpl {
public:
virtual ~BatchIteratorImpl() {}
virtual BatchIteratorImpl* Clone() = 0;
virtual const SparsePage& operator*() const = 0;
virtual void operator++() = 0;
virtual bool AtEnd() const = 0;
};
class BatchIterator {
public:
using iterator_category = std::forward_iterator_tag;
explicit BatchIterator(BatchIteratorImpl* impl) { impl_.reset(impl); }
BatchIterator(const BatchIterator& other) {
if (other.impl_) {
impl_.reset(other.impl_->Clone());
} else {
impl_.reset();
}
}
void operator++() {
CHECK(impl_ != nullptr);
++(*impl_);
}
const SparsePage& operator*() const {
CHECK(impl_ != nullptr);
return *(*impl_);
}
bool operator!=(const BatchIterator& rhs) const {
CHECK(impl_ != nullptr);
return !impl_->AtEnd();
}
bool AtEnd() const {
CHECK(impl_ != nullptr);
return impl_->AtEnd();
}
private:
std::unique_ptr<BatchIteratorImpl> impl_;
};
class BatchSet {
public:
explicit BatchSet(BatchIterator begin_iter) : begin_iter_(begin_iter) {}
BatchIterator begin() { return begin_iter_; }
BatchIterator end() { return BatchIterator(nullptr); }
private:
BatchIterator begin_iter_;
};
/*!
* \brief This is data structure that user can pass to DMatrix::Create
* to create a DMatrix for training, user can create this data structure
@@ -320,32 +421,17 @@ class DMatrix {
virtual MetaInfo& Info() = 0;
/*! \brief meta information of the dataset */
virtual const MetaInfo& Info() const = 0;
/*!
* \brief get the row iterator, reset to beginning position
* \note Only either RowIterator or column Iterator can be active.
/**
* \brief Gets row batches. Use range based for loop over BatchSet to access individual batches.
*/
virtual dmlc::DataIter<SparsePage>* RowIterator() = 0;
/*!\brief get column iterator, reset to the beginning position */
virtual dmlc::DataIter<SparsePage>* ColIterator() = 0;
/*!
* \brief check if column access is supported, if not, initialize column access.
* \param max_row_perbatch auxiliary information, maximum row used in each column batch.
* this is a hint information that can be ignored by the implementation.
* \param sorted If column features should be in sorted order
* \return Number of column blocks in the column access.
*/
virtual void InitColAccess(size_t max_row_perbatch, bool sorted) = 0;
virtual BatchSet GetRowBatches() = 0;
virtual BatchSet GetSortedColumnBatches() = 0;
virtual BatchSet GetColumnBatches() = 0;
// the following are column meta data, should be able to answer them fast.
/*! \return whether column access is enabled */
virtual bool HaveColAccess(bool sorted) const = 0;
/*! \return Whether the data columns single column block. */
virtual bool SingleColBlock() const = 0;
/*! \brief get number of non-missing entries in column */
virtual size_t GetColSize(size_t cidx) const = 0;
/*! \brief get column density */
virtual float GetColDensity(size_t cidx) const = 0;
/*! \return reference of buffered rowset, in column access */
virtual const RowSet& BufferedRowset() const = 0;
virtual float GetColDensity(size_t cidx) = 0;
/*! \brief virtual destructor */
virtual ~DMatrix() = default;
/*!
@@ -392,12 +478,6 @@ class DMatrix {
*/
static DMatrix* Create(dmlc::Parser<uint32_t>* parser,
const std::string& cache_prefix = "");
private:
// allow learner class to access this field.
friend class LearnerImpl;
/*! \brief public field to back ref cached matrix. */
LearnerImpl* cache_learner_ptr_{nullptr};
};
// implementation of inline functions