From 46bcba7173b14202e3cee5c68fa718761ff6e9cc Mon Sep 17 00:00:00 2001 From: tqchen Date: Wed, 30 Dec 2015 03:28:25 -0800 Subject: [PATCH] [DATA] basic data refactor done, basic version of csr source. --- .gitignore | 1 - include/xgboost/base.h | 76 +------ include/xgboost/data.h | 120 ++++++++--- old_src/io/simple_dmatrix-inl.hpp | 324 ------------------------------ src/data/data.cc | 51 ++++- src/data/simple_csr_source.cc | 101 ++++++++++ src/data/simple_csr_source.h | 81 ++++++++ 7 files changed, 337 insertions(+), 417 deletions(-) delete mode 100644 old_src/io/simple_dmatrix-inl.hpp create mode 100644 src/data/simple_csr_source.cc create mode 100644 src/data/simple_csr_source.h diff --git a/.gitignore b/.gitignore index 8b378c254..62f3b823f 100644 --- a/.gitignore +++ b/.gitignore @@ -25,7 +25,6 @@ *group *rar *vali -*data *sdf Release *exe* diff --git a/include/xgboost/base.h b/include/xgboost/base.h index b283c93cb..b1d0e453a 100644 --- a/include/xgboost/base.h +++ b/include/xgboost/base.h @@ -16,78 +16,20 @@ namespace xgboost { typedef uint32_t bst_uint; /*! \brief float type, used for storing statistics */ typedef float bst_float; + const float rt_eps = 1e-5f; // min gap between feature values to allow a split happen const float rt_2eps = rt_eps * 2.0f; -/*! \brief read-only sparse instance batch in CSR format */ -struct SparseBatch { - /*! \brief an entry of sparse vector */ - struct Entry { - /*! \brief feature index */ - bst_uint index; - /*! \brief feature value */ - bst_float fvalue; - /*! \brief default constructor */ - Entry() {} - /*! - * \brief constructor with index and value - * \param index The feature or row index. - * \param fvalue THe feature value. - */ - Entry(bst_uint index, bst_float fvalue) : index(index), fvalue(fvalue) {} - /*! \brief reversely compare feature values */ - inline static bool CmpValue(const Entry &a, const Entry &b) { - return a.fvalue < b.fvalue; - } - }; - - /*! \brief an instance of sparse vector in the batch */ - struct Inst { - /*! \brief pointer to the elements*/ - const Entry *data; - /*! \brief length of the instance */ - bst_uint length; - /*! \brief constructor */ - Inst(const Entry *data, bst_uint length) : data(data), length(length) {} - /*! \brief get i-th pair in the sparse vector*/ - inline const Entry& operator[](size_t i) const { - return data[i]; - } - }; - - /*! \brief batch size */ - size_t size; -}; - -/*! \brief read-only row batch, used to access row continuously */ -struct RowBatch : public SparseBatch { - /*! \brief the offset of rowid of this batch */ - size_t base_rowid; - /*! \brief array[size+1], row pointer of each of the elements */ - const size_t *ind_ptr; - /*! \brief array[ind_ptr.back()], content of the sparse element */ - const Entry *data_ptr; - /*! \brief get i-th row from the batch */ - inline Inst operator[](size_t i) const { - return Inst(data_ptr + ind_ptr[i], static_cast(ind_ptr[i+1] - ind_ptr[i])); - } -}; - /*! - * \brief read-only column batch, used to access columns, - * the columns are not required to be continuous + * \brief define compatible keywords in g++ + * Used to support g++-4.6 and g++4.7 */ -struct ColBatch : public SparseBatch { - /*! \brief column index of each columns in the data */ - const bst_uint *col_index; - /*! \brief pointer to the column data */ - const Inst *col_data; - /*! \brief get i-th column from the batch */ - inline Inst operator[](size_t i) const { - return col_data[i]; - } -}; - +#if DMLC_USE_CXX11 && defined(__GNUC__) && !defined(__clang_version__) +#if __GNUC__ == 4 && __GNUC_MINOR__ < 8 +#define override +#define final +#endif +#endif } // namespace xgboost #endif // XGBOOST_BASE_H_ diff --git a/include/xgboost/data.h b/include/xgboost/data.h index 354563b5c..160f5bad6 100644 --- a/include/xgboost/data.h +++ b/include/xgboost/data.h @@ -1,7 +1,7 @@ /*! - * Copyright (c) 2014 by Contributors + * Copyright (c) 2015 by Contributors * \file data.h - * \brief The input data structure for gradient boosting. + * \brief The input data structure of xgboost. * \author Tianqi Chen */ #ifndef XGBOOST_DATA_H_ @@ -13,9 +13,6 @@ #include "./base.h" namespace xgboost { -// forward declare learner. -class Learner; - /*! \brief data type accepted by xgboost interface */ enum DataType { kFloat32 = 1, @@ -29,9 +26,11 @@ enum DataType { */ struct MetaInfo { /*! \brief number of rows in the data */ - size_t num_row; + uint64_t num_row; /*! \brief number of columns in the data */ - size_t num_col; + uint64_t num_col; + /*! \brief number of nonzero entries in the data */ + uint64_t num_nonzero; /*! \brief label of each instance */ std::vector labels; /*! @@ -53,7 +52,7 @@ struct MetaInfo { */ std::vector base_margin; /*! \brief version flag, used to check version of this info */ - static const int kVersion = 0; + static const int kVersion = 1; /*! \brief default constructor */ MetaInfo() : num_row(0), num_col(0) {} /*! @@ -78,12 +77,12 @@ struct MetaInfo { * \brief Load the Meta info from binary stream. * \param fi The input stream */ - void LoadBinary(dmlc::Stream *fi); + void LoadBinary(dmlc::Stream* fi); /*! * \brief Save the Meta info to binary stream * \param fo The output stream. */ - void SaveBinary(dmlc::Stream *fo) const; + void SaveBinary(dmlc::Stream* fo) const; /*! * \brief Set information in the meta info. * \param key The key of the information. @@ -102,36 +101,105 @@ struct MetaInfo { void GetInfo(const char* key, const void** dptr, DataType* dtype, size_t* num) const; }; +/*! \brief read-only sparse instance batch in CSR format */ +struct SparseBatch { + /*! \brief an entry of sparse vector */ + struct Entry { + /*! \brief feature index */ + bst_uint index; + /*! \brief feature value */ + bst_float fvalue; + /*! \brief default constructor */ + Entry() {} + /*! + * \brief constructor with index and value + * \param index The feature or row index. + * \param fvalue THe feature value. + */ + Entry(bst_uint index, bst_float fvalue) : index(index), fvalue(fvalue) {} + /*! \brief reversely compare feature values */ + inline static bool CmpValue(const Entry& a, const Entry& b) { + return a.fvalue < b.fvalue; + } + }; + + /*! \brief an instance of sparse vector in the batch */ + struct Inst { + /*! \brief pointer to the elements*/ + const Entry *data; + /*! \brief length of the instance */ + bst_uint length; + /*! \brief constructor */ + Inst(const Entry *data, bst_uint length) : data(data), length(length) {} + /*! \brief get i-th pair in the sparse vector*/ + inline const Entry& operator[](size_t i) const { + return data[i]; + } + }; + + /*! \brief batch size */ + size_t size; +}; + +/*! \brief read-only row batch, used to access row continuously */ +struct RowBatch : public SparseBatch { + /*! \brief the offset of rowid of this batch */ + size_t base_rowid; + /*! \brief array[size+1], row pointer of each of the elements */ + const size_t *ind_ptr; + /*! \brief array[ind_ptr.back()], content of the sparse element */ + const Entry *data_ptr; + /*! \brief get i-th row from the batch */ + inline Inst operator[](size_t i) const { + return Inst(data_ptr + ind_ptr[i], static_cast(ind_ptr[i + 1] - ind_ptr[i])); + } +}; + +/*! + * \brief read-only column batch, used to access columns, + * the columns are not required to be continuous + */ +struct ColBatch : public SparseBatch { + /*! \brief column index of each columns in the data */ + const bst_uint *col_index; + /*! \brief pointer to the column data */ + const Inst *col_data; + /*! \brief get i-th column from the batch */ + inline Inst operator[](size_t i) const { + return col_data[i]; + } +}; + /*! * \brief This is data structure that user can pass to DMatrix::Create * to create a DMatrix for training, user can create this data structure * for customized Data Loading on single machine. + * + * On distributed setting, usually an customized dmlc::Parser is needed instead. */ -struct DataSource { +class DataSource : public dmlc::DataIter { + public: /*! - * \brief Used to initialize the meta information of DMatrix - * The created DMatrix can change its own info later. + * \brief Meta information about the dataset + * The subclass need to be able to load this correctly from data. */ MetaInfo info; - /*! - * \brief Used for row based iteration of DMatrix, - */ - std::unique_ptr > row_iter; }; /*! * \brief Internal data structured used by XGBoost during training. * There are two ways to create a customized DMatrix that reads in user defined-format. * - * - Define a new dmlc::Parser and register by DMLC_REGISTER_DATA_PARSER; - * This works best for user defined data input source, such as data-base, filesystem. + * - Provide a dmlc::Parser and pass into the DMatrix::Create + * - Alternatively, if data can be represented by an URL, define a new dmlc::Parser and register by DMLC_REGISTER_DATA_PARSER; + * - This works best for user defined data input source, such as data-base, filesystem. * - Provdie a DataSource, that can be passed to DMatrix::Create * This can be used to re-use inmemory data structure into DMatrix. */ class DMatrix { public: - /*! \brief meta information that is always stored in DMatrix */ - MetaInfo info; + /*! \brief meta information of the dataset */ + virtual MetaInfo& info() = 0; /*! * \brief get the row iterator, reset to beginning position * \note Only either RowIterator or column Iterator can be active. @@ -163,12 +231,13 @@ class DMatrix { /*! \brief get column density */ virtual float GetColDensity(size_t cidx) const = 0; /*! \return reference of buffered rowset, in column access */ - virtual const std::vector &buffered_rowset() const = 0; + virtual const std::vector& buffered_rowset() const = 0; /*! \brief virtual destructor */ virtual ~DMatrix() {} /*! * \brief Save DMatrix to local file. * The saved file only works for non-sharded dataset(single machine training). + * This API is deprecated and dis-encouraged to use. * \param fname The file name to be saved. * \return The created DMatrix. */ @@ -191,7 +260,7 @@ class DMatrix { * This can be nullptr for common cases, and in-memory mode will be used. * \return a Created DMatrix. */ - static DMatrix* Create(DataSource&& source, + static DMatrix* Create(std::unique_ptr&& source, const char* cache_prefix=nullptr); /*! * \brief Create a DMatrix by loaidng data from parser. @@ -208,5 +277,10 @@ class DMatrix { static DMatrix* Create(dmlc::Parser* parser, const char* cache_prefix=nullptr); }; + } // namespace xgboost + +namespace dmlc { +DMLC_DECLARE_TRAITS(is_pod, xgboost::SparseBatch::Entry, true); +} #endif // XGBOOST_DATA_H_ diff --git a/old_src/io/simple_dmatrix-inl.hpp b/old_src/io/simple_dmatrix-inl.hpp deleted file mode 100644 index 063b01665..000000000 --- a/old_src/io/simple_dmatrix-inl.hpp +++ /dev/null @@ -1,324 +0,0 @@ -/*! - * Copyright 2014 by Contributors - * \file simple_dmatrix-inl.hpp - * \brief simple implementation of DMatrixS that can be used - * the data format of xgboost is templatized, which means it can accept - * any data structure that implements the function defined by FMatrix - * this file is a specific implementation of input data structure that can be used by BoostLearner - * \author Tianqi Chen - */ -#ifndef XGBOOST_IO_SIMPLE_DMATRIX_INL_HPP_ -#define XGBOOST_IO_SIMPLE_DMATRIX_INL_HPP_ - -#include -#include -#include -#include -#include -#include "../data.h" -#include "../utils/utils.h" -#include "../learner/dmatrix.h" -#include "./io.h" -#include "./simple_fmatrix-inl.hpp" -#include "../sync/sync.h" -#include "./libsvm_parser.h" - -namespace xgboost { -namespace io { -/*! \brief implementation of DataMatrix, in CSR format */ -class DMatrixSimple : public DataMatrix { - public: - // constructor - DMatrixSimple(void) : DataMatrix(kMagic) { - fmat_ = new FMatrixS(new OneBatchIter(this), this->info); - this->Clear(); - } - // virtual destructor - virtual ~DMatrixSimple(void) { - delete fmat_; - } - virtual IFMatrix *fmat(void) const { - return fmat_; - } - /*! \brief clear the storage */ - inline void Clear(void) { - row_ptr_.clear(); - row_ptr_.push_back(0); - row_data_.clear(); - info.Clear(); - } - /*! \brief copy content data from source matrix */ - inline void CopyFrom(const DataMatrix &src) { - this->Clear(); - this->info = src.info; - // clone data contents from src matrix - utils::IIterator *iter = src.fmat()->RowIterator(); - iter->BeforeFirst(); - while (iter->Next()) { - const RowBatch &batch = iter->Value(); - for (size_t i = 0; i < batch.size; ++i) { - RowBatch::Inst inst = batch[i]; - row_data_.resize(row_data_.size() + inst.length); - if (inst.length != 0) { - std::memcpy(&row_data_[row_ptr_.back()], inst.data, - sizeof(RowBatch::Entry) * inst.length); - } - row_ptr_.push_back(row_ptr_.back() + inst.length); - } - } - } - /*! - * \brief add a row to the matrix - * \param feats features - * \return the index of added row - */ - inline size_t AddRow(const std::vector &feats) { - for (size_t i = 0; i < feats.size(); ++i) { - row_data_.push_back(feats[i]); - info.info.num_col = std::max(info.info.num_col, - static_cast(feats[i].index+1)); - } - row_ptr_.push_back(row_ptr_.back() + feats.size()); - info.info.num_row += 1; - return row_ptr_.size() - 2; - } - /*! - * \brief load split of input, used in distributed mode - * \param uri the uri of input - * \param loadsplit whether loadsplit of data or all the data - * \param silent whether print information or not - */ - inline void LoadText(const char *uri, bool silent = false, bool loadsplit = false) { - int rank = 0, npart = 1; - if (loadsplit) { - rank = rabit::GetRank(); - npart = rabit::GetWorldSize(); - } - LibSVMParser parser( - dmlc::InputSplit::Create(uri, rank, npart, "text"), 16); - this->Clear(); - while (parser.Next()) { - const LibSVMPage &batch = parser.Value(); - size_t nlabel = info.labels.size(); - info.labels.resize(nlabel + batch.label.size()); - if (batch.label.size() != 0) { - std::memcpy(BeginPtr(info.labels) + nlabel, - BeginPtr(batch.label), - batch.label.size() * sizeof(float)); - } - size_t ndata = row_data_.size(); - row_data_.resize(ndata + batch.data.size()); - if (batch.data.size() != 0) { - std::memcpy(BeginPtr(row_data_) + ndata, - BeginPtr(batch.data), - batch.data.size() * sizeof(RowBatch::Entry)); - } - row_ptr_.resize(row_ptr_.size() + batch.label.size()); - for (size_t i = 0; i < batch.label.size(); ++i) { - row_ptr_[nlabel + i + 1] = row_ptr_[nlabel] + batch.offset[i + 1]; - } - info.info.num_row += batch.Size(); - for (size_t i = 0; i < batch.data.size(); ++i) { - info.info.num_col = std::max(info.info.num_col, - static_cast(batch.data[i].index+1)); - } - } - if (!silent) { - utils::Printf("%lux%lu matrix with %lu entries is loaded from %s\n", - static_cast(info.num_row()), // NOLINT(*) - static_cast(info.num_col()), // NOLINT(*) - static_cast(row_data_.size()), uri); // NOLINT(*) - } - // try to load in additional file - if (!loadsplit) { - std::string name = uri; - std::string gname = name + ".group"; - if (info.TryLoadGroup(gname.c_str(), silent)) { - utils::Check(info.group_ptr.back() == info.num_row(), - "DMatrix: group data does not match the number of rows in features"); - } - std::string wname = name + ".weight"; - if (info.TryLoadFloatInfo("weight", wname.c_str(), silent)) { - utils::Check(info.weights.size() == info.num_row(), - "DMatrix: weight data does not match the number of rows in features"); - } - std::string mname = name + ".base_margin"; - if (info.TryLoadFloatInfo("base_margin", mname.c_str(), silent)) { - } - } - } - /*! - * \brief load from binary file - * \param fname name of binary data - * \param silent whether print information or not - * \return whether loading is success - */ - inline bool LoadBinary(const char* fname, bool silent = false) { - std::FILE *fp = fopen64(fname, "rb"); - if (fp == NULL) return false; - utils::FileStream fs(fp); - this->LoadBinary(fs, silent, fname); - fs.Close(); - return true; - } - /*! - * \brief load from binary stream - * \param fs input file stream - * \param silent whether print information during loading - * \param fname file name, used to print message - */ - inline void LoadBinary(utils::IStream &fs, bool silent = false, const char *fname = NULL) { // NOLINT(*) - int tmagic; - utils::Check(fs.Read(&tmagic, sizeof(tmagic)) != 0, "invalid input file format"); - utils::Check(tmagic == kMagic, "\"%s\" invalid format, magic number mismatch", - fname == NULL ? "" : fname); - - info.LoadBinary(fs); - LoadBinary(fs, &row_ptr_, &row_data_); - fmat_->LoadColAccess(fs); - - if (!silent) { - utils::Printf("%lux%lu matrix with %lu entries is loaded", - static_cast(info.num_row()), // NOLINT(*) - static_cast(info.num_col()), // NOLINT(*) - static_cast(row_data_.size())); // NOLINT(*) - if (fname != NULL) { - utils::Printf(" from %s\n", fname); - } else { - utils::Printf("\n"); - } - if (info.group_ptr.size() != 0) { - utils::Printf("data contains %u groups\n", (unsigned)info.group_ptr.size()-1); - } - } - } - /*! - * \brief save to binary file - * \param fname name of binary data - * \param silent whether print information or not - */ - inline void SaveBinary(const char* fname, bool silent = false) const { - utils::FileStream fs(utils::FopenCheck(fname, "wb")); - int tmagic = kMagic; - fs.Write(&tmagic, sizeof(tmagic)); - info.SaveBinary(fs); - SaveBinary(fs, row_ptr_, row_data_); - fmat_->SaveColAccess(fs); - fs.Close(); - - if (!silent) { - utils::Printf("%lux%lu matrix with %lu entries is saved to %s\n", - static_cast(info.num_row()), // NOLINT(*) - static_cast(info.num_col()), // NOLINT(*) - static_cast(row_data_.size()), fname); // NOLINT(*) - if (info.group_ptr.size() != 0) { - utils::Printf("data contains %u groups\n", - static_cast(info.group_ptr.size()-1)); - } - } - } - /*! - * \brief cache load data given a file name, if filename ends with .buffer, direct load binary - * otherwise the function will first check if fname + '.buffer' exists, - * if binary buffer exists, it will reads from binary buffer, otherwise, it will load from text file, - * and try to create a buffer file - * \param fname name of binary data - * \param silent whether print information or not - * \param savebuffer whether do save binary buffer if it is text - */ - inline void CacheLoad(const char *fname, bool silent = false, bool savebuffer = true) { - using namespace std; - size_t len = strlen(fname); - if (len > 8 && !strcmp(fname + len - 7, ".buffer")) { - if (!this->LoadBinary(fname, silent)) { - utils::Error("can not open file \"%s\"", fname); - } - return; - } - char bname[1024]; - utils::SPrintf(bname, sizeof(bname), "%s.buffer", fname); - if (!this->LoadBinary(bname, silent)) { - this->LoadText(fname, silent); - if (savebuffer) this->SaveBinary(bname, silent); - } - } - // data fields - /*! \brief row pointer of CSR sparse storage */ - std::vector row_ptr_; - /*! \brief data in the row */ - std::vector row_data_; - /*! \brief the real fmatrix */ - FMatrixS *fmat_; - /*! \brief magic number used to identify DMatrix */ - static const int kMagic = 0xffffab01; - - protected: - /*! - * \brief save data to binary stream - * \param fo output stream - * \param ptr pointer data - * \param data data content - */ - inline static void SaveBinary(utils::IStream &fo, // NOLINT(*) - const std::vector &ptr, - const std::vector &data) { - size_t nrow = ptr.size() - 1; - fo.Write(&nrow, sizeof(size_t)); - fo.Write(BeginPtr(ptr), ptr.size() * sizeof(size_t)); - if (data.size() != 0) { - fo.Write(BeginPtr(data), data.size() * sizeof(RowBatch::Entry)); - } - } - /*! - * \brief load data from binary stream - * \param fi input stream - * \param out_ptr pointer data - * \param out_data data content - */ - inline static void LoadBinary(utils::IStream &fi, // NOLINT(*) - std::vector *out_ptr, - std::vector *out_data) { - size_t nrow; - utils::Check(fi.Read(&nrow, sizeof(size_t)) != 0, "invalid input file format"); - out_ptr->resize(nrow + 1); - utils::Check(fi.Read(BeginPtr(*out_ptr), out_ptr->size() * sizeof(size_t)) != 0, - "invalid input file format"); - out_data->resize(out_ptr->back()); - if (out_data->size() != 0) { - utils::Assert(fi.Read(BeginPtr(*out_data), out_data->size() * sizeof(RowBatch::Entry)) != 0, - "invalid input file format"); - } - } - // one batch iterator that return content in the matrix - struct OneBatchIter: utils::IIterator { - explicit OneBatchIter(DMatrixSimple *parent) - : at_first_(true), parent_(parent) {} - virtual ~OneBatchIter(void) {} - virtual void BeforeFirst(void) { - at_first_ = true; - } - virtual bool Next(void) { - if (!at_first_) return false; - at_first_ = false; - batch_.size = parent_->row_ptr_.size() - 1; - batch_.base_rowid = 0; - batch_.ind_ptr = BeginPtr(parent_->row_ptr_); - batch_.data_ptr = BeginPtr(parent_->row_data_); - return true; - } - virtual const RowBatch &Value(void) const { - return batch_; - } - - private: - // whether is at first - bool at_first_; - // pointer to parent - DMatrixSimple *parent_; - // temporal space for batch - RowBatch batch_; - }; -}; -} // namespace io -} // namespace xgboost -#endif // namespace XGBOOST_IO_SIMPLE_DMATRIX_INL_HPP_ diff --git a/src/data/data.cc b/src/data/data.cc index 290ef523b..d020ff0de 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -1,9 +1,14 @@ +/*! + * Copyright 2015 by Contributors + * \file data.cc + */ +#include #include namespace xgboost { // implementation of inline functions void MetaInfo::Clear() { - num_row = num_col = 0; + num_row = num_col = num_nonzero = 0; labels.clear(); root_index.clear(); group_ptr.clear(); @@ -16,6 +21,7 @@ void MetaInfo::SaveBinary(dmlc::Stream *fo) const { fo->Write(&version, sizeof(version)); fo->Write(&num_row, sizeof(num_row)); fo->Write(&num_col, sizeof(num_col)); + fo->Write(&num_nonzero, sizeof(num_nonzero)); fo->Write(labels); fo->Write(group_ptr); fo->Write(weights); @@ -25,14 +31,55 @@ void MetaInfo::SaveBinary(dmlc::Stream *fo) const { void MetaInfo::LoadBinary(dmlc::Stream *fi) { int version; - CHECK(fi->Read(&version, sizeof(version)) == sizeof(version)) << "MetaInfo: invalid format"; + CHECK(fi->Read(&version, sizeof(version)) == sizeof(version)) << "MetaInfo: invalid version"; CHECK_EQ(version, kVersion) << "MetaInfo: invalid format"; CHECK(fi->Read(&num_row, sizeof(num_row)) == sizeof(num_row)) << "MetaInfo: invalid format"; CHECK(fi->Read(&num_col, sizeof(num_col)) == sizeof(num_col)) << "MetaInfo: invalid format"; + CHECK(fi->Read(&num_nonzero, sizeof(num_nonzero)) == sizeof(num_nonzero)) << "MetaInfo: invalid format"; CHECK(fi->Read(&labels)) << "MetaInfo: invalid format"; CHECK(fi->Read(&group_ptr)) << "MetaInfo: invalid format"; CHECK(fi->Read(&weights)) << "MetaInfo: invalid format"; CHECK(fi->Read(&root_index)) << "MetaInfo: invalid format"; CHECK(fi->Read(&base_margin)) << "MetaInfo: invalid format"; } + +// macro to dispatch according to specified pointer types +#define DISPATCH_CONST_PTR(dtype, old_ptr, cast_ptr, proc) \ + switch(dtype) { \ + case kFloat32: { \ + const float* cast_ptr = reinterpret_cast(old_ptr); proc; break; \ + } \ + case kDouble: { \ + const double* cast_ptr = reinterpret_cast(old_ptr); proc; break; \ + } \ + case kUInt32: { \ + const uint32_t* cast_ptr = reinterpret_cast(old_ptr); proc; break; \ + } \ + case kUInt64: { \ + const uint64_t* cast_ptr = reinterpret_cast(old_ptr); proc; break; \ + } \ + default: LOG(FATAL) << "Unknown data type" << dtype; \ + } \ + + +void MetaInfo::SetInfo(const char* key, const void* dptr, DataType dtype, size_t num) { + if (!std::strcmp(key, "root_index")) { + root_index.resize(num); + DISPATCH_CONST_PTR(dtype, dptr, cast_dptr, + std::copy(cast_dptr, cast_dptr + num, root_index.begin())); + } else if (!std::strcmp(key, "label")) { + labels.resize(num); + DISPATCH_CONST_PTR(dtype, dptr, cast_dptr, + std::copy(cast_dptr, cast_dptr + num, labels.begin())); + } else if (!std::strcmp(key, "weight")) { + weights.resize(num); + DISPATCH_CONST_PTR(dtype, dptr, cast_dptr, + std::copy(cast_dptr, cast_dptr + num, weights.begin())); + } else if (!std::strcmp(key, "base_margin")) { + base_margin.resize(num); + DISPATCH_CONST_PTR(dtype, dptr, cast_dptr, + std::copy(cast_dptr, cast_dptr + num, base_margin.begin())); + } +} + } // namespace xgboost diff --git a/src/data/simple_csr_source.cc b/src/data/simple_csr_source.cc new file mode 100644 index 000000000..daad2c21d --- /dev/null +++ b/src/data/simple_csr_source.cc @@ -0,0 +1,101 @@ +/*! + * Copyright 2015 by Contributors + * \file simple_csr_source.cc + */ +#include +#include +#include "./simple_csr_source.h" + +namespace xgboost { +namespace data { + +void SimpleCSRSource::Clear() { + row_data_.clear(); + row_ptr_.resize(1); + row_ptr_[0] = 0; + this->info.Clear(); +} + +void SimpleCSRSource::CopyFrom(DMatrix* src) { + this->Clear(); + this->info = src->info(); + dmlc::DataIter* iter = src->RowIterator(); + iter->BeforeFirst(); + while (iter->Next()) { + const RowBatch &batch = iter->Value(); + for (size_t i = 0; i < batch.size; ++i) { + RowBatch::Inst inst = batch[i]; + row_data_.insert(row_data_.end(), inst.data, inst.data + inst.length); + row_ptr_.push_back(row_ptr_.back() + inst.length); + } + } +} + +void SimpleCSRSource::CopyFrom(dmlc::Parser* parser) { + this->Clear(); + while (parser->Next()) { + const dmlc::RowBlock& batch = parser->Value(); + if (batch.label != nullptr) { + info.labels.insert(info.labels.end(), batch.label, batch.label + batch.size); + } + if (batch.weight != nullptr) { + info.weights.insert(info.weights.end(), batch.weight, batch.weight + batch.size); + } + row_data_.reserve(row_data_.size() + batch.offset[batch.size] - batch.offset[0]); + CHECK(batch.index != nullptr); + // update information + this->info.num_row += batch.size; + // copy the data over + for (size_t i = batch.offset[0]; i < batch.offset[batch.size]; ++i) { + uint32_t index = batch.index[i]; + bst_float fvalue = batch.value == nullptr ? 1.0f : batch.value[i]; + row_data_.push_back(SparseBatch::Entry(index, fvalue)); + this->info.num_col = std::max(this->info.num_col, + static_cast(index + 1)); + } + size_t top = row_ptr_.size(); + row_ptr_.resize(top + batch.size); + for (size_t i = 0; i < batch.size; ++i) { + row_ptr_[top + i] = row_ptr_[top - 1] + batch.offset[i + 1] - batch.offset[0]; + } + } + this->info.num_nonzero = static_cast(row_data_.size()); +} + +void SimpleCSRSource::LoadBinary(dmlc::Stream* fi) { + int tmagic; + CHECK(fi->Read(&tmagic, sizeof(tmagic)) == sizeof(tmagic)) << "invalid input file format"; + CHECK_EQ(tmagic, kMagic) << "invalid format, magic number mismatch"; + info.LoadBinary(fi); + fi->Read(&row_ptr_); + fi->Read(&row_data_); +} + +void SimpleCSRSource::SaveBinary(dmlc::Stream* fo) const { + int tmagic = kMagic; + fo->Write(&tmagic, sizeof(tmagic)); + info.SaveBinary(fo); + fo->Write(row_ptr_); + fo->Write(row_data_); +} + +void SimpleCSRSource::BeforeFirst() { + at_first_ = false; +} + +bool SimpleCSRSource::Next() { + if (!at_first_) return false; + at_first_ = false; + batch_.size = row_ptr_.size() - 1; + batch_.base_rowid = 0; + batch_.ind_ptr = dmlc::BeginPtr(row_ptr_); + batch_.data_ptr = dmlc::BeginPtr(row_data_); + return true; +} + +const RowBatch& SimpleCSRSource::Value() const { + return batch_; +} + +} // namespace data +} // namespace xgboost diff --git a/src/data/simple_csr_source.h b/src/data/simple_csr_source.h new file mode 100644 index 000000000..3832ba852 --- /dev/null +++ b/src/data/simple_csr_source.h @@ -0,0 +1,81 @@ +/*! + * Copyright 2015 by Contributors + * \file simple_csr_source.h + * \brief The simplest form of data source, can be used to create DMatrix. + * This is an in-memory data structure that holds the data in row oriented format. + * \author Tianqi Chen + */ +#ifndef XGBOOST_DATA_SIMPLE_CSR_ROW_ITER_H_ +#define XGBOOST_DATA_SIMPLE_CSR_ROW_ITER_H_ + +#include +#include +#include +#include + +namespace xgboost { +/*! \brief namespace of internal data structures*/ +namespace data { +/*! + * \brief The simplest form of data holder, can be used to create DMatrix. + * This is an in-memory data structure that holds the data in row oriented format. + * \code + * std::unique_ptr source(new SimpleCSRSource()); + * // add data to source + * DMatrix* dmat = DMatrix::Create(std::move(source)); + * \encode + */ +class SimpleCSRSource : public DataSource { + public: + // public data members + // MetaInfo info; // inheritated from DataSource + /*! \brief row pointer of CSR sparse storage */ + std::vector row_ptr_; + /*! \brief data in the CSR sparse storage */ + std::vector row_data_; + // functions + /*! \brief default constructor */ + SimpleCSRSource() : row_ptr_(1, 0), at_first_(true) {} + /*! \brief destructor */ + virtual ~SimpleCSRSource() {} + /*! \brief clear the data structure */ + void Clear(); + /*! + * \brief copy content of data from src + * \param src source data iter. + */ + void CopyFrom(DMatrix* src); + /*! + * \brief copy content of data from parser, also set the additional information. + * \param src source data iter. + * \param info The additional information reflected in the parser. + */ + void CopyFrom(dmlc::Parser* src); + /*! + * \brief Load data from binary stream. + * \param fi the pointer to load data from. + */ + void LoadBinary(dmlc::Stream* fi); + /*! + * \brief Save data into binary stream + * \param fo The output stream. + */ + void SaveBinary(dmlc::Stream* fo) const; + // implement Next + bool Next() override; + // implement BeforeFirst + void BeforeFirst() override; + // implement Value + const RowBatch &Value() const override; + /*! \brief magic number used to identify SimpleCSRSource */ + static const int kMagic = 0xffffab01; + + private: + /*! \brief internal variable, used to support iterator interface */ + bool at_first_; + /*! \brief */ + RowBatch batch_; +}; +} // namespace data +} // namespace xgboost +#endif // XGBOOST_DATA_SIMPLE_CSR_ROW_ITER_H_