diff --git a/include/xgboost/data.h b/include/xgboost/data.h index 81565d194..baf600092 100644 --- a/include/xgboost/data.h +++ b/include/xgboost/data.h @@ -433,12 +433,14 @@ class DMatrix { * \param load_row_split Flag to read in part of rows, divided among the workers in distributed mode. * \param file_format The format type of the file, used for dmlc::Parser::Create. * By default "auto" will be able to load in both local binary file. + * \param page_size Page size for external memory. * \return The created DMatrix. */ static DMatrix* Load(const std::string& uri, bool silent, bool load_row_split, - const std::string& file_format = "auto"); + const std::string& file_format = "auto", + const size_t page_size = kPageSize); /*! * \brief create a new DMatrix, by wrapping a row_iterator, and meta info. * \param source The source iterator of the data, the create function takes ownership of the source. @@ -454,6 +456,7 @@ class DMatrix { * \param parser The input data parser * \param cache_prefix The path to prefix of temporary cache file of the DMatrix when used in external memory mode. * This can be nullptr for common cases, and in-memory mode will be used. + * \param page_size Page size for external memory. * \sa dmlc::Parser * \note dmlc-core provides efficient distributed data parser for libsvm format. * User can create and register customized parser to load their own format using DMLC_REGISTER_DATA_PARSER. @@ -461,7 +464,11 @@ class DMatrix { * \return A created DMatrix. */ static DMatrix* Create(dmlc::Parser* parser, - const std::string& cache_prefix = ""); + const std::string& cache_prefix = "", + const size_t page_size = kPageSize); + + /*! \brief page size 32 MB */ + static const size_t kPageSize = 32UL << 20UL; }; // implementation of inline functions diff --git a/src/data/data.cc b/src/data/data.cc index 99d249b4b..a3b0c250e 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -150,7 +150,8 @@ void MetaInfo::SetInfo(const char* key, const void* dptr, DataType dtype, size_t DMatrix* DMatrix::Load(const std::string& uri, bool silent, bool load_row_split, - const std::string& file_format) { + const std::string& file_format, + const size_t page_size) { std::string fname, cache_file; size_t dlm_pos = uri.find('#'); if (dlm_pos != std::string::npos) { @@ -217,7 +218,7 @@ DMatrix* DMatrix::Load(const std::string& uri, std::unique_ptr > parser( dmlc::Parser::Create(fname.c_str(), partid, npart, file_format.c_str())); - DMatrix* dmat = DMatrix::Create(parser.get(), cache_file); + DMatrix* dmat = DMatrix::Create(parser.get(), cache_file, page_size); if (!silent) { LOG(CONSOLE) << dmat->Info().num_row_ << 'x' << dmat->Info().num_col_ << " matrix with " << dmat->Info().num_nonzero_ << " entries loaded from " << uri; @@ -248,7 +249,8 @@ DMatrix* DMatrix::Load(const std::string& uri, } DMatrix* DMatrix::Create(dmlc::Parser* parser, - const std::string& cache_prefix) { + const std::string& cache_prefix, + const size_t page_size) { if (cache_prefix.length() == 0) { std::unique_ptr source(new data::SimpleCSRSource()); source->CopyFrom(parser); @@ -256,7 +258,7 @@ DMatrix* DMatrix::Create(dmlc::Parser* parser, } else { #if DMLC_ENABLE_STD_THREAD if (!data::SparsePageSource::CacheExist(cache_prefix, ".row.page")) { - data::SparsePageSource::CreateRowPage(parser, cache_prefix); + data::SparsePageSource::CreateRowPage(parser, cache_prefix, page_size); } std::unique_ptr source( new data::SparsePageSource(cache_prefix, ".row.page")); diff --git a/src/data/sparse_page_dmatrix.h b/src/data/sparse_page_dmatrix.h index 8b4352b09..1e0879609 100644 --- a/src/data/sparse_page_dmatrix.h +++ b/src/data/sparse_page_dmatrix.h @@ -40,9 +40,6 @@ class SparsePageDMatrix : public DMatrix { bool SingleColBlock() const override; private: - /*! \brief page size 256 MB */ - static const size_t kPageSize = 256UL << 20UL; - // source data pointers. std::unique_ptr row_source_; std::unique_ptr column_source_; diff --git a/src/data/sparse_page_source.cc b/src/data/sparse_page_source.cc index 0bec51997..ee7b13a4f 100644 --- a/src/data/sparse_page_source.cc +++ b/src/data/sparse_page_source.cc @@ -126,7 +126,8 @@ bool SparsePageSource::CacheExist(const std::string& cache_info, } void SparsePageSource::CreateRowPage(dmlc::Parser* src, - const std::string& cache_info) { + const std::string& cache_info, + const size_t page_size) { const std::string page_type = ".row.page"; std::vector cache_shards = GetCacheShards(cache_info); CHECK_NE(cache_shards.size(), 0U); @@ -183,7 +184,7 @@ void SparsePageSource::CreateRowPage(dmlc::Parser* src, static_cast(index + 1)); } page->Push(batch); - if (page->MemCostBytes() >= kPageSize) { + if (page->MemCostBytes() >= page_size) { bytes_write += page->MemCostBytes(); writer.PushWrite(std::move(page)); writer.Alloc(&page); @@ -222,7 +223,8 @@ void SparsePageSource::CreateRowPage(dmlc::Parser* src, void SparsePageSource::CreatePageFromDMatrix(DMatrix* src, const std::string& cache_info, - const std::string& page_type) { + const std::string& page_type, + const size_t page_size) { std::vector cache_shards = GetCacheShards(cache_info); CHECK_NE(cache_shards.size(), 0U); // read in the info files. @@ -254,7 +256,7 @@ void SparsePageSource::CreatePageFromDMatrix(DMatrix* src, LOG(FATAL) << "Unknown page type: " << page_type; } - if (page->MemCostBytes() >= kPageSize) { + if (page->MemCostBytes() >= page_size) { bytes_write += page->MemCostBytes(); writer.PushWrite(std::move(page)); writer.Alloc(&page); diff --git a/src/data/sparse_page_source.h b/src/data/sparse_page_source.h index 0be2bd3fc..fbec44498 100644 --- a/src/data/sparse_page_source.h +++ b/src/data/sparse_page_source.h @@ -48,9 +48,11 @@ class SparsePageSource : public DataSource { * \brief Create source by taking data from parser. * \param src source parser. * \param cache_info The cache_info of cache file location. + * \param page_size Page size for external memory. */ static void CreateRowPage(dmlc::Parser* src, - const std::string& cache_info); + const std::string& cache_info, + const size_t page_size = DMatrix::kPageSize); /*! * \brief Create source cache by copy content from DMatrix. * \param cache_info The cache_info of cache file location. @@ -73,14 +75,13 @@ class SparsePageSource : public DataSource { */ static bool CacheExist(const std::string& cache_info, const std::string& page_type); - /*! \brief page size 32 MB */ - static const size_t kPageSize = 32UL << 20UL; /*! \brief magic number used to identify Page */ static const int kMagic = 0xffffab02; private: static void CreatePageFromDMatrix(DMatrix* src, const std::string& cache_info, - const std::string& page_type); + const std::string& page_type, + const size_t page_size = DMatrix::kPageSize); /*! \brief number of rows */ size_t base_rowid_; /*! \brief page currently on hold. */ diff --git a/tests/cpp/data/test_sparse_page_dmatrix.cc b/tests/cpp/data/test_sparse_page_dmatrix.cc index 8d478e487..a99be4c63 100644 --- a/tests/cpp/data/test_sparse_page_dmatrix.cc +++ b/tests/cpp/data/test_sparse_page_dmatrix.cc @@ -29,16 +29,19 @@ TEST(SparsePageDMatrix, RowAccess) { // Create sufficiently large data to make two row pages dmlc::TemporaryDirectory tempdir; const std::string tmp_file = tempdir.path + "/big.libsvm"; - CreateBigTestData(tmp_file, 5000000); + CreateBigTestData(tmp_file, 12); xgboost::DMatrix * dmat = xgboost::DMatrix::Load( - tmp_file + "#" + tmp_file + ".cache", true, false); + tmp_file + "#" + tmp_file + ".cache", true, false, "auto", 64UL); EXPECT_TRUE(FileExists(tmp_file + ".cache.row.page")); // Loop over the batches and count the records + int64_t batch_count = 0; int64_t row_count = 0; - for (auto &batch : dmat->GetRowBatches()) { + for (const auto &batch : dmat->GetRowBatches()) { + batch_count++; row_count += batch.Size(); } + EXPECT_EQ(batch_count, 2); EXPECT_EQ(row_count, dmat->Info().num_row_); // Test the data read into the first row diff --git a/tests/cpp/predictor/test_cpu_predictor.cc b/tests/cpp/predictor/test_cpu_predictor.cc index 6fc844113..8f2712fdc 100644 --- a/tests/cpp/predictor/test_cpu_predictor.cc +++ b/tests/cpp/predictor/test_cpu_predictor.cc @@ -1,4 +1,5 @@ // Copyright by Contributors +#include #include #include #include "../helpers.h" @@ -59,4 +60,66 @@ TEST(cpu_predictor, Test) { delete dmat; } + +TEST(cpu_predictor, ExternalMemoryTest) { + // Create sufficiently large data to make two row pages + dmlc::TemporaryDirectory tempdir; + const std::string tmp_file = tempdir.path + "/big.libsvm"; + CreateBigTestData(tmp_file, 12); + xgboost::DMatrix *dmat = xgboost::DMatrix::Load( + tmp_file + "#" + tmp_file + ".cache", true, false, "auto", 64UL); + EXPECT_TRUE(FileExists(tmp_file + ".cache.row.page")); + int64_t batche_count = 0; + for (const auto &batch : dmat->GetRowBatches()) { + batche_count++; + } + EXPECT_EQ(batche_count, 2); + + std::unique_ptr cpu_predictor = + std::unique_ptr(Predictor::Create("cpu_predictor")); + + std::vector> trees; + trees.push_back(std::unique_ptr(new RegTree)); + (*trees.back())[0].SetLeaf(1.5f); + (*trees.back()).Stat(0).sum_hess = 1.0f; + gbm::GBTreeModel model(0.5); + model.CommitModel(std::move(trees), 0); + model.param.num_output_group = 1; + model.base_margin = 0; + + // Test predict batch + HostDeviceVector out_predictions; + cpu_predictor->PredictBatch(dmat, &out_predictions, model, 0); + std::vector &out_predictions_h = out_predictions.HostVector(); + EXPECT_EQ(out_predictions.Size(), dmat->Info().num_row_); + for (const auto& v : out_predictions_h) { + ASSERT_EQ(v, 1.5); + } + + // Test predict leaf + std::vector leaf_out_predictions; + cpu_predictor->PredictLeaf(dmat, &leaf_out_predictions, model); + EXPECT_EQ(leaf_out_predictions.size(), dmat->Info().num_row_); + for (const auto& v : leaf_out_predictions) { + ASSERT_EQ(v, 0); + } + + // Test predict contribution + std::vector out_contribution; + cpu_predictor->PredictContribution(dmat, &out_contribution, model); + EXPECT_EQ(out_contribution.size(), dmat->Info().num_row_); + for (const auto& v : out_contribution) { + ASSERT_EQ(v, 1.5); + } + + // Test predict contribution (approximate method) + std::vector out_contribution_approximate; + cpu_predictor->PredictContribution(dmat, &out_contribution_approximate, model, true); + EXPECT_EQ(out_contribution_approximate.size(), dmat->Info().num_row_); + for (const auto& v : out_contribution_approximate) { + ASSERT_EQ(v, 1.5); + } + + delete dmat; +} } // namespace xgboost