From b2b2c4e231e4d7d8299006ba4328b4deb45b8263 Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Tue, 18 Feb 2020 16:49:17 +1300 Subject: [PATCH] Remove SimpleCSRSource (#5315) --- amalgamation/xgboost-all0.cc | 1 - include/xgboost/data.h | 18 ------ src/c_api/c_api.cc | 13 +++-- src/c_api/c_api.cu | 1 - src/data/data.cc | 17 ++---- src/data/ellpack_page_raw_format.cu | 2 +- src/data/simple_csr_source.cc | 59 ------------------- src/data/simple_csr_source.h | 74 ------------------------ src/data/simple_dmatrix.cc | 65 +++++++++++++-------- src/data/simple_dmatrix.cu | 27 +++++---- src/data/simple_dmatrix.h | 24 ++++---- tests/cpp/common/test_hist_util.cc | 8 +-- tests/cpp/common/test_hist_util.h | 8 ++- tests/cpp/data/test_array_interface.h | 1 - tests/cpp/data/test_metainfo.cc | 1 - tests/cpp/data/test_simple_csr_source.cc | 41 ------------- tests/cpp/data/test_simple_dmatrix.cc | 30 ++++++++++ tests/cpp/helpers.cc | 17 ++---- 18 files changed, 121 insertions(+), 286 deletions(-) delete mode 100644 src/data/simple_csr_source.cc delete mode 100644 src/data/simple_csr_source.h delete mode 100644 tests/cpp/data/test_simple_csr_source.cc diff --git a/amalgamation/xgboost-all0.cc b/amalgamation/xgboost-all0.cc index d987161ac..f3885ea9f 100644 --- a/amalgamation/xgboost-all0.cc +++ b/amalgamation/xgboost-all0.cc @@ -31,7 +31,6 @@ // data #include "../src/data/data.cc" -#include "../src/data/simple_csr_source.cc" #include "../src/data/simple_dmatrix.cc" #include "../src/data/sparse_page_raw_format.cc" #include "../src/data/ellpack_page.cc" diff --git a/include/xgboost/data.h b/include/xgboost/data.h index 8a2dfefb9..2894ebafa 100644 --- a/include/xgboost/data.h +++ b/include/xgboost/data.h @@ -445,14 +445,6 @@ class DMatrix { virtual float GetColDensity(size_t cidx) = 0; /*! \brief virtual destructor */ virtual ~DMatrix() = default; - /*! - * \brief Save DMatrix to local file. - * The saved file only works for non-sharded dataset(single machine training). - * This API is deprecated and dis-encouraged to use. - * \param fname The file name to be saved. - * \return The created DMatrix. - */ - virtual void SaveToLocalFile(const std::string& fname); /*! \brief Whether the matrix is dense. */ bool IsDense() const { @@ -475,16 +467,6 @@ class DMatrix { const std::string& file_format = "auto", size_t page_size = kPageSize); - /*! - * \brief create a new DMatrix, by wrapping a row_iterator, and meta info. - * \param source The source iterator of the data, the create function takes ownership of the source. - * \param cache_prefix The path to prefix of temporary cache file of the DMatrix when used in external memory mode. - * This can be nullptr for common cases, and in-memory mode will be used. - * \return a Created DMatrix. - */ - static DMatrix* Create(std::unique_ptr>&& source, - const std::string& cache_prefix = ""); - /** * \brief Creates a new DMatrix from an external data adapter. * diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index ccc49d91a..7283c70d4 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -20,7 +20,6 @@ #include "xgboost/json.h" #include "c_api_error.h" -#include "../data/simple_csr_source.h" #include "../common/io.h" #include "../data/adapter.h" #include "../data/simple_dmatrix.h" @@ -296,8 +295,6 @@ XGB_DLL int XGDMatrixSliceDMatrixEx(DMatrixHandle handle, xgboost::bst_ulong len, DMatrixHandle* out, int allow_groups) { - std::unique_ptr source(new data::SimpleCSRSource()); - API_BEGIN(); CHECK_HANDLE(); if (!allow_groups) { @@ -324,12 +321,16 @@ XGB_DLL int XGDMatrixFree(DMatrixHandle handle) { API_END(); } -XGB_DLL int XGDMatrixSaveBinary(DMatrixHandle handle, - const char* fname, +XGB_DLL int XGDMatrixSaveBinary(DMatrixHandle handle, const char* fname, int silent) { API_BEGIN(); CHECK_HANDLE(); - static_cast*>(handle)->get()->SaveToLocalFile(fname); + auto dmat = static_cast*>(handle)->get(); + if (data::SimpleDMatrix* derived = dynamic_cast(dmat)) { + derived->SaveToLocalFile(fname); + } else { + LOG(FATAL) << "binary saving only supported by SimpleDMatrix"; + } API_END(); } diff --git a/src/c_api/c_api.cu b/src/c_api/c_api.cu index 4652d5218..b76f30ea2 100644 --- a/src/c_api/c_api.cu +++ b/src/c_api/c_api.cu @@ -3,7 +3,6 @@ #include "xgboost/data.h" #include "xgboost/c_api.h" #include "c_api_error.h" -#include "../data/simple_csr_source.h" #include "../data/device_adapter.cuh" namespace xgboost { diff --git a/src/data/data.cc b/src/data/data.cc index f030686c4..9fca3c305 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -12,9 +12,9 @@ #include "xgboost/version_config.h" #include "sparse_page_writer.h" #include "simple_dmatrix.h" -#include "simple_csr_source.h" #include "../common/io.h" +#include "../common/math.h" #include "../common/version.h" #include "../common/group_data.h" #include "../data/adapter.h" @@ -336,10 +336,8 @@ DMatrix* DMatrix::Load(const std::string& uri, if (fi != nullptr) { common::PeekableInStream is(fi.get()); if (is.PeekRead(&magic, sizeof(magic)) == sizeof(magic) && - magic == data::SimpleCSRSource::kMagic) { - std::unique_ptr source(new data::SimpleCSRSource()); - source->LoadBinary(&is); - DMatrix* dmat = DMatrix::Create(std::move(source), cache_file); + magic == data::SimpleDMatrix::kMagic) { + DMatrix* dmat = new data::SimpleDMatrix(&is); if (!silent) { LOG(CONSOLE) << dmat->Info().num_row_ << 'x' << dmat->Info().num_col_ << " matrix with " << dmat->Info().num_nonzero_ << " entries loaded from " << uri; @@ -412,13 +410,7 @@ DMatrix* DMatrix::Load(const std::string& uri, } -void DMatrix::SaveToLocalFile(const std::string& fname) { - data::SimpleCSRSource source; - source.CopyFrom(this); - std::unique_ptr fo(dmlc::Stream::Create(fname.c_str(), "w")); - source.SaveBinary(fo.get()); -} - +/* DMatrix* DMatrix::Create(std::unique_ptr>&& source, const std::string& cache_prefix) { if (cache_prefix.length() == 0) { @@ -434,6 +426,7 @@ DMatrix* DMatrix::Create(std::unique_ptr>&& source, #endif // DMLC_ENABLE_STD_THREAD } } +*/ template DMatrix* DMatrix::Create(AdapterT* adapter, float missing, int nthread, diff --git a/src/data/ellpack_page_raw_format.cu b/src/data/ellpack_page_raw_format.cu index 7760b13dc..b46e35c96 100644 --- a/src/data/ellpack_page_raw_format.cu +++ b/src/data/ellpack_page_raw_format.cu @@ -18,7 +18,7 @@ class EllpackPageRawFormat : public SparsePageFormat { bool Read(EllpackPage* page, dmlc::SeekStream* fi) override { auto* impl = page->Impl(); impl->Clear(); - if (!fi->Read(&impl->matrix.n_rows)) return false; + if (!fi->Read(&impl->matrix.n_rows)) return false; return fi->Read(&impl->idx_buffer); } diff --git a/src/data/simple_csr_source.cc b/src/data/simple_csr_source.cc deleted file mode 100644 index 2b7d4e848..000000000 --- a/src/data/simple_csr_source.cc +++ /dev/null @@ -1,59 +0,0 @@ -/*! - * Copyright 2015-2019 by Contributors - * \file simple_csr_source.cc - */ -#include -#include -#include - -#include "simple_csr_source.h" - -namespace xgboost { -namespace data { - -void SimpleCSRSource::Clear() { - page_.Clear(); - this->info.Clear(); -} - -void SimpleCSRSource::CopyFrom(DMatrix* src) { - this->Clear(); - this->info = src->Info(); - for (const auto &batch : src->GetBatches()) { - page_.Push(batch); - } -} - -void SimpleCSRSource::LoadBinary(dmlc::Stream* fi) { - int tmagic; - CHECK(fi->Read(&tmagic, sizeof(tmagic)) == sizeof(tmagic)) << "invalid input file format"; - CHECK_EQ(tmagic, kMagic) << "invalid format, magic number mismatch"; - info.LoadBinary(fi); - fi->Read(&page_.offset.HostVector()); - fi->Read(&page_.data.HostVector()); -} - -void SimpleCSRSource::SaveBinary(dmlc::Stream* fo) const { - int tmagic = kMagic; - fo->Write(&tmagic, sizeof(tmagic)); - info.SaveBinary(fo); - fo->Write(page_.offset.HostVector()); - fo->Write(page_.data.HostVector()); -} - -void SimpleCSRSource::BeforeFirst() { - at_first_ = true; -} - -bool SimpleCSRSource::Next() { - if (!at_first_) return false; - at_first_ = false; - return true; -} - -const SparsePage& SimpleCSRSource::Value() const { - return page_; -} - -} // namespace data -} // namespace xgboost diff --git a/src/data/simple_csr_source.h b/src/data/simple_csr_source.h deleted file mode 100644 index a70871acb..000000000 --- a/src/data/simple_csr_source.h +++ /dev/null @@ -1,74 +0,0 @@ -/*! - * Copyright 2015 by Contributors - * \file simple_csr_source.h - * \brief The simplest form of data source, can be used to create DMatrix. - * This is an in-memory data structure that holds the data in row oriented format. - * \author Tianqi Chen - */ -#ifndef XGBOOST_DATA_SIMPLE_CSR_SOURCE_H_ -#define XGBOOST_DATA_SIMPLE_CSR_SOURCE_H_ - -#include -#include - -#include -#include -#include -#include - -namespace xgboost { - -class Json; - -namespace data { -/*! - * \brief The simplest form of data holder, can be used to create DMatrix. - * This is an in-memory data structure that holds the data in row oriented format. - * \code - * std::unique_ptr source(new SimpleCSRSource()); - * // add data to source - * DMatrix* dmat = DMatrix::Create(std::move(source)); - * \encode - */ -class SimpleCSRSource : public DataSource { - public: - // MetaInfo info; // inheritated from DataSource - SparsePage page_; - /*! \brief default constructor */ - SimpleCSRSource() = default; - /*! \brief destructor */ - ~SimpleCSRSource() override = default; - /*! \brief clear the data structure */ - void Clear(); - /*! - * \brief copy content of data from src - * \param src source data iter. - */ - void CopyFrom(DMatrix* src); - - /*! - * \brief Load data from binary stream. - * \param fi the pointer to load data from. - */ - void LoadBinary(dmlc::Stream* fi); - /*! - * \brief Save data into binary stream - * \param fo The output stream. - */ - void SaveBinary(dmlc::Stream* fo) const; - // implement Next - bool Next() override; - // implement BeforeFirst - void BeforeFirst() override; - // implement Value - const SparsePage &Value() const override; - /*! \brief magic number used to identify SimpleCSRSource */ - static const int kMagic = 0xffffab01; - - private: - /*! \brief internal variable, used to support iterator interface */ - bool at_first_{true}; -}; -} // namespace data -} // namespace xgboost -#endif // XGBOOST_DATA_SIMPLE_CSR_SOURCE_H_ diff --git a/src/data/simple_dmatrix.cc b/src/data/simple_dmatrix.cc index 229f4ab58..8b3d241d6 100644 --- a/src/data/simple_dmatrix.cc +++ b/src/data/simple_dmatrix.cc @@ -8,12 +8,13 @@ #include #include "./simple_batch_iterator.h" #include "../common/random.h" +#include "../data/adapter.h" namespace xgboost { namespace data { -MetaInfo& SimpleDMatrix::Info() { return source_->info; } +MetaInfo& SimpleDMatrix::Info() { return info; } -const MetaInfo& SimpleDMatrix::Info() const { return source_->info; } +const MetaInfo& SimpleDMatrix::Info() const { return info; } float SimpleDMatrix::GetColDensity(size_t cidx) { size_t column_size = 0; @@ -32,17 +33,15 @@ float SimpleDMatrix::GetColDensity(size_t cidx) { BatchSet SimpleDMatrix::GetRowBatches() { // since csr is the default data structure so `source_` is always available. - auto cast = dynamic_cast(source_.get()); auto begin_iter = BatchIterator( - new SimpleBatchIteratorImpl(&(cast->page_))); + new SimpleBatchIteratorImpl(&sparse_page_)); return BatchSet(begin_iter); } BatchSet SimpleDMatrix::GetColumnBatches() { // column page doesn't exist, generate it if (!column_page_) { - auto const& page = dynamic_cast(source_.get())->page_; - column_page_.reset(new CSCPage(page.GetTranspose(source_->info.num_col_))); + column_page_.reset(new CSCPage(sparse_page_.GetTranspose(info.num_col_))); } auto begin_iter = BatchIterator(new SimpleBatchIteratorImpl(column_page_.get())); @@ -52,9 +51,8 @@ BatchSet SimpleDMatrix::GetColumnBatches() { BatchSet SimpleDMatrix::GetSortedColumnBatches() { // Sorted column page doesn't exist, generate it if (!sorted_column_page_) { - auto const& page = dynamic_cast(source_.get())->page_; sorted_column_page_.reset( - new SortedCSCPage(page.GetTranspose(source_->info.num_col_))); + new SortedCSCPage(sparse_page_.GetTranspose(info.num_col_))); sorted_column_page_->SortRows(); } auto begin_iter = BatchIterator( @@ -84,35 +82,33 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) { int nthread_original = omp_get_max_threads(); omp_set_num_threads(nthread); - source_.reset(new SimpleCSRSource()); - SimpleCSRSource& mat = *reinterpret_cast(source_.get()); std::vector qids; uint64_t default_max = std::numeric_limits::max(); uint64_t last_group_id = default_max; bst_uint group_size = 0; - auto& offset_vec = mat.page_.offset.HostVector(); - auto& data_vec = mat.page_.data.HostVector(); + auto& offset_vec = sparse_page_.offset.HostVector(); + auto& data_vec = sparse_page_.data.HostVector(); uint64_t inferred_num_columns = 0; adapter->BeforeFirst(); // Iterate over batches of input data while (adapter->Next()) { auto& batch = adapter->Value(); - auto batch_max_columns = mat.page_.Push(batch, missing, nthread); + auto batch_max_columns = sparse_page_.Push(batch, missing, nthread); inferred_num_columns = std::max(batch_max_columns, inferred_num_columns); // Append meta information if available if (batch.Labels() != nullptr) { - auto& labels = mat.info.labels_.HostVector(); + auto& labels = info.labels_.HostVector(); labels.insert(labels.end(), batch.Labels(), batch.Labels() + batch.Size()); } if (batch.Weights() != nullptr) { - auto& weights = mat.info.weights_.HostVector(); + auto& weights = info.weights_.HostVector(); weights.insert(weights.end(), batch.Weights(), batch.Weights() + batch.Size()); } if (batch.BaseMargin() != nullptr) { - auto& base_margin = mat.info.base_margin_.HostVector(); + auto& base_margin = info.base_margin_.HostVector(); base_margin.insert(base_margin.end(), batch.BaseMargin(), batch.BaseMargin() + batch.Size()); } @@ -122,7 +118,7 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) { for (size_t i = 0; i < batch.Size(); ++i) { const uint64_t cur_group_id = batch.Qid()[i]; if (last_group_id == default_max || last_group_id != cur_group_id) { - mat.info.group_ptr_.push_back(group_size); + info.group_ptr_.push_back(group_size); } last_group_id = cur_group_id; ++group_size; @@ -131,22 +127,22 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) { } if (last_group_id != default_max) { - if (group_size > mat.info.group_ptr_.back()) { - mat.info.group_ptr_.push_back(group_size); + if (group_size > info.group_ptr_.back()) { + info.group_ptr_.push_back(group_size); } } // Deal with empty rows/columns if necessary if (adapter->NumColumns() == kAdapterUnknownSize) { - mat.info.num_col_ = inferred_num_columns; + info.num_col_ = inferred_num_columns; } else { - mat.info.num_col_ = adapter->NumColumns(); + info.num_col_ = adapter->NumColumns(); } // Synchronise worker columns - rabit::Allreduce(&mat.info.num_col_, 1); + rabit::Allreduce(&info.num_col_, 1); if (adapter->NumRows() == kAdapterUnknownSize) { - mat.info.num_row_ = offset_vec.size() - 1; + info.num_row_ = offset_vec.size() - 1; } else { if (offset_vec.empty()) { offset_vec.emplace_back(0); @@ -155,12 +151,31 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) { while (offset_vec.size() - 1 < adapter->NumRows()) { offset_vec.emplace_back(offset_vec.back()); } - mat.info.num_row_ = adapter->NumRows(); + info.num_row_ = adapter->NumRows(); } - mat.info.num_nonzero_ = data_vec.size(); + info.num_nonzero_ = data_vec.size(); omp_set_num_threads(nthread_original); } +SimpleDMatrix::SimpleDMatrix(dmlc::Stream* in_stream) { + int tmagic; + CHECK(in_stream->Read(&tmagic, sizeof(tmagic)) == sizeof(tmagic)) + << "invalid input file format"; + CHECK_EQ(tmagic, kMagic) << "invalid format, magic number mismatch"; + info.LoadBinary(in_stream); + in_stream->Read(&sparse_page_.offset.HostVector()); + in_stream->Read(&sparse_page_.data.HostVector()); +} + +void SimpleDMatrix::SaveToLocalFile(const std::string& fname) { + std::unique_ptr fo(dmlc::Stream::Create(fname.c_str(), "w")); + int tmagic = kMagic; + fo->Write(&tmagic, sizeof(tmagic)); + info.SaveBinary(fo.get()); + fo->Write(sparse_page_.offset.HostVector()); + fo->Write(sparse_page_.data.HostVector()); +} + template SimpleDMatrix::SimpleDMatrix(DenseAdapter* adapter, float missing, int nthread); template SimpleDMatrix::SimpleDMatrix(CSRAdapter* adapter, float missing, diff --git a/src/data/simple_dmatrix.cu b/src/data/simple_dmatrix.cu index 1d980b45d..5771d2c83 100644 --- a/src/data/simple_dmatrix.cu +++ b/src/data/simple_dmatrix.cu @@ -8,6 +8,7 @@ #include #include "../common/random.h" #include "./simple_dmatrix.h" +#include "../common/math.h" #include "device_adapter.cuh" namespace xgboost { @@ -112,38 +113,36 @@ void CopyDataRowMajor(AdapterT* adapter, common::Span data, // be supported in future. Does not currently support inferring row/column size template SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) { - source_.reset(new SimpleCSRSource()); - SimpleCSRSource& mat = *reinterpret_cast(source_.get()); CHECK(adapter->NumRows() != kAdapterUnknownSize); CHECK(adapter->NumColumns() != kAdapterUnknownSize); adapter->BeforeFirst(); adapter->Next(); auto& batch = adapter->Value(); - mat.page_.offset.SetDevice(adapter->DeviceIdx()); - mat.page_.data.SetDevice(adapter->DeviceIdx()); + sparse_page_.offset.SetDevice(adapter->DeviceIdx()); + sparse_page_.data.SetDevice(adapter->DeviceIdx()); // Enforce single batch CHECK(!adapter->Next()); - mat.page_.offset.Resize(adapter->NumRows() + 1); - auto s_offset = mat.page_.offset.DeviceSpan(); + sparse_page_.offset.Resize(adapter->NumRows() + 1); + auto s_offset = sparse_page_.offset.DeviceSpan(); CountRowOffsets(batch, s_offset, adapter->DeviceIdx(), missing); - mat.info.num_nonzero_ = mat.page_.offset.HostVector().back(); - mat.page_.data.Resize(mat.info.num_nonzero_); + info.num_nonzero_ = sparse_page_.offset.HostVector().back(); + sparse_page_.data.Resize(info.num_nonzero_); if (adapter->IsRowMajor()) { - CopyDataRowMajor(adapter, mat.page_.data.DeviceSpan(), + CopyDataRowMajor(adapter, sparse_page_.data.DeviceSpan(), adapter->DeviceIdx(), missing, s_offset); } else { - CopyDataColumnMajor(adapter, mat.page_.data.DeviceSpan(), + CopyDataColumnMajor(adapter, sparse_page_.data.DeviceSpan(), adapter->DeviceIdx(), missing, s_offset); } // Sync - mat.page_.data.HostVector(); + sparse_page_.data.HostVector(); - mat.info.num_col_ = adapter->NumColumns(); - mat.info.num_row_ = adapter->NumRows(); + info.num_col_ = adapter->NumColumns(); + info.num_row_ = adapter->NumRows(); // Synchronise worker columns - rabit::Allreduce(&mat.info.num_col_, 1); + rabit::Allreduce(&info.num_col_, 1); } template SimpleDMatrix::SimpleDMatrix(CudfAdapter* adapter, float missing, diff --git a/src/data/simple_dmatrix.h b/src/data/simple_dmatrix.h index 65f525f33..89ca77856 100644 --- a/src/data/simple_dmatrix.h +++ b/src/data/simple_dmatrix.h @@ -10,28 +10,22 @@ #include #include -#include #include -#include -#include -#include +#include -#include "simple_csr_source.h" -#include "../common/group_data.h" -#include "../common/math.h" -#include "adapter.h" namespace xgboost { namespace data { // Used for single batch data. class SimpleDMatrix : public DMatrix { public: - explicit SimpleDMatrix(std::unique_ptr>&& source) - : source_(std::move(source)) {} - template explicit SimpleDMatrix(AdapterT* adapter, float missing, int nthread); + explicit SimpleDMatrix(dmlc::Stream* in_stream); + + void SaveToLocalFile(const std::string& fname); + MetaInfo& Info() override; const MetaInfo& Info() const override; @@ -40,15 +34,17 @@ class SimpleDMatrix : public DMatrix { bool SingleColBlock() const override; + /*! \brief magic number used to identify SimpleDMatrix binary files */ + static const int kMagic = 0xffffab01; + private: BatchSet GetRowBatches() override; BatchSet GetColumnBatches() override; BatchSet GetSortedColumnBatches() override; BatchSet GetEllpackBatches(const BatchParam& param) override; - // source data pointer. - std::unique_ptr> source_; - + MetaInfo info; + SparsePage sparse_page_; // Primary storage type std::unique_ptr column_page_; std::unique_ptr sorted_column_page_; std::unique_ptr ellpack_page_; diff --git a/tests/cpp/common/test_hist_util.cc b/tests/cpp/common/test_hist_util.cc index 2d8306d53..33729e32a 100644 --- a/tests/cpp/common/test_hist_util.cc +++ b/tests/cpp/common/test_hist_util.cc @@ -240,7 +240,7 @@ TEST(hist_util, DenseCutsCategorical) { auto dmat = GetDMatrixFromData(x, n, 1); HistogramCuts cuts; DenseCuts dense(&cuts); - dense.Build(&dmat, num_bins); + dense.Build(dmat.get(), num_bins); auto cuts_from_sketch = cuts.Values(); EXPECT_LT(cuts.MinValues()[0], x_sorted.front()); EXPECT_GT(cuts_from_sketch.front(), x_sorted.front()); @@ -260,7 +260,7 @@ TEST(hist_util, DenseCutsAccuracyTest) { for (auto num_bins : bin_sizes) { HistogramCuts cuts; DenseCuts dense(&cuts); - dense.Build(&dmat, num_bins); + dense.Build(dmat.get(), num_bins); ValidateCuts(cuts, x, num_rows, num_columns, num_bins); } } @@ -294,7 +294,7 @@ TEST(hist_util, SparseCutsAccuracyTest) { for (auto num_bins : bin_sizes) { HistogramCuts cuts; SparseCuts sparse(&cuts); - sparse.Build(&dmat, num_bins); + sparse.Build(dmat.get(), num_bins); ValidateCuts(cuts, x, num_rows, num_columns, num_bins); } } @@ -312,7 +312,7 @@ TEST(hist_util, SparseCutsCategorical) { auto dmat = GetDMatrixFromData(x, n, 1); HistogramCuts cuts; SparseCuts sparse(&cuts); - sparse.Build(&dmat, num_bins); + sparse.Build(dmat.get(), num_bins); auto cuts_from_sketch = cuts.Values(); EXPECT_LT(cuts.MinValues()[0], x_sorted.front()); EXPECT_GT(cuts_from_sketch.front(), x_sorted.front()); diff --git a/tests/cpp/common/test_hist_util.h b/tests/cpp/common/test_hist_util.h index 416b4d823..663fbc4cf 100644 --- a/tests/cpp/common/test_hist_util.h +++ b/tests/cpp/common/test_hist_util.h @@ -7,6 +7,7 @@ #include #include "../../../src/common/hist_util.h" #include "../../../src/data/simple_dmatrix.h" +#include "../../../src/data/adapter.h" // Some helper functions used to test both GPU and CPU algorithms // @@ -40,10 +41,11 @@ inline std::vector GenerateRandomCategoricalSingleColumn(int n, return x; } -inline data::SimpleDMatrix GetDMatrixFromData(const std::vector& x, int num_rows, int num_columns) { +inline std::shared_ptr GetDMatrixFromData(const std::vector& x, int num_rows, int num_columns) { data::DenseAdapter adapter(x.data(), num_rows, num_columns); - return data::SimpleDMatrix(&adapter, std::numeric_limits::quiet_NaN(), - 1); + return std::shared_ptr(new data::SimpleDMatrix( + &adapter, std::numeric_limits::quiet_NaN(), + 1)); } inline std::shared_ptr GetExternalMemoryDMatrixFromData( diff --git a/tests/cpp/data/test_array_interface.h b/tests/cpp/data/test_array_interface.h index 823cd0e2a..687a437d8 100644 --- a/tests/cpp/data/test_array_interface.h +++ b/tests/cpp/data/test_array_interface.h @@ -7,7 +7,6 @@ #include #include "../../../src/common/bitfield.h" #include "../../../src/common/device_helpers.cuh" -#include "../../../src/data/simple_csr_source.h" namespace xgboost { diff --git a/tests/cpp/data/test_metainfo.cc b/tests/cpp/data/test_metainfo.cc index 380c8b142..ec4f0ca33 100644 --- a/tests/cpp/data/test_metainfo.cc +++ b/tests/cpp/data/test_metainfo.cc @@ -4,7 +4,6 @@ #include #include #include -#include "../../../src/data/simple_csr_source.h" #include "../../../src/common/version.h" #include "../helpers.h" diff --git a/tests/cpp/data/test_simple_csr_source.cc b/tests/cpp/data/test_simple_csr_source.cc deleted file mode 100644 index c3afd9e52..000000000 --- a/tests/cpp/data/test_simple_csr_source.cc +++ /dev/null @@ -1,41 +0,0 @@ -// Copyright by Contributors -#include -#include - -#include -#include -#include "../../../src/data/simple_csr_source.h" - -#include "../helpers.h" - -namespace xgboost { - -TEST(SimpleCSRSource, SaveLoadBinary) { - dmlc::TemporaryDirectory tempdir; - const std::string tmp_file = tempdir.path + "/simple.libsvm"; - CreateSimpleTestData(tmp_file); - xgboost::DMatrix * dmat = xgboost::DMatrix::Load(tmp_file, true, false); - - const std::string tmp_binfile = tempdir.path + "/csr_source.binary"; - dmat->SaveToLocalFile(tmp_binfile); - xgboost::DMatrix * dmat_read = xgboost::DMatrix::Load(tmp_binfile, true, false); - - EXPECT_EQ(dmat->Info().num_col_, dmat_read->Info().num_col_); - EXPECT_EQ(dmat->Info().num_row_, dmat_read->Info().num_row_); - EXPECT_EQ(dmat->Info().num_row_, dmat_read->Info().num_row_); - - // Test we have non-empty batch - EXPECT_EQ(dmat->GetBatches().begin().AtEnd(), false); - - auto row_iter = dmat->GetBatches().begin(); - auto row_iter_read = dmat_read->GetBatches().begin(); - // Test the data read into the first row - auto first_row = (*row_iter)[0]; - auto first_row_read = (*row_iter_read)[0]; - EXPECT_EQ(first_row.size(), first_row_read.size()); - EXPECT_EQ(first_row[2].index, first_row_read[2].index); - EXPECT_EQ(first_row[2].fvalue, first_row_read[2].fvalue); - delete dmat; - delete dmat_read; -} -} // namespace xgboost diff --git a/tests/cpp/data/test_simple_dmatrix.cc b/tests/cpp/data/test_simple_dmatrix.cc index 0a66af5a0..79f505684 100644 --- a/tests/cpp/data/test_simple_dmatrix.cc +++ b/tests/cpp/data/test_simple_dmatrix.cc @@ -254,3 +254,33 @@ TEST(SimpleDMatrix, Slice) { delete pp_dmat; }; + +TEST(SimpleDMatrix, SaveLoadBinary) { + dmlc::TemporaryDirectory tempdir; + const std::string tmp_file = tempdir.path + "/simple.libsvm"; + CreateSimpleTestData(tmp_file); + xgboost::DMatrix * dmat = xgboost::DMatrix::Load(tmp_file, true, false); + data::SimpleDMatrix *simple_dmat = dynamic_cast(dmat); + + const std::string tmp_binfile = tempdir.path + "/csr_source.binary"; + simple_dmat->SaveToLocalFile(tmp_binfile); + xgboost::DMatrix * dmat_read = xgboost::DMatrix::Load(tmp_binfile, true, false); + + EXPECT_EQ(dmat->Info().num_col_, dmat_read->Info().num_col_); + EXPECT_EQ(dmat->Info().num_row_, dmat_read->Info().num_row_); + EXPECT_EQ(dmat->Info().num_row_, dmat_read->Info().num_row_); + + // Test we have non-empty batch + EXPECT_EQ(dmat->GetBatches().begin().AtEnd(), false); + + auto row_iter = dmat->GetBatches().begin(); + auto row_iter_read = dmat_read->GetBatches().begin(); + // Test the data read into the first row + auto first_row = (*row_iter)[0]; + auto first_row_read = (*row_iter_read)[0]; + EXPECT_EQ(first_row.size(), first_row_read.size()); + EXPECT_EQ(first_row[2].index, first_row_read[2].index); + EXPECT_EQ(first_row[2].fvalue, first_row_read[2].fvalue); + delete dmat; + delete dmat_read; +} diff --git a/tests/cpp/helpers.cc b/tests/cpp/helpers.cc index f071d1eab..bca61c7b3 100644 --- a/tests/cpp/helpers.cc +++ b/tests/cpp/helpers.cc @@ -17,7 +17,6 @@ #include "helpers.h" #include "xgboost/c_api.h" -#include "../../src/data/simple_csr_source.h" #include "../../src/gbm/gbtree_model.h" #include "xgboost/predictor.h" @@ -256,17 +255,13 @@ std::unique_ptr CreateSparsePageDMatrixWithRC( } fo.close(); - std::unique_ptr dmat(DMatrix::Load( - tmp_file + "#" + tmp_file + ".cache", true, false, "auto", page_size)); - EXPECT_TRUE(FileExists(tmp_file + ".cache.row.page")); - - if (!page_size) { - std::unique_ptr source(new data::SimpleCSRSource); - source->CopyFrom(dmat.get()); - return std::unique_ptr(DMatrix::Create(std::move(source))); - } else { - return dmat; + std::string uri = tmp_file; + if (page_size > 0) { + uri += "#" + tmp_file + ".cache"; } + std::unique_ptr dmat( + DMatrix::Load(uri, true, false, "auto", page_size)); + return dmat; } gbm::GBTreeModel CreateTestModel(LearnerModelParam const* param, size_t n_classes) {