Deterministic data partitioning for external memory (#6317)

* Make external memory data partitioning deterministic.

* Change the meaning of `page_size` from bytes to number of rows.

* Design a data pool.

* Note for external memory.

* Enable unity build on Windows CI.

* Force garbage collect on test.
This commit is contained in:
Jiaming Yuan 2020-11-11 06:11:06 +08:00 committed by GitHub
parent 9564886d9f
commit 43efadea2e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 334 additions and 88 deletions

View File

@ -85,7 +85,7 @@ def BuildWin64() {
bat """ bat """
mkdir build mkdir build
cd build cd build
cmake .. -G"Visual Studio 15 2017 Win64" -DUSE_CUDA=ON -DCMAKE_VERBOSE_MAKEFILE=ON -DGOOGLE_TEST=ON -DUSE_DMLC_GTEST=ON ${arch_flag} cmake .. -G"Visual Studio 15 2017 Win64" -DUSE_CUDA=ON -DCMAKE_VERBOSE_MAKEFILE=ON -DGOOGLE_TEST=ON -DUSE_DMLC_GTEST=ON ${arch_flag} -DCMAKE_UNITY_BUILD=ON
""" """
bat """ bat """
cd build cd build

View File

@ -44,6 +44,7 @@
#if DMLC_ENABLE_STD_THREAD #if DMLC_ENABLE_STD_THREAD
#include "../src/data/sparse_page_dmatrix.cc" #include "../src/data/sparse_page_dmatrix.cc"
#include "../src/data/sparse_page_source.cc"
#endif #endif
// trees // trees

View File

@ -549,8 +549,9 @@ class DMatrix {
int max_bin); int max_bin);
virtual DMatrix *Slice(common::Span<int32_t const> ridxs) = 0; virtual DMatrix *Slice(common::Span<int32_t const> ridxs) = 0;
/*! \brief page size 32 MB */ /*! \brief Number of rows per page in external memory. Approximately 100MB per page for
static const size_t kPageSize = 32UL << 20UL; * dataset with 100 features. */
static const size_t kPageSize = 32UL << 12UL;
protected: protected:
virtual BatchSet<SparsePage> GetRowBatches() = 0; virtual BatchSet<SparsePage> GetRowBatches() = 0;

View File

@ -830,9 +830,10 @@ void SparsePage::Push(const SparsePage &batch) {
const auto& batch_data_vec = batch.data.HostVector(); const auto& batch_data_vec = batch.data.HostVector();
size_t top = offset_vec.back(); size_t top = offset_vec.back();
data_vec.resize(top + batch.data.Size()); data_vec.resize(top + batch.data.Size());
std::memcpy(dmlc::BeginPtr(data_vec) + top, if (dmlc::BeginPtr(data_vec) && dmlc::BeginPtr(batch_data_vec)) {
dmlc::BeginPtr(batch_data_vec), std::memcpy(dmlc::BeginPtr(data_vec) + top, dmlc::BeginPtr(batch_data_vec),
sizeof(Entry) * batch.data.Size()); sizeof(Entry) * batch.data.Size());
}
size_t begin = offset.Size(); size_t begin = offset.Size();
offset_vec.resize(begin + batch.Size()); offset_vec.resize(begin + batch.Size());
for (size_t i = 0; i < batch.Size(); ++i) { for (size_t i = 0; i < batch.Size(); ++i) {

View File

@ -0,0 +1,77 @@
/*!
* Copyright (c) 2020 by XGBoost Contributors
*/
#include "sparse_page_source.h"
namespace xgboost {
namespace data {
void DataPool::Slice(std::shared_ptr<SparsePage> out, size_t offset,
size_t n_rows, size_t entry_offset) const {
auto const &in_offset = pool_.offset.HostVector();
auto const &in_data = pool_.data.HostVector();
auto &h_offset = out->offset.HostVector();
CHECK_LE(offset + n_rows + 1, in_offset.size());
h_offset.resize(n_rows + 1, 0);
std::transform(in_offset.cbegin() + offset,
in_offset.cbegin() + offset + n_rows + 1, h_offset.begin(),
[=](size_t ptr) { return ptr - entry_offset; });
auto &h_data = out->data.HostVector();
CHECK_GT(h_offset.size(), 0);
size_t n_entries = h_offset.back();
h_data.resize(n_entries);
CHECK_EQ(n_entries, in_offset.at(offset + n_rows) - in_offset.at(offset));
std::copy_n(in_data.cbegin() + in_offset.at(offset), n_entries,
h_data.begin());
}
void DataPool::SplitWritePage() {
size_t total = pool_.Size();
size_t offset = 0;
size_t entry_offset = 0;
do {
size_t n_rows = std::min(page_size_, total - offset);
std::shared_ptr<SparsePage> out;
writer_->Alloc(&out);
out->Clear();
out->SetBaseRowId(inferred_num_rows_);
this->Slice(out, offset, n_rows, entry_offset);
inferred_num_rows_ += out->Size();
offset += n_rows;
entry_offset += out->data.Size();
CHECK_NE(out->Size(), 0);
writer_->PushWrite(std::move(out));
} while (total - offset >= page_size_);
if (total - offset != 0) {
auto out = std::make_shared<SparsePage>();
this->Slice(out, offset, total - offset, entry_offset);
CHECK_NE(out->Size(), 0);
pool_.Clear();
pool_.Push(*out);
} else {
pool_.Clear();
}
}
size_t DataPool::Finalize() {
inferred_num_rows_ += pool_.Size();
if (pool_.Size() != 0) {
std::shared_ptr<SparsePage> page;
this->writer_->Alloc(&page);
page->Clear();
page->Push(pool_);
this->writer_->PushWrite(std::move(page));
}
if (inferred_num_rows_ == 0) {
std::shared_ptr<SparsePage> page;
this->writer_->Alloc(&page);
page->Clear();
this->writer_->PushWrite(std::move(page));
}
return inferred_num_rows_;
}
} // namespace data
} // namespace xgboost

View File

@ -3,6 +3,37 @@
* \file page_csr_source.h * \file page_csr_source.h
* External memory data source, saved with sparse_batch_page binary format. * External memory data source, saved with sparse_batch_page binary format.
* \author Tianqi Chen * \author Tianqi Chen
*
* -------------------------------------------------
* Random notes on implementation of external memory
* -------------------------------------------------
*
* As of XGBoost 1.3, the general pipeline is:
*
* dmlc text file parser --> file adapter --> sparse page source -> data pool -->
* write to binary cache --> load it back ~~> [ other pages (csc, ellpack, sorted csc) -->
* write to binary cache ] --> use it in various algorithms.
*
* ~~> means optional
*
* The dmlc text file parser returns number of blocks based on available threads, which
* can make the data partitioning non-deterministic, so here we set up an extra data pool
* to stage parsed data. As a result, the number of blocks returned by text parser does
* not equal to number of blocks in binary cache.
*
* Binary cache loading is async by the dmlc threaded iterator, which helps performance,
* but as this iterator itself is not thread safe, so calling
* `dmatrix->GetBatches<SparsePage>` is also not thread safe. Please note that, the
* threaded iterator is also used inside dmlc text file parser.
*
* Memory consumption is difficult to control due to various reasons. Firstly the text
* parsing doesn't have a batch size, only a hard coded buffer size is available.
* Secondly, everything is loaded/written with async queue, with multiple queues running
* the memory consumption is difficult to measure.
*
* The threaded iterator relies heavily on C++ memory model and threading primitive. The
* concurrent writer for binary cache is an old copy of moody queue. We should try to
* replace them with something more robust.
*/ */
#ifndef XGBOOST_DATA_SPARSE_PAGE_SOURCE_H_ #ifndef XGBOOST_DATA_SPARSE_PAGE_SOURCE_H_
#define XGBOOST_DATA_SPARSE_PAGE_SOURCE_H_ #define XGBOOST_DATA_SPARSE_PAGE_SOURCE_H_
@ -19,6 +50,7 @@
#include <vector> #include <vector>
#include <fstream> #include <fstream>
#include "rabit/rabit.h"
#include "xgboost/base.h" #include "xgboost/base.h"
#include "xgboost/data.h" #include "xgboost/data.h"
@ -121,9 +153,12 @@ inline void TryDeleteCacheFile(const std::string& file) {
inline void CheckCacheFileExists(const std::string& file) { inline void CheckCacheFileExists(const std::string& file) {
std::ifstream f(file.c_str()); std::ifstream f(file.c_str());
if (f.good()) { if (f.good()) {
LOG(FATAL) << "Cache file " << file LOG(FATAL)
<< " exists already; Is there another DMatrix with the same " << "Cache file " << file << " exists already; "
"cache prefix? Otherwise please remove it manually."; << "Is there another DMatrix with the same "
"cache prefix? It can be caused by previously used DMatrix that "
"hasn't been collected by language environment garbage collector. "
"Otherwise please remove it manually.";
} }
} }
@ -231,6 +266,38 @@ class ExternalMemoryPrefetcher : dmlc::DataIter<PageT> {
std::vector<std::unique_ptr<dmlc::ThreadedIter<PageT>>> prefetchers_; std::vector<std::unique_ptr<dmlc::ThreadedIter<PageT>>> prefetchers_;
}; };
// A data pool to keep the size of each page balanced and data partitioning to be
// deterministic.
class DataPool {
size_t inferred_num_rows_;
MetaInfo* info_;
SparsePage pool_;
size_t page_size_;
SparsePageWriter<SparsePage> *writer_;
void Slice(std::shared_ptr<SparsePage> out, size_t offset, size_t n_rows,
size_t entry_offset) const;
void SplitWritePage();
public:
DataPool(MetaInfo *info, size_t page_size,
SparsePageWriter<SparsePage> *writer)
: inferred_num_rows_{0}, info_{info},
page_size_{page_size}, writer_{writer} {}
void Push(std::shared_ptr<SparsePage> page) {
info_->num_nonzero_ += page->data.Size();
pool_.Push(*page);
if (pool_.Size() > page_size_) {
this->SplitWritePage();
}
page->Clear();
}
size_t Finalize();
};
class SparsePageSource { class SparsePageSource {
public: public:
template <typename AdapterT> template <typename AdapterT>
@ -249,17 +316,12 @@ class SparsePageSource {
{ {
SparsePageWriter<SparsePage> writer(cache_info_.name_shards, SparsePageWriter<SparsePage> writer(cache_info_.name_shards,
cache_info_.format_shards, 6); cache_info_.format_shards, 6);
std::shared_ptr<SparsePage> page; DataPool pool(&info, page_size, &writer);
writer.Alloc(&page);
page->Clear(); std::shared_ptr<SparsePage> page { new SparsePage };
uint64_t inferred_num_columns = 0; uint64_t inferred_num_columns = 0;
uint64_t inferred_num_rows = 0; uint64_t inferred_num_rows = 0;
size_t bytes_write = 0;
double tstart = dmlc::GetTime();
// print every 4 sec.
constexpr double kStep = 4.0;
size_t tick_expected = static_cast<double>(kStep);
const uint64_t default_max = std::numeric_limits<uint64_t>::max(); const uint64_t default_max = std::numeric_limits<uint64_t>::max();
uint64_t last_group_id = default_max; uint64_t last_group_id = default_max;
@ -296,26 +358,13 @@ class SparsePageSource {
++group_size; ++group_size;
} }
} }
CHECK_EQ(page->Size(), 0);
auto batch_max_columns = page->Push(batch, missing, nthread); auto batch_max_columns = page->Push(batch, missing, nthread);
inferred_num_columns = inferred_num_columns =
std::max(batch_max_columns, inferred_num_columns); std::max(batch_max_columns, inferred_num_columns);
if (page->MemCostBytes() >= page_size) {
inferred_num_rows += page->Size(); inferred_num_rows += page->Size();
info.num_nonzero_ += page->offset.HostVector().back(); pool.Push(page);
bytes_write += page->MemCostBytes();
writer.PushWrite(std::move(page));
writer.Alloc(&page);
page->Clear();
page->SetBaseRowId(inferred_num_rows); page->SetBaseRowId(inferred_num_rows);
double tdiff = dmlc::GetTime() - tstart;
if (tdiff >= tick_expected) {
LOG(CONSOLE) << "Writing " << page_type << " to " << cache_info
<< " in " << ((bytes_write >> 20UL) / tdiff)
<< " MB/s, " << (bytes_write >> 20UL) << " written";
tick_expected += static_cast<size_t>(kStep);
}
}
} }
if (last_group_id != default_max) { if (last_group_id != default_max) {
@ -323,10 +372,6 @@ class SparsePageSource {
info.group_ptr_.push_back(group_size); info.group_ptr_.push_back(group_size);
} }
} }
inferred_num_rows += page->Size();
if (!page->offset.HostVector().empty()) {
info.num_nonzero_ += page->offset.HostVector().back();
}
// Deal with empty rows/columns if necessary // Deal with empty rows/columns if necessary
if (adapter->NumColumns() == kAdapterUnknownSize) { if (adapter->NumColumns() == kAdapterUnknownSize) {
@ -352,10 +397,9 @@ class SparsePageSource {
info.num_row_ = adapter->NumRows(); info.num_row_ = adapter->NumRows();
} }
// Make sure we have at least one page if the dataset is empty pool.Push(page);
if (page->data.Size() > 0 || info.num_row_ == 0) { pool.Finalize();
writer.PushWrite(std::move(page));
}
std::unique_ptr<dmlc::Stream> fo( std::unique_ptr<dmlc::Stream> fo(
dmlc::Stream::Create(cache_info_.name_info.c_str(), "w")); dmlc::Stream::Create(cache_info_.name_info.c_str(), "w"));
int tmagic = kMagic; int tmagic = kMagic;

View File

@ -130,8 +130,10 @@ TEST(DenseColumnWithMissing, Test) {
void TestGHistIndexMatrixCreation(size_t nthreads) { void TestGHistIndexMatrixCreation(size_t nthreads) {
dmlc::TemporaryDirectory tmpdir; dmlc::TemporaryDirectory tmpdir;
std::string filename = tmpdir.path + "/big.libsvm"; std::string filename = tmpdir.path + "/big.libsvm";
size_t constexpr kPageSize = 1024, kEntriesPerCol = 3;
size_t constexpr kEntries = kPageSize * kEntriesPerCol * 2;
/* This should create multiple sparse pages */ /* This should create multiple sparse pages */
std::unique_ptr<DMatrix> dmat{ CreateSparsePageDMatrix(1024, 1024, filename) }; std::unique_ptr<DMatrix> dmat{ CreateSparsePageDMatrix(kEntries, kPageSize, filename) };
omp_set_num_threads(nthreads); omp_set_num_threads(nthreads);
GHistIndexMatrix gmat; GHistIndexMatrix gmat;
gmat.Init(dmat.get(), 256); gmat.Init(dmat.get(), 256);

View File

@ -44,13 +44,13 @@ TEST(SparsePage, PushCSC) {
} }
auto inst = page[0]; auto inst = page[0];
ASSERT_EQ(inst.size(), 2); ASSERT_EQ(inst.size(), 2ul);
for (auto entry : inst) { for (auto entry : inst) {
ASSERT_EQ(entry.index, 0); ASSERT_EQ(entry.index, 0u);
} }
inst = page[1]; inst = page[1];
ASSERT_EQ(inst.size(), 6); ASSERT_EQ(inst.size(), 6ul);
std::vector<size_t> indices_sol {1, 2, 3}; std::vector<size_t> indices_sol {1, 2, 3};
for (size_t i = 0; i < inst.size(); ++i) { for (size_t i = 0; i < inst.size(); ++i) {
ASSERT_EQ(inst[i].index, indices_sol[i % 3]); ASSERT_EQ(inst[i].index, indices_sol[i % 3]);
@ -58,15 +58,12 @@ TEST(SparsePage, PushCSC) {
} }
TEST(SparsePage, PushCSCAfterTranspose) { TEST(SparsePage, PushCSCAfterTranspose) {
#if defined(__APPLE__)
LOG(WARNING) << "FIXME(trivialfis): Skipping `PushCSCAfterTranspose' for APPLE.";
return;
#endif
dmlc::TemporaryDirectory tmpdir; dmlc::TemporaryDirectory tmpdir;
std::string filename = tmpdir.path + "/big.libsvm"; std::string filename = tmpdir.path + "/big.libsvm";
const int n_entries = 9; size_t constexpr kPageSize = 1024, kEntriesPerCol = 3;
size_t constexpr kEntries = kPageSize * kEntriesPerCol * 2;
std::unique_ptr<DMatrix> dmat = std::unique_ptr<DMatrix> dmat =
CreateSparsePageDMatrix(n_entries, 64UL, filename); CreateSparsePageDMatrix(kEntries, 64UL, filename);
const int ncols = dmat->Info().num_col_; const int ncols = dmat->Info().num_col_;
SparsePage page; // Consolidated sparse page SparsePage page; // Consolidated sparse page
for (const auto &batch : dmat->GetBatches<xgboost::SparsePage>()) { for (const auto &batch : dmat->GetBatches<xgboost::SparsePage>()) {
@ -76,7 +73,7 @@ TEST(SparsePage, PushCSCAfterTranspose) {
} }
// Make sure that the final sparse page has the right number of entries // Make sure that the final sparse page has the right number of entries
ASSERT_EQ(n_entries, page.data.Size()); ASSERT_EQ(kEntries, page.data.Size());
// The feature value for a feature in each row should be identical, as that is // The feature value for a feature in each row should be identical, as that is
// how the dmatrix has been created // how the dmatrix has been created

View File

@ -2,6 +2,7 @@
#include <dmlc/filesystem.h> #include <dmlc/filesystem.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <xgboost/data.h> #include <xgboost/data.h>
#include "../../../src/common/io.h"
#include "../../../src/data/adapter.h" #include "../../../src/data/adapter.h"
#include "../../../src/data/sparse_page_dmatrix.h" #include "../../../src/data/sparse_page_dmatrix.h"
#include "../helpers.h" #include "../helpers.h"
@ -11,16 +12,18 @@ using namespace xgboost; // NOLINT
TEST(SparsePageDMatrix, MetaInfo) { TEST(SparsePageDMatrix, MetaInfo) {
dmlc::TemporaryDirectory tempdir; dmlc::TemporaryDirectory tempdir;
const std::string tmp_file = tempdir.path + "/simple.libsvm"; const std::string tmp_file = tempdir.path + "/simple.libsvm";
CreateSimpleTestData(tmp_file); size_t constexpr kEntries = 24;
CreateBigTestData(tmp_file, kEntries);
xgboost::DMatrix *dmat = xgboost::DMatrix::Load( xgboost::DMatrix *dmat = xgboost::DMatrix::Load(
tmp_file + "#" + tmp_file + ".cache", false, false); tmp_file + "#" + tmp_file + ".cache", false, false);
std::cout << tmp_file << std::endl; std::cout << tmp_file << std::endl;
EXPECT_TRUE(FileExists(tmp_file + ".cache")); EXPECT_TRUE(FileExists(tmp_file + ".cache"));
// Test the metadata that was parsed // Test the metadata that was parsed
EXPECT_EQ(dmat->Info().num_row_, 2); EXPECT_EQ(dmat->Info().num_row_, 8ul);
EXPECT_EQ(dmat->Info().num_col_, 5); EXPECT_EQ(dmat->Info().num_col_, 5ul);
EXPECT_EQ(dmat->Info().num_nonzero_, 6); EXPECT_EQ(dmat->Info().num_nonzero_, kEntries);
EXPECT_EQ(dmat->Info().labels_.Size(), dmat->Info().num_row_); EXPECT_EQ(dmat->Info().labels_.Size(), dmat->Info().num_row_);
delete dmat; delete dmat;
@ -30,13 +33,13 @@ TEST(SparsePageDMatrix, RowAccess) {
dmlc::TemporaryDirectory tmpdir; dmlc::TemporaryDirectory tmpdir;
std::string filename = tmpdir.path + "/big.libsvm"; std::string filename = tmpdir.path + "/big.libsvm";
std::unique_ptr<xgboost::DMatrix> dmat = std::unique_ptr<xgboost::DMatrix> dmat =
xgboost::CreateSparsePageDMatrix(12, 64, filename); xgboost::CreateSparsePageDMatrix(24, 4, filename);
// Test the data read into the first row // Test the data read into the first row
auto &batch = *dmat->GetBatches<xgboost::SparsePage>().begin(); auto &batch = *dmat->GetBatches<xgboost::SparsePage>().begin();
auto first_row = batch[0]; auto first_row = batch[0];
ASSERT_EQ(first_row.size(), 3); ASSERT_EQ(first_row.size(), 3ul);
EXPECT_EQ(first_row[2].index, 2); EXPECT_EQ(first_row[2].index, 2u);
EXPECT_EQ(first_row[2].fvalue, 20); EXPECT_EQ(first_row[2].fvalue, 20);
} }
@ -77,43 +80,56 @@ TEST(SparsePageDMatrix, ColAccess) {
TEST(SparsePageDMatrix, ExistingCacheFile) { TEST(SparsePageDMatrix, ExistingCacheFile) {
dmlc::TemporaryDirectory tmpdir; dmlc::TemporaryDirectory tmpdir;
std::string filename = tmpdir.path + "/big.libsvm"; std::string filename = tmpdir.path + "/big.libsvm";
size_t constexpr kPageSize = 64, kEntriesPerCol = 3;
size_t constexpr kEntries = kPageSize * kEntriesPerCol * 2;
std::unique_ptr<xgboost::DMatrix> dmat = std::unique_ptr<xgboost::DMatrix> dmat =
xgboost::CreateSparsePageDMatrix(12, 64, filename); xgboost::CreateSparsePageDMatrix(kEntries, kPageSize, filename);
EXPECT_ANY_THROW({ EXPECT_ANY_THROW({
std::unique_ptr<xgboost::DMatrix> dmat2 = std::unique_ptr<xgboost::DMatrix> dmat2 =
xgboost::CreateSparsePageDMatrix(12, 64, filename); xgboost::CreateSparsePageDMatrix(kEntries, kPageSize, filename);
}); });
} }
#if defined(_OPENMP)
TEST(SparsePageDMatrix, ThreadSafetyException) { TEST(SparsePageDMatrix, ThreadSafetyException) {
dmlc::TemporaryDirectory tmpdir; dmlc::TemporaryDirectory tmpdir;
std::string filename = tmpdir.path + "/test"; std::string filename = tmpdir.path + "/test";
std::unique_ptr<xgboost::DMatrix> dmat = size_t constexpr kPageSize = 64, kEntriesPerCol = 3;
xgboost::CreateSparsePageDMatrix(12, 64, filename); size_t constexpr kEntries = kPageSize * kEntriesPerCol * 2;
bool exception = false; std::unique_ptr<xgboost::DMatrix> dmat =
xgboost::CreateSparsePageDMatrix(kEntries, kPageSize, filename);
std::atomic<bool> exception {false};
int threads = 1000; int threads = 1000;
#pragma omp parallel for num_threads(threads)
for (auto i = 0; i < threads; i++) { std::vector<std::thread> waiting;
for (int32_t i = 0; i < threads; ++i) {
waiting.emplace_back([&]() {
try { try {
auto iter = dmat->GetBatches<SparsePage>().begin(); auto iter = dmat->GetBatches<SparsePage>().begin();
++iter; ++iter;
} catch (...) { } catch (...) {
exception = true; exception = true;
} }
});
}
for (auto& t : waiting) {
t.join();
} }
EXPECT_TRUE(exception); EXPECT_TRUE(exception);
} }
#endif
// Multi-batches access // Multi-batches access
TEST(SparsePageDMatrix, ColAccessBatches) { TEST(SparsePageDMatrix, ColAccessBatches) {
dmlc::TemporaryDirectory tmpdir; dmlc::TemporaryDirectory tmpdir;
std::string filename = tmpdir.path + "/big.libsvm"; std::string filename = tmpdir.path + "/big.libsvm";
size_t constexpr kPageSize = 1024, kEntriesPerCol = 3;
size_t constexpr kEntries = kPageSize * kEntriesPerCol * 2;
// Create multiple sparse pages // Create multiple sparse pages
std::unique_ptr<xgboost::DMatrix> dmat{ std::unique_ptr<xgboost::DMatrix> dmat{
xgboost::CreateSparsePageDMatrix(1024, 1024, filename)}; xgboost::CreateSparsePageDMatrix(kEntries, kPageSize, filename)};
auto n_threads = omp_get_max_threads(); auto n_threads = omp_get_max_threads();
omp_set_num_threads(16); omp_set_num_threads(16);
for (auto const &page : dmat->GetBatches<xgboost::CSCPage>()) { for (auto const &page : dmat->GetBatches<xgboost::CSCPage>()) {
@ -286,3 +302,70 @@ TEST(SparsePageDMatrix, FromFile) {
} }
} }
} }
TEST(SparsePageDMatrix, Large) {
std::string filename = "test.libsvm";
CreateBigTestData(filename, 1 << 16);
std::unique_ptr<dmlc::Parser<uint32_t>> parser(
dmlc::Parser<uint32_t>::Create(filename.c_str(), 0, 1, "auto"));
data::FileAdapter adapter(parser.get());
dmlc::TemporaryDirectory tempdir;
const std::string tmp_file = tempdir.path + "/simple.libsvm";
std::unique_ptr<DMatrix> sparse{new data::SparsePageDMatrix(
&adapter, std::numeric_limits<float>::quiet_NaN(), -1, tmp_file, 16)};
std::unique_ptr<DMatrix> simple{DMatrix::Load(filename, true, true)};
std::vector<float> sparse_data;
std::vector<size_t> sparse_rptr;
std::vector<bst_feature_t> sparse_cids;
DMatrixToCSR(sparse.get(), &sparse_data, &sparse_rptr, &sparse_cids);
std::vector<float> simple_data;
std::vector<size_t> simple_rptr;
std::vector<bst_feature_t> simple_cids;
DMatrixToCSR(simple.get(), &simple_data, &simple_rptr, &simple_cids);
ASSERT_EQ(sparse_rptr.size(), sparse->Info().num_row_ + 1);
ASSERT_EQ(sparse_rptr.size(), simple->Info().num_row_ + 1);
ASSERT_EQ(sparse_data.size(), simple_data.size());
ASSERT_EQ(sparse_data, simple_data);
ASSERT_EQ(sparse_rptr.size(), simple_rptr.size());
ASSERT_EQ(sparse_rptr, simple_rptr);
ASSERT_EQ(sparse_cids, simple_cids);
}
auto TestSparsePageDMatrixDeterminism(int32_t threads, std::string const& filename) {
omp_set_num_threads(threads);
std::vector<float> sparse_data;
std::vector<size_t> sparse_rptr;
std::vector<bst_feature_t> sparse_cids;
std::unique_ptr<dmlc::Parser<uint32_t>> parser(
dmlc::Parser<uint32_t>::Create(filename.c_str(), 0, 1, "auto"));
data::FileAdapter adapter(parser.get());
dmlc::TemporaryDirectory tempdir;
const std::string tmp_file = tempdir.path + "/simple.libsvm";
std::unique_ptr<DMatrix> sparse{new data::SparsePageDMatrix(
&adapter, std::numeric_limits<float>::quiet_NaN(), -1, tmp_file, 1 << 8)};
DMatrixToCSR(sparse.get(), &sparse_data, &sparse_rptr, &sparse_cids);
std::string cache_name = tmp_file + ".row.page";
std::string cache = common::LoadSequentialFile(cache_name);
return cache;
}
TEST(SparsePageDMatrix, Determinism) {
std::string filename = "test.libsvm";
CreateBigTestData(filename, 1 << 16);
std::vector<std::string> caches;
for (size_t i = 1; i < 18; i += 2) {
caches.emplace_back(TestSparsePageDMatrixDeterminism(i, filename));
}
for (size_t i = 1; i < caches.size(); ++i) {
ASSERT_EQ(caches[i], caches.front());
}
}

View File

@ -28,7 +28,9 @@ TEST(SparsePageDMatrix, EllpackPage) {
TEST(SparsePageDMatrix, MultipleEllpackPages) { TEST(SparsePageDMatrix, MultipleEllpackPages) {
dmlc::TemporaryDirectory tmpdir; dmlc::TemporaryDirectory tmpdir;
std::string filename = tmpdir.path + "/big.libsvm"; std::string filename = tmpdir.path + "/big.libsvm";
std::unique_ptr<DMatrix> dmat = CreateSparsePageDMatrix(12, 64, filename); size_t constexpr kPageSize = 64, kEntriesPerCol = 3;
size_t constexpr kEntries = kPageSize * kEntriesPerCol * 2;
std::unique_ptr<DMatrix> dmat = CreateSparsePageDMatrix(kEntries, kPageSize, filename);
// Loop over the batches and count the records // Loop over the batches and count the records
int64_t batch_count = 0; int64_t batch_count = 0;

View File

@ -373,12 +373,8 @@ std::unique_ptr<DMatrix> CreateSparsePageDMatrix(
batch_count++; batch_count++;
row_count += batch.Size(); row_count += batch.Size();
} }
#if defined(_OPENMP)
EXPECT_GE(batch_count, 2); EXPECT_GE(batch_count, 2);
EXPECT_EQ(row_count, dmat->Info().num_row_); EXPECT_EQ(row_count, dmat->Info().num_row_);
#else
#warning "External memory doesn't work with Non-OpenMP build "
#endif // defined(_OPENMP)
return dmat; return dmat;
} }
@ -495,6 +491,36 @@ std::unique_ptr<GradientBooster> CreateTrainedGBM(
return gbm; return gbm;
} }
void DMatrixToCSR(DMatrix *dmat, std::vector<float> *p_data,
std::vector<size_t> *p_row_ptr,
std::vector<bst_feature_t> *p_cids) {
auto &data = *p_data;
auto &row_ptr = *p_row_ptr;
auto &cids = *p_cids;
data.resize(dmat->Info().num_nonzero_);
cids.resize(data.size());
row_ptr.resize(dmat->Info().num_row_ + 1);
SparsePage page;
for (const auto &batch : dmat->GetBatches<SparsePage>()) {
page.Push(batch);
}
auto const& in_offset = page.offset.HostVector();
auto const& in_data = page.data.HostVector();
CHECK_EQ(in_offset.size(), row_ptr.size());
std::copy(in_offset.cbegin(), in_offset.cend(), row_ptr.begin());
ASSERT_EQ(in_data.size(), data.size());
std::transform(in_data.cbegin(), in_data.cend(), data.begin(), [](Entry const& e) {
return e.fvalue;
});
ASSERT_EQ(in_data.size(), cids.size());
std::transform(in_data.cbegin(), in_data.cend(), cids.begin(), [](Entry const& e) {
return e.index;
});
}
#if defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1 #if defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
using CUDAMemoryResource = rmm::mr::cuda_memory_resource; using CUDAMemoryResource = rmm::mr::cuda_memory_resource;

View File

@ -365,6 +365,10 @@ class CudaArrayIterForTest {
auto Proxy() -> decltype(proxy_) { return proxy_; } auto Proxy() -> decltype(proxy_) { return proxy_; }
}; };
void DMatrixToCSR(DMatrix *dmat, std::vector<float> *p_data,
std::vector<size_t> *p_row_ptr,
std::vector<bst_feature_t> *p_cids);
typedef void *DataIterHandle; // NOLINT(*) typedef void *DataIterHandle; // NOLINT(*)
inline void Reset(DataIterHandle self) { inline void Reset(DataIterHandle self) {

View File

@ -16,8 +16,8 @@ TEST(CpuPredictor, Basic) {
std::unique_ptr<Predictor> cpu_predictor = std::unique_ptr<Predictor> cpu_predictor =
std::unique_ptr<Predictor>(Predictor::Create("cpu_predictor", &lparam)); std::unique_ptr<Predictor>(Predictor::Create("cpu_predictor", &lparam));
int kRows = 5; size_t constexpr kRows = 5;
int kCols = 5; size_t constexpr kCols = 5;
LearnerModelParam param; LearnerModelParam param;
param.num_feature = kCols; param.num_feature = kCols;
@ -85,7 +85,11 @@ TEST(CpuPredictor, Basic) {
TEST(CpuPredictor, ExternalMemory) { TEST(CpuPredictor, ExternalMemory) {
dmlc::TemporaryDirectory tmpdir; dmlc::TemporaryDirectory tmpdir;
std::string filename = tmpdir.path + "/big.libsvm"; std::string filename = tmpdir.path + "/big.libsvm";
std::unique_ptr<DMatrix> dmat = CreateSparsePageDMatrix(12, 64, filename);
size_t constexpr kPageSize = 64, kEntriesPerCol = 3;
size_t constexpr kEntries = kPageSize * kEntriesPerCol * 2;
std::unique_ptr<DMatrix> dmat = CreateSparsePageDMatrix(kEntries, kPageSize, filename);
auto lparam = CreateEmptyGenericParam(GPUIDX); auto lparam = CreateEmptyGenericParam(GPUIDX);
std::unique_ptr<Predictor> cpu_predictor = std::unique_ptr<Predictor> cpu_predictor =

View File

@ -105,9 +105,9 @@ TEST(GPUPredictor, ExternalMemoryTest) {
std::string file0 = tmpdir.path + "/big_0.libsvm"; std::string file0 = tmpdir.path + "/big_0.libsvm";
std::string file1 = tmpdir.path + "/big_1.libsvm"; std::string file1 = tmpdir.path + "/big_1.libsvm";
std::string file2 = tmpdir.path + "/big_2.libsvm"; std::string file2 = tmpdir.path + "/big_2.libsvm";
dmats.push_back(CreateSparsePageDMatrix(9, 64UL, file0)); dmats.push_back(CreateSparsePageDMatrix(400, 64UL, file0));
dmats.push_back(CreateSparsePageDMatrix(128, 128UL, file1)); dmats.push_back(CreateSparsePageDMatrix(800, 128UL, file1));
dmats.push_back(CreateSparsePageDMatrix(1024, 1024UL, file2)); dmats.push_back(CreateSparsePageDMatrix(8000, 1024UL, file2));
for (const auto& dmat: dmats) { for (const auto& dmat: dmats) {
dmat->Info().base_margin_.Resize(dmat->Info().num_row_ * n_classes, 0.5); dmat->Info().base_margin_.Resize(dmat->Info().num_row_ * n_classes, 0.5);

View File

@ -1,5 +1,6 @@
import numpy as np import numpy as np
import sys import sys
import gc
import pytest import pytest
import xgboost as xgb import xgboost as xgb
from hypothesis import given, strategies, assume, settings, note from hypothesis import given, strategies, assume, settings, note
@ -118,7 +119,10 @@ class TestGPUUpdaters:
assume(len(dataset.y) > 0) assume(len(dataset.y) > 0)
param['tree_method'] = 'gpu_hist' param['tree_method'] = 'gpu_hist'
param = dataset.set_params(param) param = dataset.set_params(param)
external_result = train_result(param, dataset.get_external_dmat(), num_rounds) m = dataset.get_external_dmat()
external_result = train_result(param, m, num_rounds)
del m
gc.collect()
assert tm.non_increasing(external_result['train'][dataset.metric]) assert tm.non_increasing(external_result['train'][dataset.metric])
def test_empty_dmatrix_prediction(self): def test_empty_dmatrix_prediction(self):