remove the qids_ field in MetaInfo (#4744)
This commit is contained in:
@@ -27,7 +27,6 @@ void MetaInfo::Clear() {
|
||||
labels_.HostVector().clear();
|
||||
root_index_.clear();
|
||||
group_ptr_.clear();
|
||||
qids_.clear();
|
||||
weights_.HostVector().clear();
|
||||
base_margin_.HostVector().clear();
|
||||
}
|
||||
@@ -40,7 +39,6 @@ void MetaInfo::SaveBinary(dmlc::Stream *fo) const {
|
||||
fo->Write(&num_nonzero_, sizeof(num_nonzero_));
|
||||
fo->Write(labels_.HostVector());
|
||||
fo->Write(group_ptr_);
|
||||
fo->Write(qids_);
|
||||
fo->Write(weights_.HostVector());
|
||||
fo->Write(root_index_);
|
||||
fo->Write(base_margin_.HostVector());
|
||||
@@ -56,10 +54,9 @@ void MetaInfo::LoadBinary(dmlc::Stream *fi) {
|
||||
<< "MetaInfo: invalid format";
|
||||
CHECK(fi->Read(&labels_.HostVector())) << "MetaInfo: invalid format";
|
||||
CHECK(fi->Read(&group_ptr_)) << "MetaInfo: invalid format";
|
||||
if (version >= kVersionQidAdded) {
|
||||
CHECK(fi->Read(&qids_)) << "MetaInfo: invalid format";
|
||||
} else { // old format doesn't contain qid field
|
||||
qids_.clear();
|
||||
if (version == kVersionWithQid) {
|
||||
std::vector<uint64_t> qids;
|
||||
CHECK(fi->Read(&qids)) << "MetaInfo: invalid format";
|
||||
}
|
||||
CHECK(fi->Read(&weights_.HostVector())) << "MetaInfo: invalid format";
|
||||
CHECK(fi->Read(&root_index_)) << "MetaInfo: invalid format";
|
||||
|
||||
@@ -28,6 +28,7 @@ void SimpleCSRSource::CopyFrom(dmlc::Parser<uint32_t>* parser) {
|
||||
const uint64_t default_max = std::numeric_limits<uint64_t>::max();
|
||||
uint64_t last_group_id = default_max;
|
||||
bst_uint group_size = 0;
|
||||
std::vector<uint64_t> qids;
|
||||
this->Clear();
|
||||
while (parser->Next()) {
|
||||
const dmlc::RowBlock<uint32_t>& batch = parser->Value();
|
||||
@@ -40,7 +41,7 @@ void SimpleCSRSource::CopyFrom(dmlc::Parser<uint32_t>* parser) {
|
||||
weights.insert(weights.end(), batch.weight, batch.weight + batch.size);
|
||||
}
|
||||
if (batch.qid != nullptr) {
|
||||
info.qids_.insert(info.qids_.end(), batch.qid, batch.qid + batch.size);
|
||||
qids.insert(qids.end(), batch.qid, batch.qid + batch.size);
|
||||
// get group
|
||||
for (size_t i = 0; i < batch.size; ++i) {
|
||||
const uint64_t cur_group_id = batch.qid[i];
|
||||
@@ -82,7 +83,7 @@ void SimpleCSRSource::CopyFrom(dmlc::Parser<uint32_t>* parser) {
|
||||
}
|
||||
this->info.num_nonzero_ = static_cast<uint64_t>(page_.data.Size());
|
||||
// Either every row has query ID or none at all
|
||||
CHECK(info.qids_.empty() || info.qids_.size() == info.num_row_);
|
||||
CHECK(qids.empty() || qids.size() == info.num_row_);
|
||||
}
|
||||
|
||||
void SimpleCSRSource::LoadBinary(dmlc::Stream* fi) {
|
||||
|
||||
@@ -172,6 +172,7 @@ class SparsePageSource : public DataSource<T> {
|
||||
const uint64_t default_max = std::numeric_limits<uint64_t>::max();
|
||||
uint64_t last_group_id = default_max;
|
||||
bst_uint group_size = 0;
|
||||
std::vector<uint64_t> qids;
|
||||
|
||||
while (src->Next()) {
|
||||
const dmlc::RowBlock<uint32_t>& batch = src->Value();
|
||||
@@ -184,7 +185,7 @@ class SparsePageSource : public DataSource<T> {
|
||||
weights.insert(weights.end(), batch.weight, batch.weight + batch.size);
|
||||
}
|
||||
if (batch.qid != nullptr) {
|
||||
info.qids_.insert(info.qids_.end(), batch.qid, batch.qid + batch.size);
|
||||
qids.insert(qids.end(), batch.qid, batch.qid + batch.size);
|
||||
// get group
|
||||
for (size_t i = 0; i < batch.size; ++i) {
|
||||
const uint64_t cur_group_id = batch.qid[i];
|
||||
@@ -233,7 +234,7 @@ class SparsePageSource : public DataSource<T> {
|
||||
int tmagic = kMagic;
|
||||
fo->Write(&tmagic, sizeof(tmagic));
|
||||
// Either every row has query ID or none at all
|
||||
CHECK(info.qids_.empty() || info.qids_.size() == info.num_row_);
|
||||
CHECK(qids.empty() || qids.size() == info.num_row_);
|
||||
info.SaveBinary(fo.get());
|
||||
}
|
||||
LOG(INFO) << "SparsePageSource::CreateRowPage Finished writing to "
|
||||
|
||||
Reference in New Issue
Block a user