From 7874c2559b5df196c51233419aeec781c3ee5687 Mon Sep 17 00:00:00 2001 From: tqchen Date: Sun, 24 Aug 2014 17:25:17 -0700 Subject: [PATCH] add changes --- src/data.h | 7 +++++ src/gbm/gbtree-inl.hpp | 2 +- src/io/simple_dmatrix-inl.hpp | 14 +++++----- src/learner/dmatrix.h | 49 +++++++++++++++++++++++------------ src/learner/learner-inl.hpp | 8 +++--- 5 files changed, 52 insertions(+), 28 deletions(-) diff --git a/src/data.h b/src/data.h index 6f8297311..f28bec056 100644 --- a/src/data.h +++ b/src/data.h @@ -44,6 +44,10 @@ struct bst_gpair { * these information are not necessarily presented, and can be empty */ struct BoosterInfo { + /*! \brief number of rows in the data */ + size_t num_row; + /*! \brief number of columns in the data */ + size_t num_col; /*! * \brief specified root index of each instance, * can be used for multi task setting @@ -51,6 +55,9 @@ struct BoosterInfo { std::vector root_index; /*! \brief set fold indicator */ std::vector fold_index; + /*! \brief number of rows, number of columns */ + BoosterInfo(void) : num_row(0), num_col(0) { + } /*! \brief get root of ith instance */ inline unsigned GetRoot(size_t i) const { return root_index.size() == 0 ? 0 : root_index[i]; diff --git a/src/gbm/gbtree-inl.hpp b/src/gbm/gbtree-inl.hpp index 0e001a4e8..3d5ca4492 100644 --- a/src/gbm/gbtree-inl.hpp +++ b/src/gbm/gbtree-inl.hpp @@ -135,7 +135,7 @@ class GBTree : public IGradBooster { const int tid = omp_get_thread_num(); tree::RegTree::FVec &feats = thread_temp[tid]; const size_t ridx = batch.base_rowid + i; - const unsigned root_idx = info.GetRoot(i); + const unsigned root_idx = info.GetRoot(ridx); // loop over output groups for (int gid = 0; gid < mparam.num_output_group; ++gid) { preds[ridx * mparam.num_output_group + gid] = diff --git a/src/io/simple_dmatrix-inl.hpp b/src/io/simple_dmatrix-inl.hpp index c0b98b789..6ceeb3714 100644 --- a/src/io/simple_dmatrix-inl.hpp +++ b/src/io/simple_dmatrix-inl.hpp @@ -62,10 +62,10 @@ class DMatrixSimple : public DataMatrix { inline size_t AddRow(const std::vector &feats) { for (size_t i = 0; i < feats.size(); ++i) { row_data_.push_back(feats[i]); - info.num_col = std::max(info.num_col, static_cast(feats[i].findex+1)); + info.info.num_col = std::max(info.info.num_col, static_cast(feats[i].findex+1)); } row_ptr_.push_back(row_ptr_.back() + feats.size()); - info.num_row += 1; + info.info.num_row += 1; return row_ptr_.size() - 2; } /*! @@ -99,19 +99,19 @@ class DMatrixSimple : public DataMatrix { if (!silent) { printf("%lux%lu matrix with %lu entries is loaded from %s\n", - info.num_row, info.num_col, row_data_.size(), fname); + info.num_row(), info.num_col(), row_data_.size(), fname); } fclose(file); // try to load in additional file std::string name = fname; std::string gname = name + ".group"; if (info.TryLoadGroup(gname.c_str(), silent)) { - utils::Check(info.group_ptr.back() == info.num_row, + utils::Check(info.group_ptr.back() == info.num_row(), "DMatrix: group data does not match the number of rows in features"); } std::string wname = name + ".weight"; if (info.TryLoadFloatInfo("weight", wname.c_str(), silent)) { - utils::Check(info.weights.size() == info.num_row, + utils::Check(info.weights.size() == info.num_row(), "DMatrix: weight data does not match the number of rows in features"); } std::string mname = name + ".base_margin"; @@ -139,7 +139,7 @@ class DMatrixSimple : public DataMatrix { if (!silent) { printf("%lux%lu matrix with %lu entries is loaded from %s\n", - info.num_row, info.num_col, row_data_.size(), fname); + info.num_row(), info.num_col(), row_data_.size(), fname); if (info.group_ptr.size() != 0) { printf("data contains %u groups\n", (unsigned)info.group_ptr.size()-1); } @@ -163,7 +163,7 @@ class DMatrixSimple : public DataMatrix { if (!silent) { printf("%lux%lu matrix with %lu entries is saved to %s\n", - info.num_row, info.num_col, row_data_.size(), fname); + info.num_row(), info.num_col(), row_data_.size(), fname); if (info.group_ptr.size() != 0) { printf("data contains %lu groups\n", info.group_ptr.size()-1); } diff --git a/src/learner/dmatrix.h b/src/learner/dmatrix.h index b66cf86d0..c7ad52777 100644 --- a/src/learner/dmatrix.h +++ b/src/learner/dmatrix.h @@ -15,10 +15,12 @@ namespace learner { * \brief meta information needed in training, including label, weight */ struct MetaInfo { - /*! \brief number of rows in the data */ - size_t num_row; - /*! \brief number of columns in the data */ - size_t num_col; + /*! + * \brief information needed by booster + * BoosterInfo does not implement save and load, + * all serialization is done in MetaInfo + */ + BoosterInfo info; /*! \brief label of each instance */ std::vector labels; /*! @@ -28,8 +30,6 @@ struct MetaInfo { std::vector group_ptr; /*! \brief weights of each instance, optional */ std::vector weights; - /*! \brief information needed by booster */ - BoosterInfo info; /*! * \brief initialized margins, * if specified, xgboost will start from this init margin @@ -39,7 +39,15 @@ struct MetaInfo { /*! \brief version flag, used to check version of this info */ static const int kVersion = 0; // constructor - MetaInfo(void) : num_row(0), num_col(0) {} + MetaInfo(void) {} + /*! \return number of rows in dataset */ + inline size_t num_row(void) const { + return info.num_row; + } + /*! \return number of columns in dataset */ + inline size_t num_col(void) const { + return info.num_col; + } /*! \brief clear all the information */ inline void Clear(void) { labels.clear(); @@ -47,7 +55,7 @@ struct MetaInfo { weights.clear(); info.root_index.clear(); base_margin.clear(); - num_row = num_col = 0; + info.num_row = info.num_col = 0; } /*! \brief get weight of each instances */ inline float GetWeight(size_t i) const { @@ -60,8 +68,8 @@ struct MetaInfo { inline void SaveBinary(utils::IStream &fo) const { int version = kVersion; fo.Write(&version, sizeof(version)); - fo.Write(&num_row, sizeof(num_row)); - fo.Write(&num_col, sizeof(num_col)); + fo.Write(&info.num_row, sizeof(info.num_row)); + fo.Write(&info.num_col, sizeof(info.num_col)); fo.Write(labels); fo.Write(group_ptr); fo.Write(weights); @@ -71,8 +79,8 @@ struct MetaInfo { inline void LoadBinary(utils::IStream &fi) { int version; utils::Check(fi.Read(&version, sizeof(version)), "MetaInfo: invalid format"); - utils::Check(fi.Read(&num_row, sizeof(num_row)), "MetaInfo: invalid format"); - utils::Check(fi.Read(&num_col, sizeof(num_col)), "MetaInfo: invalid format"); + utils::Check(fi.Read(&info.num_row, sizeof(info.num_row)), "MetaInfo: invalid format"); + utils::Check(fi.Read(&info.num_col, sizeof(info.num_col)), "MetaInfo: invalid format"); utils::Check(fi.Read(&labels), "MetaInfo: invalid format"); utils::Check(fi.Read(&group_ptr), "MetaInfo: invalid format"); utils::Check(fi.Read(&weights), "MetaInfo: invalid format"); @@ -94,19 +102,28 @@ struct MetaInfo { fclose(fi); return true; } - inline std::vector& GetInfo(const char *field) { + inline std::vector& GetFloatInfo(const char *field) { if (!strcmp(field, "label")) return labels; if (!strcmp(field, "weight")) return weights; if (!strcmp(field, "base_margin")) return base_margin; utils::Error("unknown field %s", field); return labels; } - inline const std::vector& GetInfo(const char *field) const { - return ((MetaInfo*)this)->GetInfo(field); + inline const std::vector& GetFloatInfo(const char *field) const { + return ((MetaInfo*)this)->GetFloatInfo(field); + } + inline std::vector &GetUIntInfo(const char *field) { + if (!strcmp(field, "root_index")) return info.root_index; + if (!strcmp(field, "fold_index")) return info.fold_index; + utils::Error("unknown field %s", field); + return info.root_index; + } + inline const std::vector &GetUIntInfo(const char *field) const { + return ((MetaInfo*)this)->GetUIntInfo(field); } // try to load weight information from file, if exists inline bool TryLoadFloatInfo(const char *field, const char* fname, bool silent = false) { - std::vector &weights = this->GetInfo(field); + std::vector &weights = this->GetFloatInfo(field); FILE *fi = fopen64(fname, "r"); if (fi == NULL) return false; float wt; diff --git a/src/learner/learner-inl.hpp b/src/learner/learner-inl.hpp index 9150b5379..0f38febdc 100644 --- a/src/learner/learner-inl.hpp +++ b/src/learner/learner-inl.hpp @@ -58,9 +58,9 @@ class BoostLearner { if (dupilicate) continue; // set mats[i]'s cache learner pointer to this mats[i]->cache_learner_ptr_ = this; - cache_.push_back(CacheEntry(mats[i], buffer_size, mats[i]->info.num_row)); - buffer_size += mats[i]->info.num_row; - num_feature = std::max(num_feature, static_cast(mats[i]->info.num_col)); + cache_.push_back(CacheEntry(mats[i], buffer_size, mats[i]->info.num_row())); + buffer_size += mats[i]->info.num_row(); + num_feature = std::max(num_feature, static_cast(mats[i]->info.num_col())); } char str_temp[25]; if (num_feature > mparam.num_feature) { @@ -329,7 +329,7 @@ class BoostLearner { inline int64_t FindBufferOffset(const DMatrix &mat) const { for (size_t i = 0; i < cache_.size(); ++i) { if (cache_[i].mat_ == &mat && mat.cache_learner_ptr_ == this) { - if (cache_[i].num_row_ == mat.info.num_row) { + if (cache_[i].num_row_ == mat.info.num_row()) { return cache_[i].buffer_offset_; } }