Add qid like ranklib format (#2749)

* add qid for https://github.com/dmlc/xgboost/issues/2748

* change names

* change spaces

* change qid to bst_uint type

* change qid type to size_t

* change qid first to SIZE_MAX

* change qid type from size_t to uint64_t

* update dmlc-core

* fix qids name error

* fix group_ptr_ error

* Style fix

* Add qid handling logic to SparsePage

* New MetaInfo format + backward compatibility fix

Old MetaInfo format (1.0) doesn't contain qid field. We still want to be able
to read from MetaInfo files saved in old format. Also, define a new format
(2.0) that contains the qid field. This way, we can distinguish files that
contain qid and those that do not.

* Update MetaInfo test

* Simply group assignment logic

* Explicitly set qid=nullptr in NativeDataIter

NativeDataIter's callback does not support qid field. Users of NativeDataIter
will need to call setGroup() function separately to set group information.

* Save qids_ in SaveBinary()

* Upgrade dmlc-core submodule

* Add a test for reading qid

* Add contributor

* Check the size of qids_

* Document qid format
This commit is contained in:
liuliang01
2018-07-01 04:24:03 +08:00
committed by Philip Hyunsu Cho
parent 18813a26ab
commit 0cf88d036f
11 changed files with 182 additions and 15 deletions

View File

@@ -141,6 +141,8 @@ class NativeDataIter : public dmlc::Parser<uint32_t> {
block_.offset = dmlc::BeginPtr(offset_);
block_.label = dmlc::BeginPtr(label_);
block_.weight = dmlc::BeginPtr(weight_);
block_.qid = nullptr;
block_.field = nullptr;
block_.index = dmlc::BeginPtr(index_);
block_.value = dmlc::BeginPtr(value_);
bytes_read_ += offset_.size() * sizeof(size_t) +

View File

@@ -28,6 +28,7 @@ void MetaInfo::Clear() {
labels_.clear();
root_index_.clear();
group_ptr_.clear();
qids_.clear();
weights_.clear();
base_margin_.clear();
}
@@ -40,6 +41,7 @@ void MetaInfo::SaveBinary(dmlc::Stream *fo) const {
fo->Write(&num_nonzero_, sizeof(num_nonzero_));
fo->Write(labels_);
fo->Write(group_ptr_);
fo->Write(qids_);
fo->Write(weights_);
fo->Write(root_index_);
fo->Write(base_margin_);
@@ -48,13 +50,18 @@ void MetaInfo::SaveBinary(dmlc::Stream *fo) const {
void MetaInfo::LoadBinary(dmlc::Stream *fi) {
int version;
CHECK(fi->Read(&version, sizeof(version)) == sizeof(version)) << "MetaInfo: invalid version";
CHECK_EQ(version, kVersion) << "MetaInfo: invalid format";
CHECK(version >= 1 && version <= kVersion) << "MetaInfo: unsupported file version";
CHECK(fi->Read(&num_row_, sizeof(num_row_)) == sizeof(num_row_)) << "MetaInfo: invalid format";
CHECK(fi->Read(&num_col_, sizeof(num_col_)) == sizeof(num_col_)) << "MetaInfo: invalid format";
CHECK(fi->Read(&num_nonzero_, sizeof(num_nonzero_)) == sizeof(num_nonzero_))
<< "MetaInfo: invalid format";
CHECK(fi->Read(&labels_)) << "MetaInfo: invalid format";
CHECK(fi->Read(&group_ptr_)) << "MetaInfo: invalid format";
if (version >= kVersionQidAdded) {
CHECK(fi->Read(&qids_)) << "MetaInfo: invalid format";
} else { // old format doesn't contain qid field
qids_.clear();
}
CHECK(fi->Read(&weights_)) << "MetaInfo: invalid format";
CHECK(fi->Read(&root_index_)) << "MetaInfo: invalid format";
CHECK(fi->Read(&base_margin_)) << "MetaInfo: invalid format";

View File

@@ -4,6 +4,7 @@
*/
#include <dmlc/base.h>
#include <xgboost/logging.h>
#include <limits>
#include "./simple_csr_source.h"
namespace xgboost {
@@ -26,6 +27,10 @@ void SimpleCSRSource::CopyFrom(DMatrix* src) {
}
void SimpleCSRSource::CopyFrom(dmlc::Parser<uint32_t>* parser) {
// use qid to get group info
const uint64_t default_max = std::numeric_limits<uint64_t>::max();
uint64_t last_group_id = default_max;
bst_uint group_size = 0;
this->Clear();
while (parser->Next()) {
const dmlc::RowBlock<uint32_t>& batch = parser->Value();
@@ -35,6 +40,19 @@ void SimpleCSRSource::CopyFrom(dmlc::Parser<uint32_t>* parser) {
if (batch.weight != nullptr) {
info.weights_.insert(info.weights_.end(), batch.weight, batch.weight + batch.size);
}
if (batch.qid != nullptr) {
info.qids_.insert(info.qids_.end(), batch.qid, batch.qid + batch.size);
// get group
for (size_t i = 0; i < batch.size; ++i) {
const uint64_t cur_group_id = batch.qid[i];
if (last_group_id == default_max || last_group_id != cur_group_id) {
info.group_ptr_.push_back(group_size);
}
last_group_id = cur_group_id;
++group_size;
}
}
// Remove the assertion on batch.index, which can be null in the case that the data in this
// batch is entirely sparse. Although it's true that this indicates a likely issue with the
// user's data workflows, passing XGBoost entirely sparse data should not cause it to fail.
@@ -56,7 +74,14 @@ void SimpleCSRSource::CopyFrom(dmlc::Parser<uint32_t>* parser) {
page_.offset.push_back(page_.offset[top - 1] + batch.offset[i + 1] - batch.offset[0]);
}
}
if (last_group_id != default_max) {
if (group_size > info.group_ptr_.back()) {
info.group_ptr_.push_back(group_size);
}
}
this->info.num_nonzero_ = static_cast<uint64_t>(page_.data.size());
// Either every row has query ID or none at all
CHECK(info.qids_.empty() || info.qids_.size() == info.num_row_);
}
void SimpleCSRSource::LoadBinary(dmlc::Stream* fi) {

View File

@@ -122,6 +122,10 @@ void SparsePageSource::Create(dmlc::Parser<uint32_t>* src,
constexpr double kStep = 4.0;
size_t tick_expected = static_cast<double>(kStep);
const uint64_t default_max = std::numeric_limits<uint64_t>::max();
uint64_t last_group_id = default_max;
bst_uint group_size = 0;
while (src->Next()) {
const dmlc::RowBlock<uint32_t>& batch = src->Value();
if (batch.label != nullptr) {
@@ -130,6 +134,18 @@ void SparsePageSource::Create(dmlc::Parser<uint32_t>* src,
if (batch.weight != nullptr) {
info.weights_.insert(info.weights_.end(), batch.weight, batch.weight + batch.size);
}
if (batch.qid != nullptr) {
info.qids_.insert(info.qids_.end(), batch.qid, batch.qid + batch.size);
// get group
for (size_t i = 0; i < batch.size; ++i) {
const uint64_t cur_group_id = batch.qid[i];
if (last_group_id == default_max || last_group_id != cur_group_id) {
info.group_ptr_.push_back(group_size);
}
last_group_id = cur_group_id;
++group_size;
}
}
info.num_row_ += batch.size;
info.num_nonzero_ += batch.offset[batch.size] - batch.offset[0];
for (size_t i = batch.offset[0]; i < batch.offset[batch.size]; ++i) {
@@ -153,6 +169,11 @@ void SparsePageSource::Create(dmlc::Parser<uint32_t>* src,
}
}
}
if (last_group_id != default_max) {
if (group_size > info.group_ptr_.back()) {
info.group_ptr_.push_back(group_size);
}
}
if (page->data.size() != 0) {
writer.PushWrite(std::move(page));
@@ -162,6 +183,8 @@ void SparsePageSource::Create(dmlc::Parser<uint32_t>* src,
dmlc::Stream::Create(name_info.c_str(), "w"));
int tmagic = kMagic;
fo->Write(&tmagic, sizeof(tmagic));
// Either every row has query ID or none at all
CHECK(info.qids_.empty() || info.qids_.size() == info.num_row_);
info.SaveBinary(fo.get());
}
LOG(CONSOLE) << "SparsePageSource: Finished writing to " << name_info;