Add google test for a column sampling, restore metainfo tests (#3637)

* Add google test for a column sampling, restore metainfo tests

* Update metainfo test for visual studio

* Fix multi-GPU bug introduced in #3635
This commit is contained in:
Rory Mitchell
2018-08-28 16:10:26 +12:00
committed by GitHub
parent 7ef2b599c7
commit 78bea0d204
3 changed files with 102 additions and 3 deletions

View File

@@ -69,4 +69,66 @@ TEST(MetaInfo, SaveLoadBinary) {
}
TEST(MetaInfo, LoadQid) {
std::string tmp_file = TempFileName();
{
std::unique_ptr<dmlc::Stream> fs(
dmlc::Stream::Create(tmp_file.c_str(), "w"));
dmlc::ostream os(fs.get());
os << R"qid(3 qid:1 1:1 2:1 3:0 4:0.2 5:0
2 qid:1 1:0 2:0 3:1 4:0.1 5:1
1 qid:1 1:0 2:1 3:0 4:0.4 5:0
1 qid:1 1:0 2:0 3:1 4:0.3 5:0
1 qid:2 1:0 2:0 3:1 4:0.2 5:0
2 qid:2 1:1 2:0 3:1 4:0.4 5:0
1 qid:2 1:0 2:0 3:1 4:0.1 5:0
1 qid:2 1:0 2:0 3:1 4:0.2 5:0
2 qid:3 1:0 2:0 3:1 4:0.1 5:1
3 qid:3 1:1 2:1 3:0 4:0.3 5:0
4 qid:3 1:1 2:0 3:0 4:0.4 5:1
1 qid:3 1:0 2:1 3:1 4:0.5 5:0)qid";
os.set_stream(nullptr);
}
std::unique_ptr<xgboost::DMatrix> dmat(
xgboost::DMatrix::Load(tmp_file, true, false, "libsvm"));
std::remove(tmp_file.c_str());
const xgboost::MetaInfo& info = dmat->Info();
const std::vector<uint64_t> expected_qids{1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3};
const std::vector<xgboost::bst_uint> expected_group_ptr{0, 4, 8, 12};
CHECK(info.qids_ == expected_qids);
CHECK(info.group_ptr_ == expected_group_ptr);
CHECK_GE(info.kVersion, info.kVersionQidAdded);
const std::vector<size_t> expected_offset{
0, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 60
};
const std::vector<xgboost::Entry> expected_data{
xgboost::Entry(1, 1), xgboost::Entry(2, 1), xgboost::Entry(3, 0),
xgboost::Entry(4, 0.2), xgboost::Entry(5, 0), xgboost::Entry(1, 0),
xgboost::Entry(2, 0), xgboost::Entry(3, 1), xgboost::Entry(4, 0.1),
xgboost::Entry(5, 1), xgboost::Entry(1, 0), xgboost::Entry(2, 1),
xgboost::Entry(3, 0), xgboost::Entry(4, 0.4), xgboost::Entry(5, 0),
xgboost::Entry(1, 0), xgboost::Entry(2, 0), xgboost::Entry(3, 1),
xgboost::Entry(4, 0.3), xgboost::Entry(5, 0), xgboost::Entry(1, 0),
xgboost::Entry(2, 0), xgboost::Entry(3, 1), xgboost::Entry(4, 0.2),
xgboost::Entry(5, 0), xgboost::Entry(1, 1), xgboost::Entry(2, 0),
xgboost::Entry(3, 1), xgboost::Entry(4, 0.4), xgboost::Entry(5, 0),
xgboost::Entry(1, 0), xgboost::Entry(2, 0), xgboost::Entry(3, 1),
xgboost::Entry(4, 0.1), xgboost::Entry(5, 0), xgboost::Entry(1, 0),
xgboost::Entry(2, 0), xgboost::Entry(3, 1), xgboost::Entry(4, 0.2),
xgboost::Entry(5, 0), xgboost::Entry(1, 0), xgboost::Entry(2, 0),
xgboost::Entry(3, 1), xgboost::Entry(4, 0.1), xgboost::Entry(5, 1),
xgboost::Entry(1, 1), xgboost::Entry(2, 1), xgboost::Entry(3, 0),
xgboost::Entry(4, 0.3), xgboost::Entry(5, 0), xgboost::Entry(1, 1),
xgboost::Entry(2, 0), xgboost::Entry(3, 0), xgboost::Entry(4, 0.4),
xgboost::Entry(5, 1), xgboost::Entry(1, 0), xgboost::Entry(2, 1),
xgboost::Entry(3, 1), xgboost::Entry(4, 0.5), {5, 0}};
dmlc::DataIter<xgboost::SparsePage>* iter = dmat->RowIterator();
iter->BeforeFirst();
CHECK(iter->Next());
const xgboost::SparsePage& batch = iter->Value();
CHECK_EQ(batch.base_rowid, 0);
CHECK(batch.offset == expected_offset);
CHECK(batch.data == expected_data);
CHECK(!iter->Next());
}