[PLUGIN] Add densify parser

This commit is contained in:
tqchen 2016-01-25 11:56:16 -08:00
parent 88e362732f
commit b27b51f60e
5 changed files with 90 additions and 2 deletions

@ -1 +1 @@
Subproject commit c66d2ab2d30f55303b65b5ed9dc1f9ee04260f7e
Subproject commit e0a18eb45cb9c6e7314dbd3328dda158e3a3486f

View File

@ -31,3 +31,4 @@ LIBJVM=$(JAVA_HOME)/jre/lib/amd64/server
#
XGB_PLUGINS += plugin/example/plugin.mk
XGB_PLUGINS += plugin/lz4/plugin.mk
XGB_PLUGINS += plugin/dense_libsvm/plugin.mk

View File

@ -0,0 +1,85 @@
/*!
* Copyright 2015 by Contributors
* \file dense_libsvm.cc
* \brief Plugin to load in libsvm, but fill all the missing entries with zeros.
* This plugin is mainly used for benchmark purposes and do not need to be included.
*/
#include <dmlc/data.h>
#include <memory>
namespace dmlc {
namespace data {
template<typename IndexType>
class DensifyParser : public dmlc::Parser<IndexType> {
public:
DensifyParser(dmlc::Parser<IndexType>* parser, uint32_t num_col)
: parser_(parser), num_col_(num_col) {
}
void BeforeFirst() override {
parser_->BeforeFirst();
}
bool Next() override {
if (!parser_->Next()) return false;
const RowBlock<IndexType>& batch = parser_->Value();
LOG(INFO) << batch.size;
dense_index_.resize(num_col_ * batch.size);
dense_value_.resize(num_col_ * batch.size);
std::fill(dense_value_.begin(), dense_value_.end(), 0.0f);
offset_.resize(batch.size + 1);
offset_[0] = 0;
for (size_t i = 0; i < batch.size; ++i) {
offset_[i + 1] = (i + 1) * num_col_;
Row<IndexType> row = batch[i];
for (uint32_t j = 0; j < num_col_; ++j) {
dense_index_[i * num_col_ + j] = j;
}
for (unsigned k = 0; k < row.length; ++k) {
uint32_t index = row.get_index(k);
CHECK_LT(index, num_col_)
<< "Featuere index larger than num_col";
dense_value_[i * num_col_ + index] = row.get_value(k);
}
}
out_ = batch;
out_.index = dmlc::BeginPtr(dense_index_);
out_.value = dmlc::BeginPtr(dense_value_);
out_.offset = dmlc::BeginPtr(offset_);
return true;
}
const dmlc::RowBlock<IndexType>& Value() const override {
return out_;
}
size_t BytesRead() const override {
return parser_->BytesRead();
}
private:
RowBlock<IndexType> out_;
std::unique_ptr<Parser<IndexType> > parser_;
uint32_t num_col_;
std::vector<size_t> offset_;
std::vector<IndexType> dense_index_;
std::vector<float> dense_value_;
};
template<typename IndexType>
Parser<IndexType> *
CreateDenseLibSVMParser(const std::string& path,
const std::map<std::string, std::string>& args,
unsigned part_index,
unsigned num_parts) {
CHECK_NE(args.count("num_col"), 0) << "expect num_col in dense_libsvm";
return new DensifyParser<IndexType>(
Parser<IndexType>::Create(path.c_str(), part_index, num_parts, "libsvm"),
uint32_t(atoi(args.at("num_col").c_str())));
}
} // namespace data
DMLC_REGISTER_DATA_PARSER(uint32_t, dense_libsvm, data::CreateDenseLibSVMParser<uint32_t>);
} // namespace dmlc

View File

@ -0,0 +1,2 @@
PLUGIN_OBJS += build_plugin/dense_parser/dense_libsvm.o
PLUGIN_LDFLAGS +=

View File

@ -181,7 +181,7 @@ DMatrix* DMatrix::Load(const std::string& uri,
std::string ftype = file_format;
if (file_format == "auto") ftype = "libsvm";
std::unique_ptr<dmlc::Parser<uint32_t> > parser(
dmlc::Parser<uint32_t>::Create(fname.c_str(), partid, npart, ftype.c_str()));
dmlc::Parser<uint32_t>::Create(fname.c_str(), partid, npart, file_format.c_str()));
DMatrix* dmat = DMatrix::Create(parser.get(), cache_file);
if (!silent) {
LOG(CONSOLE) << dmat->info().num_row << 'x' << dmat->info().num_col << " matrix with "