Move ellpack page construction into DMatrix (#4833)
This commit is contained in:
@@ -6,6 +6,7 @@
|
||||
*/
|
||||
#include "./simple_dmatrix.h"
|
||||
#include <xgboost/data.h>
|
||||
#include "./simple_batch_iterator.h"
|
||||
#include "../common/random.h"
|
||||
|
||||
namespace xgboost {
|
||||
@@ -29,25 +30,6 @@ float SimpleDMatrix::GetColDensity(size_t cidx) {
|
||||
return 1.0f - (static_cast<float>(nmiss)) / this->Info().num_row_;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
class SimpleBatchIteratorImpl : public BatchIteratorImpl<T> {
|
||||
public:
|
||||
explicit SimpleBatchIteratorImpl(T* page) : page_(page) {}
|
||||
T& operator*() override {
|
||||
CHECK(page_ != nullptr);
|
||||
return *page_;
|
||||
}
|
||||
const T& operator*() const override {
|
||||
CHECK(page_ != nullptr);
|
||||
return *page_;
|
||||
}
|
||||
void operator++() override { page_ = nullptr; }
|
||||
bool AtEnd() const override { return page_ == nullptr; }
|
||||
|
||||
private:
|
||||
T* page_{nullptr};
|
||||
};
|
||||
|
||||
BatchSet<SparsePage> SimpleDMatrix::GetRowBatches() {
|
||||
// since csr is the default data structure so `source_` is always available.
|
||||
auto cast = dynamic_cast<SimpleCSRSource*>(source_.get());
|
||||
@@ -80,6 +62,16 @@ BatchSet<SortedCSCPage> SimpleDMatrix::GetSortedColumnBatches() {
|
||||
return BatchSet<SortedCSCPage>(begin_iter);
|
||||
}
|
||||
|
||||
BatchSet<EllpackPage> SimpleDMatrix::GetEllpackBatches() {
|
||||
// ELLPACK page doesn't exist, generate it
|
||||
if (!ellpack_page_) {
|
||||
ellpack_page_.reset(new EllpackPage(this));
|
||||
}
|
||||
auto begin_iter =
|
||||
BatchIterator<EllpackPage>(new SimpleBatchIteratorImpl<EllpackPage>(ellpack_page_.get()));
|
||||
return BatchSet<EllpackPage>(begin_iter);
|
||||
}
|
||||
|
||||
bool SimpleDMatrix::SingleColBlock() const { return true; }
|
||||
} // namespace data
|
||||
} // namespace xgboost
|
||||
|
||||
Reference in New Issue
Block a user