Move ellpack page construction into DMatrix (#4833)

This commit is contained in:
Rong Ou
2019-09-16 20:50:55 -07:00
committed by Jiaming Yuan
parent 512f037e55
commit 125bcec62e
17 changed files with 761 additions and 513 deletions

View File

@@ -6,6 +6,7 @@
*/
#include "./simple_dmatrix.h"
#include <xgboost/data.h>
#include "./simple_batch_iterator.h"
#include "../common/random.h"
namespace xgboost {
@@ -29,25 +30,6 @@ float SimpleDMatrix::GetColDensity(size_t cidx) {
return 1.0f - (static_cast<float>(nmiss)) / this->Info().num_row_;
}
template<typename T>
class SimpleBatchIteratorImpl : public BatchIteratorImpl<T> {
public:
explicit SimpleBatchIteratorImpl(T* page) : page_(page) {}
T& operator*() override {
CHECK(page_ != nullptr);
return *page_;
}
const T& operator*() const override {
CHECK(page_ != nullptr);
return *page_;
}
void operator++() override { page_ = nullptr; }
bool AtEnd() const override { return page_ == nullptr; }
private:
T* page_{nullptr};
};
BatchSet<SparsePage> SimpleDMatrix::GetRowBatches() {
// since csr is the default data structure so `source_` is always available.
auto cast = dynamic_cast<SimpleCSRSource*>(source_.get());
@@ -80,6 +62,16 @@ BatchSet<SortedCSCPage> SimpleDMatrix::GetSortedColumnBatches() {
return BatchSet<SortedCSCPage>(begin_iter);
}
BatchSet<EllpackPage> SimpleDMatrix::GetEllpackBatches() {
// ELLPACK page doesn't exist, generate it
if (!ellpack_page_) {
ellpack_page_.reset(new EllpackPage(this));
}
auto begin_iter =
BatchIterator<EllpackPage>(new SimpleBatchIteratorImpl<EllpackPage>(ellpack_page_.get()));
return BatchSet<EllpackPage>(begin_iter);
}
bool SimpleDMatrix::SingleColBlock() const { return true; }
} // namespace data
} // namespace xgboost