diff --git a/tests/cpp/data/test_sparse_page_dmatrix.cc b/tests/cpp/data/test_sparse_page_dmatrix.cc new file mode 100644 index 000000000..82957dcee --- /dev/null +++ b/tests/cpp/data/test_sparse_page_dmatrix.cc @@ -0,0 +1,98 @@ +// Copyright by Contributors +#include +#include "../../../src/data/sparse_page_dmatrix.h" + +#include "../helpers.h" + +TEST(SparsePageDMatrix, MetaInfo) { + std::string tmp_file = CreateSimpleTestData(); + xgboost::DMatrix * dmat = xgboost::DMatrix::Load( + tmp_file + "#" + tmp_file + ".cache", true, false); + std::remove(tmp_file.c_str()); + EXPECT_TRUE(FileExists(tmp_file + ".cache")); + + // Test the metadata that was parsed + EXPECT_EQ(dmat->info().num_row, 2); + EXPECT_EQ(dmat->info().num_col, 5); + EXPECT_EQ(dmat->info().num_nonzero, 6); + EXPECT_EQ(dmat->info().labels.size(), dmat->info().num_row); + + // Clean up of external memory files + std::remove((tmp_file + ".cache").c_str()); + std::remove((tmp_file + ".cache.row.page").c_str()); +} + +TEST(SparsePageDMatrix, RowAccess) { + std::string tmp_file = CreateSimpleTestData(); + xgboost::DMatrix * dmat = xgboost::DMatrix::Load( + tmp_file + "#" + tmp_file + ".cache", true, false); + std::remove(tmp_file.c_str()); + EXPECT_TRUE(FileExists(tmp_file + ".cache.row.page")); + + dmlc::DataIter * row_iter = dmat->RowIterator(); + // Loop over the batches and count the records + long row_count = 0; + row_iter->BeforeFirst(); + while (row_iter->Next()) row_count += row_iter->Value().size; + EXPECT_EQ(row_count, dmat->info().num_row); + // Test the data read into the first row + row_iter->BeforeFirst(); + row_iter->Next(); + xgboost::SparseBatch::Inst first_row = row_iter->Value()[0]; + ASSERT_EQ(first_row.length, 3); + EXPECT_EQ(first_row[2].index, 2); + EXPECT_EQ(first_row[2].fvalue, 20); + row_iter = nullptr; + + // Clean up of external memory files + std::remove((tmp_file + ".cache").c_str()); + std::remove((tmp_file + ".cache.row.page").c_str()); +} + +TEST(SparsePageDMatrix, ColAcess) { + std::string tmp_file = CreateSimpleTestData(); + xgboost::DMatrix * dmat = xgboost::DMatrix::Load( + tmp_file + "#" + tmp_file + ".cache", true, false); + std::remove(tmp_file.c_str()); + EXPECT_FALSE(FileExists(tmp_file + ".cache.col.page")); + + EXPECT_EQ(dmat->HaveColAccess(), false); + const std::vector enable(dmat->info().num_col, true); + dmat->InitColAccess(enable, 1, 1); // Max 1 row per patch + ASSERT_EQ(dmat->HaveColAccess(), true); + EXPECT_TRUE(FileExists(tmp_file + ".cache.col.page")); + + EXPECT_EQ(dmat->GetColSize(0), 2); + EXPECT_EQ(dmat->GetColSize(1), 1); + EXPECT_EQ(dmat->GetColDensity(0), 1); + EXPECT_EQ(dmat->GetColDensity(1), 0.5); + + dmlc::DataIter * col_iter = dmat->ColIterator(); + // Loop over the batches and assert the data is as expected + long num_col_batch = 0; + col_iter->BeforeFirst(); + while (col_iter->Next()) { + num_col_batch += 1; + EXPECT_EQ(col_iter->Value().size, dmat->info().num_col) + << "Expected batch size to be same as num_cols as max_row_perbatch is 1."; + } + EXPECT_EQ(num_col_batch, dmat->info().num_row) + << "Expected num batches to be same as num_rows as max_row_perbatch is 1"; + col_iter = nullptr; + + std::vector sub_feats = {4, 3}; + dmlc::DataIter * sub_col_iter = dmat->ColIterator(sub_feats); + // Loop over the batches and assert the data is as expected + sub_col_iter->BeforeFirst(); + while (sub_col_iter->Next()) { + EXPECT_EQ(sub_col_iter->Value().size, sub_feats.size()) + << "Expected size of a batch to be same as number of columns " + << "as max_row_perbatch was set to 1."; + } + sub_col_iter = nullptr; + + // Clean up of external memory files + std::remove((tmp_file + ".cache").c_str()); + std::remove((tmp_file + ".cache.col.page").c_str()); + std::remove((tmp_file + ".cache.row.page").c_str()); +} diff --git a/tests/cpp/helpers.cc b/tests/cpp/helpers.cc index 263d29541..24dee0c06 100644 --- a/tests/cpp/helpers.cc +++ b/tests/cpp/helpers.cc @@ -4,6 +4,11 @@ std::string TempFileName() { return std::tmpnam(nullptr); } +bool FileExists(const std::string name) { + struct stat st; + return stat(name.c_str(), &st) == 0; +} + long GetFileSize(const std::string filename) { struct stat st; stat(filename.c_str(), &st); diff --git a/tests/cpp/helpers.h b/tests/cpp/helpers.h index 9b916ceb5..1d842a4c5 100644 --- a/tests/cpp/helpers.h +++ b/tests/cpp/helpers.h @@ -12,6 +12,8 @@ std::string TempFileName(); +bool FileExists(const std::string name); + long GetFileSize(const std::string filename); std::string CreateSimpleTestData();