xgboost/src/data/file_iterator.h
2021-11-27 13:42:36 +08:00

115 lines
3.5 KiB
C++

/*!
* Copyright 2021 XGBoost contributors
*/
#ifndef XGBOOST_DATA_FILE_ITERATOR_H_
#define XGBOOST_DATA_FILE_ITERATOR_H_
#include <string>
#include <memory>
#include <vector>
#include <utility>
#include "dmlc/data.h"
#include "xgboost/c_api.h"
#include "xgboost/json.h"
#include "xgboost/linalg.h"
#include "array_interface.h"
namespace xgboost {
namespace data {
/**
* An iterator for implementing external memory support with file inputs. Users of
* external memory are encouraged to define their own file parsers/loaders so this one is
* just here for compatibility with old versions of XGBoost and CLI interface.
*/
class FileIterator {
// uri of input file, encodes parameters about whether it's 1-based index etc. dmlc
// parser will decode these information.
std::string uri_;
// Equals to rank_id in distributed training, used to split file into parts for each
// worker.
uint32_t part_idx_;
// Equals to total number of workers.
uint32_t n_parts_;
// Format of the input file, like "libsvm".
std::string type_;
DMatrixHandle proxy_;
std::unique_ptr<dmlc::Parser<uint32_t>> parser_;
// Temporary reference to stage the data.
dmlc::RowBlock<uint32_t, float> row_block_;
// Storage for the array interface strings.
std::string indptr_;
std::string values_;
std::string indices_;
public:
FileIterator(std::string uri, unsigned part_index, unsigned num_parts,
std::string type)
: uri_{std::move(uri)}, part_idx_{part_index}, n_parts_{num_parts},
type_{std::move(type)} {
XGProxyDMatrixCreate(&proxy_);
}
~FileIterator() {
XGDMatrixFree(proxy_);
}
int Next() {
CHECK(parser_);
if (parser_->Next()) {
row_block_ = parser_->Value();
using linalg::MakeVec;
indptr_ = ArrayInterfaceStr(MakeVec(row_block_.offset, row_block_.size + 1));
values_ = ArrayInterfaceStr(MakeVec(row_block_.value, row_block_.offset[row_block_.size]));
indices_ = ArrayInterfaceStr(MakeVec(row_block_.index, row_block_.offset[row_block_.size]));
size_t n_columns = *std::max_element(row_block_.index,
row_block_.index + row_block_.offset[row_block_.size]);
// dmlc parser converts 1-based indexing back to 0-based indexing so we can ignore
// this condition and just add 1 to n_columns
n_columns += 1;
XGProxyDMatrixSetDataCSR(proxy_, indptr_.c_str(), indices_.c_str(),
values_.c_str(), n_columns);
if (row_block_.label) {
XGDMatrixSetDenseInfo(proxy_, "label", row_block_.label, row_block_.size, 1);
}
if (row_block_.qid) {
XGDMatrixSetDenseInfo(proxy_, "qid", row_block_.qid, row_block_.size, 1);
}
if (row_block_.weight) {
XGDMatrixSetDenseInfo(proxy_, "weight", row_block_.weight, row_block_.size, 1);
}
// Continue iteration
return true;
} else {
// Stop iteration
return false;
}
}
auto Proxy() -> decltype(proxy_) { return proxy_; }
void Reset() {
CHECK(!type_.empty());
parser_.reset(dmlc::Parser<uint32_t>::Create(uri_.c_str(), part_idx_,
n_parts_, type_.c_str()));
}
};
namespace fileiter {
inline void Reset(DataIterHandle self) {
static_cast<FileIterator*>(self)->Reset();
}
inline int Next(DataIterHandle self) {
return static_cast<FileIterator*>(self)->Next();
}
} // namespace fileiter
} // namespace data
} // namespace xgboost
#endif // XGBOOST_DATA_FILE_ITERATOR_H_