Dmatrix refactor stage 2 (#3395)
* DMatrix refactor 2 * Remove buffered rowset usage where possible * Transition to c++11 style iterators for row access * Transition column iterators to C++ 11
This commit is contained in:
@@ -4,103 +4,79 @@
|
||||
* \brief the input data structure for gradient boosting
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
#include <xgboost/data.h>
|
||||
#include <limits>
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
#include "./simple_dmatrix.h"
|
||||
#include <xgboost/data.h>
|
||||
#include "../common/random.h"
|
||||
#include "../common/group_data.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace data {
|
||||
MetaInfo& SimpleDMatrix::Info() { return source_->info; }
|
||||
|
||||
bool SimpleDMatrix::ColBatchIter::Next() {
|
||||
if (data_ >= 1) return false;
|
||||
data_ += 1;
|
||||
return true;
|
||||
}
|
||||
const MetaInfo& SimpleDMatrix::Info() const { return source_->info; }
|
||||
|
||||
dmlc::DataIter<SparsePage>* SimpleDMatrix::ColIterator() {
|
||||
col_iter_.BeforeFirst();
|
||||
return &col_iter_;
|
||||
}
|
||||
|
||||
void SimpleDMatrix::InitColAccess(
|
||||
size_t max_row_perbatch, bool sorted) {
|
||||
if (this->HaveColAccess(sorted)) return;
|
||||
col_iter_.sorted_ = sorted;
|
||||
col_iter_.column_page_.reset(new SparsePage());
|
||||
this->MakeOneBatch(col_iter_.column_page_.get(), sorted);
|
||||
}
|
||||
|
||||
// internal function to make one batch from row iter.
|
||||
void SimpleDMatrix::MakeOneBatch(SparsePage* pcol, bool sorted) {
|
||||
// clear rowset
|
||||
buffered_rowset_.Clear();
|
||||
// bit map
|
||||
const int nthread = omp_get_max_threads();
|
||||
pcol->Clear();
|
||||
auto& pcol_offset_vec = pcol->offset.HostVector();
|
||||
auto& pcol_data_vec = pcol->data.HostVector();
|
||||
common::ParallelGroupBuilder<Entry>
|
||||
builder(&pcol_offset_vec, &pcol_data_vec);
|
||||
builder.InitBudget(Info().num_col_, nthread);
|
||||
// start working
|
||||
auto iter = this->RowIterator();
|
||||
iter->BeforeFirst();
|
||||
while (iter->Next()) {
|
||||
const auto& batch = iter->Value();
|
||||
long batch_size = static_cast<long>(batch.Size()); // NOLINT(*)
|
||||
for (long i = 0; i < batch_size; ++i) { // NOLINT(*)
|
||||
auto ridx = static_cast<bst_uint>(batch.base_rowid + i);
|
||||
buffered_rowset_.PushBack(ridx);
|
||||
}
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (long i = 0; i < batch_size; ++i) { // NOLINT(*)
|
||||
int tid = omp_get_thread_num();
|
||||
auto inst = batch[i];
|
||||
for (auto& ins : inst) {
|
||||
builder.AddBudget(ins.index, tid);
|
||||
}
|
||||
}
|
||||
}
|
||||
builder.InitStorage();
|
||||
|
||||
iter->BeforeFirst();
|
||||
while (iter->Next()) {
|
||||
auto &batch = iter->Value();
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (long i = 0; i < static_cast<long>(batch.Size()); ++i) { // NOLINT(*)
|
||||
int tid = omp_get_thread_num();
|
||||
auto inst = batch[i];
|
||||
for (auto& ins : inst) {
|
||||
builder.Push(ins.index,
|
||||
Entry(static_cast<bst_uint>(batch.base_rowid + i),
|
||||
ins.fvalue),
|
||||
tid);
|
||||
}
|
||||
}
|
||||
float SimpleDMatrix::GetColDensity(size_t cidx) {
|
||||
size_t column_size = 0;
|
||||
// Use whatever version of column batches already exists
|
||||
if (sorted_column_page_) {
|
||||
auto batch = this->GetSortedColumnBatches();
|
||||
column_size = (*batch.begin())[cidx].size();
|
||||
} else {
|
||||
auto batch = this->GetColumnBatches();
|
||||
column_size = (*batch.begin())[cidx].size();
|
||||
}
|
||||
|
||||
CHECK_EQ(pcol->Size(), Info().num_col_);
|
||||
size_t nmiss = this->Info().num_row_ - column_size;
|
||||
return 1.0f - (static_cast<float>(nmiss)) / this->Info().num_row_;
|
||||
}
|
||||
|
||||
if (sorted) {
|
||||
// sort columns
|
||||
auto ncol = static_cast<bst_omp_uint>(pcol->Size());
|
||||
#pragma omp parallel for schedule(dynamic, 1) num_threads(nthread)
|
||||
for (bst_omp_uint i = 0; i < ncol; ++i) {
|
||||
if (pcol_offset_vec[i] < pcol_offset_vec[i + 1]) {
|
||||
std::sort(dmlc::BeginPtr(pcol_data_vec) + pcol_offset_vec[i],
|
||||
dmlc::BeginPtr(pcol_data_vec) + pcol_offset_vec[i + 1],
|
||||
Entry::CmpValue);
|
||||
}
|
||||
}
|
||||
class SimpleBatchIteratorImpl : public BatchIteratorImpl {
|
||||
public:
|
||||
explicit SimpleBatchIteratorImpl(SparsePage* page) : page_(page) {}
|
||||
const SparsePage& operator*() const override {
|
||||
CHECK(page_ != nullptr);
|
||||
return *page_;
|
||||
}
|
||||
void operator++() override { page_ = nullptr; }
|
||||
bool AtEnd() const override { return page_ == nullptr; }
|
||||
SimpleBatchIteratorImpl* Clone() override {
|
||||
return new SimpleBatchIteratorImpl(*this);
|
||||
}
|
||||
|
||||
private:
|
||||
SparsePage* page_{nullptr};
|
||||
};
|
||||
|
||||
BatchSet SimpleDMatrix::GetRowBatches() {
|
||||
auto cast = dynamic_cast<SimpleCSRSource*>(source_.get());
|
||||
auto begin_iter = BatchIterator(new SimpleBatchIteratorImpl(&(cast->page_)));
|
||||
return BatchSet(begin_iter);
|
||||
}
|
||||
|
||||
bool SimpleDMatrix::SingleColBlock() const {
|
||||
return true;
|
||||
BatchSet SimpleDMatrix::GetColumnBatches() {
|
||||
// column page doesn't exist, generate it
|
||||
if (!column_page_) {
|
||||
auto page = dynamic_cast<SimpleCSRSource*>(source_.get())->page_;
|
||||
column_page_.reset(
|
||||
new SparsePage(page.GetTranspose(source_->info.num_col_)));
|
||||
}
|
||||
auto begin_iter =
|
||||
BatchIterator(new SimpleBatchIteratorImpl(column_page_.get()));
|
||||
return BatchSet(begin_iter);
|
||||
}
|
||||
|
||||
BatchSet SimpleDMatrix::GetSortedColumnBatches() {
|
||||
// Sorted column page doesn't exist, generate it
|
||||
if (!sorted_column_page_) {
|
||||
auto page = dynamic_cast<SimpleCSRSource*>(source_.get())->page_;
|
||||
sorted_column_page_.reset(
|
||||
new SparsePage(page.GetTranspose(source_->info.num_col_)));
|
||||
sorted_column_page_->SortRows();
|
||||
}
|
||||
auto begin_iter =
|
||||
BatchIterator(new SimpleBatchIteratorImpl(sorted_column_page_.get()));
|
||||
return BatchSet(begin_iter);
|
||||
}
|
||||
|
||||
bool SimpleDMatrix::SingleColBlock() const { return true; }
|
||||
} // namespace data
|
||||
} // namespace xgboost
|
||||
|
||||
Reference in New Issue
Block a user