Add qid like ranklib format (#2749)
* add qid for https://github.com/dmlc/xgboost/issues/2748 * change names * change spaces * change qid to bst_uint type * change qid type to size_t * change qid first to SIZE_MAX * change qid type from size_t to uint64_t * update dmlc-core * fix qids name error * fix group_ptr_ error * Style fix * Add qid handling logic to SparsePage * New MetaInfo format + backward compatibility fix Old MetaInfo format (1.0) doesn't contain qid field. We still want to be able to read from MetaInfo files saved in old format. Also, define a new format (2.0) that contains the qid field. This way, we can distinguish files that contain qid and those that do not. * Update MetaInfo test * Simply group assignment logic * Explicitly set qid=nullptr in NativeDataIter NativeDataIter's callback does not support qid field. Users of NativeDataIter will need to call setGroup() function separately to set group information. * Save qids_ in SaveBinary() * Upgrade dmlc-core submodule * Add a test for reading qid * Add contributor * Check the size of qids_ * Document qid format
This commit is contained in:
parent
18813a26ab
commit
0cf88d036f
@ -76,3 +76,5 @@ List of Contributors
|
|||||||
* [Andy Adinets](https://github.com/canonizer)
|
* [Andy Adinets](https://github.com/canonizer)
|
||||||
* [Henry Gouk](https://github.com/henrygouk)
|
* [Henry Gouk](https://github.com/henrygouk)
|
||||||
* [Pierre de Sahb](https://github.com/pdesahb)
|
* [Pierre de Sahb](https://github.com/pdesahb)
|
||||||
|
* [liuliang01](https://github.com/liuliang01)
|
||||||
|
- liuliang01 added support for the qid column for LibSVM input format. This makes ranking task easier in distributed setting.
|
||||||
|
|||||||
@ -1 +1 @@
|
|||||||
Subproject commit 1c422a99786ba677db953a565807973154f6d374
|
Subproject commit 459ab734d15acd68fd437abf845c7c1730b5a38f
|
||||||
@ -5,6 +5,7 @@ on the demo dataset on a binary classification task.
|
|||||||
|
|
||||||
## Links to Helpful Other Resources
|
## Links to Helpful Other Resources
|
||||||
- See [Installation Guide](../build.md) on how to install xgboost.
|
- See [Installation Guide](../build.md) on how to install xgboost.
|
||||||
|
- See [Text Input Format](../input_format.md) on using text format for specifying training/testing data.
|
||||||
- See [How to pages](../how_to/index.md) on various tips on using xgboost.
|
- See [How to pages](../how_to/index.md) on various tips on using xgboost.
|
||||||
- See [Tutorials](../tutorials/index.md) on tutorials on specific tasks.
|
- See [Tutorials](../tutorials/index.md) on tutorials on specific tasks.
|
||||||
- See [Learning to use XGBoost by Examples](../../demo) for more code examples.
|
- See [Learning to use XGBoost by Examples](../../demo) for more code examples.
|
||||||
|
|||||||
@ -6,6 +6,7 @@ This page contains guidelines to use and develop XGBoost.
|
|||||||
- [How to Install XGBoost](../build.md)
|
- [How to Install XGBoost](../build.md)
|
||||||
|
|
||||||
## Use XGBoost in Specific Ways
|
## Use XGBoost in Specific Ways
|
||||||
|
- [Text input format](../input_format.md)
|
||||||
- [Parameter tuning guide](param_tuning.md)
|
- [Parameter tuning guide](param_tuning.md)
|
||||||
- [Use out of core computation for large dataset](external_memory.md)
|
- [Use out of core computation for large dataset](external_memory.md)
|
||||||
- [Use XGBoost GPU algorithms](../gpu/index.md)
|
- [Use XGBoost GPU algorithms](../gpu/index.md)
|
||||||
|
|||||||
@ -2,7 +2,9 @@ Text Input Format of DMatrix
|
|||||||
============================
|
============================
|
||||||
|
|
||||||
## Basic Input Format
|
## Basic Input Format
|
||||||
As we have mentioned, XGBoost takes LibSVM format. For training or predicting, XGBoost takes an instance file with the format as below:
|
XGBoost currently supports two text formats for ingesting data: LibSVM and CSV. The rest of this document will describe the LibSVM format. (See [here](https://en.wikipedia.org/wiki/Comma-separated_values) for a description of the CSV format.)
|
||||||
|
|
||||||
|
For training or predicting, XGBoost takes an instance file with the format as below:
|
||||||
|
|
||||||
train.txt
|
train.txt
|
||||||
```
|
```
|
||||||
@ -14,13 +16,12 @@ train.txt
|
|||||||
```
|
```
|
||||||
Each line represent a single instance, and in the first line '1' is the instance label,'101' and '102' are feature indices, '1.2' and '0.03' are feature values. In the binary classification case, '1' is used to indicate positive samples, and '0' is used to indicate negative samples. We also support probability values in [0,1] as label, to indicate the probability of the instance being positive.
|
Each line represent a single instance, and in the first line '1' is the instance label,'101' and '102' are feature indices, '1.2' and '0.03' are feature values. In the binary classification case, '1' is used to indicate positive samples, and '0' is used to indicate negative samples. We also support probability values in [0,1] as label, to indicate the probability of the instance being positive.
|
||||||
|
|
||||||
Additional Information
|
Auxiliary Files for Additional Information
|
||||||
----------------------
|
------------------------------------------
|
||||||
Note: these additional information are only applicable to single machine version of the package.
|
**Note: all information below is applicable only to single-node version of the package.** If you'd like to perform distributed training with multiple nodes, skip to the next section.
|
||||||
|
|
||||||
### Group Input Format
|
### Group Input Format
|
||||||
As XGBoost supports accomplishing [ranking task](../demo/rank), we support the group input format. In ranking task, instances are categorized into different groups in real world scenarios, for example, in the learning to rank web pages scenario, the web page instances are grouped by their queries. Except the instance file mentioned in the group input format, XGBoost need an file indicating the group information. For example, if the instance file is the "train.txt" shown above,
|
For [ranking task](../demo/rank), XGBoost supports the group input format. In ranking task, instances are categorized into *query groups* in real world scenarios. For example, in the learning to rank web pages scenario, the web page instances are grouped by their queries. XGBoost requires an file that indicates the group information. For example, if the instance file is the "train.txt" shown above, the group file should be named "train.txt.group" and be of the following format:
|
||||||
and the group file is as below:
|
|
||||||
|
|
||||||
train.txt.group
|
train.txt.group
|
||||||
```
|
```
|
||||||
@ -28,10 +29,10 @@ train.txt.group
|
|||||||
3
|
3
|
||||||
```
|
```
|
||||||
This means that, the data set contains 5 instances, and the first two instances are in a group and the other three are in another group. The numbers in the group file are actually indicating the number of instances in each group in the instance file in order.
|
This means that, the data set contains 5 instances, and the first two instances are in a group and the other three are in another group. The numbers in the group file are actually indicating the number of instances in each group in the instance file in order.
|
||||||
While configuration, you do not have to indicate the path of the group file. If the instance file name is "xxx", XGBoost will check whether there is a file named "xxx.group" in the same directory and decides whether to read the data as group input format.
|
At the time of configuration, you do not have to indicate the path of the group file. If the instance file name is "xxx", XGBoost will check whether there is a file named "xxx.group" in the same directory.
|
||||||
|
|
||||||
### Instance Weight File
|
### Instance Weight File
|
||||||
XGBoost supports providing each instance an weight to differentiate the importance of instances. For example, if we provide an instance weight file for the "train.txt" file in the example as below:
|
Instances in the training data may be assigned weights to differentiate relative importance among them. For example, if we provide an instance weight file for the "train.txt" file in the example as below:
|
||||||
|
|
||||||
train.txt.weight
|
train.txt.weight
|
||||||
```
|
```
|
||||||
@ -41,10 +42,12 @@ train.txt.weight
|
|||||||
1
|
1
|
||||||
0.5
|
0.5
|
||||||
```
|
```
|
||||||
It means that XGBoost will emphasize more on the first and fourth instance, that is to say positive instances while training.
|
It means that XGBoost will emphasize more on the first and fourth instance (i.e. the positive instances) while training.
|
||||||
The configuration is similar to configuring the group information. If the instance file name is "xxx", XGBoost will check whether there is a file named "xxx.weight" in the same directory and if there is, will use the weights while training models. Weights will be included into an "xxx.buffer" file that is created by XGBoost automatically. If you want to update the weights, you need to delete the "xxx.buffer" file prior to launching XGBoost.
|
The configuration is similar to configuring the group information. If the instance file name is "xxx", XGBoost will look for a file named "xxx.weight" in the same directory. If the file exists, the instance weights will be extracted and used at the time of training.
|
||||||
|
|
||||||
### Initial Margin file
|
NOTE. If you choose to save the training data as a binary buffer (using [save_binary()](http://xgboost.readthedocs.io/en/latest/python/python_api.html#xgboost.DMatrix.save_binary)), keep in mind that the resulting binary buffer file will include the instance weights. To update the weights, use [the set_weight() function](http://xgboost.readthedocs.io/en/latest/python/python_api.html#xgboost.DMatrix.set_weight).
|
||||||
|
|
||||||
|
### Initial Margin File
|
||||||
XGBoost supports providing each instance an initial margin prediction. For example, if we have a initial prediction using logistic regression for "train.txt" file, we can create the following file:
|
XGBoost supports providing each instance an initial margin prediction. For example, if we have a initial prediction using logistic regression for "train.txt" file, we can create the following file:
|
||||||
|
|
||||||
train.txt.base_margin
|
train.txt.base_margin
|
||||||
@ -54,3 +57,37 @@ train.txt.base_margin
|
|||||||
3.4
|
3.4
|
||||||
```
|
```
|
||||||
XGBoost will take these values as initial margin prediction and boost from that. An important note about base_margin is that it should be margin prediction before transformation, so if you are doing logistic loss, you will need to put in value before logistic transformation. If you are using XGBoost predictor, use pred_margin=1 to output margin values.
|
XGBoost will take these values as initial margin prediction and boost from that. An important note about base_margin is that it should be margin prediction before transformation, so if you are doing logistic loss, you will need to put in value before logistic transformation. If you are using XGBoost predictor, use pred_margin=1 to output margin values.
|
||||||
|
|
||||||
|
Embedding additional information inside LibSVM file
|
||||||
|
---------------------------------------------------
|
||||||
|
**This section is applicable to both single- and multiple-node settings.**
|
||||||
|
|
||||||
|
### Query ID Columns
|
||||||
|
This is most useful for [ranking task](../demo/rank), where the instances are grouped into query groups. You may embed query group ID for each instance in the LibSVM file by adding a token of form `qid:xx` in each row:
|
||||||
|
|
||||||
|
train.txt
|
||||||
|
```
|
||||||
|
1 qid:1 101:1.2 102:0.03
|
||||||
|
0 qid:1 1:2.1 10001:300 10002:400
|
||||||
|
0 qid:2 0:1.3 1:0.3
|
||||||
|
1 qid:2 0:0.01 1:0.3
|
||||||
|
0 qid:3 0:0.2 1:0.3
|
||||||
|
1 qid:3 3:-0.1 10:-0.3
|
||||||
|
0 qid:3 6:0.2 10:0.15
|
||||||
|
```
|
||||||
|
Keep in mind the following restrictions:
|
||||||
|
* It is not allowed to specify query ID's for some instances but not for others. Either every row is assigned query ID's or none at all.
|
||||||
|
* The rows have to be sorted in ascending order by the query IDs. So, for instance, you may not have one row having large query ID than any of the following rows.
|
||||||
|
|
||||||
|
### Instance weights
|
||||||
|
You may specify instance weights in the LibSVM file by appending each instance label with the corresponding weight in the form of `[label]:[weight]`, as shown by the following example:
|
||||||
|
|
||||||
|
train.txt
|
||||||
|
```
|
||||||
|
1:1.0 101:1.2 102:0.03
|
||||||
|
0:0.5 1:2.1 10001:300 10002:400
|
||||||
|
0:0.5 0:1.3 1:0.3
|
||||||
|
1:1.0 0:0.01 1:0.3
|
||||||
|
0:0.5 0:0.2 1:0.3
|
||||||
|
```
|
||||||
|
where the negative instances are assigned half weights compared to the positive instances.
|
||||||
|
|||||||
@ -53,6 +53,8 @@ class MetaInfo {
|
|||||||
std::vector<bst_uint> group_ptr_;
|
std::vector<bst_uint> group_ptr_;
|
||||||
/*! \brief weights of each instance, optional */
|
/*! \brief weights of each instance, optional */
|
||||||
std::vector<bst_float> weights_;
|
std::vector<bst_float> weights_;
|
||||||
|
/*! \brief session-id of each instance, optional */
|
||||||
|
std::vector<uint64_t> qids_;
|
||||||
/*!
|
/*!
|
||||||
* \brief initialized margins,
|
* \brief initialized margins,
|
||||||
* if specified, xgboost will start from this init margin
|
* if specified, xgboost will start from this init margin
|
||||||
@ -60,7 +62,9 @@ class MetaInfo {
|
|||||||
*/
|
*/
|
||||||
std::vector<bst_float> base_margin_;
|
std::vector<bst_float> base_margin_;
|
||||||
/*! \brief version flag, used to check version of this info */
|
/*! \brief version flag, used to check version of this info */
|
||||||
static const int kVersion = 1;
|
static const int kVersion = 2;
|
||||||
|
/*! \brief version that introduced qid field */
|
||||||
|
static const int kVersionQidAdded = 2;
|
||||||
/*! \brief default constructor */
|
/*! \brief default constructor */
|
||||||
MetaInfo() = default;
|
MetaInfo() = default;
|
||||||
/*!
|
/*!
|
||||||
@ -136,6 +140,9 @@ struct Entry {
|
|||||||
inline static bool CmpValue(const Entry& a, const Entry& b) {
|
inline static bool CmpValue(const Entry& a, const Entry& b) {
|
||||||
return a.fvalue < b.fvalue;
|
return a.fvalue < b.fvalue;
|
||||||
}
|
}
|
||||||
|
inline bool operator==(const Entry& other) const {
|
||||||
|
return (this->index == other.index && this->fvalue == other.fvalue);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
|
|||||||
@ -141,6 +141,8 @@ class NativeDataIter : public dmlc::Parser<uint32_t> {
|
|||||||
block_.offset = dmlc::BeginPtr(offset_);
|
block_.offset = dmlc::BeginPtr(offset_);
|
||||||
block_.label = dmlc::BeginPtr(label_);
|
block_.label = dmlc::BeginPtr(label_);
|
||||||
block_.weight = dmlc::BeginPtr(weight_);
|
block_.weight = dmlc::BeginPtr(weight_);
|
||||||
|
block_.qid = nullptr;
|
||||||
|
block_.field = nullptr;
|
||||||
block_.index = dmlc::BeginPtr(index_);
|
block_.index = dmlc::BeginPtr(index_);
|
||||||
block_.value = dmlc::BeginPtr(value_);
|
block_.value = dmlc::BeginPtr(value_);
|
||||||
bytes_read_ += offset_.size() * sizeof(size_t) +
|
bytes_read_ += offset_.size() * sizeof(size_t) +
|
||||||
|
|||||||
@ -28,6 +28,7 @@ void MetaInfo::Clear() {
|
|||||||
labels_.clear();
|
labels_.clear();
|
||||||
root_index_.clear();
|
root_index_.clear();
|
||||||
group_ptr_.clear();
|
group_ptr_.clear();
|
||||||
|
qids_.clear();
|
||||||
weights_.clear();
|
weights_.clear();
|
||||||
base_margin_.clear();
|
base_margin_.clear();
|
||||||
}
|
}
|
||||||
@ -40,6 +41,7 @@ void MetaInfo::SaveBinary(dmlc::Stream *fo) const {
|
|||||||
fo->Write(&num_nonzero_, sizeof(num_nonzero_));
|
fo->Write(&num_nonzero_, sizeof(num_nonzero_));
|
||||||
fo->Write(labels_);
|
fo->Write(labels_);
|
||||||
fo->Write(group_ptr_);
|
fo->Write(group_ptr_);
|
||||||
|
fo->Write(qids_);
|
||||||
fo->Write(weights_);
|
fo->Write(weights_);
|
||||||
fo->Write(root_index_);
|
fo->Write(root_index_);
|
||||||
fo->Write(base_margin_);
|
fo->Write(base_margin_);
|
||||||
@ -48,13 +50,18 @@ void MetaInfo::SaveBinary(dmlc::Stream *fo) const {
|
|||||||
void MetaInfo::LoadBinary(dmlc::Stream *fi) {
|
void MetaInfo::LoadBinary(dmlc::Stream *fi) {
|
||||||
int version;
|
int version;
|
||||||
CHECK(fi->Read(&version, sizeof(version)) == sizeof(version)) << "MetaInfo: invalid version";
|
CHECK(fi->Read(&version, sizeof(version)) == sizeof(version)) << "MetaInfo: invalid version";
|
||||||
CHECK_EQ(version, kVersion) << "MetaInfo: invalid format";
|
CHECK(version >= 1 && version <= kVersion) << "MetaInfo: unsupported file version";
|
||||||
CHECK(fi->Read(&num_row_, sizeof(num_row_)) == sizeof(num_row_)) << "MetaInfo: invalid format";
|
CHECK(fi->Read(&num_row_, sizeof(num_row_)) == sizeof(num_row_)) << "MetaInfo: invalid format";
|
||||||
CHECK(fi->Read(&num_col_, sizeof(num_col_)) == sizeof(num_col_)) << "MetaInfo: invalid format";
|
CHECK(fi->Read(&num_col_, sizeof(num_col_)) == sizeof(num_col_)) << "MetaInfo: invalid format";
|
||||||
CHECK(fi->Read(&num_nonzero_, sizeof(num_nonzero_)) == sizeof(num_nonzero_))
|
CHECK(fi->Read(&num_nonzero_, sizeof(num_nonzero_)) == sizeof(num_nonzero_))
|
||||||
<< "MetaInfo: invalid format";
|
<< "MetaInfo: invalid format";
|
||||||
CHECK(fi->Read(&labels_)) << "MetaInfo: invalid format";
|
CHECK(fi->Read(&labels_)) << "MetaInfo: invalid format";
|
||||||
CHECK(fi->Read(&group_ptr_)) << "MetaInfo: invalid format";
|
CHECK(fi->Read(&group_ptr_)) << "MetaInfo: invalid format";
|
||||||
|
if (version >= kVersionQidAdded) {
|
||||||
|
CHECK(fi->Read(&qids_)) << "MetaInfo: invalid format";
|
||||||
|
} else { // old format doesn't contain qid field
|
||||||
|
qids_.clear();
|
||||||
|
}
|
||||||
CHECK(fi->Read(&weights_)) << "MetaInfo: invalid format";
|
CHECK(fi->Read(&weights_)) << "MetaInfo: invalid format";
|
||||||
CHECK(fi->Read(&root_index_)) << "MetaInfo: invalid format";
|
CHECK(fi->Read(&root_index_)) << "MetaInfo: invalid format";
|
||||||
CHECK(fi->Read(&base_margin_)) << "MetaInfo: invalid format";
|
CHECK(fi->Read(&base_margin_)) << "MetaInfo: invalid format";
|
||||||
|
|||||||
@ -4,6 +4,7 @@
|
|||||||
*/
|
*/
|
||||||
#include <dmlc/base.h>
|
#include <dmlc/base.h>
|
||||||
#include <xgboost/logging.h>
|
#include <xgboost/logging.h>
|
||||||
|
#include <limits>
|
||||||
#include "./simple_csr_source.h"
|
#include "./simple_csr_source.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
@ -26,6 +27,10 @@ void SimpleCSRSource::CopyFrom(DMatrix* src) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void SimpleCSRSource::CopyFrom(dmlc::Parser<uint32_t>* parser) {
|
void SimpleCSRSource::CopyFrom(dmlc::Parser<uint32_t>* parser) {
|
||||||
|
// use qid to get group info
|
||||||
|
const uint64_t default_max = std::numeric_limits<uint64_t>::max();
|
||||||
|
uint64_t last_group_id = default_max;
|
||||||
|
bst_uint group_size = 0;
|
||||||
this->Clear();
|
this->Clear();
|
||||||
while (parser->Next()) {
|
while (parser->Next()) {
|
||||||
const dmlc::RowBlock<uint32_t>& batch = parser->Value();
|
const dmlc::RowBlock<uint32_t>& batch = parser->Value();
|
||||||
@ -35,6 +40,19 @@ void SimpleCSRSource::CopyFrom(dmlc::Parser<uint32_t>* parser) {
|
|||||||
if (batch.weight != nullptr) {
|
if (batch.weight != nullptr) {
|
||||||
info.weights_.insert(info.weights_.end(), batch.weight, batch.weight + batch.size);
|
info.weights_.insert(info.weights_.end(), batch.weight, batch.weight + batch.size);
|
||||||
}
|
}
|
||||||
|
if (batch.qid != nullptr) {
|
||||||
|
info.qids_.insert(info.qids_.end(), batch.qid, batch.qid + batch.size);
|
||||||
|
// get group
|
||||||
|
for (size_t i = 0; i < batch.size; ++i) {
|
||||||
|
const uint64_t cur_group_id = batch.qid[i];
|
||||||
|
if (last_group_id == default_max || last_group_id != cur_group_id) {
|
||||||
|
info.group_ptr_.push_back(group_size);
|
||||||
|
}
|
||||||
|
last_group_id = cur_group_id;
|
||||||
|
++group_size;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Remove the assertion on batch.index, which can be null in the case that the data in this
|
// Remove the assertion on batch.index, which can be null in the case that the data in this
|
||||||
// batch is entirely sparse. Although it's true that this indicates a likely issue with the
|
// batch is entirely sparse. Although it's true that this indicates a likely issue with the
|
||||||
// user's data workflows, passing XGBoost entirely sparse data should not cause it to fail.
|
// user's data workflows, passing XGBoost entirely sparse data should not cause it to fail.
|
||||||
@ -56,7 +74,14 @@ void SimpleCSRSource::CopyFrom(dmlc::Parser<uint32_t>* parser) {
|
|||||||
page_.offset.push_back(page_.offset[top - 1] + batch.offset[i + 1] - batch.offset[0]);
|
page_.offset.push_back(page_.offset[top - 1] + batch.offset[i + 1] - batch.offset[0]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (last_group_id != default_max) {
|
||||||
|
if (group_size > info.group_ptr_.back()) {
|
||||||
|
info.group_ptr_.push_back(group_size);
|
||||||
|
}
|
||||||
|
}
|
||||||
this->info.num_nonzero_ = static_cast<uint64_t>(page_.data.size());
|
this->info.num_nonzero_ = static_cast<uint64_t>(page_.data.size());
|
||||||
|
// Either every row has query ID or none at all
|
||||||
|
CHECK(info.qids_.empty() || info.qids_.size() == info.num_row_);
|
||||||
}
|
}
|
||||||
|
|
||||||
void SimpleCSRSource::LoadBinary(dmlc::Stream* fi) {
|
void SimpleCSRSource::LoadBinary(dmlc::Stream* fi) {
|
||||||
|
|||||||
@ -122,6 +122,10 @@ void SparsePageSource::Create(dmlc::Parser<uint32_t>* src,
|
|||||||
constexpr double kStep = 4.0;
|
constexpr double kStep = 4.0;
|
||||||
size_t tick_expected = static_cast<double>(kStep);
|
size_t tick_expected = static_cast<double>(kStep);
|
||||||
|
|
||||||
|
const uint64_t default_max = std::numeric_limits<uint64_t>::max();
|
||||||
|
uint64_t last_group_id = default_max;
|
||||||
|
bst_uint group_size = 0;
|
||||||
|
|
||||||
while (src->Next()) {
|
while (src->Next()) {
|
||||||
const dmlc::RowBlock<uint32_t>& batch = src->Value();
|
const dmlc::RowBlock<uint32_t>& batch = src->Value();
|
||||||
if (batch.label != nullptr) {
|
if (batch.label != nullptr) {
|
||||||
@ -130,6 +134,18 @@ void SparsePageSource::Create(dmlc::Parser<uint32_t>* src,
|
|||||||
if (batch.weight != nullptr) {
|
if (batch.weight != nullptr) {
|
||||||
info.weights_.insert(info.weights_.end(), batch.weight, batch.weight + batch.size);
|
info.weights_.insert(info.weights_.end(), batch.weight, batch.weight + batch.size);
|
||||||
}
|
}
|
||||||
|
if (batch.qid != nullptr) {
|
||||||
|
info.qids_.insert(info.qids_.end(), batch.qid, batch.qid + batch.size);
|
||||||
|
// get group
|
||||||
|
for (size_t i = 0; i < batch.size; ++i) {
|
||||||
|
const uint64_t cur_group_id = batch.qid[i];
|
||||||
|
if (last_group_id == default_max || last_group_id != cur_group_id) {
|
||||||
|
info.group_ptr_.push_back(group_size);
|
||||||
|
}
|
||||||
|
last_group_id = cur_group_id;
|
||||||
|
++group_size;
|
||||||
|
}
|
||||||
|
}
|
||||||
info.num_row_ += batch.size;
|
info.num_row_ += batch.size;
|
||||||
info.num_nonzero_ += batch.offset[batch.size] - batch.offset[0];
|
info.num_nonzero_ += batch.offset[batch.size] - batch.offset[0];
|
||||||
for (size_t i = batch.offset[0]; i < batch.offset[batch.size]; ++i) {
|
for (size_t i = batch.offset[0]; i < batch.offset[batch.size]; ++i) {
|
||||||
@ -153,6 +169,11 @@ void SparsePageSource::Create(dmlc::Parser<uint32_t>* src,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (last_group_id != default_max) {
|
||||||
|
if (group_size > info.group_ptr_.back()) {
|
||||||
|
info.group_ptr_.push_back(group_size);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (page->data.size() != 0) {
|
if (page->data.size() != 0) {
|
||||||
writer.PushWrite(std::move(page));
|
writer.PushWrite(std::move(page));
|
||||||
@ -162,6 +183,8 @@ void SparsePageSource::Create(dmlc::Parser<uint32_t>* src,
|
|||||||
dmlc::Stream::Create(name_info.c_str(), "w"));
|
dmlc::Stream::Create(name_info.c_str(), "w"));
|
||||||
int tmagic = kMagic;
|
int tmagic = kMagic;
|
||||||
fo->Write(&tmagic, sizeof(tmagic));
|
fo->Write(&tmagic, sizeof(tmagic));
|
||||||
|
// Either every row has query ID or none at all
|
||||||
|
CHECK(info.qids_.empty() || info.qids_.size() == info.num_row_);
|
||||||
info.SaveBinary(fo.get());
|
info.SaveBinary(fo.get());
|
||||||
}
|
}
|
||||||
LOG(CONSOLE) << "SparsePageSource: Finished writing to " << name_info;
|
LOG(CONSOLE) << "SparsePageSource: Finished writing to " << name_info;
|
||||||
|
|||||||
@ -1,5 +1,9 @@
|
|||||||
// Copyright by Contributors
|
// Copyright by Contributors
|
||||||
|
#include <dmlc/io.h>
|
||||||
#include <xgboost/data.h>
|
#include <xgboost/data.h>
|
||||||
|
#include <string>
|
||||||
|
#include <memory>
|
||||||
|
#include "../../../src/data/simple_csr_source.h"
|
||||||
|
|
||||||
#include "../helpers.h"
|
#include "../helpers.h"
|
||||||
|
|
||||||
@ -49,7 +53,7 @@ TEST(MetaInfo, SaveLoadBinary) {
|
|||||||
info.SaveBinary(fs);
|
info.SaveBinary(fs);
|
||||||
delete fs;
|
delete fs;
|
||||||
|
|
||||||
ASSERT_EQ(GetFileSize(tmp_file), 76)
|
ASSERT_EQ(GetFileSize(tmp_file), 84)
|
||||||
<< "Expected saved binary file size to be same as object size";
|
<< "Expected saved binary file size to be same as object size";
|
||||||
|
|
||||||
fs = dmlc::Stream::Create(tmp_file.c_str(), "r");
|
fs = dmlc::Stream::Create(tmp_file.c_str(), "r");
|
||||||
@ -61,3 +65,61 @@ TEST(MetaInfo, SaveLoadBinary) {
|
|||||||
|
|
||||||
std::remove(tmp_file.c_str());
|
std::remove(tmp_file.c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(MetaInfo, LoadQid) {
|
||||||
|
std::string tmp_file = TempFileName();
|
||||||
|
{
|
||||||
|
std::unique_ptr<dmlc::Stream> fs(
|
||||||
|
dmlc::Stream::Create(tmp_file.c_str(), "w"));
|
||||||
|
dmlc::ostream os(fs.get());
|
||||||
|
os << R"qid(3 qid:1 1:1 2:1 3:0 4:0.2 5:0
|
||||||
|
2 qid:1 1:0 2:0 3:1 4:0.1 5:1
|
||||||
|
1 qid:1 1:0 2:1 3:0 4:0.4 5:0
|
||||||
|
1 qid:1 1:0 2:0 3:1 4:0.3 5:0
|
||||||
|
1 qid:2 1:0 2:0 3:1 4:0.2 5:0
|
||||||
|
2 qid:2 1:1 2:0 3:1 4:0.4 5:0
|
||||||
|
1 qid:2 1:0 2:0 3:1 4:0.1 5:0
|
||||||
|
1 qid:2 1:0 2:0 3:1 4:0.2 5:0
|
||||||
|
2 qid:3 1:0 2:0 3:1 4:0.1 5:1
|
||||||
|
3 qid:3 1:1 2:1 3:0 4:0.3 5:0
|
||||||
|
4 qid:3 1:1 2:0 3:0 4:0.4 5:1
|
||||||
|
1 qid:3 1:0 2:1 3:1 4:0.5 5:0)qid";
|
||||||
|
os.set_stream(nullptr);
|
||||||
|
}
|
||||||
|
std::unique_ptr<xgboost::DMatrix> dmat(
|
||||||
|
xgboost::DMatrix::Load(tmp_file, true, false, "libsvm"));
|
||||||
|
std::remove(tmp_file.c_str());
|
||||||
|
|
||||||
|
const xgboost::MetaInfo& info = dmat->Info();
|
||||||
|
const std::vector<uint64_t> expected_qids{1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3};
|
||||||
|
const std::vector<xgboost::bst_uint> expected_group_ptr{0, 4, 8, 12};
|
||||||
|
CHECK(info.qids_ == expected_qids);
|
||||||
|
CHECK(info.group_ptr_ == expected_group_ptr);
|
||||||
|
CHECK_GE(info.kVersion, info.kVersionQidAdded);
|
||||||
|
|
||||||
|
const std::vector<size_t> expected_offset{
|
||||||
|
0, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 60
|
||||||
|
};
|
||||||
|
const std::vector<xgboost::Entry> expected_data{
|
||||||
|
{1, 1}, {2, 1}, {3, 0}, {4, 0.2}, {5, 0},
|
||||||
|
{1, 0}, {2, 0}, {3, 1}, {4, 0.1}, {5, 1},
|
||||||
|
{1, 0}, {2, 1}, {3, 0}, {4, 0.4}, {5, 0},
|
||||||
|
{1, 0}, {2, 0}, {3, 1}, {4, 0.3}, {5, 0},
|
||||||
|
{1, 0}, {2, 0}, {3, 1}, {4, 0.2}, {5, 0},
|
||||||
|
{1, 1}, {2, 0}, {3, 1}, {4, 0.4}, {5, 0},
|
||||||
|
{1, 0}, {2, 0}, {3, 1}, {4, 0.1}, {5, 0},
|
||||||
|
{1, 0}, {2, 0}, {3, 1}, {4, 0.2}, {5, 0},
|
||||||
|
{1, 0}, {2, 0}, {3, 1}, {4, 0.1}, {5, 1},
|
||||||
|
{1, 1}, {2, 1}, {3, 0}, {4, 0.3}, {5, 0},
|
||||||
|
{1, 1}, {2, 0}, {3, 0}, {4, 0.4}, {5, 1},
|
||||||
|
{1, 0}, {2, 1}, {3, 1}, {4, 0.5}, {5, 0}
|
||||||
|
};
|
||||||
|
dmlc::DataIter<xgboost::SparsePage>* iter = dmat->RowIterator();
|
||||||
|
iter->BeforeFirst();
|
||||||
|
CHECK(iter->Next());
|
||||||
|
const xgboost::SparsePage& batch = iter->Value();
|
||||||
|
CHECK_EQ(batch.base_rowid, 0);
|
||||||
|
CHECK(batch.offset == expected_offset);
|
||||||
|
CHECK(batch.data == expected_data);
|
||||||
|
CHECK(!iter->Next());
|
||||||
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user