Remove SimpleCSRSource (#5315)

This commit is contained in:
Rory Mitchell 2020-02-18 16:49:17 +13:00 committed by GitHub
parent 9f77c18b0d
commit b2b2c4e231
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 121 additions and 286 deletions

View File

@ -31,7 +31,6 @@
// data // data
#include "../src/data/data.cc" #include "../src/data/data.cc"
#include "../src/data/simple_csr_source.cc"
#include "../src/data/simple_dmatrix.cc" #include "../src/data/simple_dmatrix.cc"
#include "../src/data/sparse_page_raw_format.cc" #include "../src/data/sparse_page_raw_format.cc"
#include "../src/data/ellpack_page.cc" #include "../src/data/ellpack_page.cc"

View File

@ -445,14 +445,6 @@ class DMatrix {
virtual float GetColDensity(size_t cidx) = 0; virtual float GetColDensity(size_t cidx) = 0;
/*! \brief virtual destructor */ /*! \brief virtual destructor */
virtual ~DMatrix() = default; virtual ~DMatrix() = default;
/*!
* \brief Save DMatrix to local file.
* The saved file only works for non-sharded dataset(single machine training).
* This API is deprecated and dis-encouraged to use.
* \param fname The file name to be saved.
* \return The created DMatrix.
*/
virtual void SaveToLocalFile(const std::string& fname);
/*! \brief Whether the matrix is dense. */ /*! \brief Whether the matrix is dense. */
bool IsDense() const { bool IsDense() const {
@ -475,16 +467,6 @@ class DMatrix {
const std::string& file_format = "auto", const std::string& file_format = "auto",
size_t page_size = kPageSize); size_t page_size = kPageSize);
/*!
* \brief create a new DMatrix, by wrapping a row_iterator, and meta info.
* \param source The source iterator of the data, the create function takes ownership of the source.
* \param cache_prefix The path to prefix of temporary cache file of the DMatrix when used in external memory mode.
* This can be nullptr for common cases, and in-memory mode will be used.
* \return a Created DMatrix.
*/
static DMatrix* Create(std::unique_ptr<DataSource<SparsePage>>&& source,
const std::string& cache_prefix = "");
/** /**
* \brief Creates a new DMatrix from an external data adapter. * \brief Creates a new DMatrix from an external data adapter.
* *

View File

@ -20,7 +20,6 @@
#include "xgboost/json.h" #include "xgboost/json.h"
#include "c_api_error.h" #include "c_api_error.h"
#include "../data/simple_csr_source.h"
#include "../common/io.h" #include "../common/io.h"
#include "../data/adapter.h" #include "../data/adapter.h"
#include "../data/simple_dmatrix.h" #include "../data/simple_dmatrix.h"
@ -296,8 +295,6 @@ XGB_DLL int XGDMatrixSliceDMatrixEx(DMatrixHandle handle,
xgboost::bst_ulong len, xgboost::bst_ulong len,
DMatrixHandle* out, DMatrixHandle* out,
int allow_groups) { int allow_groups) {
std::unique_ptr<data::SimpleCSRSource> source(new data::SimpleCSRSource());
API_BEGIN(); API_BEGIN();
CHECK_HANDLE(); CHECK_HANDLE();
if (!allow_groups) { if (!allow_groups) {
@ -324,12 +321,16 @@ XGB_DLL int XGDMatrixFree(DMatrixHandle handle) {
API_END(); API_END();
} }
XGB_DLL int XGDMatrixSaveBinary(DMatrixHandle handle, XGB_DLL int XGDMatrixSaveBinary(DMatrixHandle handle, const char* fname,
const char* fname,
int silent) { int silent) {
API_BEGIN(); API_BEGIN();
CHECK_HANDLE(); CHECK_HANDLE();
static_cast<std::shared_ptr<DMatrix>*>(handle)->get()->SaveToLocalFile(fname); auto dmat = static_cast<std::shared_ptr<DMatrix>*>(handle)->get();
if (data::SimpleDMatrix* derived = dynamic_cast<data::SimpleDMatrix*>(dmat)) {
derived->SaveToLocalFile(fname);
} else {
LOG(FATAL) << "binary saving only supported by SimpleDMatrix";
}
API_END(); API_END();
} }

View File

@ -3,7 +3,6 @@
#include "xgboost/data.h" #include "xgboost/data.h"
#include "xgboost/c_api.h" #include "xgboost/c_api.h"
#include "c_api_error.h" #include "c_api_error.h"
#include "../data/simple_csr_source.h"
#include "../data/device_adapter.cuh" #include "../data/device_adapter.cuh"
namespace xgboost { namespace xgboost {

View File

@ -12,9 +12,9 @@
#include "xgboost/version_config.h" #include "xgboost/version_config.h"
#include "sparse_page_writer.h" #include "sparse_page_writer.h"
#include "simple_dmatrix.h" #include "simple_dmatrix.h"
#include "simple_csr_source.h"
#include "../common/io.h" #include "../common/io.h"
#include "../common/math.h"
#include "../common/version.h" #include "../common/version.h"
#include "../common/group_data.h" #include "../common/group_data.h"
#include "../data/adapter.h" #include "../data/adapter.h"
@ -336,10 +336,8 @@ DMatrix* DMatrix::Load(const std::string& uri,
if (fi != nullptr) { if (fi != nullptr) {
common::PeekableInStream is(fi.get()); common::PeekableInStream is(fi.get());
if (is.PeekRead(&magic, sizeof(magic)) == sizeof(magic) && if (is.PeekRead(&magic, sizeof(magic)) == sizeof(magic) &&
magic == data::SimpleCSRSource::kMagic) { magic == data::SimpleDMatrix::kMagic) {
std::unique_ptr<data::SimpleCSRSource> source(new data::SimpleCSRSource()); DMatrix* dmat = new data::SimpleDMatrix(&is);
source->LoadBinary(&is);
DMatrix* dmat = DMatrix::Create(std::move(source), cache_file);
if (!silent) { if (!silent) {
LOG(CONSOLE) << dmat->Info().num_row_ << 'x' << dmat->Info().num_col_ << " matrix with " LOG(CONSOLE) << dmat->Info().num_row_ << 'x' << dmat->Info().num_col_ << " matrix with "
<< dmat->Info().num_nonzero_ << " entries loaded from " << uri; << dmat->Info().num_nonzero_ << " entries loaded from " << uri;
@ -412,13 +410,7 @@ DMatrix* DMatrix::Load(const std::string& uri,
} }
void DMatrix::SaveToLocalFile(const std::string& fname) { /*
data::SimpleCSRSource source;
source.CopyFrom(this);
std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(fname.c_str(), "w"));
source.SaveBinary(fo.get());
}
DMatrix* DMatrix::Create(std::unique_ptr<DataSource<SparsePage>>&& source, DMatrix* DMatrix::Create(std::unique_ptr<DataSource<SparsePage>>&& source,
const std::string& cache_prefix) { const std::string& cache_prefix) {
if (cache_prefix.length() == 0) { if (cache_prefix.length() == 0) {
@ -434,6 +426,7 @@ DMatrix* DMatrix::Create(std::unique_ptr<DataSource<SparsePage>>&& source,
#endif // DMLC_ENABLE_STD_THREAD #endif // DMLC_ENABLE_STD_THREAD
} }
} }
*/
template <typename AdapterT> template <typename AdapterT>
DMatrix* DMatrix::Create(AdapterT* adapter, float missing, int nthread, DMatrix* DMatrix::Create(AdapterT* adapter, float missing, int nthread,

View File

@ -18,7 +18,7 @@ class EllpackPageRawFormat : public SparsePageFormat<EllpackPage> {
bool Read(EllpackPage* page, dmlc::SeekStream* fi) override { bool Read(EllpackPage* page, dmlc::SeekStream* fi) override {
auto* impl = page->Impl(); auto* impl = page->Impl();
impl->Clear(); impl->Clear();
if (!fi->Read(&impl->matrix.n_rows)) return false; if (!fi->Read(&impl->matrix.n_rows)) return false;
return fi->Read(&impl->idx_buffer); return fi->Read(&impl->idx_buffer);
} }

View File

@ -1,59 +0,0 @@
/*!
* Copyright 2015-2019 by Contributors
* \file simple_csr_source.cc
*/
#include <dmlc/base.h>
#include <xgboost/logging.h>
#include <xgboost/json.h>
#include "simple_csr_source.h"
namespace xgboost {
namespace data {
void SimpleCSRSource::Clear() {
page_.Clear();
this->info.Clear();
}
void SimpleCSRSource::CopyFrom(DMatrix* src) {
this->Clear();
this->info = src->Info();
for (const auto &batch : src->GetBatches<SparsePage>()) {
page_.Push(batch);
}
}
void SimpleCSRSource::LoadBinary(dmlc::Stream* fi) {
int tmagic;
CHECK(fi->Read(&tmagic, sizeof(tmagic)) == sizeof(tmagic)) << "invalid input file format";
CHECK_EQ(tmagic, kMagic) << "invalid format, magic number mismatch";
info.LoadBinary(fi);
fi->Read(&page_.offset.HostVector());
fi->Read(&page_.data.HostVector());
}
void SimpleCSRSource::SaveBinary(dmlc::Stream* fo) const {
int tmagic = kMagic;
fo->Write(&tmagic, sizeof(tmagic));
info.SaveBinary(fo);
fo->Write(page_.offset.HostVector());
fo->Write(page_.data.HostVector());
}
void SimpleCSRSource::BeforeFirst() {
at_first_ = true;
}
bool SimpleCSRSource::Next() {
if (!at_first_) return false;
at_first_ = false;
return true;
}
const SparsePage& SimpleCSRSource::Value() const {
return page_;
}
} // namespace data
} // namespace xgboost

View File

@ -1,74 +0,0 @@
/*!
* Copyright 2015 by Contributors
* \file simple_csr_source.h
* \brief The simplest form of data source, can be used to create DMatrix.
* This is an in-memory data structure that holds the data in row oriented format.
* \author Tianqi Chen
*/
#ifndef XGBOOST_DATA_SIMPLE_CSR_SOURCE_H_
#define XGBOOST_DATA_SIMPLE_CSR_SOURCE_H_
#include <xgboost/base.h>
#include <xgboost/data.h>
#include <algorithm>
#include <string>
#include <vector>
#include <limits>
namespace xgboost {
class Json;
namespace data {
/*!
* \brief The simplest form of data holder, can be used to create DMatrix.
* This is an in-memory data structure that holds the data in row oriented format.
* \code
* std::unique_ptr<DataSource> source(new SimpleCSRSource());
* // add data to source
* DMatrix* dmat = DMatrix::Create(std::move(source));
* \encode
*/
class SimpleCSRSource : public DataSource<SparsePage> {
public:
// MetaInfo info; // inheritated from DataSource
SparsePage page_;
/*! \brief default constructor */
SimpleCSRSource() = default;
/*! \brief destructor */
~SimpleCSRSource() override = default;
/*! \brief clear the data structure */
void Clear();
/*!
* \brief copy content of data from src
* \param src source data iter.
*/
void CopyFrom(DMatrix* src);
/*!
* \brief Load data from binary stream.
* \param fi the pointer to load data from.
*/
void LoadBinary(dmlc::Stream* fi);
/*!
* \brief Save data into binary stream
* \param fo The output stream.
*/
void SaveBinary(dmlc::Stream* fo) const;
// implement Next
bool Next() override;
// implement BeforeFirst
void BeforeFirst() override;
// implement Value
const SparsePage &Value() const override;
/*! \brief magic number used to identify SimpleCSRSource */
static const int kMagic = 0xffffab01;
private:
/*! \brief internal variable, used to support iterator interface */
bool at_first_{true};
};
} // namespace data
} // namespace xgboost
#endif // XGBOOST_DATA_SIMPLE_CSR_SOURCE_H_

View File

@ -8,12 +8,13 @@
#include <xgboost/data.h> #include <xgboost/data.h>
#include "./simple_batch_iterator.h" #include "./simple_batch_iterator.h"
#include "../common/random.h" #include "../common/random.h"
#include "../data/adapter.h"
namespace xgboost { namespace xgboost {
namespace data { namespace data {
MetaInfo& SimpleDMatrix::Info() { return source_->info; } MetaInfo& SimpleDMatrix::Info() { return info; }
const MetaInfo& SimpleDMatrix::Info() const { return source_->info; } const MetaInfo& SimpleDMatrix::Info() const { return info; }
float SimpleDMatrix::GetColDensity(size_t cidx) { float SimpleDMatrix::GetColDensity(size_t cidx) {
size_t column_size = 0; size_t column_size = 0;
@ -32,17 +33,15 @@ float SimpleDMatrix::GetColDensity(size_t cidx) {
BatchSet<SparsePage> SimpleDMatrix::GetRowBatches() { BatchSet<SparsePage> SimpleDMatrix::GetRowBatches() {
// since csr is the default data structure so `source_` is always available. // since csr is the default data structure so `source_` is always available.
auto cast = dynamic_cast<SimpleCSRSource*>(source_.get());
auto begin_iter = BatchIterator<SparsePage>( auto begin_iter = BatchIterator<SparsePage>(
new SimpleBatchIteratorImpl<SparsePage>(&(cast->page_))); new SimpleBatchIteratorImpl<SparsePage>(&sparse_page_));
return BatchSet<SparsePage>(begin_iter); return BatchSet<SparsePage>(begin_iter);
} }
BatchSet<CSCPage> SimpleDMatrix::GetColumnBatches() { BatchSet<CSCPage> SimpleDMatrix::GetColumnBatches() {
// column page doesn't exist, generate it // column page doesn't exist, generate it
if (!column_page_) { if (!column_page_) {
auto const& page = dynamic_cast<SimpleCSRSource*>(source_.get())->page_; column_page_.reset(new CSCPage(sparse_page_.GetTranspose(info.num_col_)));
column_page_.reset(new CSCPage(page.GetTranspose(source_->info.num_col_)));
} }
auto begin_iter = auto begin_iter =
BatchIterator<CSCPage>(new SimpleBatchIteratorImpl<CSCPage>(column_page_.get())); BatchIterator<CSCPage>(new SimpleBatchIteratorImpl<CSCPage>(column_page_.get()));
@ -52,9 +51,8 @@ BatchSet<CSCPage> SimpleDMatrix::GetColumnBatches() {
BatchSet<SortedCSCPage> SimpleDMatrix::GetSortedColumnBatches() { BatchSet<SortedCSCPage> SimpleDMatrix::GetSortedColumnBatches() {
// Sorted column page doesn't exist, generate it // Sorted column page doesn't exist, generate it
if (!sorted_column_page_) { if (!sorted_column_page_) {
auto const& page = dynamic_cast<SimpleCSRSource*>(source_.get())->page_;
sorted_column_page_.reset( sorted_column_page_.reset(
new SortedCSCPage(page.GetTranspose(source_->info.num_col_))); new SortedCSCPage(sparse_page_.GetTranspose(info.num_col_)));
sorted_column_page_->SortRows(); sorted_column_page_->SortRows();
} }
auto begin_iter = BatchIterator<SortedCSCPage>( auto begin_iter = BatchIterator<SortedCSCPage>(
@ -84,35 +82,33 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) {
int nthread_original = omp_get_max_threads(); int nthread_original = omp_get_max_threads();
omp_set_num_threads(nthread); omp_set_num_threads(nthread);
source_.reset(new SimpleCSRSource());
SimpleCSRSource& mat = *reinterpret_cast<SimpleCSRSource*>(source_.get());
std::vector<uint64_t> qids; std::vector<uint64_t> qids;
uint64_t default_max = std::numeric_limits<uint64_t>::max(); uint64_t default_max = std::numeric_limits<uint64_t>::max();
uint64_t last_group_id = default_max; uint64_t last_group_id = default_max;
bst_uint group_size = 0; bst_uint group_size = 0;
auto& offset_vec = mat.page_.offset.HostVector(); auto& offset_vec = sparse_page_.offset.HostVector();
auto& data_vec = mat.page_.data.HostVector(); auto& data_vec = sparse_page_.data.HostVector();
uint64_t inferred_num_columns = 0; uint64_t inferred_num_columns = 0;
adapter->BeforeFirst(); adapter->BeforeFirst();
// Iterate over batches of input data // Iterate over batches of input data
while (adapter->Next()) { while (adapter->Next()) {
auto& batch = adapter->Value(); auto& batch = adapter->Value();
auto batch_max_columns = mat.page_.Push(batch, missing, nthread); auto batch_max_columns = sparse_page_.Push(batch, missing, nthread);
inferred_num_columns = std::max(batch_max_columns, inferred_num_columns); inferred_num_columns = std::max(batch_max_columns, inferred_num_columns);
// Append meta information if available // Append meta information if available
if (batch.Labels() != nullptr) { if (batch.Labels() != nullptr) {
auto& labels = mat.info.labels_.HostVector(); auto& labels = info.labels_.HostVector();
labels.insert(labels.end(), batch.Labels(), labels.insert(labels.end(), batch.Labels(),
batch.Labels() + batch.Size()); batch.Labels() + batch.Size());
} }
if (batch.Weights() != nullptr) { if (batch.Weights() != nullptr) {
auto& weights = mat.info.weights_.HostVector(); auto& weights = info.weights_.HostVector();
weights.insert(weights.end(), batch.Weights(), weights.insert(weights.end(), batch.Weights(),
batch.Weights() + batch.Size()); batch.Weights() + batch.Size());
} }
if (batch.BaseMargin() != nullptr) { if (batch.BaseMargin() != nullptr) {
auto& base_margin = mat.info.base_margin_.HostVector(); auto& base_margin = info.base_margin_.HostVector();
base_margin.insert(base_margin.end(), batch.BaseMargin(), base_margin.insert(base_margin.end(), batch.BaseMargin(),
batch.BaseMargin() + batch.Size()); batch.BaseMargin() + batch.Size());
} }
@ -122,7 +118,7 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) {
for (size_t i = 0; i < batch.Size(); ++i) { for (size_t i = 0; i < batch.Size(); ++i) {
const uint64_t cur_group_id = batch.Qid()[i]; const uint64_t cur_group_id = batch.Qid()[i];
if (last_group_id == default_max || last_group_id != cur_group_id) { if (last_group_id == default_max || last_group_id != cur_group_id) {
mat.info.group_ptr_.push_back(group_size); info.group_ptr_.push_back(group_size);
} }
last_group_id = cur_group_id; last_group_id = cur_group_id;
++group_size; ++group_size;
@ -131,22 +127,22 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) {
} }
if (last_group_id != default_max) { if (last_group_id != default_max) {
if (group_size > mat.info.group_ptr_.back()) { if (group_size > info.group_ptr_.back()) {
mat.info.group_ptr_.push_back(group_size); info.group_ptr_.push_back(group_size);
} }
} }
// Deal with empty rows/columns if necessary // Deal with empty rows/columns if necessary
if (adapter->NumColumns() == kAdapterUnknownSize) { if (adapter->NumColumns() == kAdapterUnknownSize) {
mat.info.num_col_ = inferred_num_columns; info.num_col_ = inferred_num_columns;
} else { } else {
mat.info.num_col_ = adapter->NumColumns(); info.num_col_ = adapter->NumColumns();
} }
// Synchronise worker columns // Synchronise worker columns
rabit::Allreduce<rabit::op::Max>(&mat.info.num_col_, 1); rabit::Allreduce<rabit::op::Max>(&info.num_col_, 1);
if (adapter->NumRows() == kAdapterUnknownSize) { if (adapter->NumRows() == kAdapterUnknownSize) {
mat.info.num_row_ = offset_vec.size() - 1; info.num_row_ = offset_vec.size() - 1;
} else { } else {
if (offset_vec.empty()) { if (offset_vec.empty()) {
offset_vec.emplace_back(0); offset_vec.emplace_back(0);
@ -155,12 +151,31 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) {
while (offset_vec.size() - 1 < adapter->NumRows()) { while (offset_vec.size() - 1 < adapter->NumRows()) {
offset_vec.emplace_back(offset_vec.back()); offset_vec.emplace_back(offset_vec.back());
} }
mat.info.num_row_ = adapter->NumRows(); info.num_row_ = adapter->NumRows();
} }
mat.info.num_nonzero_ = data_vec.size(); info.num_nonzero_ = data_vec.size();
omp_set_num_threads(nthread_original); omp_set_num_threads(nthread_original);
} }
SimpleDMatrix::SimpleDMatrix(dmlc::Stream* in_stream) {
int tmagic;
CHECK(in_stream->Read(&tmagic, sizeof(tmagic)) == sizeof(tmagic))
<< "invalid input file format";
CHECK_EQ(tmagic, kMagic) << "invalid format, magic number mismatch";
info.LoadBinary(in_stream);
in_stream->Read(&sparse_page_.offset.HostVector());
in_stream->Read(&sparse_page_.data.HostVector());
}
void SimpleDMatrix::SaveToLocalFile(const std::string& fname) {
std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(fname.c_str(), "w"));
int tmagic = kMagic;
fo->Write(&tmagic, sizeof(tmagic));
info.SaveBinary(fo.get());
fo->Write(sparse_page_.offset.HostVector());
fo->Write(sparse_page_.data.HostVector());
}
template SimpleDMatrix::SimpleDMatrix(DenseAdapter* adapter, float missing, template SimpleDMatrix::SimpleDMatrix(DenseAdapter* adapter, float missing,
int nthread); int nthread);
template SimpleDMatrix::SimpleDMatrix(CSRAdapter* adapter, float missing, template SimpleDMatrix::SimpleDMatrix(CSRAdapter* adapter, float missing,

View File

@ -8,6 +8,7 @@
#include <xgboost/data.h> #include <xgboost/data.h>
#include "../common/random.h" #include "../common/random.h"
#include "./simple_dmatrix.h" #include "./simple_dmatrix.h"
#include "../common/math.h"
#include "device_adapter.cuh" #include "device_adapter.cuh"
namespace xgboost { namespace xgboost {
@ -112,38 +113,36 @@ void CopyDataRowMajor(AdapterT* adapter, common::Span<Entry> data,
// be supported in future. Does not currently support inferring row/column size // be supported in future. Does not currently support inferring row/column size
template <typename AdapterT> template <typename AdapterT>
SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) { SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) {
source_.reset(new SimpleCSRSource());
SimpleCSRSource& mat = *reinterpret_cast<SimpleCSRSource*>(source_.get());
CHECK(adapter->NumRows() != kAdapterUnknownSize); CHECK(adapter->NumRows() != kAdapterUnknownSize);
CHECK(adapter->NumColumns() != kAdapterUnknownSize); CHECK(adapter->NumColumns() != kAdapterUnknownSize);
adapter->BeforeFirst(); adapter->BeforeFirst();
adapter->Next(); adapter->Next();
auto& batch = adapter->Value(); auto& batch = adapter->Value();
mat.page_.offset.SetDevice(adapter->DeviceIdx()); sparse_page_.offset.SetDevice(adapter->DeviceIdx());
mat.page_.data.SetDevice(adapter->DeviceIdx()); sparse_page_.data.SetDevice(adapter->DeviceIdx());
// Enforce single batch // Enforce single batch
CHECK(!adapter->Next()); CHECK(!adapter->Next());
mat.page_.offset.Resize(adapter->NumRows() + 1); sparse_page_.offset.Resize(adapter->NumRows() + 1);
auto s_offset = mat.page_.offset.DeviceSpan(); auto s_offset = sparse_page_.offset.DeviceSpan();
CountRowOffsets(batch, s_offset, adapter->DeviceIdx(), missing); CountRowOffsets(batch, s_offset, adapter->DeviceIdx(), missing);
mat.info.num_nonzero_ = mat.page_.offset.HostVector().back(); info.num_nonzero_ = sparse_page_.offset.HostVector().back();
mat.page_.data.Resize(mat.info.num_nonzero_); sparse_page_.data.Resize(info.num_nonzero_);
if (adapter->IsRowMajor()) { if (adapter->IsRowMajor()) {
CopyDataRowMajor(adapter, mat.page_.data.DeviceSpan(), CopyDataRowMajor(adapter, sparse_page_.data.DeviceSpan(),
adapter->DeviceIdx(), missing, s_offset); adapter->DeviceIdx(), missing, s_offset);
} else { } else {
CopyDataColumnMajor(adapter, mat.page_.data.DeviceSpan(), CopyDataColumnMajor(adapter, sparse_page_.data.DeviceSpan(),
adapter->DeviceIdx(), missing, s_offset); adapter->DeviceIdx(), missing, s_offset);
} }
// Sync // Sync
mat.page_.data.HostVector(); sparse_page_.data.HostVector();
mat.info.num_col_ = adapter->NumColumns(); info.num_col_ = adapter->NumColumns();
mat.info.num_row_ = adapter->NumRows(); info.num_row_ = adapter->NumRows();
// Synchronise worker columns // Synchronise worker columns
rabit::Allreduce<rabit::op::Max>(&mat.info.num_col_, 1); rabit::Allreduce<rabit::op::Max>(&info.num_col_, 1);
} }
template SimpleDMatrix::SimpleDMatrix(CudfAdapter* adapter, float missing, template SimpleDMatrix::SimpleDMatrix(CudfAdapter* adapter, float missing,

View File

@ -10,28 +10,22 @@
#include <xgboost/base.h> #include <xgboost/base.h>
#include <xgboost/data.h> #include <xgboost/data.h>
#include <algorithm>
#include <memory> #include <memory>
#include <limits> #include <string>
#include <utility>
#include <vector>
#include "simple_csr_source.h"
#include "../common/group_data.h"
#include "../common/math.h"
#include "adapter.h"
namespace xgboost { namespace xgboost {
namespace data { namespace data {
// Used for single batch data. // Used for single batch data.
class SimpleDMatrix : public DMatrix { class SimpleDMatrix : public DMatrix {
public: public:
explicit SimpleDMatrix(std::unique_ptr<DataSource<SparsePage>>&& source)
: source_(std::move(source)) {}
template <typename AdapterT> template <typename AdapterT>
explicit SimpleDMatrix(AdapterT* adapter, float missing, int nthread); explicit SimpleDMatrix(AdapterT* adapter, float missing, int nthread);
explicit SimpleDMatrix(dmlc::Stream* in_stream);
void SaveToLocalFile(const std::string& fname);
MetaInfo& Info() override; MetaInfo& Info() override;
const MetaInfo& Info() const override; const MetaInfo& Info() const override;
@ -40,15 +34,17 @@ class SimpleDMatrix : public DMatrix {
bool SingleColBlock() const override; bool SingleColBlock() const override;
/*! \brief magic number used to identify SimpleDMatrix binary files */
static const int kMagic = 0xffffab01;
private: private:
BatchSet<SparsePage> GetRowBatches() override; BatchSet<SparsePage> GetRowBatches() override;
BatchSet<CSCPage> GetColumnBatches() override; BatchSet<CSCPage> GetColumnBatches() override;
BatchSet<SortedCSCPage> GetSortedColumnBatches() override; BatchSet<SortedCSCPage> GetSortedColumnBatches() override;
BatchSet<EllpackPage> GetEllpackBatches(const BatchParam& param) override; BatchSet<EllpackPage> GetEllpackBatches(const BatchParam& param) override;
// source data pointer. MetaInfo info;
std::unique_ptr<DataSource<SparsePage>> source_; SparsePage sparse_page_; // Primary storage type
std::unique_ptr<CSCPage> column_page_; std::unique_ptr<CSCPage> column_page_;
std::unique_ptr<SortedCSCPage> sorted_column_page_; std::unique_ptr<SortedCSCPage> sorted_column_page_;
std::unique_ptr<EllpackPage> ellpack_page_; std::unique_ptr<EllpackPage> ellpack_page_;

View File

@ -240,7 +240,7 @@ TEST(hist_util, DenseCutsCategorical) {
auto dmat = GetDMatrixFromData(x, n, 1); auto dmat = GetDMatrixFromData(x, n, 1);
HistogramCuts cuts; HistogramCuts cuts;
DenseCuts dense(&cuts); DenseCuts dense(&cuts);
dense.Build(&dmat, num_bins); dense.Build(dmat.get(), num_bins);
auto cuts_from_sketch = cuts.Values(); auto cuts_from_sketch = cuts.Values();
EXPECT_LT(cuts.MinValues()[0], x_sorted.front()); EXPECT_LT(cuts.MinValues()[0], x_sorted.front());
EXPECT_GT(cuts_from_sketch.front(), x_sorted.front()); EXPECT_GT(cuts_from_sketch.front(), x_sorted.front());
@ -260,7 +260,7 @@ TEST(hist_util, DenseCutsAccuracyTest) {
for (auto num_bins : bin_sizes) { for (auto num_bins : bin_sizes) {
HistogramCuts cuts; HistogramCuts cuts;
DenseCuts dense(&cuts); DenseCuts dense(&cuts);
dense.Build(&dmat, num_bins); dense.Build(dmat.get(), num_bins);
ValidateCuts(cuts, x, num_rows, num_columns, num_bins); ValidateCuts(cuts, x, num_rows, num_columns, num_bins);
} }
} }
@ -294,7 +294,7 @@ TEST(hist_util, SparseCutsAccuracyTest) {
for (auto num_bins : bin_sizes) { for (auto num_bins : bin_sizes) {
HistogramCuts cuts; HistogramCuts cuts;
SparseCuts sparse(&cuts); SparseCuts sparse(&cuts);
sparse.Build(&dmat, num_bins); sparse.Build(dmat.get(), num_bins);
ValidateCuts(cuts, x, num_rows, num_columns, num_bins); ValidateCuts(cuts, x, num_rows, num_columns, num_bins);
} }
} }
@ -312,7 +312,7 @@ TEST(hist_util, SparseCutsCategorical) {
auto dmat = GetDMatrixFromData(x, n, 1); auto dmat = GetDMatrixFromData(x, n, 1);
HistogramCuts cuts; HistogramCuts cuts;
SparseCuts sparse(&cuts); SparseCuts sparse(&cuts);
sparse.Build(&dmat, num_bins); sparse.Build(dmat.get(), num_bins);
auto cuts_from_sketch = cuts.Values(); auto cuts_from_sketch = cuts.Values();
EXPECT_LT(cuts.MinValues()[0], x_sorted.front()); EXPECT_LT(cuts.MinValues()[0], x_sorted.front());
EXPECT_GT(cuts_from_sketch.front(), x_sorted.front()); EXPECT_GT(cuts_from_sketch.front(), x_sorted.front());

View File

@ -7,6 +7,7 @@
#include <fstream> #include <fstream>
#include "../../../src/common/hist_util.h" #include "../../../src/common/hist_util.h"
#include "../../../src/data/simple_dmatrix.h" #include "../../../src/data/simple_dmatrix.h"
#include "../../../src/data/adapter.h"
// Some helper functions used to test both GPU and CPU algorithms // Some helper functions used to test both GPU and CPU algorithms
// //
@ -40,10 +41,11 @@ inline std::vector<float> GenerateRandomCategoricalSingleColumn(int n,
return x; return x;
} }
inline data::SimpleDMatrix GetDMatrixFromData(const std::vector<float>& x, int num_rows, int num_columns) { inline std::shared_ptr<data::SimpleDMatrix> GetDMatrixFromData(const std::vector<float>& x, int num_rows, int num_columns) {
data::DenseAdapter adapter(x.data(), num_rows, num_columns); data::DenseAdapter adapter(x.data(), num_rows, num_columns);
return data::SimpleDMatrix(&adapter, std::numeric_limits<float>::quiet_NaN(), return std::shared_ptr<data::SimpleDMatrix>(new data::SimpleDMatrix(
1); &adapter, std::numeric_limits<float>::quiet_NaN(),
1));
} }
inline std::shared_ptr<DMatrix> GetExternalMemoryDMatrixFromData( inline std::shared_ptr<DMatrix> GetExternalMemoryDMatrixFromData(

View File

@ -7,7 +7,6 @@
#include <memory> #include <memory>
#include "../../../src/common/bitfield.h" #include "../../../src/common/bitfield.h"
#include "../../../src/common/device_helpers.cuh" #include "../../../src/common/device_helpers.cuh"
#include "../../../src/data/simple_csr_source.h"
namespace xgboost { namespace xgboost {

View File

@ -4,7 +4,6 @@
#include <xgboost/data.h> #include <xgboost/data.h>
#include <string> #include <string>
#include <memory> #include <memory>
#include "../../../src/data/simple_csr_source.h"
#include "../../../src/common/version.h" #include "../../../src/common/version.h"
#include "../helpers.h" #include "../helpers.h"

View File

@ -1,41 +0,0 @@
// Copyright by Contributors
#include <gtest/gtest.h>
#include <dmlc/filesystem.h>
#include <xgboost/data.h>
#include <xgboost/json.h>
#include "../../../src/data/simple_csr_source.h"
#include "../helpers.h"
namespace xgboost {
TEST(SimpleCSRSource, SaveLoadBinary) {
dmlc::TemporaryDirectory tempdir;
const std::string tmp_file = tempdir.path + "/simple.libsvm";
CreateSimpleTestData(tmp_file);
xgboost::DMatrix * dmat = xgboost::DMatrix::Load(tmp_file, true, false);
const std::string tmp_binfile = tempdir.path + "/csr_source.binary";
dmat->SaveToLocalFile(tmp_binfile);
xgboost::DMatrix * dmat_read = xgboost::DMatrix::Load(tmp_binfile, true, false);
EXPECT_EQ(dmat->Info().num_col_, dmat_read->Info().num_col_);
EXPECT_EQ(dmat->Info().num_row_, dmat_read->Info().num_row_);
EXPECT_EQ(dmat->Info().num_row_, dmat_read->Info().num_row_);
// Test we have non-empty batch
EXPECT_EQ(dmat->GetBatches<xgboost::SparsePage>().begin().AtEnd(), false);
auto row_iter = dmat->GetBatches<xgboost::SparsePage>().begin();
auto row_iter_read = dmat_read->GetBatches<xgboost::SparsePage>().begin();
// Test the data read into the first row
auto first_row = (*row_iter)[0];
auto first_row_read = (*row_iter_read)[0];
EXPECT_EQ(first_row.size(), first_row_read.size());
EXPECT_EQ(first_row[2].index, first_row_read[2].index);
EXPECT_EQ(first_row[2].fvalue, first_row_read[2].fvalue);
delete dmat;
delete dmat_read;
}
} // namespace xgboost

View File

@ -254,3 +254,33 @@ TEST(SimpleDMatrix, Slice) {
delete pp_dmat; delete pp_dmat;
}; };
TEST(SimpleDMatrix, SaveLoadBinary) {
dmlc::TemporaryDirectory tempdir;
const std::string tmp_file = tempdir.path + "/simple.libsvm";
CreateSimpleTestData(tmp_file);
xgboost::DMatrix * dmat = xgboost::DMatrix::Load(tmp_file, true, false);
data::SimpleDMatrix *simple_dmat = dynamic_cast<data::SimpleDMatrix*>(dmat);
const std::string tmp_binfile = tempdir.path + "/csr_source.binary";
simple_dmat->SaveToLocalFile(tmp_binfile);
xgboost::DMatrix * dmat_read = xgboost::DMatrix::Load(tmp_binfile, true, false);
EXPECT_EQ(dmat->Info().num_col_, dmat_read->Info().num_col_);
EXPECT_EQ(dmat->Info().num_row_, dmat_read->Info().num_row_);
EXPECT_EQ(dmat->Info().num_row_, dmat_read->Info().num_row_);
// Test we have non-empty batch
EXPECT_EQ(dmat->GetBatches<xgboost::SparsePage>().begin().AtEnd(), false);
auto row_iter = dmat->GetBatches<xgboost::SparsePage>().begin();
auto row_iter_read = dmat_read->GetBatches<xgboost::SparsePage>().begin();
// Test the data read into the first row
auto first_row = (*row_iter)[0];
auto first_row_read = (*row_iter_read)[0];
EXPECT_EQ(first_row.size(), first_row_read.size());
EXPECT_EQ(first_row[2].index, first_row_read[2].index);
EXPECT_EQ(first_row[2].fvalue, first_row_read[2].fvalue);
delete dmat;
delete dmat_read;
}

View File

@ -17,7 +17,6 @@
#include "helpers.h" #include "helpers.h"
#include "xgboost/c_api.h" #include "xgboost/c_api.h"
#include "../../src/data/simple_csr_source.h"
#include "../../src/gbm/gbtree_model.h" #include "../../src/gbm/gbtree_model.h"
#include "xgboost/predictor.h" #include "xgboost/predictor.h"
@ -256,17 +255,13 @@ std::unique_ptr<DMatrix> CreateSparsePageDMatrixWithRC(
} }
fo.close(); fo.close();
std::unique_ptr<DMatrix> dmat(DMatrix::Load( std::string uri = tmp_file;
tmp_file + "#" + tmp_file + ".cache", true, false, "auto", page_size)); if (page_size > 0) {
EXPECT_TRUE(FileExists(tmp_file + ".cache.row.page")); uri += "#" + tmp_file + ".cache";
if (!page_size) {
std::unique_ptr<data::SimpleCSRSource> source(new data::SimpleCSRSource);
source->CopyFrom(dmat.get());
return std::unique_ptr<DMatrix>(DMatrix::Create(std::move(source)));
} else {
return dmat;
} }
std::unique_ptr<DMatrix> dmat(
DMatrix::Load(uri, true, false, "auto", page_size));
return dmat;
} }
gbm::GBTreeModel CreateTestModel(LearnerModelParam const* param, size_t n_classes) { gbm::GBTreeModel CreateTestModel(LearnerModelParam const* param, size_t n_classes) {