Deterministic data partitioning for external memory (#6317)

* Make external memory data partitioning deterministic.

* Change the meaning of `page_size` from bytes to number of rows.

* Design a data pool.

* Note for external memory.

* Enable unity build on Windows CI.

* Force garbage collect on test.
This commit is contained in:
Jiaming Yuan 2020-11-11 06:11:06 +08:00 committed by GitHub
parent 9564886d9f
commit 43efadea2e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 334 additions and 88 deletions

View File

@ -85,7 +85,7 @@ def BuildWin64() {
bat """
mkdir build
cd build
cmake .. -G"Visual Studio 15 2017 Win64" -DUSE_CUDA=ON -DCMAKE_VERBOSE_MAKEFILE=ON -DGOOGLE_TEST=ON -DUSE_DMLC_GTEST=ON ${arch_flag}
cmake .. -G"Visual Studio 15 2017 Win64" -DUSE_CUDA=ON -DCMAKE_VERBOSE_MAKEFILE=ON -DGOOGLE_TEST=ON -DUSE_DMLC_GTEST=ON ${arch_flag} -DCMAKE_UNITY_BUILD=ON
"""
bat """
cd build

View File

@ -44,6 +44,7 @@
#if DMLC_ENABLE_STD_THREAD
#include "../src/data/sparse_page_dmatrix.cc"
#include "../src/data/sparse_page_source.cc"
#endif
// trees

View File

@ -549,8 +549,9 @@ class DMatrix {
int max_bin);
virtual DMatrix *Slice(common::Span<int32_t const> ridxs) = 0;
/*! \brief page size 32 MB */
static const size_t kPageSize = 32UL << 20UL;
/*! \brief Number of rows per page in external memory. Approximately 100MB per page for
* dataset with 100 features. */
static const size_t kPageSize = 32UL << 12UL;
protected:
virtual BatchSet<SparsePage> GetRowBatches() = 0;

View File

@ -830,9 +830,10 @@ void SparsePage::Push(const SparsePage &batch) {
const auto& batch_data_vec = batch.data.HostVector();
size_t top = offset_vec.back();
data_vec.resize(top + batch.data.Size());
std::memcpy(dmlc::BeginPtr(data_vec) + top,
dmlc::BeginPtr(batch_data_vec),
sizeof(Entry) * batch.data.Size());
if (dmlc::BeginPtr(data_vec) && dmlc::BeginPtr(batch_data_vec)) {
std::memcpy(dmlc::BeginPtr(data_vec) + top, dmlc::BeginPtr(batch_data_vec),
sizeof(Entry) * batch.data.Size());
}
size_t begin = offset.Size();
offset_vec.resize(begin + batch.Size());
for (size_t i = 0; i < batch.Size(); ++i) {

View File

@ -0,0 +1,77 @@
/*!
* Copyright (c) 2020 by XGBoost Contributors
*/
#include "sparse_page_source.h"
namespace xgboost {
namespace data {
void DataPool::Slice(std::shared_ptr<SparsePage> out, size_t offset,
size_t n_rows, size_t entry_offset) const {
auto const &in_offset = pool_.offset.HostVector();
auto const &in_data = pool_.data.HostVector();
auto &h_offset = out->offset.HostVector();
CHECK_LE(offset + n_rows + 1, in_offset.size());
h_offset.resize(n_rows + 1, 0);
std::transform(in_offset.cbegin() + offset,
in_offset.cbegin() + offset + n_rows + 1, h_offset.begin(),
[=](size_t ptr) { return ptr - entry_offset; });
auto &h_data = out->data.HostVector();
CHECK_GT(h_offset.size(), 0);
size_t n_entries = h_offset.back();
h_data.resize(n_entries);
CHECK_EQ(n_entries, in_offset.at(offset + n_rows) - in_offset.at(offset));
std::copy_n(in_data.cbegin() + in_offset.at(offset), n_entries,
h_data.begin());
}
void DataPool::SplitWritePage() {
size_t total = pool_.Size();
size_t offset = 0;
size_t entry_offset = 0;
do {
size_t n_rows = std::min(page_size_, total - offset);
std::shared_ptr<SparsePage> out;
writer_->Alloc(&out);
out->Clear();
out->SetBaseRowId(inferred_num_rows_);
this->Slice(out, offset, n_rows, entry_offset);
inferred_num_rows_ += out->Size();
offset += n_rows;
entry_offset += out->data.Size();
CHECK_NE(out->Size(), 0);
writer_->PushWrite(std::move(out));
} while (total - offset >= page_size_);
if (total - offset != 0) {
auto out = std::make_shared<SparsePage>();
this->Slice(out, offset, total - offset, entry_offset);
CHECK_NE(out->Size(), 0);
pool_.Clear();
pool_.Push(*out);
} else {
pool_.Clear();
}
}
size_t DataPool::Finalize() {
inferred_num_rows_ += pool_.Size();
if (pool_.Size() != 0) {
std::shared_ptr<SparsePage> page;
this->writer_->Alloc(&page);
page->Clear();
page->Push(pool_);
this->writer_->PushWrite(std::move(page));
}
if (inferred_num_rows_ == 0) {
std::shared_ptr<SparsePage> page;
this->writer_->Alloc(&page);
page->Clear();
this->writer_->PushWrite(std::move(page));
}
return inferred_num_rows_;
}
} // namespace data
} // namespace xgboost

View File

@ -3,6 +3,37 @@
* \file page_csr_source.h
* External memory data source, saved with sparse_batch_page binary format.
* \author Tianqi Chen
*
* -------------------------------------------------
* Random notes on implementation of external memory
* -------------------------------------------------
*
* As of XGBoost 1.3, the general pipeline is:
*
* dmlc text file parser --> file adapter --> sparse page source -> data pool -->
* write to binary cache --> load it back ~~> [ other pages (csc, ellpack, sorted csc) -->
* write to binary cache ] --> use it in various algorithms.
*
* ~~> means optional
*
* The dmlc text file parser returns number of blocks based on available threads, which
* can make the data partitioning non-deterministic, so here we set up an extra data pool
* to stage parsed data. As a result, the number of blocks returned by text parser does
* not equal to number of blocks in binary cache.
*
* Binary cache loading is async by the dmlc threaded iterator, which helps performance,
* but as this iterator itself is not thread safe, so calling
* `dmatrix->GetBatches<SparsePage>` is also not thread safe. Please note that, the
* threaded iterator is also used inside dmlc text file parser.
*
* Memory consumption is difficult to control due to various reasons. Firstly the text
* parsing doesn't have a batch size, only a hard coded buffer size is available.
* Secondly, everything is loaded/written with async queue, with multiple queues running
* the memory consumption is difficult to measure.
*
* The threaded iterator relies heavily on C++ memory model and threading primitive. The
* concurrent writer for binary cache is an old copy of moody queue. We should try to
* replace them with something more robust.
*/
#ifndef XGBOOST_DATA_SPARSE_PAGE_SOURCE_H_
#define XGBOOST_DATA_SPARSE_PAGE_SOURCE_H_
@ -19,6 +50,7 @@
#include <vector>
#include <fstream>
#include "rabit/rabit.h"
#include "xgboost/base.h"
#include "xgboost/data.h"
@ -121,9 +153,12 @@ inline void TryDeleteCacheFile(const std::string& file) {
inline void CheckCacheFileExists(const std::string& file) {
std::ifstream f(file.c_str());
if (f.good()) {
LOG(FATAL) << "Cache file " << file
<< " exists already; Is there another DMatrix with the same "
"cache prefix? Otherwise please remove it manually.";
LOG(FATAL)
<< "Cache file " << file << " exists already; "
<< "Is there another DMatrix with the same "
"cache prefix? It can be caused by previously used DMatrix that "
"hasn't been collected by language environment garbage collector. "
"Otherwise please remove it manually.";
}
}
@ -231,6 +266,38 @@ class ExternalMemoryPrefetcher : dmlc::DataIter<PageT> {
std::vector<std::unique_ptr<dmlc::ThreadedIter<PageT>>> prefetchers_;
};
// A data pool to keep the size of each page balanced and data partitioning to be
// deterministic.
class DataPool {
size_t inferred_num_rows_;
MetaInfo* info_;
SparsePage pool_;
size_t page_size_;
SparsePageWriter<SparsePage> *writer_;
void Slice(std::shared_ptr<SparsePage> out, size_t offset, size_t n_rows,
size_t entry_offset) const;
void SplitWritePage();
public:
DataPool(MetaInfo *info, size_t page_size,
SparsePageWriter<SparsePage> *writer)
: inferred_num_rows_{0}, info_{info},
page_size_{page_size}, writer_{writer} {}
void Push(std::shared_ptr<SparsePage> page) {
info_->num_nonzero_ += page->data.Size();
pool_.Push(*page);
if (pool_.Size() > page_size_) {
this->SplitWritePage();
}
page->Clear();
}
size_t Finalize();
};
class SparsePageSource {
public:
template <typename AdapterT>
@ -249,17 +316,12 @@ class SparsePageSource {
{
SparsePageWriter<SparsePage> writer(cache_info_.name_shards,
cache_info_.format_shards, 6);
std::shared_ptr<SparsePage> page;
writer.Alloc(&page);
page->Clear();
DataPool pool(&info, page_size, &writer);
std::shared_ptr<SparsePage> page { new SparsePage };
uint64_t inferred_num_columns = 0;
uint64_t inferred_num_rows = 0;
size_t bytes_write = 0;
double tstart = dmlc::GetTime();
// print every 4 sec.
constexpr double kStep = 4.0;
size_t tick_expected = static_cast<double>(kStep);
const uint64_t default_max = std::numeric_limits<uint64_t>::max();
uint64_t last_group_id = default_max;
@ -296,26 +358,13 @@ class SparsePageSource {
++group_size;
}
}
CHECK_EQ(page->Size(), 0);
auto batch_max_columns = page->Push(batch, missing, nthread);
inferred_num_columns =
std::max(batch_max_columns, inferred_num_columns);
if (page->MemCostBytes() >= page_size) {
inferred_num_rows += page->Size();
info.num_nonzero_ += page->offset.HostVector().back();
bytes_write += page->MemCostBytes();
writer.PushWrite(std::move(page));
writer.Alloc(&page);
page->Clear();
page->SetBaseRowId(inferred_num_rows);
double tdiff = dmlc::GetTime() - tstart;
if (tdiff >= tick_expected) {
LOG(CONSOLE) << "Writing " << page_type << " to " << cache_info
<< " in " << ((bytes_write >> 20UL) / tdiff)
<< " MB/s, " << (bytes_write >> 20UL) << " written";
tick_expected += static_cast<size_t>(kStep);
}
}
inferred_num_rows += page->Size();
pool.Push(page);
page->SetBaseRowId(inferred_num_rows);
}
if (last_group_id != default_max) {
@ -323,10 +372,6 @@ class SparsePageSource {
info.group_ptr_.push_back(group_size);
}
}
inferred_num_rows += page->Size();
if (!page->offset.HostVector().empty()) {
info.num_nonzero_ += page->offset.HostVector().back();
}
// Deal with empty rows/columns if necessary
if (adapter->NumColumns() == kAdapterUnknownSize) {
@ -352,10 +397,9 @@ class SparsePageSource {
info.num_row_ = adapter->NumRows();
}
// Make sure we have at least one page if the dataset is empty
if (page->data.Size() > 0 || info.num_row_ == 0) {
writer.PushWrite(std::move(page));
}
pool.Push(page);
pool.Finalize();
std::unique_ptr<dmlc::Stream> fo(
dmlc::Stream::Create(cache_info_.name_info.c_str(), "w"));
int tmagic = kMagic;

View File

@ -130,8 +130,10 @@ TEST(DenseColumnWithMissing, Test) {
void TestGHistIndexMatrixCreation(size_t nthreads) {
dmlc::TemporaryDirectory tmpdir;
std::string filename = tmpdir.path + "/big.libsvm";
size_t constexpr kPageSize = 1024, kEntriesPerCol = 3;
size_t constexpr kEntries = kPageSize * kEntriesPerCol * 2;
/* This should create multiple sparse pages */
std::unique_ptr<DMatrix> dmat{ CreateSparsePageDMatrix(1024, 1024, filename) };
std::unique_ptr<DMatrix> dmat{ CreateSparsePageDMatrix(kEntries, kPageSize, filename) };
omp_set_num_threads(nthreads);
GHistIndexMatrix gmat;
gmat.Init(dmat.get(), 256);

View File

@ -44,13 +44,13 @@ TEST(SparsePage, PushCSC) {
}
auto inst = page[0];
ASSERT_EQ(inst.size(), 2);
ASSERT_EQ(inst.size(), 2ul);
for (auto entry : inst) {
ASSERT_EQ(entry.index, 0);
ASSERT_EQ(entry.index, 0u);
}
inst = page[1];
ASSERT_EQ(inst.size(), 6);
ASSERT_EQ(inst.size(), 6ul);
std::vector<size_t> indices_sol {1, 2, 3};
for (size_t i = 0; i < inst.size(); ++i) {
ASSERT_EQ(inst[i].index, indices_sol[i % 3]);
@ -58,15 +58,12 @@ TEST(SparsePage, PushCSC) {
}
TEST(SparsePage, PushCSCAfterTranspose) {
#if defined(__APPLE__)
LOG(WARNING) << "FIXME(trivialfis): Skipping `PushCSCAfterTranspose' for APPLE.";
return;
#endif
dmlc::TemporaryDirectory tmpdir;
std::string filename = tmpdir.path + "/big.libsvm";
const int n_entries = 9;
size_t constexpr kPageSize = 1024, kEntriesPerCol = 3;
size_t constexpr kEntries = kPageSize * kEntriesPerCol * 2;
std::unique_ptr<DMatrix> dmat =
CreateSparsePageDMatrix(n_entries, 64UL, filename);
CreateSparsePageDMatrix(kEntries, 64UL, filename);
const int ncols = dmat->Info().num_col_;
SparsePage page; // Consolidated sparse page
for (const auto &batch : dmat->GetBatches<xgboost::SparsePage>()) {
@ -76,7 +73,7 @@ TEST(SparsePage, PushCSCAfterTranspose) {
}
// Make sure that the final sparse page has the right number of entries
ASSERT_EQ(n_entries, page.data.Size());
ASSERT_EQ(kEntries, page.data.Size());
// The feature value for a feature in each row should be identical, as that is
// how the dmatrix has been created

View File

@ -2,6 +2,7 @@
#include <dmlc/filesystem.h>
#include <gtest/gtest.h>
#include <xgboost/data.h>
#include "../../../src/common/io.h"
#include "../../../src/data/adapter.h"
#include "../../../src/data/sparse_page_dmatrix.h"
#include "../helpers.h"
@ -11,16 +12,18 @@ using namespace xgboost; // NOLINT
TEST(SparsePageDMatrix, MetaInfo) {
dmlc::TemporaryDirectory tempdir;
const std::string tmp_file = tempdir.path + "/simple.libsvm";
CreateSimpleTestData(tmp_file);
size_t constexpr kEntries = 24;
CreateBigTestData(tmp_file, kEntries);
xgboost::DMatrix *dmat = xgboost::DMatrix::Load(
tmp_file + "#" + tmp_file + ".cache", false, false);
std::cout << tmp_file << std::endl;
EXPECT_TRUE(FileExists(tmp_file + ".cache"));
// Test the metadata that was parsed
EXPECT_EQ(dmat->Info().num_row_, 2);
EXPECT_EQ(dmat->Info().num_col_, 5);
EXPECT_EQ(dmat->Info().num_nonzero_, 6);
EXPECT_EQ(dmat->Info().num_row_, 8ul);
EXPECT_EQ(dmat->Info().num_col_, 5ul);
EXPECT_EQ(dmat->Info().num_nonzero_, kEntries);
EXPECT_EQ(dmat->Info().labels_.Size(), dmat->Info().num_row_);
delete dmat;
@ -30,13 +33,13 @@ TEST(SparsePageDMatrix, RowAccess) {
dmlc::TemporaryDirectory tmpdir;
std::string filename = tmpdir.path + "/big.libsvm";
std::unique_ptr<xgboost::DMatrix> dmat =
xgboost::CreateSparsePageDMatrix(12, 64, filename);
xgboost::CreateSparsePageDMatrix(24, 4, filename);
// Test the data read into the first row
auto &batch = *dmat->GetBatches<xgboost::SparsePage>().begin();
auto first_row = batch[0];
ASSERT_EQ(first_row.size(), 3);
EXPECT_EQ(first_row[2].index, 2);
ASSERT_EQ(first_row.size(), 3ul);
EXPECT_EQ(first_row[2].index, 2u);
EXPECT_EQ(first_row[2].fvalue, 20);
}
@ -77,43 +80,56 @@ TEST(SparsePageDMatrix, ColAccess) {
TEST(SparsePageDMatrix, ExistingCacheFile) {
dmlc::TemporaryDirectory tmpdir;
std::string filename = tmpdir.path + "/big.libsvm";
size_t constexpr kPageSize = 64, kEntriesPerCol = 3;
size_t constexpr kEntries = kPageSize * kEntriesPerCol * 2;
std::unique_ptr<xgboost::DMatrix> dmat =
xgboost::CreateSparsePageDMatrix(12, 64, filename);
xgboost::CreateSparsePageDMatrix(kEntries, kPageSize, filename);
EXPECT_ANY_THROW({
std::unique_ptr<xgboost::DMatrix> dmat2 =
xgboost::CreateSparsePageDMatrix(12, 64, filename);
xgboost::CreateSparsePageDMatrix(kEntries, kPageSize, filename);
});
}
#if defined(_OPENMP)
TEST(SparsePageDMatrix, ThreadSafetyException) {
dmlc::TemporaryDirectory tmpdir;
std::string filename = tmpdir.path + "/test";
std::unique_ptr<xgboost::DMatrix> dmat =
xgboost::CreateSparsePageDMatrix(12, 64, filename);
size_t constexpr kPageSize = 64, kEntriesPerCol = 3;
size_t constexpr kEntries = kPageSize * kEntriesPerCol * 2;
bool exception = false;
std::unique_ptr<xgboost::DMatrix> dmat =
xgboost::CreateSparsePageDMatrix(kEntries, kPageSize, filename);
std::atomic<bool> exception {false};
int threads = 1000;
#pragma omp parallel for num_threads(threads)
for (auto i = 0; i < threads; i++) {
try {
auto iter = dmat->GetBatches<SparsePage>().begin();
++iter;
} catch (...) {
exception = true;
}
std::vector<std::thread> waiting;
for (int32_t i = 0; i < threads; ++i) {
waiting.emplace_back([&]() {
try {
auto iter = dmat->GetBatches<SparsePage>().begin();
++iter;
} catch (...) {
exception = true;
}
});
}
for (auto& t : waiting) {
t.join();
}
EXPECT_TRUE(exception);
}
#endif
// Multi-batches access
TEST(SparsePageDMatrix, ColAccessBatches) {
dmlc::TemporaryDirectory tmpdir;
std::string filename = tmpdir.path + "/big.libsvm";
size_t constexpr kPageSize = 1024, kEntriesPerCol = 3;
size_t constexpr kEntries = kPageSize * kEntriesPerCol * 2;
// Create multiple sparse pages
std::unique_ptr<xgboost::DMatrix> dmat{
xgboost::CreateSparsePageDMatrix(1024, 1024, filename)};
xgboost::CreateSparsePageDMatrix(kEntries, kPageSize, filename)};
auto n_threads = omp_get_max_threads();
omp_set_num_threads(16);
for (auto const &page : dmat->GetBatches<xgboost::CSCPage>()) {
@ -286,3 +302,70 @@ TEST(SparsePageDMatrix, FromFile) {
}
}
}
TEST(SparsePageDMatrix, Large) {
std::string filename = "test.libsvm";
CreateBigTestData(filename, 1 << 16);
std::unique_ptr<dmlc::Parser<uint32_t>> parser(
dmlc::Parser<uint32_t>::Create(filename.c_str(), 0, 1, "auto"));
data::FileAdapter adapter(parser.get());
dmlc::TemporaryDirectory tempdir;
const std::string tmp_file = tempdir.path + "/simple.libsvm";
std::unique_ptr<DMatrix> sparse{new data::SparsePageDMatrix(
&adapter, std::numeric_limits<float>::quiet_NaN(), -1, tmp_file, 16)};
std::unique_ptr<DMatrix> simple{DMatrix::Load(filename, true, true)};
std::vector<float> sparse_data;
std::vector<size_t> sparse_rptr;
std::vector<bst_feature_t> sparse_cids;
DMatrixToCSR(sparse.get(), &sparse_data, &sparse_rptr, &sparse_cids);
std::vector<float> simple_data;
std::vector<size_t> simple_rptr;
std::vector<bst_feature_t> simple_cids;
DMatrixToCSR(simple.get(), &simple_data, &simple_rptr, &simple_cids);
ASSERT_EQ(sparse_rptr.size(), sparse->Info().num_row_ + 1);
ASSERT_EQ(sparse_rptr.size(), simple->Info().num_row_ + 1);
ASSERT_EQ(sparse_data.size(), simple_data.size());
ASSERT_EQ(sparse_data, simple_data);
ASSERT_EQ(sparse_rptr.size(), simple_rptr.size());
ASSERT_EQ(sparse_rptr, simple_rptr);
ASSERT_EQ(sparse_cids, simple_cids);
}
auto TestSparsePageDMatrixDeterminism(int32_t threads, std::string const& filename) {
omp_set_num_threads(threads);
std::vector<float> sparse_data;
std::vector<size_t> sparse_rptr;
std::vector<bst_feature_t> sparse_cids;
std::unique_ptr<dmlc::Parser<uint32_t>> parser(
dmlc::Parser<uint32_t>::Create(filename.c_str(), 0, 1, "auto"));
data::FileAdapter adapter(parser.get());
dmlc::TemporaryDirectory tempdir;
const std::string tmp_file = tempdir.path + "/simple.libsvm";
std::unique_ptr<DMatrix> sparse{new data::SparsePageDMatrix(
&adapter, std::numeric_limits<float>::quiet_NaN(), -1, tmp_file, 1 << 8)};
DMatrixToCSR(sparse.get(), &sparse_data, &sparse_rptr, &sparse_cids);
std::string cache_name = tmp_file + ".row.page";
std::string cache = common::LoadSequentialFile(cache_name);
return cache;
}
TEST(SparsePageDMatrix, Determinism) {
std::string filename = "test.libsvm";
CreateBigTestData(filename, 1 << 16);
std::vector<std::string> caches;
for (size_t i = 1; i < 18; i += 2) {
caches.emplace_back(TestSparsePageDMatrixDeterminism(i, filename));
}
for (size_t i = 1; i < caches.size(); ++i) {
ASSERT_EQ(caches[i], caches.front());
}
}

View File

@ -28,7 +28,9 @@ TEST(SparsePageDMatrix, EllpackPage) {
TEST(SparsePageDMatrix, MultipleEllpackPages) {
dmlc::TemporaryDirectory tmpdir;
std::string filename = tmpdir.path + "/big.libsvm";
std::unique_ptr<DMatrix> dmat = CreateSparsePageDMatrix(12, 64, filename);
size_t constexpr kPageSize = 64, kEntriesPerCol = 3;
size_t constexpr kEntries = kPageSize * kEntriesPerCol * 2;
std::unique_ptr<DMatrix> dmat = CreateSparsePageDMatrix(kEntries, kPageSize, filename);
// Loop over the batches and count the records
int64_t batch_count = 0;

View File

@ -373,12 +373,8 @@ std::unique_ptr<DMatrix> CreateSparsePageDMatrix(
batch_count++;
row_count += batch.Size();
}
#if defined(_OPENMP)
EXPECT_GE(batch_count, 2);
EXPECT_EQ(row_count, dmat->Info().num_row_);
#else
#warning "External memory doesn't work with Non-OpenMP build "
#endif // defined(_OPENMP)
return dmat;
}
@ -495,6 +491,36 @@ std::unique_ptr<GradientBooster> CreateTrainedGBM(
return gbm;
}
void DMatrixToCSR(DMatrix *dmat, std::vector<float> *p_data,
std::vector<size_t> *p_row_ptr,
std::vector<bst_feature_t> *p_cids) {
auto &data = *p_data;
auto &row_ptr = *p_row_ptr;
auto &cids = *p_cids;
data.resize(dmat->Info().num_nonzero_);
cids.resize(data.size());
row_ptr.resize(dmat->Info().num_row_ + 1);
SparsePage page;
for (const auto &batch : dmat->GetBatches<SparsePage>()) {
page.Push(batch);
}
auto const& in_offset = page.offset.HostVector();
auto const& in_data = page.data.HostVector();
CHECK_EQ(in_offset.size(), row_ptr.size());
std::copy(in_offset.cbegin(), in_offset.cend(), row_ptr.begin());
ASSERT_EQ(in_data.size(), data.size());
std::transform(in_data.cbegin(), in_data.cend(), data.begin(), [](Entry const& e) {
return e.fvalue;
});
ASSERT_EQ(in_data.size(), cids.size());
std::transform(in_data.cbegin(), in_data.cend(), cids.begin(), [](Entry const& e) {
return e.index;
});
}
#if defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1
using CUDAMemoryResource = rmm::mr::cuda_memory_resource;

View File

@ -365,6 +365,10 @@ class CudaArrayIterForTest {
auto Proxy() -> decltype(proxy_) { return proxy_; }
};
void DMatrixToCSR(DMatrix *dmat, std::vector<float> *p_data,
std::vector<size_t> *p_row_ptr,
std::vector<bst_feature_t> *p_cids);
typedef void *DataIterHandle; // NOLINT(*)
inline void Reset(DataIterHandle self) {

View File

@ -16,8 +16,8 @@ TEST(CpuPredictor, Basic) {
std::unique_ptr<Predictor> cpu_predictor =
std::unique_ptr<Predictor>(Predictor::Create("cpu_predictor", &lparam));
int kRows = 5;
int kCols = 5;
size_t constexpr kRows = 5;
size_t constexpr kCols = 5;
LearnerModelParam param;
param.num_feature = kCols;
@ -85,7 +85,11 @@ TEST(CpuPredictor, Basic) {
TEST(CpuPredictor, ExternalMemory) {
dmlc::TemporaryDirectory tmpdir;
std::string filename = tmpdir.path + "/big.libsvm";
std::unique_ptr<DMatrix> dmat = CreateSparsePageDMatrix(12, 64, filename);
size_t constexpr kPageSize = 64, kEntriesPerCol = 3;
size_t constexpr kEntries = kPageSize * kEntriesPerCol * 2;
std::unique_ptr<DMatrix> dmat = CreateSparsePageDMatrix(kEntries, kPageSize, filename);
auto lparam = CreateEmptyGenericParam(GPUIDX);
std::unique_ptr<Predictor> cpu_predictor =

View File

@ -105,9 +105,9 @@ TEST(GPUPredictor, ExternalMemoryTest) {
std::string file0 = tmpdir.path + "/big_0.libsvm";
std::string file1 = tmpdir.path + "/big_1.libsvm";
std::string file2 = tmpdir.path + "/big_2.libsvm";
dmats.push_back(CreateSparsePageDMatrix(9, 64UL, file0));
dmats.push_back(CreateSparsePageDMatrix(128, 128UL, file1));
dmats.push_back(CreateSparsePageDMatrix(1024, 1024UL, file2));
dmats.push_back(CreateSparsePageDMatrix(400, 64UL, file0));
dmats.push_back(CreateSparsePageDMatrix(800, 128UL, file1));
dmats.push_back(CreateSparsePageDMatrix(8000, 1024UL, file2));
for (const auto& dmat: dmats) {
dmat->Info().base_margin_.Resize(dmat->Info().num_row_ * n_classes, 0.5);

View File

@ -1,5 +1,6 @@
import numpy as np
import sys
import gc
import pytest
import xgboost as xgb
from hypothesis import given, strategies, assume, settings, note
@ -118,7 +119,10 @@ class TestGPUUpdaters:
assume(len(dataset.y) > 0)
param['tree_method'] = 'gpu_hist'
param = dataset.set_params(param)
external_result = train_result(param, dataset.get_external_dmat(), num_rounds)
m = dataset.get_external_dmat()
external_result = train_result(param, m, num_rounds)
del m
gc.collect()
assert tm.non_increasing(external_result['train'][dataset.metric])
def test_empty_dmatrix_prediction(self):