Dmatrix refactor stage 2 (#3395)
* DMatrix refactor 2 * Remove buffered rowset usage where possible * Transition to c++11 style iterators for row access * Transition column iterators to C++ 11
This commit is contained in:
@@ -12,10 +12,12 @@
|
||||
#include <cstring>
|
||||
#include <memory>
|
||||
#include <numeric>
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "./base.h"
|
||||
#include "../../src/common/span.h"
|
||||
#include "../../src/common/group_data.h"
|
||||
|
||||
#include "../../src/common/host_device_vector.h"
|
||||
|
||||
@@ -191,6 +193,49 @@ class SparsePage {
|
||||
data.HostVector().clear();
|
||||
}
|
||||
|
||||
SparsePage GetTranspose(int num_columns) const {
|
||||
SparsePage transpose;
|
||||
common::ParallelGroupBuilder<Entry> builder(&transpose.offset.HostVector(),
|
||||
&transpose.data.HostVector());
|
||||
const int nthread = omp_get_max_threads();
|
||||
builder.InitBudget(num_columns, nthread);
|
||||
long batch_size = static_cast<long>(this->Size()); // NOLINT(*)
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (long i = 0; i < batch_size; ++i) { // NOLINT(*)
|
||||
int tid = omp_get_thread_num();
|
||||
auto inst = (*this)[i];
|
||||
for (bst_uint j = 0; j < inst.size(); ++j) {
|
||||
builder.AddBudget(inst[j].index, tid);
|
||||
}
|
||||
}
|
||||
builder.InitStorage();
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (long i = 0; i < batch_size; ++i) { // NOLINT(*)
|
||||
int tid = omp_get_thread_num();
|
||||
auto inst = (*this)[i];
|
||||
for (bst_uint j = 0; j < inst.size(); ++j) {
|
||||
builder.Push(
|
||||
inst[j].index,
|
||||
Entry(static_cast<bst_uint>(this->base_rowid + i), inst[j].fvalue),
|
||||
tid);
|
||||
}
|
||||
}
|
||||
return transpose;
|
||||
}
|
||||
|
||||
void SortRows() {
|
||||
auto ncol = static_cast<bst_omp_uint>(this->Size());
|
||||
#pragma omp parallel for schedule(dynamic, 1)
|
||||
for (bst_omp_uint i = 0; i < ncol; ++i) {
|
||||
if (this->offset.HostVector()[i] < this->offset.HostVector()[i + 1]) {
|
||||
std::sort(
|
||||
this->data.HostVector().begin() + this->offset.HostVector()[i],
|
||||
this->data.HostVector().begin() + this->offset.HostVector()[i + 1],
|
||||
Entry::CmpValue);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief Push row block into the page.
|
||||
* \param batch the row batch.
|
||||
@@ -251,6 +296,62 @@ class SparsePage {
|
||||
size_t Size() { return offset.Size() - 1; }
|
||||
};
|
||||
|
||||
class BatchIteratorImpl {
|
||||
public:
|
||||
virtual ~BatchIteratorImpl() {}
|
||||
virtual BatchIteratorImpl* Clone() = 0;
|
||||
virtual const SparsePage& operator*() const = 0;
|
||||
virtual void operator++() = 0;
|
||||
virtual bool AtEnd() const = 0;
|
||||
};
|
||||
|
||||
class BatchIterator {
|
||||
public:
|
||||
using iterator_category = std::forward_iterator_tag;
|
||||
explicit BatchIterator(BatchIteratorImpl* impl) { impl_.reset(impl); }
|
||||
|
||||
BatchIterator(const BatchIterator& other) {
|
||||
if (other.impl_) {
|
||||
impl_.reset(other.impl_->Clone());
|
||||
} else {
|
||||
impl_.reset();
|
||||
}
|
||||
}
|
||||
|
||||
void operator++() {
|
||||
CHECK(impl_ != nullptr);
|
||||
++(*impl_);
|
||||
}
|
||||
|
||||
const SparsePage& operator*() const {
|
||||
CHECK(impl_ != nullptr);
|
||||
return *(*impl_);
|
||||
}
|
||||
|
||||
bool operator!=(const BatchIterator& rhs) const {
|
||||
CHECK(impl_ != nullptr);
|
||||
return !impl_->AtEnd();
|
||||
}
|
||||
|
||||
bool AtEnd() const {
|
||||
CHECK(impl_ != nullptr);
|
||||
return impl_->AtEnd();
|
||||
}
|
||||
|
||||
private:
|
||||
std::unique_ptr<BatchIteratorImpl> impl_;
|
||||
};
|
||||
|
||||
class BatchSet {
|
||||
public:
|
||||
explicit BatchSet(BatchIterator begin_iter) : begin_iter_(begin_iter) {}
|
||||
BatchIterator begin() { return begin_iter_; }
|
||||
BatchIterator end() { return BatchIterator(nullptr); }
|
||||
|
||||
private:
|
||||
BatchIterator begin_iter_;
|
||||
};
|
||||
|
||||
/*!
|
||||
* \brief This is data structure that user can pass to DMatrix::Create
|
||||
* to create a DMatrix for training, user can create this data structure
|
||||
@@ -320,32 +421,17 @@ class DMatrix {
|
||||
virtual MetaInfo& Info() = 0;
|
||||
/*! \brief meta information of the dataset */
|
||||
virtual const MetaInfo& Info() const = 0;
|
||||
/*!
|
||||
* \brief get the row iterator, reset to beginning position
|
||||
* \note Only either RowIterator or column Iterator can be active.
|
||||
/**
|
||||
* \brief Gets row batches. Use range based for loop over BatchSet to access individual batches.
|
||||
*/
|
||||
virtual dmlc::DataIter<SparsePage>* RowIterator() = 0;
|
||||
/*!\brief get column iterator, reset to the beginning position */
|
||||
virtual dmlc::DataIter<SparsePage>* ColIterator() = 0;
|
||||
/*!
|
||||
* \brief check if column access is supported, if not, initialize column access.
|
||||
* \param max_row_perbatch auxiliary information, maximum row used in each column batch.
|
||||
* this is a hint information that can be ignored by the implementation.
|
||||
* \param sorted If column features should be in sorted order
|
||||
* \return Number of column blocks in the column access.
|
||||
*/
|
||||
virtual void InitColAccess(size_t max_row_perbatch, bool sorted) = 0;
|
||||
virtual BatchSet GetRowBatches() = 0;
|
||||
virtual BatchSet GetSortedColumnBatches() = 0;
|
||||
virtual BatchSet GetColumnBatches() = 0;
|
||||
// the following are column meta data, should be able to answer them fast.
|
||||
/*! \return whether column access is enabled */
|
||||
virtual bool HaveColAccess(bool sorted) const = 0;
|
||||
/*! \return Whether the data columns single column block. */
|
||||
virtual bool SingleColBlock() const = 0;
|
||||
/*! \brief get number of non-missing entries in column */
|
||||
virtual size_t GetColSize(size_t cidx) const = 0;
|
||||
/*! \brief get column density */
|
||||
virtual float GetColDensity(size_t cidx) const = 0;
|
||||
/*! \return reference of buffered rowset, in column access */
|
||||
virtual const RowSet& BufferedRowset() const = 0;
|
||||
virtual float GetColDensity(size_t cidx) = 0;
|
||||
/*! \brief virtual destructor */
|
||||
virtual ~DMatrix() = default;
|
||||
/*!
|
||||
@@ -392,12 +478,6 @@ class DMatrix {
|
||||
*/
|
||||
static DMatrix* Create(dmlc::Parser<uint32_t>* parser,
|
||||
const std::string& cache_prefix = "");
|
||||
|
||||
private:
|
||||
// allow learner class to access this field.
|
||||
friend class LearnerImpl;
|
||||
/*! \brief public field to back ref cached matrix. */
|
||||
LearnerImpl* cache_learner_ptr_{nullptr};
|
||||
};
|
||||
|
||||
// implementation of inline functions
|
||||
|
||||
Reference in New Issue
Block a user