Refactor DMatrix to return batches of different page types (#4686)

* Use explicit template parameter for specifying page type.
This commit is contained in:
Rong Ou
2019-08-03 12:10:34 -07:00
committed by Jiaming Yuan
parent e930a8e54f
commit 6edddd7966
41 changed files with 477 additions and 470 deletions

View File

@@ -18,10 +18,10 @@ float SimpleDMatrix::GetColDensity(size_t cidx) {
size_t column_size = 0;
// Use whatever version of column batches already exists
if (sorted_column_page_) {
auto batch = this->GetSortedColumnBatches();
auto batch = this->GetBatches<SortedCSCPage>();
column_size = (*batch.begin())[cidx].size();
} else {
auto batch = this->GetColumnBatches();
auto batch = this->GetBatches<CSCPage>();
column_size = (*batch.begin())[cidx].size();
}
@@ -29,14 +29,15 @@ float SimpleDMatrix::GetColDensity(size_t cidx) {
return 1.0f - (static_cast<float>(nmiss)) / this->Info().num_row_;
}
class SimpleBatchIteratorImpl : public BatchIteratorImpl {
template<typename T>
class SimpleBatchIteratorImpl : public BatchIteratorImpl<T> {
public:
explicit SimpleBatchIteratorImpl(SparsePage* page) : page_(page) {}
SparsePage& operator*() override {
explicit SimpleBatchIteratorImpl(T* page) : page_(page) {}
T& operator*() override {
CHECK(page_ != nullptr);
return *page_;
}
const SparsePage& operator*() const override {
const T& operator*() const override {
CHECK(page_ != nullptr);
return *page_;
}
@@ -47,38 +48,38 @@ class SimpleBatchIteratorImpl : public BatchIteratorImpl {
}
private:
SparsePage* page_{nullptr};
T* page_{nullptr};
};
BatchSet SimpleDMatrix::GetRowBatches() {
BatchSet<SparsePage> SimpleDMatrix::GetRowBatches() {
auto cast = dynamic_cast<SimpleCSRSource*>(source_.get());
auto begin_iter = BatchIterator(new SimpleBatchIteratorImpl(&(cast->page_)));
return BatchSet(begin_iter);
auto begin_iter = BatchIterator<SparsePage>(
new SimpleBatchIteratorImpl<SparsePage>(&(cast->page_)));
return BatchSet<SparsePage>(begin_iter);
}
BatchSet SimpleDMatrix::GetColumnBatches() {
BatchSet<CSCPage> SimpleDMatrix::GetColumnBatches() {
// column page doesn't exist, generate it
if (!column_page_) {
auto page = dynamic_cast<SimpleCSRSource*>(source_.get())->page_;
column_page_.reset(
new SparsePage(page.GetTranspose(source_->info.num_col_)));
column_page_.reset(new CSCPage(page.GetTranspose(source_->info.num_col_)));
}
auto begin_iter =
BatchIterator(new SimpleBatchIteratorImpl(column_page_.get()));
return BatchSet(begin_iter);
BatchIterator<CSCPage>(new SimpleBatchIteratorImpl<CSCPage>(column_page_.get()));
return BatchSet<CSCPage>(begin_iter);
}
BatchSet SimpleDMatrix::GetSortedColumnBatches() {
BatchSet<SortedCSCPage> SimpleDMatrix::GetSortedColumnBatches() {
// Sorted column page doesn't exist, generate it
if (!sorted_column_page_) {
auto page = dynamic_cast<SimpleCSRSource*>(source_.get())->page_;
sorted_column_page_.reset(
new SparsePage(page.GetTranspose(source_->info.num_col_)));
new SortedCSCPage(page.GetTranspose(source_->info.num_col_)));
sorted_column_page_->SortRows();
}
auto begin_iter =
BatchIterator(new SimpleBatchIteratorImpl(sorted_column_page_.get()));
return BatchSet(begin_iter);
auto begin_iter = BatchIterator<SortedCSCPage>(
new SimpleBatchIteratorImpl<SortedCSCPage>(sorted_column_page_.get()));
return BatchSet<SortedCSCPage>(begin_iter);
}
bool SimpleDMatrix::SingleColBlock() const { return true; }