diff --git a/include/xgboost/data.h b/include/xgboost/data.h index d8cca2ec6..6cae56349 100644 --- a/include/xgboost/data.h +++ b/include/xgboost/data.h @@ -60,8 +60,6 @@ class MetaInfo { std::vector group_ptr_; /*! \brief weights of each instance, optional */ HostDeviceVector weights_; - /*! \brief session-id of each instance, optional */ - std::vector qids_; /*! * \brief initialized margins, * if specified, xgboost will start from this init margin @@ -69,9 +67,9 @@ class MetaInfo { */ HostDeviceVector base_margin_; /*! \brief version flag, used to check version of this info */ - static const int kVersion = 2; - /*! \brief version that introduced qid field */ - static const int kVersionQidAdded = 2; + static const int kVersion = 3; + /*! \brief version that contains qid field */ + static const int kVersionWithQid = 2; /*! \brief default constructor */ MetaInfo() = default; /*! diff --git a/src/data/data.cc b/src/data/data.cc index 68ca0cfef..40084eaec 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -27,7 +27,6 @@ void MetaInfo::Clear() { labels_.HostVector().clear(); root_index_.clear(); group_ptr_.clear(); - qids_.clear(); weights_.HostVector().clear(); base_margin_.HostVector().clear(); } @@ -40,7 +39,6 @@ void MetaInfo::SaveBinary(dmlc::Stream *fo) const { fo->Write(&num_nonzero_, sizeof(num_nonzero_)); fo->Write(labels_.HostVector()); fo->Write(group_ptr_); - fo->Write(qids_); fo->Write(weights_.HostVector()); fo->Write(root_index_); fo->Write(base_margin_.HostVector()); @@ -56,10 +54,9 @@ void MetaInfo::LoadBinary(dmlc::Stream *fi) { << "MetaInfo: invalid format"; CHECK(fi->Read(&labels_.HostVector())) << "MetaInfo: invalid format"; CHECK(fi->Read(&group_ptr_)) << "MetaInfo: invalid format"; - if (version >= kVersionQidAdded) { - CHECK(fi->Read(&qids_)) << "MetaInfo: invalid format"; - } else { // old format doesn't contain qid field - qids_.clear(); + if (version == kVersionWithQid) { + std::vector qids; + CHECK(fi->Read(&qids)) << "MetaInfo: invalid format"; } CHECK(fi->Read(&weights_.HostVector())) << "MetaInfo: invalid format"; CHECK(fi->Read(&root_index_)) << "MetaInfo: invalid format"; diff --git a/src/data/simple_csr_source.cc b/src/data/simple_csr_source.cc index fb52cf8d3..56d7223ce 100644 --- a/src/data/simple_csr_source.cc +++ b/src/data/simple_csr_source.cc @@ -28,6 +28,7 @@ void SimpleCSRSource::CopyFrom(dmlc::Parser* parser) { const uint64_t default_max = std::numeric_limits::max(); uint64_t last_group_id = default_max; bst_uint group_size = 0; + std::vector qids; this->Clear(); while (parser->Next()) { const dmlc::RowBlock& batch = parser->Value(); @@ -40,7 +41,7 @@ void SimpleCSRSource::CopyFrom(dmlc::Parser* parser) { weights.insert(weights.end(), batch.weight, batch.weight + batch.size); } if (batch.qid != nullptr) { - info.qids_.insert(info.qids_.end(), batch.qid, batch.qid + batch.size); + qids.insert(qids.end(), batch.qid, batch.qid + batch.size); // get group for (size_t i = 0; i < batch.size; ++i) { const uint64_t cur_group_id = batch.qid[i]; @@ -82,7 +83,7 @@ void SimpleCSRSource::CopyFrom(dmlc::Parser* parser) { } this->info.num_nonzero_ = static_cast(page_.data.Size()); // Either every row has query ID or none at all - CHECK(info.qids_.empty() || info.qids_.size() == info.num_row_); + CHECK(qids.empty() || qids.size() == info.num_row_); } void SimpleCSRSource::LoadBinary(dmlc::Stream* fi) { diff --git a/src/data/sparse_page_source.h b/src/data/sparse_page_source.h index 8bc008b15..a6f15173b 100644 --- a/src/data/sparse_page_source.h +++ b/src/data/sparse_page_source.h @@ -172,6 +172,7 @@ class SparsePageSource : public DataSource { const uint64_t default_max = std::numeric_limits::max(); uint64_t last_group_id = default_max; bst_uint group_size = 0; + std::vector qids; while (src->Next()) { const dmlc::RowBlock& batch = src->Value(); @@ -184,7 +185,7 @@ class SparsePageSource : public DataSource { weights.insert(weights.end(), batch.weight, batch.weight + batch.size); } if (batch.qid != nullptr) { - info.qids_.insert(info.qids_.end(), batch.qid, batch.qid + batch.size); + qids.insert(qids.end(), batch.qid, batch.qid + batch.size); // get group for (size_t i = 0; i < batch.size; ++i) { const uint64_t cur_group_id = batch.qid[i]; @@ -233,7 +234,7 @@ class SparsePageSource : public DataSource { int tmagic = kMagic; fo->Write(&tmagic, sizeof(tmagic)); // Either every row has query ID or none at all - CHECK(info.qids_.empty() || info.qids_.size() == info.num_row_); + CHECK(qids.empty() || qids.size() == info.num_row_); info.SaveBinary(fo.get()); } LOG(INFO) << "SparsePageSource::CreateRowPage Finished writing to " diff --git a/tests/cpp/data/test_metainfo.cc b/tests/cpp/data/test_metainfo.cc index 158c50bd5..38e157fbb 100644 --- a/tests/cpp/data/test_metainfo.cc +++ b/tests/cpp/data/test_metainfo.cc @@ -55,7 +55,7 @@ TEST(MetaInfo, SaveLoadBinary) { info.SaveBinary(fs); delete fs; - ASSERT_EQ(GetFileSize(tmp_file), 84) + ASSERT_EQ(GetFileSize(tmp_file), 76) << "Expected saved binary file size to be same as object size"; fs = dmlc::Stream::Create(tmp_file.c_str(), "r"); @@ -92,11 +92,8 @@ TEST(MetaInfo, LoadQid) { xgboost::DMatrix::Load(tmp_file, true, false, "libsvm")); const xgboost::MetaInfo& info = dmat->Info(); - const std::vector expected_qids{1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3}; const std::vector expected_group_ptr{0, 4, 8, 12}; - CHECK(info.qids_ == expected_qids); CHECK(info.group_ptr_ == expected_group_ptr); - CHECK_GE(info.kVersion, info.kVersionQidAdded); const std::vector expected_offset{ 0, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 60