Refactor DMatrix to return batches of different page types (#4686)
* Use explicit template parameter for specifying page type.
This commit is contained in:
@@ -18,10 +18,10 @@ float SimpleDMatrix::GetColDensity(size_t cidx) {
|
||||
size_t column_size = 0;
|
||||
// Use whatever version of column batches already exists
|
||||
if (sorted_column_page_) {
|
||||
auto batch = this->GetSortedColumnBatches();
|
||||
auto batch = this->GetBatches<SortedCSCPage>();
|
||||
column_size = (*batch.begin())[cidx].size();
|
||||
} else {
|
||||
auto batch = this->GetColumnBatches();
|
||||
auto batch = this->GetBatches<CSCPage>();
|
||||
column_size = (*batch.begin())[cidx].size();
|
||||
}
|
||||
|
||||
@@ -29,14 +29,15 @@ float SimpleDMatrix::GetColDensity(size_t cidx) {
|
||||
return 1.0f - (static_cast<float>(nmiss)) / this->Info().num_row_;
|
||||
}
|
||||
|
||||
class SimpleBatchIteratorImpl : public BatchIteratorImpl {
|
||||
template<typename T>
|
||||
class SimpleBatchIteratorImpl : public BatchIteratorImpl<T> {
|
||||
public:
|
||||
explicit SimpleBatchIteratorImpl(SparsePage* page) : page_(page) {}
|
||||
SparsePage& operator*() override {
|
||||
explicit SimpleBatchIteratorImpl(T* page) : page_(page) {}
|
||||
T& operator*() override {
|
||||
CHECK(page_ != nullptr);
|
||||
return *page_;
|
||||
}
|
||||
const SparsePage& operator*() const override {
|
||||
const T& operator*() const override {
|
||||
CHECK(page_ != nullptr);
|
||||
return *page_;
|
||||
}
|
||||
@@ -47,38 +48,38 @@ class SimpleBatchIteratorImpl : public BatchIteratorImpl {
|
||||
}
|
||||
|
||||
private:
|
||||
SparsePage* page_{nullptr};
|
||||
T* page_{nullptr};
|
||||
};
|
||||
|
||||
BatchSet SimpleDMatrix::GetRowBatches() {
|
||||
BatchSet<SparsePage> SimpleDMatrix::GetRowBatches() {
|
||||
auto cast = dynamic_cast<SimpleCSRSource*>(source_.get());
|
||||
auto begin_iter = BatchIterator(new SimpleBatchIteratorImpl(&(cast->page_)));
|
||||
return BatchSet(begin_iter);
|
||||
auto begin_iter = BatchIterator<SparsePage>(
|
||||
new SimpleBatchIteratorImpl<SparsePage>(&(cast->page_)));
|
||||
return BatchSet<SparsePage>(begin_iter);
|
||||
}
|
||||
|
||||
BatchSet SimpleDMatrix::GetColumnBatches() {
|
||||
BatchSet<CSCPage> SimpleDMatrix::GetColumnBatches() {
|
||||
// column page doesn't exist, generate it
|
||||
if (!column_page_) {
|
||||
auto page = dynamic_cast<SimpleCSRSource*>(source_.get())->page_;
|
||||
column_page_.reset(
|
||||
new SparsePage(page.GetTranspose(source_->info.num_col_)));
|
||||
column_page_.reset(new CSCPage(page.GetTranspose(source_->info.num_col_)));
|
||||
}
|
||||
auto begin_iter =
|
||||
BatchIterator(new SimpleBatchIteratorImpl(column_page_.get()));
|
||||
return BatchSet(begin_iter);
|
||||
BatchIterator<CSCPage>(new SimpleBatchIteratorImpl<CSCPage>(column_page_.get()));
|
||||
return BatchSet<CSCPage>(begin_iter);
|
||||
}
|
||||
|
||||
BatchSet SimpleDMatrix::GetSortedColumnBatches() {
|
||||
BatchSet<SortedCSCPage> SimpleDMatrix::GetSortedColumnBatches() {
|
||||
// Sorted column page doesn't exist, generate it
|
||||
if (!sorted_column_page_) {
|
||||
auto page = dynamic_cast<SimpleCSRSource*>(source_.get())->page_;
|
||||
sorted_column_page_.reset(
|
||||
new SparsePage(page.GetTranspose(source_->info.num_col_)));
|
||||
new SortedCSCPage(page.GetTranspose(source_->info.num_col_)));
|
||||
sorted_column_page_->SortRows();
|
||||
}
|
||||
auto begin_iter =
|
||||
BatchIterator(new SimpleBatchIteratorImpl(sorted_column_page_.get()));
|
||||
return BatchSet(begin_iter);
|
||||
auto begin_iter = BatchIterator<SortedCSCPage>(
|
||||
new SimpleBatchIteratorImpl<SortedCSCPage>(sorted_column_page_.get()));
|
||||
return BatchSet<SortedCSCPage>(begin_iter);
|
||||
}
|
||||
|
||||
bool SimpleDMatrix::SingleColBlock() const { return true; }
|
||||
|
||||
Reference in New Issue
Block a user