Deterministic data partitioning for external memory (#6317)
* Make external memory data partitioning deterministic. * Change the meaning of `page_size` from bytes to number of rows. * Design a data pool. * Note for external memory. * Enable unity build on Windows CI. * Force garbage collect on test.
This commit is contained in:
parent
9564886d9f
commit
43efadea2e
@ -85,7 +85,7 @@ def BuildWin64() {
|
||||
bat """
|
||||
mkdir build
|
||||
cd build
|
||||
cmake .. -G"Visual Studio 15 2017 Win64" -DUSE_CUDA=ON -DCMAKE_VERBOSE_MAKEFILE=ON -DGOOGLE_TEST=ON -DUSE_DMLC_GTEST=ON ${arch_flag}
|
||||
cmake .. -G"Visual Studio 15 2017 Win64" -DUSE_CUDA=ON -DCMAKE_VERBOSE_MAKEFILE=ON -DGOOGLE_TEST=ON -DUSE_DMLC_GTEST=ON ${arch_flag} -DCMAKE_UNITY_BUILD=ON
|
||||
"""
|
||||
bat """
|
||||
cd build
|
||||
|
||||
@ -44,6 +44,7 @@
|
||||
|
||||
#if DMLC_ENABLE_STD_THREAD
|
||||
#include "../src/data/sparse_page_dmatrix.cc"
|
||||
#include "../src/data/sparse_page_source.cc"
|
||||
#endif
|
||||
|
||||
// trees
|
||||
|
||||
@ -549,8 +549,9 @@ class DMatrix {
|
||||
int max_bin);
|
||||
|
||||
virtual DMatrix *Slice(common::Span<int32_t const> ridxs) = 0;
|
||||
/*! \brief page size 32 MB */
|
||||
static const size_t kPageSize = 32UL << 20UL;
|
||||
/*! \brief Number of rows per page in external memory. Approximately 100MB per page for
|
||||
* dataset with 100 features. */
|
||||
static const size_t kPageSize = 32UL << 12UL;
|
||||
|
||||
protected:
|
||||
virtual BatchSet<SparsePage> GetRowBatches() = 0;
|
||||
|
||||
@ -830,9 +830,10 @@ void SparsePage::Push(const SparsePage &batch) {
|
||||
const auto& batch_data_vec = batch.data.HostVector();
|
||||
size_t top = offset_vec.back();
|
||||
data_vec.resize(top + batch.data.Size());
|
||||
std::memcpy(dmlc::BeginPtr(data_vec) + top,
|
||||
dmlc::BeginPtr(batch_data_vec),
|
||||
if (dmlc::BeginPtr(data_vec) && dmlc::BeginPtr(batch_data_vec)) {
|
||||
std::memcpy(dmlc::BeginPtr(data_vec) + top, dmlc::BeginPtr(batch_data_vec),
|
||||
sizeof(Entry) * batch.data.Size());
|
||||
}
|
||||
size_t begin = offset.Size();
|
||||
offset_vec.resize(begin + batch.Size());
|
||||
for (size_t i = 0; i < batch.Size(); ++i) {
|
||||
|
||||
77
src/data/sparse_page_source.cc
Normal file
77
src/data/sparse_page_source.cc
Normal file
@ -0,0 +1,77 @@
|
||||
/*!
|
||||
* Copyright (c) 2020 by XGBoost Contributors
|
||||
*/
|
||||
#include "sparse_page_source.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace data {
|
||||
void DataPool::Slice(std::shared_ptr<SparsePage> out, size_t offset,
|
||||
size_t n_rows, size_t entry_offset) const {
|
||||
auto const &in_offset = pool_.offset.HostVector();
|
||||
auto const &in_data = pool_.data.HostVector();
|
||||
auto &h_offset = out->offset.HostVector();
|
||||
CHECK_LE(offset + n_rows + 1, in_offset.size());
|
||||
h_offset.resize(n_rows + 1, 0);
|
||||
std::transform(in_offset.cbegin() + offset,
|
||||
in_offset.cbegin() + offset + n_rows + 1, h_offset.begin(),
|
||||
[=](size_t ptr) { return ptr - entry_offset; });
|
||||
|
||||
auto &h_data = out->data.HostVector();
|
||||
CHECK_GT(h_offset.size(), 0);
|
||||
size_t n_entries = h_offset.back();
|
||||
h_data.resize(n_entries);
|
||||
|
||||
CHECK_EQ(n_entries, in_offset.at(offset + n_rows) - in_offset.at(offset));
|
||||
std::copy_n(in_data.cbegin() + in_offset.at(offset), n_entries,
|
||||
h_data.begin());
|
||||
}
|
||||
|
||||
void DataPool::SplitWritePage() {
|
||||
size_t total = pool_.Size();
|
||||
size_t offset = 0;
|
||||
size_t entry_offset = 0;
|
||||
do {
|
||||
size_t n_rows = std::min(page_size_, total - offset);
|
||||
std::shared_ptr<SparsePage> out;
|
||||
writer_->Alloc(&out);
|
||||
out->Clear();
|
||||
out->SetBaseRowId(inferred_num_rows_);
|
||||
this->Slice(out, offset, n_rows, entry_offset);
|
||||
inferred_num_rows_ += out->Size();
|
||||
offset += n_rows;
|
||||
entry_offset += out->data.Size();
|
||||
CHECK_NE(out->Size(), 0);
|
||||
writer_->PushWrite(std::move(out));
|
||||
} while (total - offset >= page_size_);
|
||||
|
||||
if (total - offset != 0) {
|
||||
auto out = std::make_shared<SparsePage>();
|
||||
this->Slice(out, offset, total - offset, entry_offset);
|
||||
CHECK_NE(out->Size(), 0);
|
||||
pool_.Clear();
|
||||
pool_.Push(*out);
|
||||
} else {
|
||||
pool_.Clear();
|
||||
}
|
||||
}
|
||||
size_t DataPool::Finalize() {
|
||||
inferred_num_rows_ += pool_.Size();
|
||||
if (pool_.Size() != 0) {
|
||||
std::shared_ptr<SparsePage> page;
|
||||
this->writer_->Alloc(&page);
|
||||
page->Clear();
|
||||
page->Push(pool_);
|
||||
this->writer_->PushWrite(std::move(page));
|
||||
}
|
||||
|
||||
if (inferred_num_rows_ == 0) {
|
||||
std::shared_ptr<SparsePage> page;
|
||||
this->writer_->Alloc(&page);
|
||||
page->Clear();
|
||||
this->writer_->PushWrite(std::move(page));
|
||||
}
|
||||
|
||||
return inferred_num_rows_;
|
||||
}
|
||||
} // namespace data
|
||||
} // namespace xgboost
|
||||
@ -3,6 +3,37 @@
|
||||
* \file page_csr_source.h
|
||||
* External memory data source, saved with sparse_batch_page binary format.
|
||||
* \author Tianqi Chen
|
||||
*
|
||||
* -------------------------------------------------
|
||||
* Random notes on implementation of external memory
|
||||
* -------------------------------------------------
|
||||
*
|
||||
* As of XGBoost 1.3, the general pipeline is:
|
||||
*
|
||||
* dmlc text file parser --> file adapter --> sparse page source -> data pool -->
|
||||
* write to binary cache --> load it back ~~> [ other pages (csc, ellpack, sorted csc) -->
|
||||
* write to binary cache ] --> use it in various algorithms.
|
||||
*
|
||||
* ~~> means optional
|
||||
*
|
||||
* The dmlc text file parser returns number of blocks based on available threads, which
|
||||
* can make the data partitioning non-deterministic, so here we set up an extra data pool
|
||||
* to stage parsed data. As a result, the number of blocks returned by text parser does
|
||||
* not equal to number of blocks in binary cache.
|
||||
*
|
||||
* Binary cache loading is async by the dmlc threaded iterator, which helps performance,
|
||||
* but as this iterator itself is not thread safe, so calling
|
||||
* `dmatrix->GetBatches<SparsePage>` is also not thread safe. Please note that, the
|
||||
* threaded iterator is also used inside dmlc text file parser.
|
||||
*
|
||||
* Memory consumption is difficult to control due to various reasons. Firstly the text
|
||||
* parsing doesn't have a batch size, only a hard coded buffer size is available.
|
||||
* Secondly, everything is loaded/written with async queue, with multiple queues running
|
||||
* the memory consumption is difficult to measure.
|
||||
*
|
||||
* The threaded iterator relies heavily on C++ memory model and threading primitive. The
|
||||
* concurrent writer for binary cache is an old copy of moody queue. We should try to
|
||||
* replace them with something more robust.
|
||||
*/
|
||||
#ifndef XGBOOST_DATA_SPARSE_PAGE_SOURCE_H_
|
||||
#define XGBOOST_DATA_SPARSE_PAGE_SOURCE_H_
|
||||
@ -19,6 +50,7 @@
|
||||
#include <vector>
|
||||
#include <fstream>
|
||||
|
||||
#include "rabit/rabit.h"
|
||||
#include "xgboost/base.h"
|
||||
#include "xgboost/data.h"
|
||||
|
||||
@ -121,9 +153,12 @@ inline void TryDeleteCacheFile(const std::string& file) {
|
||||
inline void CheckCacheFileExists(const std::string& file) {
|
||||
std::ifstream f(file.c_str());
|
||||
if (f.good()) {
|
||||
LOG(FATAL) << "Cache file " << file
|
||||
<< " exists already; Is there another DMatrix with the same "
|
||||
"cache prefix? Otherwise please remove it manually.";
|
||||
LOG(FATAL)
|
||||
<< "Cache file " << file << " exists already; "
|
||||
<< "Is there another DMatrix with the same "
|
||||
"cache prefix? It can be caused by previously used DMatrix that "
|
||||
"hasn't been collected by language environment garbage collector. "
|
||||
"Otherwise please remove it manually.";
|
||||
}
|
||||
}
|
||||
|
||||
@ -231,6 +266,38 @@ class ExternalMemoryPrefetcher : dmlc::DataIter<PageT> {
|
||||
std::vector<std::unique_ptr<dmlc::ThreadedIter<PageT>>> prefetchers_;
|
||||
};
|
||||
|
||||
|
||||
// A data pool to keep the size of each page balanced and data partitioning to be
|
||||
// deterministic.
|
||||
class DataPool {
|
||||
size_t inferred_num_rows_;
|
||||
MetaInfo* info_;
|
||||
SparsePage pool_;
|
||||
size_t page_size_;
|
||||
SparsePageWriter<SparsePage> *writer_;
|
||||
|
||||
void Slice(std::shared_ptr<SparsePage> out, size_t offset, size_t n_rows,
|
||||
size_t entry_offset) const;
|
||||
void SplitWritePage();
|
||||
|
||||
public:
|
||||
DataPool(MetaInfo *info, size_t page_size,
|
||||
SparsePageWriter<SparsePage> *writer)
|
||||
: inferred_num_rows_{0}, info_{info},
|
||||
page_size_{page_size}, writer_{writer} {}
|
||||
|
||||
void Push(std::shared_ptr<SparsePage> page) {
|
||||
info_->num_nonzero_ += page->data.Size();
|
||||
pool_.Push(*page);
|
||||
if (pool_.Size() > page_size_) {
|
||||
this->SplitWritePage();
|
||||
}
|
||||
page->Clear();
|
||||
}
|
||||
|
||||
size_t Finalize();
|
||||
};
|
||||
|
||||
class SparsePageSource {
|
||||
public:
|
||||
template <typename AdapterT>
|
||||
@ -249,17 +316,12 @@ class SparsePageSource {
|
||||
{
|
||||
SparsePageWriter<SparsePage> writer(cache_info_.name_shards,
|
||||
cache_info_.format_shards, 6);
|
||||
std::shared_ptr<SparsePage> page;
|
||||
writer.Alloc(&page);
|
||||
page->Clear();
|
||||
DataPool pool(&info, page_size, &writer);
|
||||
|
||||
std::shared_ptr<SparsePage> page { new SparsePage };
|
||||
|
||||
uint64_t inferred_num_columns = 0;
|
||||
uint64_t inferred_num_rows = 0;
|
||||
size_t bytes_write = 0;
|
||||
double tstart = dmlc::GetTime();
|
||||
// print every 4 sec.
|
||||
constexpr double kStep = 4.0;
|
||||
size_t tick_expected = static_cast<double>(kStep);
|
||||
|
||||
const uint64_t default_max = std::numeric_limits<uint64_t>::max();
|
||||
uint64_t last_group_id = default_max;
|
||||
@ -296,26 +358,13 @@ class SparsePageSource {
|
||||
++group_size;
|
||||
}
|
||||
}
|
||||
CHECK_EQ(page->Size(), 0);
|
||||
auto batch_max_columns = page->Push(batch, missing, nthread);
|
||||
inferred_num_columns =
|
||||
std::max(batch_max_columns, inferred_num_columns);
|
||||
if (page->MemCostBytes() >= page_size) {
|
||||
inferred_num_rows += page->Size();
|
||||
info.num_nonzero_ += page->offset.HostVector().back();
|
||||
bytes_write += page->MemCostBytes();
|
||||
writer.PushWrite(std::move(page));
|
||||
writer.Alloc(&page);
|
||||
page->Clear();
|
||||
pool.Push(page);
|
||||
page->SetBaseRowId(inferred_num_rows);
|
||||
|
||||
double tdiff = dmlc::GetTime() - tstart;
|
||||
if (tdiff >= tick_expected) {
|
||||
LOG(CONSOLE) << "Writing " << page_type << " to " << cache_info
|
||||
<< " in " << ((bytes_write >> 20UL) / tdiff)
|
||||
<< " MB/s, " << (bytes_write >> 20UL) << " written";
|
||||
tick_expected += static_cast<size_t>(kStep);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (last_group_id != default_max) {
|
||||
@ -323,10 +372,6 @@ class SparsePageSource {
|
||||
info.group_ptr_.push_back(group_size);
|
||||
}
|
||||
}
|
||||
inferred_num_rows += page->Size();
|
||||
if (!page->offset.HostVector().empty()) {
|
||||
info.num_nonzero_ += page->offset.HostVector().back();
|
||||
}
|
||||
|
||||
// Deal with empty rows/columns if necessary
|
||||
if (adapter->NumColumns() == kAdapterUnknownSize) {
|
||||
@ -352,10 +397,9 @@ class SparsePageSource {
|
||||
info.num_row_ = adapter->NumRows();
|
||||
}
|
||||
|
||||
// Make sure we have at least one page if the dataset is empty
|
||||
if (page->data.Size() > 0 || info.num_row_ == 0) {
|
||||
writer.PushWrite(std::move(page));
|
||||
}
|
||||
pool.Push(page);
|
||||
pool.Finalize();
|
||||
|
||||
std::unique_ptr<dmlc::Stream> fo(
|
||||
dmlc::Stream::Create(cache_info_.name_info.c_str(), "w"));
|
||||
int tmagic = kMagic;
|
||||
|
||||
@ -130,8 +130,10 @@ TEST(DenseColumnWithMissing, Test) {
|
||||
void TestGHistIndexMatrixCreation(size_t nthreads) {
|
||||
dmlc::TemporaryDirectory tmpdir;
|
||||
std::string filename = tmpdir.path + "/big.libsvm";
|
||||
size_t constexpr kPageSize = 1024, kEntriesPerCol = 3;
|
||||
size_t constexpr kEntries = kPageSize * kEntriesPerCol * 2;
|
||||
/* This should create multiple sparse pages */
|
||||
std::unique_ptr<DMatrix> dmat{ CreateSparsePageDMatrix(1024, 1024, filename) };
|
||||
std::unique_ptr<DMatrix> dmat{ CreateSparsePageDMatrix(kEntries, kPageSize, filename) };
|
||||
omp_set_num_threads(nthreads);
|
||||
GHistIndexMatrix gmat;
|
||||
gmat.Init(dmat.get(), 256);
|
||||
|
||||
@ -44,13 +44,13 @@ TEST(SparsePage, PushCSC) {
|
||||
}
|
||||
|
||||
auto inst = page[0];
|
||||
ASSERT_EQ(inst.size(), 2);
|
||||
ASSERT_EQ(inst.size(), 2ul);
|
||||
for (auto entry : inst) {
|
||||
ASSERT_EQ(entry.index, 0);
|
||||
ASSERT_EQ(entry.index, 0u);
|
||||
}
|
||||
|
||||
inst = page[1];
|
||||
ASSERT_EQ(inst.size(), 6);
|
||||
ASSERT_EQ(inst.size(), 6ul);
|
||||
std::vector<size_t> indices_sol {1, 2, 3};
|
||||
for (size_t i = 0; i < inst.size(); ++i) {
|
||||
ASSERT_EQ(inst[i].index, indices_sol[i % 3]);
|
||||
@ -58,15 +58,12 @@ TEST(SparsePage, PushCSC) {
|
||||
}
|
||||
|
||||
TEST(SparsePage, PushCSCAfterTranspose) {
|
||||
#if defined(__APPLE__)
|
||||
LOG(WARNING) << "FIXME(trivialfis): Skipping `PushCSCAfterTranspose' for APPLE.";
|
||||
return;
|
||||
#endif
|
||||
dmlc::TemporaryDirectory tmpdir;
|
||||
std::string filename = tmpdir.path + "/big.libsvm";
|
||||
const int n_entries = 9;
|
||||
size_t constexpr kPageSize = 1024, kEntriesPerCol = 3;
|
||||
size_t constexpr kEntries = kPageSize * kEntriesPerCol * 2;
|
||||
std::unique_ptr<DMatrix> dmat =
|
||||
CreateSparsePageDMatrix(n_entries, 64UL, filename);
|
||||
CreateSparsePageDMatrix(kEntries, 64UL, filename);
|
||||
const int ncols = dmat->Info().num_col_;
|
||||
SparsePage page; // Consolidated sparse page
|
||||
for (const auto &batch : dmat->GetBatches<xgboost::SparsePage>()) {
|
||||
@ -76,7 +73,7 @@ TEST(SparsePage, PushCSCAfterTranspose) {
|
||||
}
|
||||
|
||||
// Make sure that the final sparse page has the right number of entries
|
||||
ASSERT_EQ(n_entries, page.data.Size());
|
||||
ASSERT_EQ(kEntries, page.data.Size());
|
||||
|
||||
// The feature value for a feature in each row should be identical, as that is
|
||||
// how the dmatrix has been created
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
#include <dmlc/filesystem.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include <xgboost/data.h>
|
||||
#include "../../../src/common/io.h"
|
||||
#include "../../../src/data/adapter.h"
|
||||
#include "../../../src/data/sparse_page_dmatrix.h"
|
||||
#include "../helpers.h"
|
||||
@ -11,16 +12,18 @@ using namespace xgboost; // NOLINT
|
||||
TEST(SparsePageDMatrix, MetaInfo) {
|
||||
dmlc::TemporaryDirectory tempdir;
|
||||
const std::string tmp_file = tempdir.path + "/simple.libsvm";
|
||||
CreateSimpleTestData(tmp_file);
|
||||
size_t constexpr kEntries = 24;
|
||||
CreateBigTestData(tmp_file, kEntries);
|
||||
|
||||
xgboost::DMatrix *dmat = xgboost::DMatrix::Load(
|
||||
tmp_file + "#" + tmp_file + ".cache", false, false);
|
||||
std::cout << tmp_file << std::endl;
|
||||
EXPECT_TRUE(FileExists(tmp_file + ".cache"));
|
||||
|
||||
// Test the metadata that was parsed
|
||||
EXPECT_EQ(dmat->Info().num_row_, 2);
|
||||
EXPECT_EQ(dmat->Info().num_col_, 5);
|
||||
EXPECT_EQ(dmat->Info().num_nonzero_, 6);
|
||||
EXPECT_EQ(dmat->Info().num_row_, 8ul);
|
||||
EXPECT_EQ(dmat->Info().num_col_, 5ul);
|
||||
EXPECT_EQ(dmat->Info().num_nonzero_, kEntries);
|
||||
EXPECT_EQ(dmat->Info().labels_.Size(), dmat->Info().num_row_);
|
||||
|
||||
delete dmat;
|
||||
@ -30,13 +33,13 @@ TEST(SparsePageDMatrix, RowAccess) {
|
||||
dmlc::TemporaryDirectory tmpdir;
|
||||
std::string filename = tmpdir.path + "/big.libsvm";
|
||||
std::unique_ptr<xgboost::DMatrix> dmat =
|
||||
xgboost::CreateSparsePageDMatrix(12, 64, filename);
|
||||
xgboost::CreateSparsePageDMatrix(24, 4, filename);
|
||||
|
||||
// Test the data read into the first row
|
||||
auto &batch = *dmat->GetBatches<xgboost::SparsePage>().begin();
|
||||
auto first_row = batch[0];
|
||||
ASSERT_EQ(first_row.size(), 3);
|
||||
EXPECT_EQ(first_row[2].index, 2);
|
||||
ASSERT_EQ(first_row.size(), 3ul);
|
||||
EXPECT_EQ(first_row[2].index, 2u);
|
||||
EXPECT_EQ(first_row[2].fvalue, 20);
|
||||
}
|
||||
|
||||
@ -77,43 +80,56 @@ TEST(SparsePageDMatrix, ColAccess) {
|
||||
TEST(SparsePageDMatrix, ExistingCacheFile) {
|
||||
dmlc::TemporaryDirectory tmpdir;
|
||||
std::string filename = tmpdir.path + "/big.libsvm";
|
||||
size_t constexpr kPageSize = 64, kEntriesPerCol = 3;
|
||||
size_t constexpr kEntries = kPageSize * kEntriesPerCol * 2;
|
||||
std::unique_ptr<xgboost::DMatrix> dmat =
|
||||
xgboost::CreateSparsePageDMatrix(12, 64, filename);
|
||||
xgboost::CreateSparsePageDMatrix(kEntries, kPageSize, filename);
|
||||
EXPECT_ANY_THROW({
|
||||
std::unique_ptr<xgboost::DMatrix> dmat2 =
|
||||
xgboost::CreateSparsePageDMatrix(12, 64, filename);
|
||||
xgboost::CreateSparsePageDMatrix(kEntries, kPageSize, filename);
|
||||
});
|
||||
}
|
||||
|
||||
#if defined(_OPENMP)
|
||||
TEST(SparsePageDMatrix, ThreadSafetyException) {
|
||||
dmlc::TemporaryDirectory tmpdir;
|
||||
std::string filename = tmpdir.path + "/test";
|
||||
std::unique_ptr<xgboost::DMatrix> dmat =
|
||||
xgboost::CreateSparsePageDMatrix(12, 64, filename);
|
||||
size_t constexpr kPageSize = 64, kEntriesPerCol = 3;
|
||||
size_t constexpr kEntries = kPageSize * kEntriesPerCol * 2;
|
||||
|
||||
bool exception = false;
|
||||
std::unique_ptr<xgboost::DMatrix> dmat =
|
||||
xgboost::CreateSparsePageDMatrix(kEntries, kPageSize, filename);
|
||||
|
||||
std::atomic<bool> exception {false};
|
||||
int threads = 1000;
|
||||
#pragma omp parallel for num_threads(threads)
|
||||
for (auto i = 0; i < threads; i++) {
|
||||
|
||||
std::vector<std::thread> waiting;
|
||||
|
||||
for (int32_t i = 0; i < threads; ++i) {
|
||||
waiting.emplace_back([&]() {
|
||||
try {
|
||||
auto iter = dmat->GetBatches<SparsePage>().begin();
|
||||
++iter;
|
||||
} catch (...) {
|
||||
exception = true;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
for (auto& t : waiting) {
|
||||
t.join();
|
||||
}
|
||||
EXPECT_TRUE(exception);
|
||||
}
|
||||
#endif
|
||||
|
||||
// Multi-batches access
|
||||
TEST(SparsePageDMatrix, ColAccessBatches) {
|
||||
dmlc::TemporaryDirectory tmpdir;
|
||||
std::string filename = tmpdir.path + "/big.libsvm";
|
||||
size_t constexpr kPageSize = 1024, kEntriesPerCol = 3;
|
||||
size_t constexpr kEntries = kPageSize * kEntriesPerCol * 2;
|
||||
// Create multiple sparse pages
|
||||
std::unique_ptr<xgboost::DMatrix> dmat{
|
||||
xgboost::CreateSparsePageDMatrix(1024, 1024, filename)};
|
||||
xgboost::CreateSparsePageDMatrix(kEntries, kPageSize, filename)};
|
||||
auto n_threads = omp_get_max_threads();
|
||||
omp_set_num_threads(16);
|
||||
for (auto const &page : dmat->GetBatches<xgboost::CSCPage>()) {
|
||||
@ -286,3 +302,70 @@ TEST(SparsePageDMatrix, FromFile) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(SparsePageDMatrix, Large) {
|
||||
std::string filename = "test.libsvm";
|
||||
CreateBigTestData(filename, 1 << 16);
|
||||
std::unique_ptr<dmlc::Parser<uint32_t>> parser(
|
||||
dmlc::Parser<uint32_t>::Create(filename.c_str(), 0, 1, "auto"));
|
||||
data::FileAdapter adapter(parser.get());
|
||||
dmlc::TemporaryDirectory tempdir;
|
||||
const std::string tmp_file = tempdir.path + "/simple.libsvm";
|
||||
|
||||
std::unique_ptr<DMatrix> sparse{new data::SparsePageDMatrix(
|
||||
&adapter, std::numeric_limits<float>::quiet_NaN(), -1, tmp_file, 16)};
|
||||
std::unique_ptr<DMatrix> simple{DMatrix::Load(filename, true, true)};
|
||||
|
||||
std::vector<float> sparse_data;
|
||||
std::vector<size_t> sparse_rptr;
|
||||
std::vector<bst_feature_t> sparse_cids;
|
||||
DMatrixToCSR(sparse.get(), &sparse_data, &sparse_rptr, &sparse_cids);
|
||||
|
||||
std::vector<float> simple_data;
|
||||
std::vector<size_t> simple_rptr;
|
||||
std::vector<bst_feature_t> simple_cids;
|
||||
DMatrixToCSR(simple.get(), &simple_data, &simple_rptr, &simple_cids);
|
||||
|
||||
ASSERT_EQ(sparse_rptr.size(), sparse->Info().num_row_ + 1);
|
||||
ASSERT_EQ(sparse_rptr.size(), simple->Info().num_row_ + 1);
|
||||
|
||||
ASSERT_EQ(sparse_data.size(), simple_data.size());
|
||||
ASSERT_EQ(sparse_data, simple_data);
|
||||
ASSERT_EQ(sparse_rptr.size(), simple_rptr.size());
|
||||
ASSERT_EQ(sparse_rptr, simple_rptr);
|
||||
ASSERT_EQ(sparse_cids, simple_cids);
|
||||
}
|
||||
|
||||
auto TestSparsePageDMatrixDeterminism(int32_t threads, std::string const& filename) {
|
||||
omp_set_num_threads(threads);
|
||||
std::vector<float> sparse_data;
|
||||
std::vector<size_t> sparse_rptr;
|
||||
std::vector<bst_feature_t> sparse_cids;
|
||||
|
||||
std::unique_ptr<dmlc::Parser<uint32_t>> parser(
|
||||
dmlc::Parser<uint32_t>::Create(filename.c_str(), 0, 1, "auto"));
|
||||
data::FileAdapter adapter(parser.get());
|
||||
dmlc::TemporaryDirectory tempdir;
|
||||
const std::string tmp_file = tempdir.path + "/simple.libsvm";
|
||||
std::unique_ptr<DMatrix> sparse{new data::SparsePageDMatrix(
|
||||
&adapter, std::numeric_limits<float>::quiet_NaN(), -1, tmp_file, 1 << 8)};
|
||||
|
||||
DMatrixToCSR(sparse.get(), &sparse_data, &sparse_rptr, &sparse_cids);
|
||||
|
||||
std::string cache_name = tmp_file + ".row.page";
|
||||
std::string cache = common::LoadSequentialFile(cache_name);
|
||||
return cache;
|
||||
}
|
||||
|
||||
TEST(SparsePageDMatrix, Determinism) {
|
||||
std::string filename = "test.libsvm";
|
||||
CreateBigTestData(filename, 1 << 16);
|
||||
std::vector<std::string> caches;
|
||||
for (size_t i = 1; i < 18; i += 2) {
|
||||
caches.emplace_back(TestSparsePageDMatrixDeterminism(i, filename));
|
||||
}
|
||||
|
||||
for (size_t i = 1; i < caches.size(); ++i) {
|
||||
ASSERT_EQ(caches[i], caches.front());
|
||||
}
|
||||
}
|
||||
|
||||
@ -28,7 +28,9 @@ TEST(SparsePageDMatrix, EllpackPage) {
|
||||
TEST(SparsePageDMatrix, MultipleEllpackPages) {
|
||||
dmlc::TemporaryDirectory tmpdir;
|
||||
std::string filename = tmpdir.path + "/big.libsvm";
|
||||
std::unique_ptr<DMatrix> dmat = CreateSparsePageDMatrix(12, 64, filename);
|
||||
size_t constexpr kPageSize = 64, kEntriesPerCol = 3;
|
||||
size_t constexpr kEntries = kPageSize * kEntriesPerCol * 2;
|
||||
std::unique_ptr<DMatrix> dmat = CreateSparsePageDMatrix(kEntries, kPageSize, filename);
|
||||
|
||||
// Loop over the batches and count the records
|
||||
int64_t batch_count = 0;
|
||||
|
||||
@ -373,12 +373,8 @@ std::unique_ptr<DMatrix> CreateSparsePageDMatrix(
|
||||
batch_count++;
|
||||
row_count += batch.Size();
|
||||
}
|
||||
#if defined(_OPENMP)
|
||||
EXPECT_GE(batch_count, 2);
|
||||
EXPECT_EQ(row_count, dmat->Info().num_row_);
|
||||
#else
|
||||
#warning "External memory doesn't work with Non-OpenMP build "
|
||||
#endif // defined(_OPENMP)
|
||||
return dmat;
|
||||
}
|
||||
|
||||
@ -495,6 +491,36 @@ std::unique_ptr<GradientBooster> CreateTrainedGBM(
|
||||
return gbm;
|
||||
}
|
||||
|
||||
void DMatrixToCSR(DMatrix *dmat, std::vector<float> *p_data,
|
||||
std::vector<size_t> *p_row_ptr,
|
||||
std::vector<bst_feature_t> *p_cids) {
|
||||
auto &data = *p_data;
|
||||
auto &row_ptr = *p_row_ptr;
|
||||
auto &cids = *p_cids;
|
||||
|
||||
data.resize(dmat->Info().num_nonzero_);
|
||||
cids.resize(data.size());
|
||||
row_ptr.resize(dmat->Info().num_row_ + 1);
|
||||
SparsePage page;
|
||||
for (const auto &batch : dmat->GetBatches<SparsePage>()) {
|
||||
page.Push(batch);
|
||||
}
|
||||
|
||||
auto const& in_offset = page.offset.HostVector();
|
||||
auto const& in_data = page.data.HostVector();
|
||||
|
||||
CHECK_EQ(in_offset.size(), row_ptr.size());
|
||||
std::copy(in_offset.cbegin(), in_offset.cend(), row_ptr.begin());
|
||||
ASSERT_EQ(in_data.size(), data.size());
|
||||
std::transform(in_data.cbegin(), in_data.cend(), data.begin(), [](Entry const& e) {
|
||||
return e.fvalue;
|
||||
});
|
||||
ASSERT_EQ(in_data.size(), cids.size());
|
||||
std::transform(in_data.cbegin(), in_data.cend(), cids.begin(), [](Entry const& e) {
|
||||
return e.index;
|
||||
});
|
||||
}
|
||||
|
||||
#if defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
|
||||
|
||||
using CUDAMemoryResource = rmm::mr::cuda_memory_resource;
|
||||
|
||||
@ -365,6 +365,10 @@ class CudaArrayIterForTest {
|
||||
auto Proxy() -> decltype(proxy_) { return proxy_; }
|
||||
};
|
||||
|
||||
void DMatrixToCSR(DMatrix *dmat, std::vector<float> *p_data,
|
||||
std::vector<size_t> *p_row_ptr,
|
||||
std::vector<bst_feature_t> *p_cids);
|
||||
|
||||
typedef void *DataIterHandle; // NOLINT(*)
|
||||
|
||||
inline void Reset(DataIterHandle self) {
|
||||
|
||||
@ -16,8 +16,8 @@ TEST(CpuPredictor, Basic) {
|
||||
std::unique_ptr<Predictor> cpu_predictor =
|
||||
std::unique_ptr<Predictor>(Predictor::Create("cpu_predictor", &lparam));
|
||||
|
||||
int kRows = 5;
|
||||
int kCols = 5;
|
||||
size_t constexpr kRows = 5;
|
||||
size_t constexpr kCols = 5;
|
||||
|
||||
LearnerModelParam param;
|
||||
param.num_feature = kCols;
|
||||
@ -85,7 +85,11 @@ TEST(CpuPredictor, Basic) {
|
||||
TEST(CpuPredictor, ExternalMemory) {
|
||||
dmlc::TemporaryDirectory tmpdir;
|
||||
std::string filename = tmpdir.path + "/big.libsvm";
|
||||
std::unique_ptr<DMatrix> dmat = CreateSparsePageDMatrix(12, 64, filename);
|
||||
|
||||
size_t constexpr kPageSize = 64, kEntriesPerCol = 3;
|
||||
size_t constexpr kEntries = kPageSize * kEntriesPerCol * 2;
|
||||
|
||||
std::unique_ptr<DMatrix> dmat = CreateSparsePageDMatrix(kEntries, kPageSize, filename);
|
||||
auto lparam = CreateEmptyGenericParam(GPUIDX);
|
||||
|
||||
std::unique_ptr<Predictor> cpu_predictor =
|
||||
|
||||
@ -105,9 +105,9 @@ TEST(GPUPredictor, ExternalMemoryTest) {
|
||||
std::string file0 = tmpdir.path + "/big_0.libsvm";
|
||||
std::string file1 = tmpdir.path + "/big_1.libsvm";
|
||||
std::string file2 = tmpdir.path + "/big_2.libsvm";
|
||||
dmats.push_back(CreateSparsePageDMatrix(9, 64UL, file0));
|
||||
dmats.push_back(CreateSparsePageDMatrix(128, 128UL, file1));
|
||||
dmats.push_back(CreateSparsePageDMatrix(1024, 1024UL, file2));
|
||||
dmats.push_back(CreateSparsePageDMatrix(400, 64UL, file0));
|
||||
dmats.push_back(CreateSparsePageDMatrix(800, 128UL, file1));
|
||||
dmats.push_back(CreateSparsePageDMatrix(8000, 1024UL, file2));
|
||||
|
||||
for (const auto& dmat: dmats) {
|
||||
dmat->Info().base_margin_.Resize(dmat->Info().num_row_ * n_classes, 0.5);
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import numpy as np
|
||||
import sys
|
||||
import gc
|
||||
import pytest
|
||||
import xgboost as xgb
|
||||
from hypothesis import given, strategies, assume, settings, note
|
||||
@ -118,7 +119,10 @@ class TestGPUUpdaters:
|
||||
assume(len(dataset.y) > 0)
|
||||
param['tree_method'] = 'gpu_hist'
|
||||
param = dataset.set_params(param)
|
||||
external_result = train_result(param, dataset.get_external_dmat(), num_rounds)
|
||||
m = dataset.get_external_dmat()
|
||||
external_result = train_result(param, m, num_rounds)
|
||||
del m
|
||||
gc.collect()
|
||||
assert tm.non_increasing(external_result['train'][dataset.metric])
|
||||
|
||||
def test_empty_dmatrix_prediction(self):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user