[DATA] basic data refactor done, basic version of csr source.
This commit is contained in:
@@ -1,9 +1,14 @@
|
||||
/*!
|
||||
* Copyright 2015 by Contributors
|
||||
* \file data.cc
|
||||
*/
|
||||
#include <cstring>
|
||||
#include <xgboost/data.h>
|
||||
|
||||
namespace xgboost {
|
||||
// implementation of inline functions
|
||||
void MetaInfo::Clear() {
|
||||
num_row = num_col = 0;
|
||||
num_row = num_col = num_nonzero = 0;
|
||||
labels.clear();
|
||||
root_index.clear();
|
||||
group_ptr.clear();
|
||||
@@ -16,6 +21,7 @@ void MetaInfo::SaveBinary(dmlc::Stream *fo) const {
|
||||
fo->Write(&version, sizeof(version));
|
||||
fo->Write(&num_row, sizeof(num_row));
|
||||
fo->Write(&num_col, sizeof(num_col));
|
||||
fo->Write(&num_nonzero, sizeof(num_nonzero));
|
||||
fo->Write(labels);
|
||||
fo->Write(group_ptr);
|
||||
fo->Write(weights);
|
||||
@@ -25,14 +31,55 @@ void MetaInfo::SaveBinary(dmlc::Stream *fo) const {
|
||||
|
||||
void MetaInfo::LoadBinary(dmlc::Stream *fi) {
|
||||
int version;
|
||||
CHECK(fi->Read(&version, sizeof(version)) == sizeof(version)) << "MetaInfo: invalid format";
|
||||
CHECK(fi->Read(&version, sizeof(version)) == sizeof(version)) << "MetaInfo: invalid version";
|
||||
CHECK_EQ(version, kVersion) << "MetaInfo: invalid format";
|
||||
CHECK(fi->Read(&num_row, sizeof(num_row)) == sizeof(num_row)) << "MetaInfo: invalid format";
|
||||
CHECK(fi->Read(&num_col, sizeof(num_col)) == sizeof(num_col)) << "MetaInfo: invalid format";
|
||||
CHECK(fi->Read(&num_nonzero, sizeof(num_nonzero)) == sizeof(num_nonzero)) << "MetaInfo: invalid format";
|
||||
CHECK(fi->Read(&labels)) << "MetaInfo: invalid format";
|
||||
CHECK(fi->Read(&group_ptr)) << "MetaInfo: invalid format";
|
||||
CHECK(fi->Read(&weights)) << "MetaInfo: invalid format";
|
||||
CHECK(fi->Read(&root_index)) << "MetaInfo: invalid format";
|
||||
CHECK(fi->Read(&base_margin)) << "MetaInfo: invalid format";
|
||||
}
|
||||
|
||||
// macro to dispatch according to specified pointer types
|
||||
#define DISPATCH_CONST_PTR(dtype, old_ptr, cast_ptr, proc) \
|
||||
switch(dtype) { \
|
||||
case kFloat32: { \
|
||||
const float* cast_ptr = reinterpret_cast<const float*>(old_ptr); proc; break; \
|
||||
} \
|
||||
case kDouble: { \
|
||||
const double* cast_ptr = reinterpret_cast<const double*>(old_ptr); proc; break; \
|
||||
} \
|
||||
case kUInt32: { \
|
||||
const uint32_t* cast_ptr = reinterpret_cast<const uint32_t*>(old_ptr); proc; break; \
|
||||
} \
|
||||
case kUInt64: { \
|
||||
const uint64_t* cast_ptr = reinterpret_cast<const uint64_t*>(old_ptr); proc; break; \
|
||||
} \
|
||||
default: LOG(FATAL) << "Unknown data type" << dtype; \
|
||||
} \
|
||||
|
||||
|
||||
void MetaInfo::SetInfo(const char* key, const void* dptr, DataType dtype, size_t num) {
|
||||
if (!std::strcmp(key, "root_index")) {
|
||||
root_index.resize(num);
|
||||
DISPATCH_CONST_PTR(dtype, dptr, cast_dptr,
|
||||
std::copy(cast_dptr, cast_dptr + num, root_index.begin()));
|
||||
} else if (!std::strcmp(key, "label")) {
|
||||
labels.resize(num);
|
||||
DISPATCH_CONST_PTR(dtype, dptr, cast_dptr,
|
||||
std::copy(cast_dptr, cast_dptr + num, labels.begin()));
|
||||
} else if (!std::strcmp(key, "weight")) {
|
||||
weights.resize(num);
|
||||
DISPATCH_CONST_PTR(dtype, dptr, cast_dptr,
|
||||
std::copy(cast_dptr, cast_dptr + num, weights.begin()));
|
||||
} else if (!std::strcmp(key, "base_margin")) {
|
||||
base_margin.resize(num);
|
||||
DISPATCH_CONST_PTR(dtype, dptr, cast_dptr,
|
||||
std::copy(cast_dptr, cast_dptr + num, base_margin.begin()));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace xgboost
|
||||
|
||||
101
src/data/simple_csr_source.cc
Normal file
101
src/data/simple_csr_source.cc
Normal file
@@ -0,0 +1,101 @@
|
||||
/*!
|
||||
* Copyright 2015 by Contributors
|
||||
* \file simple_csr_source.cc
|
||||
*/
|
||||
#include <dmlc/base.h>
|
||||
#include <dmlc/logging.h>
|
||||
#include "./simple_csr_source.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace data {
|
||||
|
||||
void SimpleCSRSource::Clear() {
|
||||
row_data_.clear();
|
||||
row_ptr_.resize(1);
|
||||
row_ptr_[0] = 0;
|
||||
this->info.Clear();
|
||||
}
|
||||
|
||||
void SimpleCSRSource::CopyFrom(DMatrix* src) {
|
||||
this->Clear();
|
||||
this->info = src->info();
|
||||
dmlc::DataIter<RowBatch>* iter = src->RowIterator();
|
||||
iter->BeforeFirst();
|
||||
while (iter->Next()) {
|
||||
const RowBatch &batch = iter->Value();
|
||||
for (size_t i = 0; i < batch.size; ++i) {
|
||||
RowBatch::Inst inst = batch[i];
|
||||
row_data_.insert(row_data_.end(), inst.data, inst.data + inst.length);
|
||||
row_ptr_.push_back(row_ptr_.back() + inst.length);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void SimpleCSRSource::CopyFrom(dmlc::Parser<uint32_t>* parser) {
|
||||
this->Clear();
|
||||
while (parser->Next()) {
|
||||
const dmlc::RowBlock<uint32_t>& batch = parser->Value();
|
||||
if (batch.label != nullptr) {
|
||||
info.labels.insert(info.labels.end(), batch.label, batch.label + batch.size);
|
||||
}
|
||||
if (batch.weight != nullptr) {
|
||||
info.weights.insert(info.weights.end(), batch.weight, batch.weight + batch.size);
|
||||
}
|
||||
row_data_.reserve(row_data_.size() + batch.offset[batch.size] - batch.offset[0]);
|
||||
CHECK(batch.index != nullptr);
|
||||
// update information
|
||||
this->info.num_row += batch.size;
|
||||
// copy the data over
|
||||
for (size_t i = batch.offset[0]; i < batch.offset[batch.size]; ++i) {
|
||||
uint32_t index = batch.index[i];
|
||||
bst_float fvalue = batch.value == nullptr ? 1.0f : batch.value[i];
|
||||
row_data_.push_back(SparseBatch::Entry(index, fvalue));
|
||||
this->info.num_col = std::max(this->info.num_col,
|
||||
static_cast<size_t>(index + 1));
|
||||
}
|
||||
size_t top = row_ptr_.size();
|
||||
row_ptr_.resize(top + batch.size);
|
||||
for (size_t i = 0; i < batch.size; ++i) {
|
||||
row_ptr_[top + i] = row_ptr_[top - 1] + batch.offset[i + 1] - batch.offset[0];
|
||||
}
|
||||
}
|
||||
this->info.num_nonzero = static_cast<uint64_t>(row_data_.size());
|
||||
}
|
||||
|
||||
void SimpleCSRSource::LoadBinary(dmlc::Stream* fi) {
|
||||
int tmagic;
|
||||
CHECK(fi->Read(&tmagic, sizeof(tmagic)) == sizeof(tmagic)) << "invalid input file format";
|
||||
CHECK_EQ(tmagic, kMagic) << "invalid format, magic number mismatch";
|
||||
info.LoadBinary(fi);
|
||||
fi->Read(&row_ptr_);
|
||||
fi->Read(&row_data_);
|
||||
}
|
||||
|
||||
void SimpleCSRSource::SaveBinary(dmlc::Stream* fo) const {
|
||||
int tmagic = kMagic;
|
||||
fo->Write(&tmagic, sizeof(tmagic));
|
||||
info.SaveBinary(fo);
|
||||
fo->Write(row_ptr_);
|
||||
fo->Write(row_data_);
|
||||
}
|
||||
|
||||
void SimpleCSRSource::BeforeFirst() {
|
||||
at_first_ = false;
|
||||
}
|
||||
|
||||
bool SimpleCSRSource::Next() {
|
||||
if (!at_first_) return false;
|
||||
at_first_ = false;
|
||||
batch_.size = row_ptr_.size() - 1;
|
||||
batch_.base_rowid = 0;
|
||||
batch_.ind_ptr = dmlc::BeginPtr(row_ptr_);
|
||||
batch_.data_ptr = dmlc::BeginPtr(row_data_);
|
||||
return true;
|
||||
}
|
||||
|
||||
const RowBatch& SimpleCSRSource::Value() const {
|
||||
return batch_;
|
||||
}
|
||||
|
||||
} // namespace data
|
||||
} // namespace xgboost
|
||||
81
src/data/simple_csr_source.h
Normal file
81
src/data/simple_csr_source.h
Normal file
@@ -0,0 +1,81 @@
|
||||
/*!
|
||||
* Copyright 2015 by Contributors
|
||||
* \file simple_csr_source.h
|
||||
* \brief The simplest form of data source, can be used to create DMatrix.
|
||||
* This is an in-memory data structure that holds the data in row oriented format.
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
#ifndef XGBOOST_DATA_SIMPLE_CSR_ROW_ITER_H_
|
||||
#define XGBOOST_DATA_SIMPLE_CSR_ROW_ITER_H_
|
||||
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <xgboost/base.h>
|
||||
#include <xgboost/data.h>
|
||||
|
||||
namespace xgboost {
|
||||
/*! \brief namespace of internal data structures*/
|
||||
namespace data {
|
||||
/*!
|
||||
* \brief The simplest form of data holder, can be used to create DMatrix.
|
||||
* This is an in-memory data structure that holds the data in row oriented format.
|
||||
* \code
|
||||
* std::unique_ptr<DataSource> source(new SimpleCSRSource());
|
||||
* // add data to source
|
||||
* DMatrix* dmat = DMatrix::Create(std::move(source));
|
||||
* \encode
|
||||
*/
|
||||
class SimpleCSRSource : public DataSource {
|
||||
public:
|
||||
// public data members
|
||||
// MetaInfo info; // inheritated from DataSource
|
||||
/*! \brief row pointer of CSR sparse storage */
|
||||
std::vector<size_t> row_ptr_;
|
||||
/*! \brief data in the CSR sparse storage */
|
||||
std::vector<RowBatch::Entry> row_data_;
|
||||
// functions
|
||||
/*! \brief default constructor */
|
||||
SimpleCSRSource() : row_ptr_(1, 0), at_first_(true) {}
|
||||
/*! \brief destructor */
|
||||
virtual ~SimpleCSRSource() {}
|
||||
/*! \brief clear the data structure */
|
||||
void Clear();
|
||||
/*!
|
||||
* \brief copy content of data from src
|
||||
* \param src source data iter.
|
||||
*/
|
||||
void CopyFrom(DMatrix* src);
|
||||
/*!
|
||||
* \brief copy content of data from parser, also set the additional information.
|
||||
* \param src source data iter.
|
||||
* \param info The additional information reflected in the parser.
|
||||
*/
|
||||
void CopyFrom(dmlc::Parser<uint32_t>* src);
|
||||
/*!
|
||||
* \brief Load data from binary stream.
|
||||
* \param fi the pointer to load data from.
|
||||
*/
|
||||
void LoadBinary(dmlc::Stream* fi);
|
||||
/*!
|
||||
* \brief Save data into binary stream
|
||||
* \param fo The output stream.
|
||||
*/
|
||||
void SaveBinary(dmlc::Stream* fo) const;
|
||||
// implement Next
|
||||
bool Next() override;
|
||||
// implement BeforeFirst
|
||||
void BeforeFirst() override;
|
||||
// implement Value
|
||||
const RowBatch &Value() const override;
|
||||
/*! \brief magic number used to identify SimpleCSRSource */
|
||||
static const int kMagic = 0xffffab01;
|
||||
|
||||
private:
|
||||
/*! \brief internal variable, used to support iterator interface */
|
||||
bool at_first_;
|
||||
/*! \brief */
|
||||
RowBatch batch_;
|
||||
};
|
||||
} // namespace data
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_DATA_SIMPLE_CSR_ROW_ITER_H_
|
||||
Reference in New Issue
Block a user