Rewrite sparse dmatrix using callbacks. (#7092)

- Reduce dependency on dmlc parsers and provide an interface for users to load data by themselves.
- Remove use of threaded iterator and IO queue.
- Remove `page_size`.
- Make sure the number of pages in memory is bounded.
- Make sure the cache can not be violated.
- Provide an interface for internal algorithms to process data asynchronously.
This commit is contained in:
Jiaming Yuan
2021-07-16 12:33:31 +08:00
committed by GitHub
parent 2f524e9f41
commit bd1f3a38f0
51 changed files with 1445 additions and 1391 deletions

View File

@@ -1,5 +1,5 @@
/*!
* Copyright 2014~2020 by Contributors
* Copyright 2014~2021 by Contributors
* \file simple_dmatrix.cc
* \brief the input data structure for gradient boosting
* \author Tianqi Chen
@@ -27,7 +27,7 @@ const MetaInfo& SimpleDMatrix::Info() const { return info_; }
DMatrix* SimpleDMatrix::Slice(common::Span<int32_t const> ridxs) {
auto out = new SimpleDMatrix;
SparsePage& out_page = out->sparse_page_;
SparsePage& out_page = *out->sparse_page_;
for (auto const &page : this->GetBatches<SparsePage>()) {
auto batch = page.GetView();
auto& h_data = out_page.data.HostVector();
@@ -48,17 +48,17 @@ DMatrix* SimpleDMatrix::Slice(common::Span<int32_t const> ridxs) {
BatchSet<SparsePage> SimpleDMatrix::GetRowBatches() {
// since csr is the default data structure so `source_` is always available.
auto begin_iter = BatchIterator<SparsePage>(
new SimpleBatchIteratorImpl<SparsePage>(&sparse_page_));
new SimpleBatchIteratorImpl<SparsePage>(sparse_page_));
return BatchSet<SparsePage>(begin_iter);
}
BatchSet<CSCPage> SimpleDMatrix::GetColumnBatches() {
// column page doesn't exist, generate it
if (!column_page_) {
column_page_.reset(new CSCPage(sparse_page_.GetTranspose(info_.num_col_)));
column_page_.reset(new CSCPage(sparse_page_->GetTranspose(info_.num_col_)));
}
auto begin_iter =
BatchIterator<CSCPage>(new SimpleBatchIteratorImpl<CSCPage>(column_page_.get()));
BatchIterator<CSCPage>(new SimpleBatchIteratorImpl<CSCPage>(column_page_));
return BatchSet<CSCPage>(begin_iter);
}
@@ -66,11 +66,11 @@ BatchSet<SortedCSCPage> SimpleDMatrix::GetSortedColumnBatches() {
// Sorted column page doesn't exist, generate it
if (!sorted_column_page_) {
sorted_column_page_.reset(
new SortedCSCPage(sparse_page_.GetTranspose(info_.num_col_)));
new SortedCSCPage(sparse_page_->GetTranspose(info_.num_col_)));
sorted_column_page_->SortRows();
}
auto begin_iter = BatchIterator<SortedCSCPage>(
new SimpleBatchIteratorImpl<SortedCSCPage>(sorted_column_page_.get()));
new SimpleBatchIteratorImpl<SortedCSCPage>(sorted_column_page_));
return BatchSet<SortedCSCPage>(begin_iter);
}
@@ -86,7 +86,7 @@ BatchSet<EllpackPage> SimpleDMatrix::GetEllpackBatches(const BatchParam& param)
batch_param_ = param;
}
auto begin_iter =
BatchIterator<EllpackPage>(new SimpleBatchIteratorImpl<EllpackPage>(ellpack_page_.get()));
BatchIterator<EllpackPage>(new SimpleBatchIteratorImpl<EllpackPage>(ellpack_page_));
return BatchSet<EllpackPage>(begin_iter);
}
@@ -100,7 +100,7 @@ BatchSet<GHistIndexMatrix> SimpleDMatrix::GetGradientIndex(const BatchParam& par
batch_param_ = param;
}
auto begin_iter = BatchIterator<GHistIndexMatrix>(
new SimpleBatchIteratorImpl<GHistIndexMatrix>(gradient_index_.get()));
new SimpleBatchIteratorImpl<GHistIndexMatrix>(gradient_index_));
return BatchSet<GHistIndexMatrix>(begin_iter);
}
@@ -110,8 +110,8 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) {
uint64_t default_max = std::numeric_limits<uint64_t>::max();
uint64_t last_group_id = default_max;
bst_uint group_size = 0;
auto& offset_vec = sparse_page_.offset.HostVector();
auto& data_vec = sparse_page_.data.HostVector();
auto& offset_vec = sparse_page_->offset.HostVector();
auto& data_vec = sparse_page_->data.HostVector();
uint64_t inferred_num_columns = 0;
uint64_t total_batch_size = 0;
// batch_size is either number of rows or cols, depending on data layout
@@ -120,7 +120,7 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) {
// Iterate over batches of input data
while (adapter->Next()) {
auto& batch = adapter->Value();
auto batch_max_columns = sparse_page_.Push(batch, missing, nthread);
auto batch_max_columns = sparse_page_->Push(batch, missing, nthread);
inferred_num_columns = std::max(batch_max_columns, inferred_num_columns);
total_batch_size += batch.Size();
// Append meta information if available
@@ -203,8 +203,8 @@ SimpleDMatrix::SimpleDMatrix(dmlc::Stream* in_stream) {
CHECK(in_stream->Read(&tmagic)) << "invalid input file format";
CHECK_EQ(tmagic, kMagic) << "invalid format, magic number mismatch";
info_.LoadBinary(in_stream);
in_stream->Read(&sparse_page_.offset.HostVector());
in_stream->Read(&sparse_page_.data.HostVector());
in_stream->Read(&sparse_page_->offset.HostVector());
in_stream->Read(&sparse_page_->data.HostVector());
}
void SimpleDMatrix::SaveToLocalFile(const std::string& fname) {
@@ -212,8 +212,8 @@ void SimpleDMatrix::SaveToLocalFile(const std::string& fname) {
int tmagic = kMagic;
fo->Write(tmagic);
info_.SaveBinary(fo.get());
fo->Write(sparse_page_.offset.HostVector());
fo->Write(sparse_page_.data.HostVector());
fo->Write(sparse_page_->offset.HostVector());
fo->Write(sparse_page_->data.HostVector());
}
template SimpleDMatrix::SimpleDMatrix(DenseAdapter* adapter, float missing,