diff --git a/Jenkinsfile-win64 b/Jenkinsfile-win64 index 573a900a7..b15683342 100644 --- a/Jenkinsfile-win64 +++ b/Jenkinsfile-win64 @@ -85,7 +85,7 @@ def BuildWin64() { bat """ mkdir build cd build - cmake .. -G"Visual Studio 15 2017 Win64" -DUSE_CUDA=ON -DCMAKE_VERBOSE_MAKEFILE=ON -DGOOGLE_TEST=ON -DUSE_DMLC_GTEST=ON ${arch_flag} + cmake .. -G"Visual Studio 15 2017 Win64" -DUSE_CUDA=ON -DCMAKE_VERBOSE_MAKEFILE=ON -DGOOGLE_TEST=ON -DUSE_DMLC_GTEST=ON ${arch_flag} -DCMAKE_UNITY_BUILD=ON """ bat """ cd build diff --git a/amalgamation/xgboost-all0.cc b/amalgamation/xgboost-all0.cc index 37e4168d1..921e12bfe 100644 --- a/amalgamation/xgboost-all0.cc +++ b/amalgamation/xgboost-all0.cc @@ -44,6 +44,7 @@ #if DMLC_ENABLE_STD_THREAD #include "../src/data/sparse_page_dmatrix.cc" +#include "../src/data/sparse_page_source.cc" #endif // trees diff --git a/include/xgboost/data.h b/include/xgboost/data.h index 449c785b4..7226c7b82 100644 --- a/include/xgboost/data.h +++ b/include/xgboost/data.h @@ -549,8 +549,9 @@ class DMatrix { int max_bin); virtual DMatrix *Slice(common::Span ridxs) = 0; - /*! \brief page size 32 MB */ - static const size_t kPageSize = 32UL << 20UL; + /*! \brief Number of rows per page in external memory. Approximately 100MB per page for + * dataset with 100 features. */ + static const size_t kPageSize = 32UL << 12UL; protected: virtual BatchSet GetRowBatches() = 0; diff --git a/src/data/data.cc b/src/data/data.cc index cd64f10d8..2d56a6f29 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -830,9 +830,10 @@ void SparsePage::Push(const SparsePage &batch) { const auto& batch_data_vec = batch.data.HostVector(); size_t top = offset_vec.back(); data_vec.resize(top + batch.data.Size()); - std::memcpy(dmlc::BeginPtr(data_vec) + top, - dmlc::BeginPtr(batch_data_vec), - sizeof(Entry) * batch.data.Size()); + if (dmlc::BeginPtr(data_vec) && dmlc::BeginPtr(batch_data_vec)) { + std::memcpy(dmlc::BeginPtr(data_vec) + top, dmlc::BeginPtr(batch_data_vec), + sizeof(Entry) * batch.data.Size()); + } size_t begin = offset.Size(); offset_vec.resize(begin + batch.Size()); for (size_t i = 0; i < batch.Size(); ++i) { diff --git a/src/data/sparse_page_source.cc b/src/data/sparse_page_source.cc new file mode 100644 index 000000000..18376a18e --- /dev/null +++ b/src/data/sparse_page_source.cc @@ -0,0 +1,77 @@ +/*! + * Copyright (c) 2020 by XGBoost Contributors + */ +#include "sparse_page_source.h" + +namespace xgboost { +namespace data { +void DataPool::Slice(std::shared_ptr out, size_t offset, + size_t n_rows, size_t entry_offset) const { + auto const &in_offset = pool_.offset.HostVector(); + auto const &in_data = pool_.data.HostVector(); + auto &h_offset = out->offset.HostVector(); + CHECK_LE(offset + n_rows + 1, in_offset.size()); + h_offset.resize(n_rows + 1, 0); + std::transform(in_offset.cbegin() + offset, + in_offset.cbegin() + offset + n_rows + 1, h_offset.begin(), + [=](size_t ptr) { return ptr - entry_offset; }); + + auto &h_data = out->data.HostVector(); + CHECK_GT(h_offset.size(), 0); + size_t n_entries = h_offset.back(); + h_data.resize(n_entries); + + CHECK_EQ(n_entries, in_offset.at(offset + n_rows) - in_offset.at(offset)); + std::copy_n(in_data.cbegin() + in_offset.at(offset), n_entries, + h_data.begin()); +} + +void DataPool::SplitWritePage() { + size_t total = pool_.Size(); + size_t offset = 0; + size_t entry_offset = 0; + do { + size_t n_rows = std::min(page_size_, total - offset); + std::shared_ptr out; + writer_->Alloc(&out); + out->Clear(); + out->SetBaseRowId(inferred_num_rows_); + this->Slice(out, offset, n_rows, entry_offset); + inferred_num_rows_ += out->Size(); + offset += n_rows; + entry_offset += out->data.Size(); + CHECK_NE(out->Size(), 0); + writer_->PushWrite(std::move(out)); + } while (total - offset >= page_size_); + + if (total - offset != 0) { + auto out = std::make_shared(); + this->Slice(out, offset, total - offset, entry_offset); + CHECK_NE(out->Size(), 0); + pool_.Clear(); + pool_.Push(*out); + } else { + pool_.Clear(); + } +} +size_t DataPool::Finalize() { + inferred_num_rows_ += pool_.Size(); + if (pool_.Size() != 0) { + std::shared_ptr page; + this->writer_->Alloc(&page); + page->Clear(); + page->Push(pool_); + this->writer_->PushWrite(std::move(page)); + } + + if (inferred_num_rows_ == 0) { + std::shared_ptr page; + this->writer_->Alloc(&page); + page->Clear(); + this->writer_->PushWrite(std::move(page)); + } + + return inferred_num_rows_; +} +} // namespace data +} // namespace xgboost diff --git a/src/data/sparse_page_source.h b/src/data/sparse_page_source.h index 6db6de9fa..d36c6b07e 100644 --- a/src/data/sparse_page_source.h +++ b/src/data/sparse_page_source.h @@ -3,6 +3,37 @@ * \file page_csr_source.h * External memory data source, saved with sparse_batch_page binary format. * \author Tianqi Chen + * + * ------------------------------------------------- + * Random notes on implementation of external memory + * ------------------------------------------------- + * + * As of XGBoost 1.3, the general pipeline is: + * + * dmlc text file parser --> file adapter --> sparse page source -> data pool --> + * write to binary cache --> load it back ~~> [ other pages (csc, ellpack, sorted csc) --> + * write to binary cache ] --> use it in various algorithms. + * + * ~~> means optional + * + * The dmlc text file parser returns number of blocks based on available threads, which + * can make the data partitioning non-deterministic, so here we set up an extra data pool + * to stage parsed data. As a result, the number of blocks returned by text parser does + * not equal to number of blocks in binary cache. + * + * Binary cache loading is async by the dmlc threaded iterator, which helps performance, + * but as this iterator itself is not thread safe, so calling + * `dmatrix->GetBatches` is also not thread safe. Please note that, the + * threaded iterator is also used inside dmlc text file parser. + * + * Memory consumption is difficult to control due to various reasons. Firstly the text + * parsing doesn't have a batch size, only a hard coded buffer size is available. + * Secondly, everything is loaded/written with async queue, with multiple queues running + * the memory consumption is difficult to measure. + * + * The threaded iterator relies heavily on C++ memory model and threading primitive. The + * concurrent writer for binary cache is an old copy of moody queue. We should try to + * replace them with something more robust. */ #ifndef XGBOOST_DATA_SPARSE_PAGE_SOURCE_H_ #define XGBOOST_DATA_SPARSE_PAGE_SOURCE_H_ @@ -19,6 +50,7 @@ #include #include +#include "rabit/rabit.h" #include "xgboost/base.h" #include "xgboost/data.h" @@ -121,9 +153,12 @@ inline void TryDeleteCacheFile(const std::string& file) { inline void CheckCacheFileExists(const std::string& file) { std::ifstream f(file.c_str()); if (f.good()) { - LOG(FATAL) << "Cache file " << file - << " exists already; Is there another DMatrix with the same " - "cache prefix? Otherwise please remove it manually."; + LOG(FATAL) + << "Cache file " << file << " exists already; " + << "Is there another DMatrix with the same " + "cache prefix? It can be caused by previously used DMatrix that " + "hasn't been collected by language environment garbage collector. " + "Otherwise please remove it manually."; } } @@ -231,6 +266,38 @@ class ExternalMemoryPrefetcher : dmlc::DataIter { std::vector>> prefetchers_; }; + +// A data pool to keep the size of each page balanced and data partitioning to be +// deterministic. +class DataPool { + size_t inferred_num_rows_; + MetaInfo* info_; + SparsePage pool_; + size_t page_size_; + SparsePageWriter *writer_; + + void Slice(std::shared_ptr out, size_t offset, size_t n_rows, + size_t entry_offset) const; + void SplitWritePage(); + + public: + DataPool(MetaInfo *info, size_t page_size, + SparsePageWriter *writer) + : inferred_num_rows_{0}, info_{info}, + page_size_{page_size}, writer_{writer} {} + + void Push(std::shared_ptr page) { + info_->num_nonzero_ += page->data.Size(); + pool_.Push(*page); + if (pool_.Size() > page_size_) { + this->SplitWritePage(); + } + page->Clear(); + } + + size_t Finalize(); +}; + class SparsePageSource { public: template @@ -249,17 +316,12 @@ class SparsePageSource { { SparsePageWriter writer(cache_info_.name_shards, cache_info_.format_shards, 6); - std::shared_ptr page; - writer.Alloc(&page); - page->Clear(); + DataPool pool(&info, page_size, &writer); + + std::shared_ptr page { new SparsePage }; uint64_t inferred_num_columns = 0; uint64_t inferred_num_rows = 0; - size_t bytes_write = 0; - double tstart = dmlc::GetTime(); - // print every 4 sec. - constexpr double kStep = 4.0; - size_t tick_expected = static_cast(kStep); const uint64_t default_max = std::numeric_limits::max(); uint64_t last_group_id = default_max; @@ -296,26 +358,13 @@ class SparsePageSource { ++group_size; } } + CHECK_EQ(page->Size(), 0); auto batch_max_columns = page->Push(batch, missing, nthread); inferred_num_columns = std::max(batch_max_columns, inferred_num_columns); - if (page->MemCostBytes() >= page_size) { - inferred_num_rows += page->Size(); - info.num_nonzero_ += page->offset.HostVector().back(); - bytes_write += page->MemCostBytes(); - writer.PushWrite(std::move(page)); - writer.Alloc(&page); - page->Clear(); - page->SetBaseRowId(inferred_num_rows); - - double tdiff = dmlc::GetTime() - tstart; - if (tdiff >= tick_expected) { - LOG(CONSOLE) << "Writing " << page_type << " to " << cache_info - << " in " << ((bytes_write >> 20UL) / tdiff) - << " MB/s, " << (bytes_write >> 20UL) << " written"; - tick_expected += static_cast(kStep); - } - } + inferred_num_rows += page->Size(); + pool.Push(page); + page->SetBaseRowId(inferred_num_rows); } if (last_group_id != default_max) { @@ -323,10 +372,6 @@ class SparsePageSource { info.group_ptr_.push_back(group_size); } } - inferred_num_rows += page->Size(); - if (!page->offset.HostVector().empty()) { - info.num_nonzero_ += page->offset.HostVector().back(); - } // Deal with empty rows/columns if necessary if (adapter->NumColumns() == kAdapterUnknownSize) { @@ -352,10 +397,9 @@ class SparsePageSource { info.num_row_ = adapter->NumRows(); } - // Make sure we have at least one page if the dataset is empty - if (page->data.Size() > 0 || info.num_row_ == 0) { - writer.PushWrite(std::move(page)); - } + pool.Push(page); + pool.Finalize(); + std::unique_ptr fo( dmlc::Stream::Create(cache_info_.name_info.c_str(), "w")); int tmagic = kMagic; diff --git a/tests/cpp/common/test_column_matrix.cc b/tests/cpp/common/test_column_matrix.cc index 93ec8becd..75530fc53 100644 --- a/tests/cpp/common/test_column_matrix.cc +++ b/tests/cpp/common/test_column_matrix.cc @@ -130,8 +130,10 @@ TEST(DenseColumnWithMissing, Test) { void TestGHistIndexMatrixCreation(size_t nthreads) { dmlc::TemporaryDirectory tmpdir; std::string filename = tmpdir.path + "/big.libsvm"; + size_t constexpr kPageSize = 1024, kEntriesPerCol = 3; + size_t constexpr kEntries = kPageSize * kEntriesPerCol * 2; /* This should create multiple sparse pages */ - std::unique_ptr dmat{ CreateSparsePageDMatrix(1024, 1024, filename) }; + std::unique_ptr dmat{ CreateSparsePageDMatrix(kEntries, kPageSize, filename) }; omp_set_num_threads(nthreads); GHistIndexMatrix gmat; gmat.Init(dmat.get(), 256); diff --git a/tests/cpp/data/test_data.cc b/tests/cpp/data/test_data.cc index d01da568c..c63c4b1d7 100644 --- a/tests/cpp/data/test_data.cc +++ b/tests/cpp/data/test_data.cc @@ -44,13 +44,13 @@ TEST(SparsePage, PushCSC) { } auto inst = page[0]; - ASSERT_EQ(inst.size(), 2); + ASSERT_EQ(inst.size(), 2ul); for (auto entry : inst) { - ASSERT_EQ(entry.index, 0); + ASSERT_EQ(entry.index, 0u); } inst = page[1]; - ASSERT_EQ(inst.size(), 6); + ASSERT_EQ(inst.size(), 6ul); std::vector indices_sol {1, 2, 3}; for (size_t i = 0; i < inst.size(); ++i) { ASSERT_EQ(inst[i].index, indices_sol[i % 3]); @@ -58,15 +58,12 @@ TEST(SparsePage, PushCSC) { } TEST(SparsePage, PushCSCAfterTranspose) { -#if defined(__APPLE__) - LOG(WARNING) << "FIXME(trivialfis): Skipping `PushCSCAfterTranspose' for APPLE."; - return; -#endif dmlc::TemporaryDirectory tmpdir; std::string filename = tmpdir.path + "/big.libsvm"; - const int n_entries = 9; + size_t constexpr kPageSize = 1024, kEntriesPerCol = 3; + size_t constexpr kEntries = kPageSize * kEntriesPerCol * 2; std::unique_ptr dmat = - CreateSparsePageDMatrix(n_entries, 64UL, filename); + CreateSparsePageDMatrix(kEntries, 64UL, filename); const int ncols = dmat->Info().num_col_; SparsePage page; // Consolidated sparse page for (const auto &batch : dmat->GetBatches()) { @@ -76,7 +73,7 @@ TEST(SparsePage, PushCSCAfterTranspose) { } // Make sure that the final sparse page has the right number of entries - ASSERT_EQ(n_entries, page.data.Size()); + ASSERT_EQ(kEntries, page.data.Size()); // The feature value for a feature in each row should be identical, as that is // how the dmatrix has been created diff --git a/tests/cpp/data/test_sparse_page_dmatrix.cc b/tests/cpp/data/test_sparse_page_dmatrix.cc index 5c719f78a..d75393438 100644 --- a/tests/cpp/data/test_sparse_page_dmatrix.cc +++ b/tests/cpp/data/test_sparse_page_dmatrix.cc @@ -2,6 +2,7 @@ #include #include #include +#include "../../../src/common/io.h" #include "../../../src/data/adapter.h" #include "../../../src/data/sparse_page_dmatrix.h" #include "../helpers.h" @@ -11,16 +12,18 @@ using namespace xgboost; // NOLINT TEST(SparsePageDMatrix, MetaInfo) { dmlc::TemporaryDirectory tempdir; const std::string tmp_file = tempdir.path + "/simple.libsvm"; - CreateSimpleTestData(tmp_file); + size_t constexpr kEntries = 24; + CreateBigTestData(tmp_file, kEntries); + xgboost::DMatrix *dmat = xgboost::DMatrix::Load( tmp_file + "#" + tmp_file + ".cache", false, false); std::cout << tmp_file << std::endl; EXPECT_TRUE(FileExists(tmp_file + ".cache")); // Test the metadata that was parsed - EXPECT_EQ(dmat->Info().num_row_, 2); - EXPECT_EQ(dmat->Info().num_col_, 5); - EXPECT_EQ(dmat->Info().num_nonzero_, 6); + EXPECT_EQ(dmat->Info().num_row_, 8ul); + EXPECT_EQ(dmat->Info().num_col_, 5ul); + EXPECT_EQ(dmat->Info().num_nonzero_, kEntries); EXPECT_EQ(dmat->Info().labels_.Size(), dmat->Info().num_row_); delete dmat; @@ -30,13 +33,13 @@ TEST(SparsePageDMatrix, RowAccess) { dmlc::TemporaryDirectory tmpdir; std::string filename = tmpdir.path + "/big.libsvm"; std::unique_ptr dmat = - xgboost::CreateSparsePageDMatrix(12, 64, filename); + xgboost::CreateSparsePageDMatrix(24, 4, filename); // Test the data read into the first row auto &batch = *dmat->GetBatches().begin(); auto first_row = batch[0]; - ASSERT_EQ(first_row.size(), 3); - EXPECT_EQ(first_row[2].index, 2); + ASSERT_EQ(first_row.size(), 3ul); + EXPECT_EQ(first_row[2].index, 2u); EXPECT_EQ(first_row[2].fvalue, 20); } @@ -77,43 +80,56 @@ TEST(SparsePageDMatrix, ColAccess) { TEST(SparsePageDMatrix, ExistingCacheFile) { dmlc::TemporaryDirectory tmpdir; std::string filename = tmpdir.path + "/big.libsvm"; + size_t constexpr kPageSize = 64, kEntriesPerCol = 3; + size_t constexpr kEntries = kPageSize * kEntriesPerCol * 2; std::unique_ptr dmat = - xgboost::CreateSparsePageDMatrix(12, 64, filename); + xgboost::CreateSparsePageDMatrix(kEntries, kPageSize, filename); EXPECT_ANY_THROW({ std::unique_ptr dmat2 = - xgboost::CreateSparsePageDMatrix(12, 64, filename); + xgboost::CreateSparsePageDMatrix(kEntries, kPageSize, filename); }); } -#if defined(_OPENMP) TEST(SparsePageDMatrix, ThreadSafetyException) { dmlc::TemporaryDirectory tmpdir; std::string filename = tmpdir.path + "/test"; - std::unique_ptr dmat = - xgboost::CreateSparsePageDMatrix(12, 64, filename); + size_t constexpr kPageSize = 64, kEntriesPerCol = 3; + size_t constexpr kEntries = kPageSize * kEntriesPerCol * 2; - bool exception = false; + std::unique_ptr dmat = + xgboost::CreateSparsePageDMatrix(kEntries, kPageSize, filename); + + std::atomic exception {false}; int threads = 1000; -#pragma omp parallel for num_threads(threads) - for (auto i = 0; i < threads; i++) { - try { - auto iter = dmat->GetBatches().begin(); - ++iter; - } catch (...) { - exception = true; - } + + std::vector waiting; + + for (int32_t i = 0; i < threads; ++i) { + waiting.emplace_back([&]() { + try { + auto iter = dmat->GetBatches().begin(); + ++iter; + } catch (...) { + exception = true; + } + }); + } + + for (auto& t : waiting) { + t.join(); } EXPECT_TRUE(exception); } -#endif // Multi-batches access TEST(SparsePageDMatrix, ColAccessBatches) { dmlc::TemporaryDirectory tmpdir; std::string filename = tmpdir.path + "/big.libsvm"; + size_t constexpr kPageSize = 1024, kEntriesPerCol = 3; + size_t constexpr kEntries = kPageSize * kEntriesPerCol * 2; // Create multiple sparse pages std::unique_ptr dmat{ - xgboost::CreateSparsePageDMatrix(1024, 1024, filename)}; + xgboost::CreateSparsePageDMatrix(kEntries, kPageSize, filename)}; auto n_threads = omp_get_max_threads(); omp_set_num_threads(16); for (auto const &page : dmat->GetBatches()) { @@ -286,3 +302,70 @@ TEST(SparsePageDMatrix, FromFile) { } } } + +TEST(SparsePageDMatrix, Large) { + std::string filename = "test.libsvm"; + CreateBigTestData(filename, 1 << 16); + std::unique_ptr> parser( + dmlc::Parser::Create(filename.c_str(), 0, 1, "auto")); + data::FileAdapter adapter(parser.get()); + dmlc::TemporaryDirectory tempdir; + const std::string tmp_file = tempdir.path + "/simple.libsvm"; + + std::unique_ptr sparse{new data::SparsePageDMatrix( + &adapter, std::numeric_limits::quiet_NaN(), -1, tmp_file, 16)}; + std::unique_ptr simple{DMatrix::Load(filename, true, true)}; + + std::vector sparse_data; + std::vector sparse_rptr; + std::vector sparse_cids; + DMatrixToCSR(sparse.get(), &sparse_data, &sparse_rptr, &sparse_cids); + + std::vector simple_data; + std::vector simple_rptr; + std::vector simple_cids; + DMatrixToCSR(simple.get(), &simple_data, &simple_rptr, &simple_cids); + + ASSERT_EQ(sparse_rptr.size(), sparse->Info().num_row_ + 1); + ASSERT_EQ(sparse_rptr.size(), simple->Info().num_row_ + 1); + + ASSERT_EQ(sparse_data.size(), simple_data.size()); + ASSERT_EQ(sparse_data, simple_data); + ASSERT_EQ(sparse_rptr.size(), simple_rptr.size()); + ASSERT_EQ(sparse_rptr, simple_rptr); + ASSERT_EQ(sparse_cids, simple_cids); +} + +auto TestSparsePageDMatrixDeterminism(int32_t threads, std::string const& filename) { + omp_set_num_threads(threads); + std::vector sparse_data; + std::vector sparse_rptr; + std::vector sparse_cids; + + std::unique_ptr> parser( + dmlc::Parser::Create(filename.c_str(), 0, 1, "auto")); + data::FileAdapter adapter(parser.get()); + dmlc::TemporaryDirectory tempdir; + const std::string tmp_file = tempdir.path + "/simple.libsvm"; + std::unique_ptr sparse{new data::SparsePageDMatrix( + &adapter, std::numeric_limits::quiet_NaN(), -1, tmp_file, 1 << 8)}; + + DMatrixToCSR(sparse.get(), &sparse_data, &sparse_rptr, &sparse_cids); + + std::string cache_name = tmp_file + ".row.page"; + std::string cache = common::LoadSequentialFile(cache_name); + return cache; +} + +TEST(SparsePageDMatrix, Determinism) { + std::string filename = "test.libsvm"; + CreateBigTestData(filename, 1 << 16); + std::vector caches; + for (size_t i = 1; i < 18; i += 2) { + caches.emplace_back(TestSparsePageDMatrixDeterminism(i, filename)); + } + + for (size_t i = 1; i < caches.size(); ++i) { + ASSERT_EQ(caches[i], caches.front()); + } +} diff --git a/tests/cpp/data/test_sparse_page_dmatrix.cu b/tests/cpp/data/test_sparse_page_dmatrix.cu index cf4059d40..29ac89040 100644 --- a/tests/cpp/data/test_sparse_page_dmatrix.cu +++ b/tests/cpp/data/test_sparse_page_dmatrix.cu @@ -28,7 +28,9 @@ TEST(SparsePageDMatrix, EllpackPage) { TEST(SparsePageDMatrix, MultipleEllpackPages) { dmlc::TemporaryDirectory tmpdir; std::string filename = tmpdir.path + "/big.libsvm"; - std::unique_ptr dmat = CreateSparsePageDMatrix(12, 64, filename); + size_t constexpr kPageSize = 64, kEntriesPerCol = 3; + size_t constexpr kEntries = kPageSize * kEntriesPerCol * 2; + std::unique_ptr dmat = CreateSparsePageDMatrix(kEntries, kPageSize, filename); // Loop over the batches and count the records int64_t batch_count = 0; diff --git a/tests/cpp/helpers.cc b/tests/cpp/helpers.cc index 1b319a887..d0cca1c74 100644 --- a/tests/cpp/helpers.cc +++ b/tests/cpp/helpers.cc @@ -373,12 +373,8 @@ std::unique_ptr CreateSparsePageDMatrix( batch_count++; row_count += batch.Size(); } -#if defined(_OPENMP) EXPECT_GE(batch_count, 2); EXPECT_EQ(row_count, dmat->Info().num_row_); -#else -#warning "External memory doesn't work with Non-OpenMP build " -#endif // defined(_OPENMP) return dmat; } @@ -495,6 +491,36 @@ std::unique_ptr CreateTrainedGBM( return gbm; } +void DMatrixToCSR(DMatrix *dmat, std::vector *p_data, + std::vector *p_row_ptr, + std::vector *p_cids) { + auto &data = *p_data; + auto &row_ptr = *p_row_ptr; + auto &cids = *p_cids; + + data.resize(dmat->Info().num_nonzero_); + cids.resize(data.size()); + row_ptr.resize(dmat->Info().num_row_ + 1); + SparsePage page; + for (const auto &batch : dmat->GetBatches()) { + page.Push(batch); + } + + auto const& in_offset = page.offset.HostVector(); + auto const& in_data = page.data.HostVector(); + + CHECK_EQ(in_offset.size(), row_ptr.size()); + std::copy(in_offset.cbegin(), in_offset.cend(), row_ptr.begin()); + ASSERT_EQ(in_data.size(), data.size()); + std::transform(in_data.cbegin(), in_data.cend(), data.begin(), [](Entry const& e) { + return e.fvalue; + }); + ASSERT_EQ(in_data.size(), cids.size()); + std::transform(in_data.cbegin(), in_data.cend(), cids.begin(), [](Entry const& e) { + return e.index; + }); +} + #if defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1 using CUDAMemoryResource = rmm::mr::cuda_memory_resource; diff --git a/tests/cpp/helpers.h b/tests/cpp/helpers.h index 02bde0598..df7a26851 100644 --- a/tests/cpp/helpers.h +++ b/tests/cpp/helpers.h @@ -365,6 +365,10 @@ class CudaArrayIterForTest { auto Proxy() -> decltype(proxy_) { return proxy_; } }; +void DMatrixToCSR(DMatrix *dmat, std::vector *p_data, + std::vector *p_row_ptr, + std::vector *p_cids); + typedef void *DataIterHandle; // NOLINT(*) inline void Reset(DataIterHandle self) { diff --git a/tests/cpp/predictor/test_cpu_predictor.cc b/tests/cpp/predictor/test_cpu_predictor.cc index 4191682eb..6ec07b353 100644 --- a/tests/cpp/predictor/test_cpu_predictor.cc +++ b/tests/cpp/predictor/test_cpu_predictor.cc @@ -16,8 +16,8 @@ TEST(CpuPredictor, Basic) { std::unique_ptr cpu_predictor = std::unique_ptr(Predictor::Create("cpu_predictor", &lparam)); - int kRows = 5; - int kCols = 5; + size_t constexpr kRows = 5; + size_t constexpr kCols = 5; LearnerModelParam param; param.num_feature = kCols; @@ -85,7 +85,11 @@ TEST(CpuPredictor, Basic) { TEST(CpuPredictor, ExternalMemory) { dmlc::TemporaryDirectory tmpdir; std::string filename = tmpdir.path + "/big.libsvm"; - std::unique_ptr dmat = CreateSparsePageDMatrix(12, 64, filename); + + size_t constexpr kPageSize = 64, kEntriesPerCol = 3; + size_t constexpr kEntries = kPageSize * kEntriesPerCol * 2; + + std::unique_ptr dmat = CreateSparsePageDMatrix(kEntries, kPageSize, filename); auto lparam = CreateEmptyGenericParam(GPUIDX); std::unique_ptr cpu_predictor = diff --git a/tests/cpp/predictor/test_gpu_predictor.cu b/tests/cpp/predictor/test_gpu_predictor.cu index b48e49086..f7e52c644 100644 --- a/tests/cpp/predictor/test_gpu_predictor.cu +++ b/tests/cpp/predictor/test_gpu_predictor.cu @@ -105,9 +105,9 @@ TEST(GPUPredictor, ExternalMemoryTest) { std::string file0 = tmpdir.path + "/big_0.libsvm"; std::string file1 = tmpdir.path + "/big_1.libsvm"; std::string file2 = tmpdir.path + "/big_2.libsvm"; - dmats.push_back(CreateSparsePageDMatrix(9, 64UL, file0)); - dmats.push_back(CreateSparsePageDMatrix(128, 128UL, file1)); - dmats.push_back(CreateSparsePageDMatrix(1024, 1024UL, file2)); + dmats.push_back(CreateSparsePageDMatrix(400, 64UL, file0)); + dmats.push_back(CreateSparsePageDMatrix(800, 128UL, file1)); + dmats.push_back(CreateSparsePageDMatrix(8000, 1024UL, file2)); for (const auto& dmat: dmats) { dmat->Info().base_margin_.Resize(dmat->Info().num_row_ * n_classes, 0.5); diff --git a/tests/python-gpu/test_gpu_updaters.py b/tests/python-gpu/test_gpu_updaters.py index c1a076144..ce00dbfaa 100644 --- a/tests/python-gpu/test_gpu_updaters.py +++ b/tests/python-gpu/test_gpu_updaters.py @@ -1,5 +1,6 @@ import numpy as np import sys +import gc import pytest import xgboost as xgb from hypothesis import given, strategies, assume, settings, note @@ -118,7 +119,10 @@ class TestGPUUpdaters: assume(len(dataset.y) > 0) param['tree_method'] = 'gpu_hist' param = dataset.set_params(param) - external_result = train_result(param, dataset.get_external_dmat(), num_rounds) + m = dataset.get_external_dmat() + external_result = train_result(param, m, num_rounds) + del m + gc.collect() assert tm.non_increasing(external_result['train'][dataset.metric]) def test_empty_dmatrix_prediction(self):