Add qid like ranklib format (#2749)
* add qid for https://github.com/dmlc/xgboost/issues/2748 * change names * change spaces * change qid to bst_uint type * change qid type to size_t * change qid first to SIZE_MAX * change qid type from size_t to uint64_t * update dmlc-core * fix qids name error * fix group_ptr_ error * Style fix * Add qid handling logic to SparsePage * New MetaInfo format + backward compatibility fix Old MetaInfo format (1.0) doesn't contain qid field. We still want to be able to read from MetaInfo files saved in old format. Also, define a new format (2.0) that contains the qid field. This way, we can distinguish files that contain qid and those that do not. * Update MetaInfo test * Simply group assignment logic * Explicitly set qid=nullptr in NativeDataIter NativeDataIter's callback does not support qid field. Users of NativeDataIter will need to call setGroup() function separately to set group information. * Save qids_ in SaveBinary() * Upgrade dmlc-core submodule * Add a test for reading qid * Add contributor * Check the size of qids_ * Document qid format
This commit is contained in:
committed by
Philip Hyunsu Cho
parent
18813a26ab
commit
0cf88d036f
@@ -141,6 +141,8 @@ class NativeDataIter : public dmlc::Parser<uint32_t> {
|
||||
block_.offset = dmlc::BeginPtr(offset_);
|
||||
block_.label = dmlc::BeginPtr(label_);
|
||||
block_.weight = dmlc::BeginPtr(weight_);
|
||||
block_.qid = nullptr;
|
||||
block_.field = nullptr;
|
||||
block_.index = dmlc::BeginPtr(index_);
|
||||
block_.value = dmlc::BeginPtr(value_);
|
||||
bytes_read_ += offset_.size() * sizeof(size_t) +
|
||||
|
||||
@@ -28,6 +28,7 @@ void MetaInfo::Clear() {
|
||||
labels_.clear();
|
||||
root_index_.clear();
|
||||
group_ptr_.clear();
|
||||
qids_.clear();
|
||||
weights_.clear();
|
||||
base_margin_.clear();
|
||||
}
|
||||
@@ -40,6 +41,7 @@ void MetaInfo::SaveBinary(dmlc::Stream *fo) const {
|
||||
fo->Write(&num_nonzero_, sizeof(num_nonzero_));
|
||||
fo->Write(labels_);
|
||||
fo->Write(group_ptr_);
|
||||
fo->Write(qids_);
|
||||
fo->Write(weights_);
|
||||
fo->Write(root_index_);
|
||||
fo->Write(base_margin_);
|
||||
@@ -48,13 +50,18 @@ void MetaInfo::SaveBinary(dmlc::Stream *fo) const {
|
||||
void MetaInfo::LoadBinary(dmlc::Stream *fi) {
|
||||
int version;
|
||||
CHECK(fi->Read(&version, sizeof(version)) == sizeof(version)) << "MetaInfo: invalid version";
|
||||
CHECK_EQ(version, kVersion) << "MetaInfo: invalid format";
|
||||
CHECK(version >= 1 && version <= kVersion) << "MetaInfo: unsupported file version";
|
||||
CHECK(fi->Read(&num_row_, sizeof(num_row_)) == sizeof(num_row_)) << "MetaInfo: invalid format";
|
||||
CHECK(fi->Read(&num_col_, sizeof(num_col_)) == sizeof(num_col_)) << "MetaInfo: invalid format";
|
||||
CHECK(fi->Read(&num_nonzero_, sizeof(num_nonzero_)) == sizeof(num_nonzero_))
|
||||
<< "MetaInfo: invalid format";
|
||||
CHECK(fi->Read(&labels_)) << "MetaInfo: invalid format";
|
||||
CHECK(fi->Read(&group_ptr_)) << "MetaInfo: invalid format";
|
||||
if (version >= kVersionQidAdded) {
|
||||
CHECK(fi->Read(&qids_)) << "MetaInfo: invalid format";
|
||||
} else { // old format doesn't contain qid field
|
||||
qids_.clear();
|
||||
}
|
||||
CHECK(fi->Read(&weights_)) << "MetaInfo: invalid format";
|
||||
CHECK(fi->Read(&root_index_)) << "MetaInfo: invalid format";
|
||||
CHECK(fi->Read(&base_margin_)) << "MetaInfo: invalid format";
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
*/
|
||||
#include <dmlc/base.h>
|
||||
#include <xgboost/logging.h>
|
||||
#include <limits>
|
||||
#include "./simple_csr_source.h"
|
||||
|
||||
namespace xgboost {
|
||||
@@ -26,6 +27,10 @@ void SimpleCSRSource::CopyFrom(DMatrix* src) {
|
||||
}
|
||||
|
||||
void SimpleCSRSource::CopyFrom(dmlc::Parser<uint32_t>* parser) {
|
||||
// use qid to get group info
|
||||
const uint64_t default_max = std::numeric_limits<uint64_t>::max();
|
||||
uint64_t last_group_id = default_max;
|
||||
bst_uint group_size = 0;
|
||||
this->Clear();
|
||||
while (parser->Next()) {
|
||||
const dmlc::RowBlock<uint32_t>& batch = parser->Value();
|
||||
@@ -35,6 +40,19 @@ void SimpleCSRSource::CopyFrom(dmlc::Parser<uint32_t>* parser) {
|
||||
if (batch.weight != nullptr) {
|
||||
info.weights_.insert(info.weights_.end(), batch.weight, batch.weight + batch.size);
|
||||
}
|
||||
if (batch.qid != nullptr) {
|
||||
info.qids_.insert(info.qids_.end(), batch.qid, batch.qid + batch.size);
|
||||
// get group
|
||||
for (size_t i = 0; i < batch.size; ++i) {
|
||||
const uint64_t cur_group_id = batch.qid[i];
|
||||
if (last_group_id == default_max || last_group_id != cur_group_id) {
|
||||
info.group_ptr_.push_back(group_size);
|
||||
}
|
||||
last_group_id = cur_group_id;
|
||||
++group_size;
|
||||
}
|
||||
}
|
||||
|
||||
// Remove the assertion on batch.index, which can be null in the case that the data in this
|
||||
// batch is entirely sparse. Although it's true that this indicates a likely issue with the
|
||||
// user's data workflows, passing XGBoost entirely sparse data should not cause it to fail.
|
||||
@@ -56,7 +74,14 @@ void SimpleCSRSource::CopyFrom(dmlc::Parser<uint32_t>* parser) {
|
||||
page_.offset.push_back(page_.offset[top - 1] + batch.offset[i + 1] - batch.offset[0]);
|
||||
}
|
||||
}
|
||||
if (last_group_id != default_max) {
|
||||
if (group_size > info.group_ptr_.back()) {
|
||||
info.group_ptr_.push_back(group_size);
|
||||
}
|
||||
}
|
||||
this->info.num_nonzero_ = static_cast<uint64_t>(page_.data.size());
|
||||
// Either every row has query ID or none at all
|
||||
CHECK(info.qids_.empty() || info.qids_.size() == info.num_row_);
|
||||
}
|
||||
|
||||
void SimpleCSRSource::LoadBinary(dmlc::Stream* fi) {
|
||||
|
||||
@@ -122,6 +122,10 @@ void SparsePageSource::Create(dmlc::Parser<uint32_t>* src,
|
||||
constexpr double kStep = 4.0;
|
||||
size_t tick_expected = static_cast<double>(kStep);
|
||||
|
||||
const uint64_t default_max = std::numeric_limits<uint64_t>::max();
|
||||
uint64_t last_group_id = default_max;
|
||||
bst_uint group_size = 0;
|
||||
|
||||
while (src->Next()) {
|
||||
const dmlc::RowBlock<uint32_t>& batch = src->Value();
|
||||
if (batch.label != nullptr) {
|
||||
@@ -130,6 +134,18 @@ void SparsePageSource::Create(dmlc::Parser<uint32_t>* src,
|
||||
if (batch.weight != nullptr) {
|
||||
info.weights_.insert(info.weights_.end(), batch.weight, batch.weight + batch.size);
|
||||
}
|
||||
if (batch.qid != nullptr) {
|
||||
info.qids_.insert(info.qids_.end(), batch.qid, batch.qid + batch.size);
|
||||
// get group
|
||||
for (size_t i = 0; i < batch.size; ++i) {
|
||||
const uint64_t cur_group_id = batch.qid[i];
|
||||
if (last_group_id == default_max || last_group_id != cur_group_id) {
|
||||
info.group_ptr_.push_back(group_size);
|
||||
}
|
||||
last_group_id = cur_group_id;
|
||||
++group_size;
|
||||
}
|
||||
}
|
||||
info.num_row_ += batch.size;
|
||||
info.num_nonzero_ += batch.offset[batch.size] - batch.offset[0];
|
||||
for (size_t i = batch.offset[0]; i < batch.offset[batch.size]; ++i) {
|
||||
@@ -153,6 +169,11 @@ void SparsePageSource::Create(dmlc::Parser<uint32_t>* src,
|
||||
}
|
||||
}
|
||||
}
|
||||
if (last_group_id != default_max) {
|
||||
if (group_size > info.group_ptr_.back()) {
|
||||
info.group_ptr_.push_back(group_size);
|
||||
}
|
||||
}
|
||||
|
||||
if (page->data.size() != 0) {
|
||||
writer.PushWrite(std::move(page));
|
||||
@@ -162,6 +183,8 @@ void SparsePageSource::Create(dmlc::Parser<uint32_t>* src,
|
||||
dmlc::Stream::Create(name_info.c_str(), "w"));
|
||||
int tmagic = kMagic;
|
||||
fo->Write(&tmagic, sizeof(tmagic));
|
||||
// Either every row has query ID or none at all
|
||||
CHECK(info.qids_.empty() || info.qids_.size() == info.num_row_);
|
||||
info.SaveBinary(fo.get());
|
||||
}
|
||||
LOG(CONSOLE) << "SparsePageSource: Finished writing to " << name_info;
|
||||
|
||||
Reference in New Issue
Block a user