Remove SimpleCSRSource (#5315)

This commit is contained in:
Rory Mitchell
2020-02-18 16:49:17 +13:00
committed by GitHub
parent 9f77c18b0d
commit b2b2c4e231
18 changed files with 121 additions and 286 deletions

View File

@@ -8,12 +8,13 @@
#include <xgboost/data.h>
#include "./simple_batch_iterator.h"
#include "../common/random.h"
#include "../data/adapter.h"
namespace xgboost {
namespace data {
MetaInfo& SimpleDMatrix::Info() { return source_->info; }
MetaInfo& SimpleDMatrix::Info() { return info; }
const MetaInfo& SimpleDMatrix::Info() const { return source_->info; }
const MetaInfo& SimpleDMatrix::Info() const { return info; }
float SimpleDMatrix::GetColDensity(size_t cidx) {
size_t column_size = 0;
@@ -32,17 +33,15 @@ float SimpleDMatrix::GetColDensity(size_t cidx) {
BatchSet<SparsePage> SimpleDMatrix::GetRowBatches() {
// since csr is the default data structure so `source_` is always available.
auto cast = dynamic_cast<SimpleCSRSource*>(source_.get());
auto begin_iter = BatchIterator<SparsePage>(
new SimpleBatchIteratorImpl<SparsePage>(&(cast->page_)));
new SimpleBatchIteratorImpl<SparsePage>(&sparse_page_));
return BatchSet<SparsePage>(begin_iter);
}
BatchSet<CSCPage> SimpleDMatrix::GetColumnBatches() {
// column page doesn't exist, generate it
if (!column_page_) {
auto const& page = dynamic_cast<SimpleCSRSource*>(source_.get())->page_;
column_page_.reset(new CSCPage(page.GetTranspose(source_->info.num_col_)));
column_page_.reset(new CSCPage(sparse_page_.GetTranspose(info.num_col_)));
}
auto begin_iter =
BatchIterator<CSCPage>(new SimpleBatchIteratorImpl<CSCPage>(column_page_.get()));
@@ -52,9 +51,8 @@ BatchSet<CSCPage> SimpleDMatrix::GetColumnBatches() {
BatchSet<SortedCSCPage> SimpleDMatrix::GetSortedColumnBatches() {
// Sorted column page doesn't exist, generate it
if (!sorted_column_page_) {
auto const& page = dynamic_cast<SimpleCSRSource*>(source_.get())->page_;
sorted_column_page_.reset(
new SortedCSCPage(page.GetTranspose(source_->info.num_col_)));
new SortedCSCPage(sparse_page_.GetTranspose(info.num_col_)));
sorted_column_page_->SortRows();
}
auto begin_iter = BatchIterator<SortedCSCPage>(
@@ -84,35 +82,33 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) {
int nthread_original = omp_get_max_threads();
omp_set_num_threads(nthread);
source_.reset(new SimpleCSRSource());
SimpleCSRSource& mat = *reinterpret_cast<SimpleCSRSource*>(source_.get());
std::vector<uint64_t> qids;
uint64_t default_max = std::numeric_limits<uint64_t>::max();
uint64_t last_group_id = default_max;
bst_uint group_size = 0;
auto& offset_vec = mat.page_.offset.HostVector();
auto& data_vec = mat.page_.data.HostVector();
auto& offset_vec = sparse_page_.offset.HostVector();
auto& data_vec = sparse_page_.data.HostVector();
uint64_t inferred_num_columns = 0;
adapter->BeforeFirst();
// Iterate over batches of input data
while (adapter->Next()) {
auto& batch = adapter->Value();
auto batch_max_columns = mat.page_.Push(batch, missing, nthread);
auto batch_max_columns = sparse_page_.Push(batch, missing, nthread);
inferred_num_columns = std::max(batch_max_columns, inferred_num_columns);
// Append meta information if available
if (batch.Labels() != nullptr) {
auto& labels = mat.info.labels_.HostVector();
auto& labels = info.labels_.HostVector();
labels.insert(labels.end(), batch.Labels(),
batch.Labels() + batch.Size());
}
if (batch.Weights() != nullptr) {
auto& weights = mat.info.weights_.HostVector();
auto& weights = info.weights_.HostVector();
weights.insert(weights.end(), batch.Weights(),
batch.Weights() + batch.Size());
}
if (batch.BaseMargin() != nullptr) {
auto& base_margin = mat.info.base_margin_.HostVector();
auto& base_margin = info.base_margin_.HostVector();
base_margin.insert(base_margin.end(), batch.BaseMargin(),
batch.BaseMargin() + batch.Size());
}
@@ -122,7 +118,7 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) {
for (size_t i = 0; i < batch.Size(); ++i) {
const uint64_t cur_group_id = batch.Qid()[i];
if (last_group_id == default_max || last_group_id != cur_group_id) {
mat.info.group_ptr_.push_back(group_size);
info.group_ptr_.push_back(group_size);
}
last_group_id = cur_group_id;
++group_size;
@@ -131,22 +127,22 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) {
}
if (last_group_id != default_max) {
if (group_size > mat.info.group_ptr_.back()) {
mat.info.group_ptr_.push_back(group_size);
if (group_size > info.group_ptr_.back()) {
info.group_ptr_.push_back(group_size);
}
}
// Deal with empty rows/columns if necessary
if (adapter->NumColumns() == kAdapterUnknownSize) {
mat.info.num_col_ = inferred_num_columns;
info.num_col_ = inferred_num_columns;
} else {
mat.info.num_col_ = adapter->NumColumns();
info.num_col_ = adapter->NumColumns();
}
// Synchronise worker columns
rabit::Allreduce<rabit::op::Max>(&mat.info.num_col_, 1);
rabit::Allreduce<rabit::op::Max>(&info.num_col_, 1);
if (adapter->NumRows() == kAdapterUnknownSize) {
mat.info.num_row_ = offset_vec.size() - 1;
info.num_row_ = offset_vec.size() - 1;
} else {
if (offset_vec.empty()) {
offset_vec.emplace_back(0);
@@ -155,12 +151,31 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) {
while (offset_vec.size() - 1 < adapter->NumRows()) {
offset_vec.emplace_back(offset_vec.back());
}
mat.info.num_row_ = adapter->NumRows();
info.num_row_ = adapter->NumRows();
}
mat.info.num_nonzero_ = data_vec.size();
info.num_nonzero_ = data_vec.size();
omp_set_num_threads(nthread_original);
}
SimpleDMatrix::SimpleDMatrix(dmlc::Stream* in_stream) {
int tmagic;
CHECK(in_stream->Read(&tmagic, sizeof(tmagic)) == sizeof(tmagic))
<< "invalid input file format";
CHECK_EQ(tmagic, kMagic) << "invalid format, magic number mismatch";
info.LoadBinary(in_stream);
in_stream->Read(&sparse_page_.offset.HostVector());
in_stream->Read(&sparse_page_.data.HostVector());
}
void SimpleDMatrix::SaveToLocalFile(const std::string& fname) {
std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(fname.c_str(), "w"));
int tmagic = kMagic;
fo->Write(&tmagic, sizeof(tmagic));
info.SaveBinary(fo.get());
fo->Write(sparse_page_.offset.HostVector());
fo->Write(sparse_page_.data.HostVector());
}
template SimpleDMatrix::SimpleDMatrix(DenseAdapter* adapter, float missing,
int nthread);
template SimpleDMatrix::SimpleDMatrix(CSRAdapter* adapter, float missing,