Use adapters for SparsePageDMatrix (#5092)

This commit is contained in:
Rory Mitchell 2019-12-11 15:59:23 +13:00 committed by GitHub
parent e089e16e3d
commit c7cc657a4d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 437 additions and 253 deletions

View File

@ -200,7 +200,7 @@ class SparsePage {
/*! \return Number of instances in the page. */
inline size_t Size() const {
return offset.Size() - 1;
return offset.Size() == 0 ? 0 : offset.Size() - 1;
}
/*! \return estimation of memory cost of this page */
@ -242,6 +242,20 @@ class SparsePage {
* \param batch the row batch.
*/
void Push(const dmlc::RowBlock<uint32_t>& batch);
/**
* \brief Pushes external data batch onto this page
*
* \tparam AdapterBatchT
* \param batch
* \param missing
* \param nthread
*
* \return The maximum number of columns encountered in this input batch. Useful when pushing many adapter batches to work out the total number of columns.
*/
template <typename AdapterBatchT>
uint64_t Push(const AdapterBatchT& batch, float missing, int nthread);
/*!
* \brief Push a sparse page
* \param batch the row page
@ -455,32 +469,20 @@ class DMatrix {
* \brief Creates a new DMatrix from an external data adapter.
*
* \tparam AdapterT Type of the adapter.
* \param adapter View onto an external data.
* \param [in,out] adapter View onto an external data.
* \param missing Values to count as missing.
* \param nthread Number of threads for construction.
* \param cache_prefix (Optional) The cache prefix for external memory.
* \param page_size (Optional) Size of the page.
*
* \return a Created DMatrix.
*/
template <typename AdapterT>
static DMatrix* Create(AdapterT* adapter, float missing, int nthread);
/*!
* \brief Create a DMatrix by loading data from parser.
* Parser can later be deleted after the DMatrix i created.
* \param parser The input data parser
* \param cache_prefix The path to prefix of temporary cache file of the DMatrix when used in external memory mode.
* This can be nullptr for common cases, and in-memory mode will be used.
* \param page_size Page size for external memory.
* \sa dmlc::Parser
* \note dmlc-core provides efficient distributed data parser for libsvm format.
* User can create and register customized parser to load their own format using DMLC_REGISTER_DATA_PARSER.
* See "dmlc-core/include/dmlc/data.h" for detail.
* \return A created DMatrix.
*/
static DMatrix* Create(dmlc::Parser<uint32_t>* parser,
static DMatrix* Create(AdapterT* adapter, float missing, int nthread,
const std::string& cache_prefix = "",
size_t page_size = kPageSize);
/*! \brief page size 32 MB */
static const size_t kPageSize = 32UL << 20UL;

View File

@ -196,7 +196,9 @@ int XGDMatrixCreateFromDataIter(
scache = cache_info;
}
NativeDataIter parser(data_handle, callback);
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(&parser, scache));
data::FileAdapter adapter(&parser);
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(
&adapter, std::numeric_limits<float>::quiet_NaN(), 1, scache));
API_END();
}

View File

@ -359,7 +359,9 @@ void GHistIndexMatrix::Init(DMatrix* p_fmat, int max_num_bins) {
// The number of threads is pegged to the batch size. If the OMP
// block is parallelized on anything other than the batch/block size,
// it should be reassigned
const size_t batch_threads = std::min(batch.Size(), static_cast<size_t>(omp_get_max_threads()));
const size_t batch_threads = std::max(
size_t(1),
std::min(batch.Size(), static_cast<size_t>(omp_get_max_threads())));
MemStackAllocator<size_t, 128> partial_sums(batch_threads);
size_t* p_part = partial_sums.Get();

View File

@ -124,9 +124,7 @@ class CSRAdapterBatch : public detail::NoMetaInfo {
: row_ptr(row_ptr),
feature_idx(feature_idx),
values(values),
num_rows(num_rows),
num_elements(num_elements),
num_features(num_features) {}
num_rows(num_rows) {}
const Line GetLine(size_t idx) const {
size_t begin_offset = row_ptr[idx];
size_t end_offset = row_ptr[idx + 1];
@ -139,9 +137,7 @@ class CSRAdapterBatch : public detail::NoMetaInfo {
const size_t* row_ptr;
const unsigned* feature_idx;
const float* values;
size_t num_elements;
size_t num_rows;
size_t num_features;
};
class CSRAdapter : public detail::SingleBatchDataIter<CSRAdapterBatch> {

View File

@ -224,10 +224,12 @@ DMatrix* DMatrix::Load(const std::string& uri,
std::unique_ptr<dmlc::Parser<uint32_t> > parser(
dmlc::Parser<uint32_t>::Create(fname.c_str(), partid, npart, file_format.c_str()));
data::FileAdapter adapter(parser.get());
DMatrix* dmat {nullptr};
try {
dmat = DMatrix::Create(parser.get(), cache_file, page_size);
dmat = DMatrix::Create(&adapter, std::numeric_limits<float>::quiet_NaN(), 1,
cache_file, page_size);
} catch (dmlc::Error& e) {
std::vector<std::string> splited = common::Split(fname, '#');
std::vector<std::string> args = common::Split(splited.front(), '?');
@ -282,27 +284,6 @@ DMatrix* DMatrix::Load(const std::string& uri,
return dmat;
}
DMatrix* DMatrix::Create(dmlc::Parser<uint32_t>* parser,
const std::string& cache_prefix,
const size_t page_size) {
if (cache_prefix.length() == 0) {
data::FileAdapter adapter(parser);
return DMatrix::Create(&adapter, std::numeric_limits<float>::quiet_NaN(),
1);
} else {
#if DMLC_ENABLE_STD_THREAD
if (!data::SparsePageSource<SparsePage>::CacheExist(cache_prefix, ".row.page")) {
data::SparsePageSource<SparsePage>::CreateRowPage(parser, cache_prefix, page_size);
}
std::unique_ptr<data::SparsePageSource<SparsePage>> source(
new data::SparsePageSource<SparsePage>(cache_prefix, ".row.page"));
return DMatrix::Create(std::move(source), cache_prefix);
#else
LOG(FATAL) << "External memory is not enabled in mingw";
return nullptr;
#endif // DMLC_ENABLE_STD_THREAD
}
}
void DMatrix::SaveToLocalFile(const std::string& fname) {
data::SimpleCSRSource source;
@ -352,20 +333,36 @@ DMatrix* DMatrix::Create(std::unique_ptr<DataSource<SparsePage>>&& source,
}
template <typename AdapterT>
DMatrix* DMatrix::Create(AdapterT* adapter, float missing, int nthread) {
DMatrix* DMatrix::Create(AdapterT* adapter, float missing, int nthread,
const std::string& cache_prefix, size_t page_size ) {
if (cache_prefix.length() == 0) {
return new data::SimpleDMatrix(adapter, missing, nthread);
} else {
#if DMLC_ENABLE_STD_THREAD
return new data::SparsePageDMatrix(adapter, missing, nthread, cache_prefix,
page_size);
#else
LOG(FATAL) << "External memory is not enabled in mingw";
return nullptr;
#endif // DMLC_ENABLE_STD_THREAD
}
}
template DMatrix* DMatrix::Create<data::DenseAdapter>(data::DenseAdapter* adapter,
float missing, int nthread);
template DMatrix* DMatrix::Create<data::CSRAdapter>(data::CSRAdapter* adapter,
float missing, int nthread);
template DMatrix* DMatrix::Create<data::CSCAdapter>(data::CSCAdapter* adapter,
float missing, int nthread);
template DMatrix* DMatrix::Create<data::DenseAdapter>(
data::DenseAdapter* adapter, float missing, int nthread,
const std::string& cache_prefix, size_t page_size);
template DMatrix* DMatrix::Create<data::CSRAdapter>(
data::CSRAdapter* adapter, float missing, int nthread,
const std::string& cache_prefix, size_t page_size);
template DMatrix* DMatrix::Create<data::CSCAdapter>(
data::CSCAdapter* adapter, float missing, int nthread,
const std::string& cache_prefix, size_t page_size);
template DMatrix* DMatrix::Create<data::DataTableAdapter>(
data::DataTableAdapter* adapter, float missing, int nthread);
template DMatrix* DMatrix::Create<data::FileAdapter>(data::FileAdapter* adapter,
float missing, int nthread);
data::DataTableAdapter* adapter, float missing, int nthread,
const std::string& cache_prefix, size_t page_size);
template DMatrix* DMatrix::Create<data::FileAdapter>(
data::FileAdapter* adapter, float missing, int nthread,
const std::string& cache_prefix, size_t page_size);
SparsePage SparsePage::GetTranspose(int num_columns) const {
SparsePage transpose;
@ -413,21 +410,72 @@ void SparsePage::Push(const SparsePage &batch) {
}
}
void SparsePage::Push(const dmlc::RowBlock<uint32_t>& batch) {
auto& data_vec = data.HostVector();
template <typename AdapterBatchT>
uint64_t SparsePage::Push(const AdapterBatchT& batch, float missing, int nthread) {
// Set number of threads but keep old value so we can reset it after
const int nthreadmax = omp_get_max_threads();
if (nthread <= 0) nthread = nthreadmax;
int nthread_original = omp_get_max_threads();
omp_set_num_threads(nthread);
auto& offset_vec = offset.HostVector();
data_vec.reserve(data.Size() + batch.offset[batch.size] - batch.offset[0]);
offset_vec.reserve(offset.Size() + batch.size);
CHECK(batch.index != nullptr);
for (size_t i = 0; i < batch.size; ++i) {
offset_vec.push_back(offset_vec.back() + batch.offset[i + 1] - batch.offset[i]);
auto& data_vec = data.HostVector();
size_t builder_base_row_offset = this->Size();
common::ParallelGroupBuilder<
Entry, std::remove_reference<decltype(offset_vec)>::type::value_type>
builder(&offset_vec, &data_vec, builder_base_row_offset);
// Estimate expected number of rows by using last element in batch
// This is not required to be exact but prevents unnecessary resizing
size_t expected_rows = 0;
if (batch.Size() > 0) {
auto last_line = batch.GetLine(batch.Size() - 1);
if (last_line.Size() > 0) {
expected_rows =
last_line.GetElement(last_line.Size() - 1).row_idx - base_rowid;
}
for (size_t i = batch.offset[0]; i < batch.offset[batch.size]; ++i) {
uint32_t index = batch.index[i];
bst_float fvalue = batch.value == nullptr ? 1.0f : batch.value[i];
data_vec.emplace_back(index, fvalue);
}
CHECK_EQ(offset_vec.back(), data.Size());
builder.InitBudget(expected_rows, nthread);
uint64_t max_columns = 0;
// First-pass over the batch counting valid elements
size_t num_lines = batch.Size();
#pragma omp parallel for schedule(static)
for (omp_ulong i = 0; i < static_cast<omp_ulong>(num_lines);
++i) { // NOLINT(*)
int tid = omp_get_thread_num();
auto line = batch.GetLine(i);
for (auto j = 0ull; j < line.Size(); j++) {
auto element = line.GetElement(j);
max_columns =
std::max(max_columns, static_cast<uint64_t>(element.column_idx + 1));
if (!common::CheckNAN(element.value) && element.value != missing) {
size_t key = element.row_idx -
base_rowid; // Adapter row index is absolute, here we want
// it relative to current page
CHECK_GE(key, builder_base_row_offset);
builder.AddBudget(element.row_idx - base_rowid, tid);
}
}
}
builder.InitStorage();
// Second pass over batch, placing elements in correct position
#pragma omp parallel for schedule(static)
for (omp_ulong i = 0; i < static_cast<omp_ulong>(num_lines);
++i) { // NOLINT(*)
int tid = omp_get_thread_num();
auto line = batch.GetLine(i);
for (auto j = 0ull; j < line.Size(); j++) {
auto element = line.GetElement(j);
if (!common::CheckNAN(element.value) && element.value != missing) {
size_t key = element.row_idx -
base_rowid; // Adapter row index is absolute, here we want
// it relative to current page
builder.Push(key, Entry(element.column_idx, element.value), tid);
}
}
}
omp_set_num_threads(nthread_original);
return max_columns;
}
void SparsePage::PushCSC(const SparsePage &batch) {

View File

@ -50,57 +50,9 @@ class SimpleDMatrix : public DMatrix {
adapter->BeforeFirst();
// Iterate over batches of input data
while (adapter->Next()) {
auto &batch = adapter->Value();
size_t base_row_offset = offset_vec.empty() ? 0 : offset_vec.size() - 1;
common::ParallelGroupBuilder<
Entry, std::remove_reference<decltype(offset_vec)>::type::value_type>
builder(&offset_vec, &data_vec, base_row_offset);
// Estimate expected number of rows by using last element in batch
// This is not required to be exact but prevents unnecessary resizing
size_t expected_rows = 0;
if (batch.Size() > 0) {
auto last_line = batch.GetLine(batch.Size() - 1);
if (last_line.Size() > 0) {
expected_rows = last_line.GetElement(last_line.Size() - 1).row_idx;
}
}
builder.InitBudget(expected_rows, nthread);
// First-pass over the batch counting valid elements
size_t num_lines = batch.Size();
#pragma omp parallel for schedule(static)
for (omp_ulong i = 0; i < static_cast<omp_ulong>(num_lines);
++i) { // NOLINT(*)
int tid = omp_get_thread_num();
auto line = batch.GetLine(i);
for (auto j = 0ull; j < line.Size(); j++) {
auto element = line.GetElement(j);
inferred_num_columns =
std::max(inferred_num_columns,
static_cast<uint64_t>(element.column_idx + 1));
if (!common::CheckNAN(element.value) && element.value != missing) {
builder.AddBudget(element.row_idx, tid);
}
}
}
builder.InitStorage();
// Second pass over batch, placing elements in correct position
#pragma omp parallel for schedule(static)
for (omp_ulong i = 0; i < static_cast<omp_ulong>(num_lines);
++i) { // NOLINT(*)
int tid = omp_get_thread_num();
auto line = batch.GetLine(i);
for (auto j = 0ull; j < line.Size(); j++) {
auto element = line.GetElement(j);
if (!common::CheckNAN(element.value) && element.value != missing) {
builder.Push(element.row_idx, Entry(element.column_idx, element.value),
tid);
}
}
}
auto& batch = adapter->Value();
auto batch_max_columns = mat.page_.Push(batch, missing, nthread);
inferred_num_columns = std::max(batch_max_columns, inferred_num_columns);
// Append meta information if available
if (batch.Labels() != nullptr) {
auto& labels = mat.info.labels_.HostVector();

View File

@ -25,6 +25,21 @@ class SparsePageDMatrix : public DMatrix {
explicit SparsePageDMatrix(std::unique_ptr<DataSource<SparsePage>>&& source,
std::string cache_info)
: row_source_(std::move(source)), cache_info_(std::move(cache_info)) {}
template <typename AdapterT>
explicit SparsePageDMatrix(AdapterT* adapter, float missing, int nthread,
const std::string& cache_prefix,
size_t page_size = kPageSize)
: cache_info_(std::move(cache_prefix)) {
if (!data::SparsePageSource<SparsePage>::CacheExist(cache_prefix,
".row.page")) {
data::SparsePageSource<SparsePage>::CreateRowPage(
adapter, missing, nthread, cache_prefix, page_size);
}
row_source_.reset(
new data::SparsePageSource<SparsePage>(cache_prefix, ".row.page"));
}
// Set number of threads but keep old value so we can reset it after
~SparsePageDMatrix() override = default;
MetaInfo& Info() override;

View File

@ -21,6 +21,7 @@
#include "xgboost/base.h"
#include "xgboost/data.h"
#include "adapter.h"
#include "sparse_page_writer.h"
#include "../common/common.h"
@ -182,22 +183,21 @@ class SparsePageSource : public DataSource<T> {
return *page_;
}
/*!
* \brief Create source by taking data from parser.
* \param src source parser.
* \param cache_info The cache_info of cache file location.
* \param page_size Page size for external memory.
*/
static void CreateRowPage(dmlc::Parser<uint32_t>* src,
template <typename AdapterT>
static void CreateRowPage(AdapterT* adapter, float missing, int nthread,
const std::string& cache_info,
const size_t page_size = DMatrix::kPageSize) {
const std::string page_type = ".row.page";
auto cinfo = ParseCacheInfo(cache_info, page_type);
{
SparsePageWriter<SparsePage> writer(cinfo.name_shards, cinfo.format_shards, 6);
SparsePageWriter<SparsePage> writer(cinfo.name_shards,
cinfo.format_shards, 6);
std::shared_ptr<SparsePage> page;
writer.Alloc(&page); page->Clear();
writer.Alloc(&page);
page->Clear();
uint64_t inferred_num_columns = 0;
uint64_t inferred_num_rows = 0;
MetaInfo info;
size_t bytes_write = 0;
double tstart = dmlc::GetTime();
@ -209,22 +209,24 @@ class SparsePageSource : public DataSource<T> {
uint64_t last_group_id = default_max;
bst_uint group_size = 0;
std::vector<uint64_t> qids;
while (src->Next()) {
const dmlc::RowBlock<uint32_t>& batch = src->Value();
if (batch.label != nullptr) {
adapter->BeforeFirst();
while (adapter->Next()) {
auto& batch = adapter->Value();
if (batch.Labels() != nullptr) {
auto& labels = info.labels_.HostVector();
labels.insert(labels.end(), batch.label, batch.label + batch.size);
labels.insert(labels.end(), batch.Labels(),
batch.Labels() + batch.Size());
}
if (batch.weight != nullptr) {
if (batch.Weights() != nullptr) {
auto& weights = info.weights_.HostVector();
weights.insert(weights.end(), batch.weight, batch.weight + batch.size);
weights.insert(weights.end(), batch.Weights(),
batch.Weights() + batch.Size());
}
if (batch.qid != nullptr) {
qids.insert(qids.end(), batch.qid, batch.qid + batch.size);
if (batch.Qid() != nullptr) {
qids.insert(qids.end(), batch.Qid(), batch.Qid() + batch.Size());
// get group
for (size_t i = 0; i < batch.size; ++i) {
const uint64_t cur_group_id = batch.qid[i];
for (size_t i = 0; i < batch.Size(); ++i) {
const uint64_t cur_group_id = batch.Qid()[i];
if (last_group_id == default_max || last_group_id != cur_group_id) {
info.group_ptr_.push_back(group_size);
}
@ -232,49 +234,77 @@ class SparsePageSource : public DataSource<T> {
++group_size;
}
}
info.num_row_ += batch.size;
info.num_nonzero_ += batch.offset[batch.size] - batch.offset[0];
for (size_t i = batch.offset[0]; i < batch.offset[batch.size]; ++i) {
uint32_t index = batch.index[i];
info.num_col_ = std::max(info.num_col_,
static_cast<uint64_t>(index + 1));
}
page->Push(batch);
auto batch_max_columns = page->Push(batch, missing, nthread);
inferred_num_columns =
std::max(batch_max_columns, inferred_num_columns);
if (page->MemCostBytes() >= page_size) {
inferred_num_rows += page->Size();
info.num_nonzero_ += page->offset.HostVector().back();
bytes_write += page->MemCostBytes();
writer.PushWrite(std::move(page));
writer.Alloc(&page);
page->Clear();
page->SetBaseRowId(inferred_num_rows);
double tdiff = dmlc::GetTime() - tstart;
if (tdiff >= tick_expected) {
LOG(CONSOLE) << "Writing " << page_type << " to " << cache_info
<< " in " << ((bytes_write >> 20UL) / tdiff) << " MB/s, "
<< (bytes_write >> 20UL) << " written";
<< " in " << ((bytes_write >> 20UL) / tdiff)
<< " MB/s, " << (bytes_write >> 20UL) << " written";
tick_expected += static_cast<size_t>(kStep);
}
}
}
if (last_group_id != default_max) {
if (group_size > info.group_ptr_.back()) {
info.group_ptr_.push_back(group_size);
}
}
if (page->data.Size() != 0) {
writer.PushWrite(std::move(page));
inferred_num_rows += page->Size();
if (!page->offset.HostVector().empty()) {
info.num_nonzero_ += page->offset.HostVector().back();
}
std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(cinfo.name_info.c_str(), "w"));
// Deal with empty rows/columns if necessary
if (adapter->NumColumns() == kAdapterUnknownSize) {
info.num_col_ = inferred_num_columns;
} else {
info.num_col_ = adapter->NumColumns();
}
// Synchronise worker columns
rabit::Allreduce<rabit::op::Max>(&info.num_col_, 1);
if (adapter->NumRows() == kAdapterUnknownSize) {
info.num_row_ = inferred_num_rows;
} else {
if (page->offset.HostVector().empty()) {
page->offset.HostVector().emplace_back(0);
}
while (inferred_num_rows < adapter->NumRows()) {
page->offset.HostVector().emplace_back(
page->offset.HostVector().back());
inferred_num_rows++;
}
info.num_row_ = adapter->NumRows();
}
// Make sure we have at least one page if the dataset is empty
if (page->data.Size() > 0 || info.num_row_ == 0) {
writer.PushWrite(std::move(page));
}
std::unique_ptr<dmlc::Stream> fo(
dmlc::Stream::Create(cinfo.name_info.c_str(), "w"));
int tmagic = kMagic;
fo->Write(&tmagic, sizeof(tmagic));
// Either every row has query ID or none at all
CHECK(qids.empty() || qids.size() == info.num_row_);
info.SaveBinary(fo.get());
}
LOG(INFO) << "SparsePageSource::CreateRowPage Finished writing to " << cinfo.name_info;
LOG(INFO) << "SparsePageSource::CreateRowPage Finished writing to "
<< cinfo.name_info;
}
/*!
* \brief Create source cache by copy content from DMatrix.
* Creates transposed column page, may be sorted or not.

View File

@ -69,6 +69,7 @@ inline SparsePageFormat<T>* CreatePageFormat(const std::string& name) {
auto *e = ::dmlc::Registry<SparsePageFormatReg<T>>::Get()->Find(name);
if (e == nullptr) {
LOG(FATAL) << "Unknown format type " << name;
return nullptr;
}
return (e->body)();
}

View File

@ -1,8 +1,6 @@
// Copyright (c) 2019 by Contributors
#include <gtest/gtest.h>
#include <xgboost/c_api.h>
#include <xgboost/data.h>
#include <xgboost/version_config.h>
#include "../../../src/data/adapter.h"
#include "../../../src/data/simple_dmatrix.h"
#include "../../../src/common/timer.h"
@ -29,71 +27,6 @@ TEST(c_api, CSRAdapter) {
EXPECT_EQ(line2 .GetElement(0).value, 5);
EXPECT_EQ(line2 .GetElement(0).row_idx, 2);
EXPECT_EQ(line2 .GetElement(0).column_idx, 1);
data::SimpleDMatrix dmat(&adapter, std::nan(""), -1);
EXPECT_EQ(dmat.Info().num_col_, 2);
EXPECT_EQ(dmat.Info().num_row_, 3);
EXPECT_EQ(dmat.Info().num_nonzero_, 5);
for (auto &batch : dmat.GetBatches<SparsePage>()) {
for (auto i = 0ull; i < batch.Size(); i++) {
auto inst = batch[i];
for(auto j = 0ull; j < inst.size(); j++)
{
EXPECT_EQ(inst[j].fvalue, data[row_ptr[i] + j]);
EXPECT_EQ(inst[j].index, feature_idx[row_ptr[i] + j]);
}
}
}
}
TEST(c_api, DenseAdapter) {
int m = 3;
int n = 2;
std::vector<float> data = {1, 2, 3, 4, 5, 6};
data::DenseAdapter adapter(data.data(), m, m*n, n);
data::SimpleDMatrix dmat(&adapter, std::numeric_limits<float>::quiet_NaN(), -1);
EXPECT_EQ(dmat.Info().num_col_, 2);
EXPECT_EQ(dmat.Info().num_row_, 3);
EXPECT_EQ(dmat.Info().num_nonzero_, 6);
for (auto &batch : dmat.GetBatches<SparsePage>()) {
for (auto i = 0ull; i < batch.Size(); i++) {
auto inst = batch[i];
for(auto j = 0ull; j < inst.size(); j++)
{
EXPECT_EQ(inst[j].fvalue, data[i*n+j]);
EXPECT_EQ(inst[j].index, j);
}
}
}
}
TEST(c_api, CSCAdapter) {
std::vector<float> data = {1, 3, 2, 4, 5};
std::vector<unsigned> row_idx = {0, 1, 0, 1, 2};
std::vector<size_t> col_ptr = {0, 2, 5};
data::CSCAdapter adapter(col_ptr.data(), row_idx.data(), data.data(), 2, 3);
data::SimpleDMatrix dmat(&adapter, std::numeric_limits<float>::quiet_NaN(), -1);
EXPECT_EQ(dmat.Info().num_col_, 2);
EXPECT_EQ(dmat.Info().num_row_, 3);
EXPECT_EQ(dmat.Info().num_nonzero_, 5);
auto &batch = *dmat.GetBatches<SparsePage>().begin();
auto inst = batch[0];
EXPECT_EQ(inst[0].fvalue, 1);
EXPECT_EQ(inst[0].index, 0);
EXPECT_EQ(inst[1].fvalue, 2);
EXPECT_EQ(inst[1].index, 1);
inst = batch[1];
EXPECT_EQ(inst[0].fvalue, 3);
EXPECT_EQ(inst[0].index, 0);
EXPECT_EQ(inst[1].fvalue, 4);
EXPECT_EQ(inst[1].index, 1);
inst = batch[2];
EXPECT_EQ(inst[0].fvalue, 5);
EXPECT_EQ(inst[0].index, 1);
}
TEST(c_api, CSCAdapterColsMoreThanRows) {
@ -128,10 +61,3 @@ TEST(c_api, CSCAdapterColsMoreThanRows) {
EXPECT_EQ(inst[3].fvalue, 8);
EXPECT_EQ(inst[3].index, 3);
}
TEST(c_api, FileAdapter) {
std::string filename = "test.libsvm";
CreateBigTestData(filename, 10);
std::unique_ptr<dmlc::Parser<uint32_t>> parser(dmlc::Parser<uint32_t>::Create(filename.c_str(), 0, 1,"auto"));
data::FileAdapter adapter(parser.get());
}

View File

@ -1,10 +1,10 @@
// Copyright by Contributors
#include <xgboost/data.h>
#include <dmlc/filesystem.h>
#include <xgboost/data.h>
#include "../../../src/data/simple_dmatrix.h"
#include "../helpers.h"
#include "../../../src/data/adapter.h"
#include "../helpers.h"
using namespace xgboost; // NOLINT
@ -12,7 +12,7 @@ TEST(SimpleDMatrix, MetaInfo) {
dmlc::TemporaryDirectory tempdir;
const std::string tmp_file = tempdir.path + "/simple.libsvm";
CreateSimpleTestData(tmp_file);
xgboost::DMatrix * dmat = xgboost::DMatrix::Load(tmp_file, true, false);
xgboost::DMatrix *dmat = xgboost::DMatrix::Load(tmp_file, true, false);
// Test the metadata that was parsed
EXPECT_EQ(dmat->Info().num_row_, 2);
@ -27,7 +27,7 @@ TEST(SimpleDMatrix, RowAccess) {
dmlc::TemporaryDirectory tempdir;
const std::string tmp_file = tempdir.path + "/simple.libsvm";
CreateSimpleTestData(tmp_file);
xgboost::DMatrix * dmat = xgboost::DMatrix::Load(tmp_file, false, false);
xgboost::DMatrix *dmat = xgboost::DMatrix::Load(tmp_file, false, false);
// Loop over the batches and count the records
int64_t row_count = 0;
@ -49,7 +49,7 @@ TEST(SimpleDMatrix, ColAccessWithoutBatches) {
dmlc::TemporaryDirectory tempdir;
const std::string tmp_file = tempdir.path + "/simple.libsvm";
CreateSimpleTestData(tmp_file);
xgboost::DMatrix * dmat = xgboost::DMatrix::Load(tmp_file, true, false);
xgboost::DMatrix *dmat = xgboost::DMatrix::Load(tmp_file, true, false);
// Sorted column access
EXPECT_EQ(dmat->GetColDensity(0), 1);
@ -72,7 +72,8 @@ TEST(SimpleDMatrix, Empty) {
std::vector<unsigned> feature_idx = {};
std::vector<size_t> row_ptr = {};
data::CSRAdapter csr_adapter(row_ptr.data(), feature_idx.data(), data.data(), 0, 0, 0);
data::CSRAdapter csr_adapter(row_ptr.data(), feature_idx.data(), data.data(),
0, 0, 0);
data::SimpleDMatrix dmat(&csr_adapter,
std::numeric_limits<float>::quiet_NaN(), 1);
CHECK_EQ(dmat.Info().num_nonzero_, 0);
@ -108,8 +109,10 @@ TEST(SimpleDMatrix, MissingData) {
std::vector<unsigned> feature_idx = {0, 1, 0};
std::vector<size_t> row_ptr = {0, 2, 3};
data::CSRAdapter adapter(row_ptr.data(), feature_idx.data(), data.data(), 2, 3, 2);
data::SimpleDMatrix dmat(&adapter, std::numeric_limits<float>::quiet_NaN(), 1);
data::CSRAdapter adapter(row_ptr.data(), feature_idx.data(), data.data(), 2,
3, 2);
data::SimpleDMatrix dmat(&adapter, std::numeric_limits<float>::quiet_NaN(),
1);
CHECK_EQ(dmat.Info().num_nonzero_, 2);
dmat = data::SimpleDMatrix(&adapter, 1.0, 1);
CHECK_EQ(dmat.Info().num_nonzero_, 1);
@ -120,13 +123,66 @@ TEST(SimpleDMatrix, EmptyRow) {
std::vector<unsigned> feature_idx = {0, 1};
std::vector<size_t> row_ptr = {0, 2, 2};
data::CSRAdapter adapter(row_ptr.data(), feature_idx.data(), data.data(), 2, 2, 2);
data::SimpleDMatrix dmat(&adapter, std::numeric_limits<float>::quiet_NaN(), 1);
data::CSRAdapter adapter(row_ptr.data(), feature_idx.data(), data.data(), 2,
2, 2);
data::SimpleDMatrix dmat(&adapter, std::numeric_limits<float>::quiet_NaN(),
1);
CHECK_EQ(dmat.Info().num_nonzero_, 2);
CHECK_EQ(dmat.Info().num_row_, 2);
CHECK_EQ(dmat.Info().num_col_, 2);
}
TEST(SimpleDMatrix, FromDense) {
int m = 3;
int n = 2;
std::vector<float> data = {1, 2, 3, 4, 5, 6};
data::DenseAdapter adapter(data.data(), m, m * n, n);
data::SimpleDMatrix dmat(&adapter, std::numeric_limits<float>::quiet_NaN(),
-1);
EXPECT_EQ(dmat.Info().num_col_, 2);
EXPECT_EQ(dmat.Info().num_row_, 3);
EXPECT_EQ(dmat.Info().num_nonzero_, 6);
for (auto &batch : dmat.GetBatches<SparsePage>()) {
for (auto i = 0ull; i < batch.Size(); i++) {
auto inst = batch[i];
for (auto j = 0ull; j < inst.size(); j++) {
EXPECT_EQ(inst[j].fvalue, data[i * n + j]);
EXPECT_EQ(inst[j].index, j);
}
}
}
}
TEST(SimpleDMatrix, FromCSC) {
std::vector<float> data = {1, 3, 2, 4, 5};
std::vector<unsigned> row_idx = {0, 1, 0, 1, 2};
std::vector<size_t> col_ptr = {0, 2, 5};
data::CSCAdapter adapter(col_ptr.data(), row_idx.data(), data.data(), 2, 3);
data::SimpleDMatrix dmat(&adapter, std::numeric_limits<float>::quiet_NaN(),
-1);
EXPECT_EQ(dmat.Info().num_col_, 2);
EXPECT_EQ(dmat.Info().num_row_, 3);
EXPECT_EQ(dmat.Info().num_nonzero_, 5);
auto &batch = *dmat.GetBatches<SparsePage>().begin();
auto inst = batch[0];
EXPECT_EQ(inst[0].fvalue, 1);
EXPECT_EQ(inst[0].index, 0);
EXPECT_EQ(inst[1].fvalue, 2);
EXPECT_EQ(inst[1].index, 1);
inst = batch[1];
EXPECT_EQ(inst[0].fvalue, 3);
EXPECT_EQ(inst[0].index, 0);
EXPECT_EQ(inst[1].fvalue, 4);
EXPECT_EQ(inst[1].index, 1);
inst = batch[2];
EXPECT_EQ(inst[0].fvalue, 5);
EXPECT_EQ(inst[0].index, 1);
}
TEST(SimpleDMatrix, FromFile) {
std::string filename = "test.libsvm";
CreateBigTestData(filename, 3 * 5);
@ -142,12 +198,11 @@ TEST(SimpleDMatrix, FromFile) {
EXPECT_EQ(batch.base_rowid, 0);
for (auto i = 0ull; i < batch.Size(); i++) {
if (i%2== 0) {
if (i % 2 == 0) {
EXPECT_EQ(batch[i][0].index, 0);
EXPECT_EQ(batch[i][1].index, 1);
EXPECT_EQ(batch[i][2].index, 2);
}
else {
} else {
EXPECT_EQ(batch[i][0].index, 0);
EXPECT_EQ(batch[i][1].index, 3);
EXPECT_EQ(batch[i][2].index, 4);

View File

@ -1,12 +1,12 @@
// Copyright by Contributors
#include <dmlc/filesystem.h>
#include <xgboost/data.h>
#include <dmlc/filesystem.h>
#include <cinttypes>
#include "../../../src/data/sparse_page_dmatrix.h"
#include "../../../src/data/adapter.h"
#include "../helpers.h"
#include <gtest/gtest.h>
using namespace xgboost; // NOLINT
TEST(SparsePageDMatrix, MetaInfo) {
dmlc::TemporaryDirectory tempdir;
@ -87,3 +87,158 @@ TEST(SparsePageDMatrix, ColAccessBatches) {
}
omp_set_num_threads(n_threads);
}
TEST(SparsePageDMatrix, Empty) {
dmlc::TemporaryDirectory tempdir;
const std::string tmp_file = tempdir.path + "/simple.libsvm";
std::vector<float> data{};
std::vector<unsigned> feature_idx = {};
std::vector<size_t> row_ptr = {};
data::CSRAdapter csr_adapter(row_ptr.data(), feature_idx.data(), data.data(), 0, 0, 0);
data::SparsePageDMatrix dmat(&csr_adapter,
std::numeric_limits<float>::quiet_NaN(), 1,tmp_file);
EXPECT_EQ(dmat.Info().num_nonzero_, 0);
EXPECT_EQ(dmat.Info().num_row_, 0);
EXPECT_EQ(dmat.Info().num_col_, 0);
for (auto &batch : dmat.GetBatches<SparsePage>()) {
EXPECT_EQ(batch.Size(), 0);
}
data::DenseAdapter dense_adapter(nullptr, 0, 0, 0);
data::SparsePageDMatrix dmat2(&dense_adapter,
std::numeric_limits<float>::quiet_NaN(), 1,tmp_file);
EXPECT_EQ(dmat2.Info().num_nonzero_, 0);
EXPECT_EQ(dmat2.Info().num_row_, 0);
EXPECT_EQ(dmat2.Info().num_col_, 0);
for (auto &batch : dmat2.GetBatches<SparsePage>()) {
EXPECT_EQ(batch.Size(), 0);
}
data::CSCAdapter csc_adapter(nullptr, nullptr, nullptr, 0, 0);
data::SparsePageDMatrix dmat3(&csc_adapter,
std::numeric_limits<float>::quiet_NaN(), 1,tmp_file);
EXPECT_EQ(dmat3.Info().num_nonzero_, 0);
EXPECT_EQ(dmat3.Info().num_row_, 0);
EXPECT_EQ(dmat3.Info().num_col_, 0);
for (auto &batch : dmat3.GetBatches<SparsePage>()) {
EXPECT_EQ(batch.Size(), 0);
}
}
TEST(SparsePageDMatrix, MissingData) {
dmlc::TemporaryDirectory tempdir;
const std::string tmp_file = tempdir.path + "/simple.libsvm";
std::vector<float> data{0.0, std::nanf(""), 1.0};
std::vector<unsigned> feature_idx = {0, 1, 0};
std::vector<size_t> row_ptr = {0, 2, 3};
data::CSRAdapter adapter(row_ptr.data(), feature_idx.data(), data.data(), 2, 3, 2);
data::SparsePageDMatrix dmat(&adapter, std::numeric_limits<float>::quiet_NaN(), 1,tmp_file);
EXPECT_EQ(dmat.Info().num_nonzero_, 2);
const std::string tmp_file2 = tempdir.path + "/simple2.libsvm";
data::SparsePageDMatrix dmat2(&adapter, 1.0, 1,tmp_file2);
EXPECT_EQ(dmat2.Info().num_nonzero_, 1);
}
TEST(SparsePageDMatrix, EmptyRow) {
dmlc::TemporaryDirectory tempdir;
const std::string tmp_file = tempdir.path + "/simple.libsvm";
std::vector<float> data{0.0, 1.0};
std::vector<unsigned> feature_idx = {0, 1};
std::vector<size_t> row_ptr = {0, 2, 2};
data::CSRAdapter adapter(row_ptr.data(), feature_idx.data(), data.data(), 2, 2, 2);
data::SparsePageDMatrix dmat(&adapter, std::numeric_limits<float>::quiet_NaN(), 1,tmp_file);
EXPECT_EQ(dmat.Info().num_nonzero_, 2);
EXPECT_EQ(dmat.Info().num_row_, 2);
EXPECT_EQ(dmat.Info().num_col_, 2);
}
TEST(SparsePageDMatrix, FromDense) {
dmlc::TemporaryDirectory tempdir;
const std::string tmp_file = tempdir.path + "/simple.libsvm";
int m = 3;
int n = 2;
std::vector<float> data = {1, 2, 3, 4, 5, 6};
data::DenseAdapter adapter(data.data(), m, m * n, n);
data::SparsePageDMatrix dmat(
&adapter, std::numeric_limits<float>::quiet_NaN(), 1, tmp_file);
EXPECT_EQ(dmat.Info().num_col_, 2);
EXPECT_EQ(dmat.Info().num_row_, 3);
EXPECT_EQ(dmat.Info().num_nonzero_, 6);
for (auto &batch : dmat.GetBatches<SparsePage>()) {
for (auto i = 0ull; i < batch.Size(); i++) {
auto inst = batch[i];
for(auto j = 0ull; j < inst.size(); j++)
{
EXPECT_EQ(inst[j].fvalue, data[i*n+j]);
EXPECT_EQ(inst[j].index, j);
}
}
}
}
TEST(SparsePageDMatrix, FromCSC) {
dmlc::TemporaryDirectory tempdir;
const std::string tmp_file = tempdir.path + "/simple.libsvm";
std::vector<float> data = {1, 3, 2, 4, 5};
std::vector<unsigned> row_idx = {0, 1, 0, 1, 2};
std::vector<size_t> col_ptr = {0, 2, 5};
data::CSCAdapter adapter(col_ptr.data(), row_idx.data(), data.data(), 2, 3);
data::SparsePageDMatrix dmat(
&adapter, std::numeric_limits<float>::quiet_NaN(), -1, tmp_file);
EXPECT_EQ(dmat.Info().num_col_, 2);
EXPECT_EQ(dmat.Info().num_row_, 3);
EXPECT_EQ(dmat.Info().num_nonzero_, 5);
auto &batch = *dmat.GetBatches<SparsePage>().begin();
auto inst = batch[0];
EXPECT_EQ(inst[0].fvalue, 1);
EXPECT_EQ(inst[0].index, 0);
EXPECT_EQ(inst[1].fvalue, 2);
EXPECT_EQ(inst[1].index, 1);
inst = batch[1];
EXPECT_EQ(inst[0].fvalue, 3);
EXPECT_EQ(inst[0].index, 0);
EXPECT_EQ(inst[1].fvalue, 4);
EXPECT_EQ(inst[1].index, 1);
inst = batch[2];
EXPECT_EQ(inst[0].fvalue, 5);
EXPECT_EQ(inst[0].index, 1);
}
TEST(SparsePageDMatrix, FromFile) {
std::string filename = "test.libsvm";
CreateBigTestData(filename,20);
std::unique_ptr<dmlc::Parser<uint32_t>> parser(
dmlc::Parser<uint32_t>::Create(filename.c_str(), 0, 1, "auto"));
data::FileAdapter adapter(parser.get());
dmlc::TemporaryDirectory tempdir;
const std::string tmp_file = tempdir.path + "/simple.libsvm";
data::SparsePageDMatrix dmat(
&adapter, std::numeric_limits<float>::quiet_NaN(), -1, tmp_file, 1);
for (auto &batch : dmat.GetBatches<SparsePage>()) {
std::vector<bst_row_t> expected_offset(batch.Size() + 1);
int n = -3;
std::generate(expected_offset.begin(), expected_offset.end(),
[&n] { return n += 3; });
EXPECT_EQ(batch.offset.HostVector(), expected_offset);
if (batch.base_rowid % 2 == 0) {
EXPECT_EQ(batch[0][0].index, 0);
EXPECT_EQ(batch[0][1].index, 1);
EXPECT_EQ(batch[0][2].index, 2);
} else {
EXPECT_EQ(batch[0][0].index, 0);
EXPECT_EQ(batch[0][1].index, 3);
EXPECT_EQ(batch[0][2].index, 4);
}
}
}