remove the qids_ field in MetaInfo (#4744)
This commit is contained in:
parent
f22b1c0348
commit
19f9fd5de9
@ -60,8 +60,6 @@ class MetaInfo {
|
|||||||
std::vector<bst_uint> group_ptr_;
|
std::vector<bst_uint> group_ptr_;
|
||||||
/*! \brief weights of each instance, optional */
|
/*! \brief weights of each instance, optional */
|
||||||
HostDeviceVector<bst_float> weights_;
|
HostDeviceVector<bst_float> weights_;
|
||||||
/*! \brief session-id of each instance, optional */
|
|
||||||
std::vector<uint64_t> qids_;
|
|
||||||
/*!
|
/*!
|
||||||
* \brief initialized margins,
|
* \brief initialized margins,
|
||||||
* if specified, xgboost will start from this init margin
|
* if specified, xgboost will start from this init margin
|
||||||
@ -69,9 +67,9 @@ class MetaInfo {
|
|||||||
*/
|
*/
|
||||||
HostDeviceVector<bst_float> base_margin_;
|
HostDeviceVector<bst_float> base_margin_;
|
||||||
/*! \brief version flag, used to check version of this info */
|
/*! \brief version flag, used to check version of this info */
|
||||||
static const int kVersion = 2;
|
static const int kVersion = 3;
|
||||||
/*! \brief version that introduced qid field */
|
/*! \brief version that contains qid field */
|
||||||
static const int kVersionQidAdded = 2;
|
static const int kVersionWithQid = 2;
|
||||||
/*! \brief default constructor */
|
/*! \brief default constructor */
|
||||||
MetaInfo() = default;
|
MetaInfo() = default;
|
||||||
/*!
|
/*!
|
||||||
|
|||||||
@ -27,7 +27,6 @@ void MetaInfo::Clear() {
|
|||||||
labels_.HostVector().clear();
|
labels_.HostVector().clear();
|
||||||
root_index_.clear();
|
root_index_.clear();
|
||||||
group_ptr_.clear();
|
group_ptr_.clear();
|
||||||
qids_.clear();
|
|
||||||
weights_.HostVector().clear();
|
weights_.HostVector().clear();
|
||||||
base_margin_.HostVector().clear();
|
base_margin_.HostVector().clear();
|
||||||
}
|
}
|
||||||
@ -40,7 +39,6 @@ void MetaInfo::SaveBinary(dmlc::Stream *fo) const {
|
|||||||
fo->Write(&num_nonzero_, sizeof(num_nonzero_));
|
fo->Write(&num_nonzero_, sizeof(num_nonzero_));
|
||||||
fo->Write(labels_.HostVector());
|
fo->Write(labels_.HostVector());
|
||||||
fo->Write(group_ptr_);
|
fo->Write(group_ptr_);
|
||||||
fo->Write(qids_);
|
|
||||||
fo->Write(weights_.HostVector());
|
fo->Write(weights_.HostVector());
|
||||||
fo->Write(root_index_);
|
fo->Write(root_index_);
|
||||||
fo->Write(base_margin_.HostVector());
|
fo->Write(base_margin_.HostVector());
|
||||||
@ -56,10 +54,9 @@ void MetaInfo::LoadBinary(dmlc::Stream *fi) {
|
|||||||
<< "MetaInfo: invalid format";
|
<< "MetaInfo: invalid format";
|
||||||
CHECK(fi->Read(&labels_.HostVector())) << "MetaInfo: invalid format";
|
CHECK(fi->Read(&labels_.HostVector())) << "MetaInfo: invalid format";
|
||||||
CHECK(fi->Read(&group_ptr_)) << "MetaInfo: invalid format";
|
CHECK(fi->Read(&group_ptr_)) << "MetaInfo: invalid format";
|
||||||
if (version >= kVersionQidAdded) {
|
if (version == kVersionWithQid) {
|
||||||
CHECK(fi->Read(&qids_)) << "MetaInfo: invalid format";
|
std::vector<uint64_t> qids;
|
||||||
} else { // old format doesn't contain qid field
|
CHECK(fi->Read(&qids)) << "MetaInfo: invalid format";
|
||||||
qids_.clear();
|
|
||||||
}
|
}
|
||||||
CHECK(fi->Read(&weights_.HostVector())) << "MetaInfo: invalid format";
|
CHECK(fi->Read(&weights_.HostVector())) << "MetaInfo: invalid format";
|
||||||
CHECK(fi->Read(&root_index_)) << "MetaInfo: invalid format";
|
CHECK(fi->Read(&root_index_)) << "MetaInfo: invalid format";
|
||||||
|
|||||||
@ -28,6 +28,7 @@ void SimpleCSRSource::CopyFrom(dmlc::Parser<uint32_t>* parser) {
|
|||||||
const uint64_t default_max = std::numeric_limits<uint64_t>::max();
|
const uint64_t default_max = std::numeric_limits<uint64_t>::max();
|
||||||
uint64_t last_group_id = default_max;
|
uint64_t last_group_id = default_max;
|
||||||
bst_uint group_size = 0;
|
bst_uint group_size = 0;
|
||||||
|
std::vector<uint64_t> qids;
|
||||||
this->Clear();
|
this->Clear();
|
||||||
while (parser->Next()) {
|
while (parser->Next()) {
|
||||||
const dmlc::RowBlock<uint32_t>& batch = parser->Value();
|
const dmlc::RowBlock<uint32_t>& batch = parser->Value();
|
||||||
@ -40,7 +41,7 @@ void SimpleCSRSource::CopyFrom(dmlc::Parser<uint32_t>* parser) {
|
|||||||
weights.insert(weights.end(), batch.weight, batch.weight + batch.size);
|
weights.insert(weights.end(), batch.weight, batch.weight + batch.size);
|
||||||
}
|
}
|
||||||
if (batch.qid != nullptr) {
|
if (batch.qid != nullptr) {
|
||||||
info.qids_.insert(info.qids_.end(), batch.qid, batch.qid + batch.size);
|
qids.insert(qids.end(), batch.qid, batch.qid + batch.size);
|
||||||
// get group
|
// get group
|
||||||
for (size_t i = 0; i < batch.size; ++i) {
|
for (size_t i = 0; i < batch.size; ++i) {
|
||||||
const uint64_t cur_group_id = batch.qid[i];
|
const uint64_t cur_group_id = batch.qid[i];
|
||||||
@ -82,7 +83,7 @@ void SimpleCSRSource::CopyFrom(dmlc::Parser<uint32_t>* parser) {
|
|||||||
}
|
}
|
||||||
this->info.num_nonzero_ = static_cast<uint64_t>(page_.data.Size());
|
this->info.num_nonzero_ = static_cast<uint64_t>(page_.data.Size());
|
||||||
// Either every row has query ID or none at all
|
// Either every row has query ID or none at all
|
||||||
CHECK(info.qids_.empty() || info.qids_.size() == info.num_row_);
|
CHECK(qids.empty() || qids.size() == info.num_row_);
|
||||||
}
|
}
|
||||||
|
|
||||||
void SimpleCSRSource::LoadBinary(dmlc::Stream* fi) {
|
void SimpleCSRSource::LoadBinary(dmlc::Stream* fi) {
|
||||||
|
|||||||
@ -172,6 +172,7 @@ class SparsePageSource : public DataSource<T> {
|
|||||||
const uint64_t default_max = std::numeric_limits<uint64_t>::max();
|
const uint64_t default_max = std::numeric_limits<uint64_t>::max();
|
||||||
uint64_t last_group_id = default_max;
|
uint64_t last_group_id = default_max;
|
||||||
bst_uint group_size = 0;
|
bst_uint group_size = 0;
|
||||||
|
std::vector<uint64_t> qids;
|
||||||
|
|
||||||
while (src->Next()) {
|
while (src->Next()) {
|
||||||
const dmlc::RowBlock<uint32_t>& batch = src->Value();
|
const dmlc::RowBlock<uint32_t>& batch = src->Value();
|
||||||
@ -184,7 +185,7 @@ class SparsePageSource : public DataSource<T> {
|
|||||||
weights.insert(weights.end(), batch.weight, batch.weight + batch.size);
|
weights.insert(weights.end(), batch.weight, batch.weight + batch.size);
|
||||||
}
|
}
|
||||||
if (batch.qid != nullptr) {
|
if (batch.qid != nullptr) {
|
||||||
info.qids_.insert(info.qids_.end(), batch.qid, batch.qid + batch.size);
|
qids.insert(qids.end(), batch.qid, batch.qid + batch.size);
|
||||||
// get group
|
// get group
|
||||||
for (size_t i = 0; i < batch.size; ++i) {
|
for (size_t i = 0; i < batch.size; ++i) {
|
||||||
const uint64_t cur_group_id = batch.qid[i];
|
const uint64_t cur_group_id = batch.qid[i];
|
||||||
@ -233,7 +234,7 @@ class SparsePageSource : public DataSource<T> {
|
|||||||
int tmagic = kMagic;
|
int tmagic = kMagic;
|
||||||
fo->Write(&tmagic, sizeof(tmagic));
|
fo->Write(&tmagic, sizeof(tmagic));
|
||||||
// Either every row has query ID or none at all
|
// Either every row has query ID or none at all
|
||||||
CHECK(info.qids_.empty() || info.qids_.size() == info.num_row_);
|
CHECK(qids.empty() || qids.size() == info.num_row_);
|
||||||
info.SaveBinary(fo.get());
|
info.SaveBinary(fo.get());
|
||||||
}
|
}
|
||||||
LOG(INFO) << "SparsePageSource::CreateRowPage Finished writing to "
|
LOG(INFO) << "SparsePageSource::CreateRowPage Finished writing to "
|
||||||
|
|||||||
@ -55,7 +55,7 @@ TEST(MetaInfo, SaveLoadBinary) {
|
|||||||
info.SaveBinary(fs);
|
info.SaveBinary(fs);
|
||||||
delete fs;
|
delete fs;
|
||||||
|
|
||||||
ASSERT_EQ(GetFileSize(tmp_file), 84)
|
ASSERT_EQ(GetFileSize(tmp_file), 76)
|
||||||
<< "Expected saved binary file size to be same as object size";
|
<< "Expected saved binary file size to be same as object size";
|
||||||
|
|
||||||
fs = dmlc::Stream::Create(tmp_file.c_str(), "r");
|
fs = dmlc::Stream::Create(tmp_file.c_str(), "r");
|
||||||
@ -92,11 +92,8 @@ TEST(MetaInfo, LoadQid) {
|
|||||||
xgboost::DMatrix::Load(tmp_file, true, false, "libsvm"));
|
xgboost::DMatrix::Load(tmp_file, true, false, "libsvm"));
|
||||||
|
|
||||||
const xgboost::MetaInfo& info = dmat->Info();
|
const xgboost::MetaInfo& info = dmat->Info();
|
||||||
const std::vector<uint64_t> expected_qids{1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3};
|
|
||||||
const std::vector<xgboost::bst_uint> expected_group_ptr{0, 4, 8, 12};
|
const std::vector<xgboost::bst_uint> expected_group_ptr{0, 4, 8, 12};
|
||||||
CHECK(info.qids_ == expected_qids);
|
|
||||||
CHECK(info.group_ptr_ == expected_group_ptr);
|
CHECK(info.group_ptr_ == expected_group_ptr);
|
||||||
CHECK_GE(info.kVersion, info.kVersionQidAdded);
|
|
||||||
|
|
||||||
const std::vector<size_t> expected_offset{
|
const std::vector<size_t> expected_offset{
|
||||||
0, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 60
|
0, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 60
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user