diff --git a/src/data.h b/src/data.h index fa815a4ee..fea3f7a48 100644 --- a/src/data.h +++ b/src/data.h @@ -14,6 +14,7 @@ #include "utils/io.h" #include "utils/utils.h" #include "utils/iterator.h" +#include "utils/random.h" #include "utils/matrix_csr.h" namespace xgboost { @@ -184,7 +185,6 @@ class FMatrixS : public FMatrixInterface{ /*! \brief constructor */ FMatrixS(void) { iter_ = NULL; - num_buffered_row_ = 0; } // destructor ~FMatrixS(void) { @@ -199,11 +199,15 @@ class FMatrixS : public FMatrixInterface{ utils::Check(this->HaveColAccess(), "NumCol:need column access"); return col_ptr_.size() - 1; } + /*! \brief get number of buffered rows */ + inline const std::vector buffered_rowset(void) const { + return buffered_rowset_; + } /*! \brief get col sorted iterator */ inline ColIter GetSortedCol(size_t cidx) const { utils::Assert(cidx < this->NumCol(), "col id exceed bound"); - return ColIter(&col_data_[col_ptr_[cidx]] - 1, - &col_data_[col_ptr_[cidx + 1]] - 1); + return ColIter(&col_data_[0] + col_ptr_[cidx] - 1, + &col_data_[0] + col_ptr_[cidx + 1] - 1); } /*! * \brief get reversed col iterator, @@ -211,8 +215,8 @@ class FMatrixS : public FMatrixInterface{ */ inline ColBackIter GetReverseSortedCol(size_t cidx) const { utils::Assert(cidx < this->NumCol(), "col id exceed bound"); - return ColBackIter(&col_data_[col_ptr_[cidx + 1]], - &col_data_[col_ptr_[cidx]]); + return ColBackIter(&col_data_[0] + col_ptr_[cidx + 1], + &col_data_[0] + col_ptr_[cidx]); } /*! \brief get col size */ inline size_t GetColSize(size_t cidx) const { @@ -220,12 +224,12 @@ class FMatrixS : public FMatrixInterface{ } /*! \brief get column density */ inline float GetColDensity(size_t cidx) const { - size_t nmiss = num_buffered_row_ - (col_ptr_[cidx+1] - col_ptr_[cidx]); - return 1.0f - (static_cast(nmiss)) / num_buffered_row_; + size_t nmiss = buffered_rowset_.size() - (col_ptr_[cidx+1] - col_ptr_[cidx]); + return 1.0f - (static_cast(nmiss)) / buffered_rowset_.size(); } - inline void InitColAccess(size_t max_nrow = ULONG_MAX) { + inline void InitColAccess(float pkeep = 1.0f) { if (this->HaveColAccess()) return; - this->InitColData(max_nrow); + this->InitColData(pkeep); } /*! * \brief get the row iterator associated with FMatrix @@ -244,8 +248,8 @@ class FMatrixS : public FMatrixInterface{ * \param fo output stream to save to */ inline void SaveColAccess(utils::IStream &fo) const { - fo.Write(&num_buffered_row_, sizeof(num_buffered_row_)); - if (num_buffered_row_ != 0) { + fo.Write(buffered_rowset_); + if (buffered_rowset_.size() != 0) { SaveBinary(fo, col_ptr_, col_data_); } } @@ -254,9 +258,8 @@ class FMatrixS : public FMatrixInterface{ * \param fo output stream to load from */ inline void LoadColAccess(utils::IStream &fi) { - utils::Check(fi.Read(&num_buffered_row_, sizeof(num_buffered_row_)) != 0, - "invalid input file format"); - if (num_buffered_row_ != 0) { + utils::Check(fi.Read(&buffered_rowset_), "invalid input file format"); + if (buffered_rowset_.size() != 0) { LoadBinary(fi, &col_ptr_, &col_data_); } } @@ -300,39 +303,43 @@ class FMatrixS : public FMatrixInterface{ protected: /*! * \brief intialize column data - * \param max_nrow maximum number of rows supported + * \param pkeep probability to keep a row */ - inline void InitColData(size_t max_nrow) { + inline void InitColData(float pkeep) { + buffered_rowset_.clear(); // note: this part of code is serial, todo, parallelize this transformer utils::SparseCSRMBuilder builder(col_ptr_, col_data_); builder.InitBudget(0); // start working iter_->BeforeFirst(); - num_buffered_row_ = 0; while (iter_->Next()) { const SparseBatch &batch = iter_->Value(); - if (batch.base_rowid >= max_nrow) break; - const size_t nbatch = std::min(batch.size, max_nrow - batch.base_rowid); - for (size_t i = 0; i < nbatch; ++i, ++num_buffered_row_) { - SparseBatch::Inst inst = batch[i]; - for (bst_uint j = 0; j < inst.length; ++j) { - builder.AddBudget(inst[j].findex); + for (size_t i = 0; i < batch.size; ++i) { + if (pkeep==1.0f || random::SampleBinary(pkeep)) { + buffered_rowset_.push_back(batch.base_rowid+i); + SparseBatch::Inst inst = batch[i]; + for (bst_uint j = 0; j < inst.length; ++j) { + builder.AddBudget(inst[j].findex); + } } } } builder.InitStorage(); iter_->BeforeFirst(); + size_t ktop = 0; while (iter_->Next()) { const SparseBatch &batch = iter_->Value(); - if (batch.base_rowid >= max_nrow) break; - const size_t nbatch = std::min(batch.size, max_nrow - batch.base_rowid); - for (size_t i = 0; i < nbatch; ++i) { - SparseBatch::Inst inst = batch[i]; - for (bst_uint j = 0; j < inst.length; ++j) { - builder.PushElem(inst[j].findex, - Entry((bst_uint)(batch.base_rowid+i), - inst[j].fvalue)); + for (size_t i = 0; i < batch.size; ++i) { + if (ktop < buffered_rowset_.size() && + buffered_rowset_[ktop] == batch.base_rowid+i) { + ++ ktop; + SparseBatch::Inst inst = batch[i]; + for (bst_uint j = 0; j < inst.length; ++j) { + builder.PushElem(inst[j].findex, + Entry((bst_uint)(batch.base_rowid+i), + inst[j].fvalue)); + } } } } @@ -349,8 +356,8 @@ class FMatrixS : public FMatrixInterface{ private: // --- data structure used to support InitColAccess -- utils::IIterator *iter_; - /*! \brief number */ - size_t num_buffered_row_; + /*! \brief list of row index that are buffered */ + std::vector buffered_rowset_; /*! \brief column pointer of CSC format */ std::vector col_ptr_; /*! \brief column datas in CSC format */ diff --git a/src/learner/learner-inl.hpp b/src/learner/learner-inl.hpp index 54e17301f..717a7b1a8 100644 --- a/src/learner/learner-inl.hpp +++ b/src/learner/learner-inl.hpp @@ -9,6 +9,7 @@ #include #include #include +#include #include "./objective.h" #include "./evaluation.h" #include "../gbm/gbm.h" @@ -28,6 +29,8 @@ class BoostLearner { gbm_ = NULL; name_obj_ = "reg:linear"; name_gbm_ = "gbtree"; + silent= 0; + prob_buffer_row = 1.0f; } ~BoostLearner(void) { if (obj_ != NULL) delete obj_; @@ -77,6 +80,7 @@ class BoostLearner { */ inline void SetParam(const char *name, const char *val) { if (!strcmp(name, "silent")) silent = atoi(val); + if (!strcmp(name, "prob_buffer_row")) prob_buffer_row = static_cast(atof(val)); if (!strcmp(name, "eval_metric")) evaluator_.AddEval(val); if (!strcmp("seed", name)) random::Seed(atoi(val)); if (!strcmp(name, "num_class")) this->SetParam("num_output_group", val); @@ -90,7 +94,9 @@ class BoostLearner { } if (gbm_ != NULL) gbm_->SetParam(name, val); if (obj_ != NULL) obj_->SetParam(name, val); - cfg_.push_back(std::make_pair(std::string(name), std::string(val))); + if (gbm_ == NULL || obj_ == NULL) { + cfg_.push_back(std::make_pair(std::string(name), std::string(val))); + } } /*! * \brief initialize the model @@ -147,8 +153,8 @@ class BoostLearner { * if not intialize it * \param p_train pointer to the matrix used by training */ - inline void CheckInit(DMatrix *p_train) const { - p_train->fmat.InitColAccess(); + inline void CheckInit(DMatrix *p_train) { + p_train->fmat.InitColAccess(prob_buffer_row); } /*! * \brief update the model for one iteration @@ -289,6 +295,8 @@ class BoostLearner { // data fields // silent during training int silent; + // maximum buffred row value + float prob_buffer_row; // evaluation set EvalSet evaluator_; // model parameter diff --git a/src/learner/objective-inl.hpp b/src/learner/objective-inl.hpp index e45250950..41af8b605 100644 --- a/src/learner/objective-inl.hpp +++ b/src/learner/objective-inl.hpp @@ -105,19 +105,22 @@ class RegLossObj : public IObjFunction{ scale_pos_weight = static_cast(atof(val)); } } - virtual void GetGradient(const std::vector& preds, + virtual void GetGradient(const std::vector &preds, const MetaInfo &info, int iter, std::vector *out_gpair) { - utils::Check(preds.size() == info.labels.size(), + utils::Check(info.labels.size() != 0, "label set cannot be empty"); + utils::Check(preds.size() % info.labels.size() == 0, "labels are not correctly provided"); std::vector &gpair = *out_gpair; gpair.resize(preds.size()); // start calculating gradient + const unsigned nstep = static_cast(info.labels.size()); const unsigned ndata = static_cast(preds.size()); #pragma omp parallel for schedule(static) - for (unsigned j = 0; j < ndata; ++j) { - float p = loss.PredTransform(preds[j]); + for (unsigned i = 0; i < ndata; ++i) { + const unsigned j = i % nstep; + float p = loss.PredTransform(preds[i]); float w = info.GetWeight(j); if (info.labels[j] == 1.0f) w *= scale_pos_weight; gpair[j] = bst_gpair(loss.FirstOrderGradient(p, info.labels[j]) * w, @@ -155,25 +158,28 @@ class SoftmaxMultiClassObj : public IObjFunction { virtual void SetParam(const char *name, const char *val) { if (!strcmp( "num_class", name )) nclass = atoi(val); } - virtual void GetGradient(const std::vector& preds, + virtual void GetGradient(const std::vector &preds, const MetaInfo &info, int iter, std::vector *out_gpair) { utils::Check(nclass != 0, "must set num_class to use softmax"); - utils::Check(preds.size() == static_cast(nclass) * info.labels.size(), + utils::Check(info.labels.size() != 0, "label set cannot be empty"); + utils::Check(preds.size() % (static_cast(nclass) * info.labels.size()) == 0, "SoftmaxMultiClassObj: label size and pred size does not match"); std::vector &gpair = *out_gpair; gpair.resize(preds.size()); - const unsigned ndata = static_cast(info.labels.size()); + const unsigned nstep = static_cast(info.labels.size() * nclass); + const unsigned ndata = static_cast(preds.size() / nclass); #pragma omp parallel { std::vector rec(nclass); #pragma omp for schedule(static) - for (unsigned j = 0; j < ndata; ++j) { + for (unsigned i = 0; i < ndata; ++i) { for (int k = 0; k < nclass; ++k) { - rec[k] = preds[j * nclass + k]; + rec[k] = preds[i * nclass + k]; } Softmax(&rec); + const unsigned j = i % nstep; int label = static_cast(info.labels[j]); utils::Check(label < nclass, "SoftmaxMultiClassObj: label exceed num_class"); const float wt = info.GetWeight(j); @@ -181,9 +187,9 @@ class SoftmaxMultiClassObj : public IObjFunction { float p = rec[k]; const float h = 2.0f * p * (1.0f - p) * wt; if (label == k) { - gpair[j * nclass + k] = bst_gpair((p - 1.0f) * wt, h); + gpair[i * nclass + k] = bst_gpair((p - 1.0f) * wt, h); } else { - gpair[j * nclass + k] = bst_gpair(p* wt, h); + gpair[i * nclass + k] = bst_gpair(p* wt, h); } } } @@ -203,7 +209,9 @@ class SoftmaxMultiClassObj : public IObjFunction { inline void Transform(std::vector *io_preds, int prob) { utils::Check(nclass != 0, "must set num_class to use softmax"); std::vector &preds = *io_preds; + std::vector tmp; const unsigned ndata = static_cast(preds.size()/nclass); + if (prob == 0) tmp.resize(ndata); #pragma omp parallel { std::vector rec(nclass); @@ -213,7 +221,7 @@ class SoftmaxMultiClassObj : public IObjFunction { rec[k] = preds[j * nclass + k]; } if (prob == 0) { - preds[j] = FindMaxIndex(rec); + tmp[j] = FindMaxIndex(rec); } else { Softmax(&rec); for (int k = 0; k < nclass; ++k) { @@ -222,9 +230,7 @@ class SoftmaxMultiClassObj : public IObjFunction { } } } - if (prob == 0) { - preds.resize(ndata); - } + if (prob == 0) preds = tmp; } // data field int nclass; @@ -245,17 +251,17 @@ class LambdaRankObj : public IObjFunction { if (!strcmp( "fix_list_weight", name)) fix_list_weight = static_cast(atof(val)); if (!strcmp( "num_pairsample", name)) num_pairsample = atoi(val); } - virtual void GetGradient(const std::vector& preds, + virtual void GetGradient(const std::vector &preds, const MetaInfo &info, int iter, std::vector *out_gpair) { - utils::Assert(preds.size() == info.labels.size(), "label size predict size not match"); + utils::Check(preds.size() == info.labels.size(), "label size predict size not match"); std::vector &gpair = *out_gpair; gpair.resize(preds.size()); // quick consistency when group is not available - std::vector tgptr(2, 0); tgptr[1] = preds.size(); + std::vector tgptr(2, 0); tgptr[1] = info.labels.size(); const std::vector &gptr = info.group_ptr.size() == 0 ? tgptr : info.group_ptr; - utils::Check(gptr.size() != 0 && gptr.back() == preds.size(), + utils::Check(gptr.size() != 0 && gptr.back() == info.labels.size(), "group structure not consistent with #rows"); const unsigned ngroup = static_cast(gptr.size() - 1); #pragma omp parallel diff --git a/src/learner/objective.h b/src/learner/objective.h index 513219093..d741ba61f 100644 --- a/src/learner/objective.h +++ b/src/learner/objective.h @@ -27,7 +27,7 @@ class IObjFunction{ * \param iter current iteration number * \param out_gpair output of get gradient, saves gradient and second order gradient in */ - virtual void GetGradient(const std::vector& preds, + virtual void GetGradient(const std::vector &preds, const MetaInfo &info, int iter, std::vector *out_gpair) = 0; diff --git a/src/tree/updater_colmaker-inl.hpp b/src/tree/updater_colmaker-inl.hpp index 3645e53ce..35880b70e 100644 --- a/src/tree/updater_colmaker-inl.hpp +++ b/src/tree/updater_colmaker-inl.hpp @@ -80,13 +80,13 @@ class ColMaker: public IUpdater { const std::vector &root_index, RegTree *p_tree) { this->InitData(gpair, fmat, root_index, *p_tree); - this->InitNewNode(qexpand, gpair, *p_tree); + this->InitNewNode(qexpand, gpair, fmat, *p_tree); for (int depth = 0; depth < param.max_depth; ++depth) { this->FindSplit(depth, this->qexpand, gpair, fmat, p_tree); this->ResetPosition(this->qexpand, fmat, *p_tree); this->UpdateQueueExpand(*p_tree, &this->qexpand); - this->InitNewNode(qexpand, gpair, *p_tree); + this->InitNewNode(qexpand, gpair, fmat, *p_tree); // if nothing left to be expand, break if (qexpand.size() == 0) break; } @@ -109,25 +109,31 @@ class ColMaker: public IUpdater { const FMatrix &fmat, const std::vector &root_index, const RegTree &tree) { utils::Assert(tree.param.num_nodes == tree.param.num_roots, "ColMaker: can only grow new tree"); + const std::vector &rowset = fmat.buffered_rowset(); {// setup position position.resize(gpair.size()); if (root_index.size() == 0) { - std::fill(position.begin(), position.end(), 0); + for (size_t i = 0; i < rowset.size(); ++i) { + position[rowset[i]] = 0; + } } else { - for (size_t i = 0; i < root_index.size(); ++i) { - position[i] = root_index[i]; - utils::Assert(root_index[i] < (unsigned)tree.param.num_roots, "root index exceed setting"); + for (size_t i = 0; i < rowset.size(); ++i) { + const bst_uint ridx = rowset[i]; + position[ridx] = root_index[ridx]; + utils::Assert(root_index[ridx] < (unsigned)tree.param.num_roots, "root index exceed setting"); } } // mark delete for the deleted datas - for (size_t i = 0; i < gpair.size(); ++i) { - if (gpair[i].hess < 0.0f) position[i] = -1; + for (size_t i = 0; i < rowset.size(); ++i) { + const bst_uint ridx = rowset[i]; + if (gpair[ridx].hess < 0.0f) position[ridx] = -1; } // mark subsample if (param.subsample < 1.0f) { - for (size_t i = 0; i < gpair.size(); ++i) { - if (gpair[i].hess < 0.0f) continue; - if (random::SampleBinary(param.subsample) == 0) position[i] = -1; + for (size_t i = 0; i < rowset.size(); ++i) { + const bst_uint ridx = rowset[i]; + if (gpair[ridx].hess < 0.0f) continue; + if (random::SampleBinary(param.subsample) == 0) position[ridx] = -1; } } } @@ -168,6 +174,7 @@ class ColMaker: public IUpdater { /*! \brief initialize the base_weight, root_gain, and NodeEntry for all the new nodes in qexpand */ inline void InitNewNode(const std::vector &qexpand, const std::vector &gpair, + const FMatrix &fmat, const RegTree &tree) { {// setup statistics space for each tree node for (size_t i = 0; i < stemp.size(); ++i) { @@ -175,13 +182,15 @@ class ColMaker: public IUpdater { } snode.resize(tree.param.num_nodes, NodeEntry()); } + const std::vector &rowset = fmat.buffered_rowset(); // setup position - const unsigned ndata = static_cast(position.size()); + const unsigned ndata = static_cast(rowset.size()); #pragma omp parallel for schedule(static) for (unsigned i = 0; i < ndata; ++i) { + const bst_uint ridx = rowset[i]; const int tid = omp_get_thread_num(); - if (position[i] < 0) continue; - stemp[tid][position[i]].stats.Add(gpair[i]); + if (position[ridx] < 0) continue; + stemp[tid][position[ridx]].stats.Add(gpair[ridx]); } // sum the per thread statistics together for (size_t j = 0; j < qexpand.size(); ++j) { @@ -271,7 +280,9 @@ class ColMaker: public IUpdater { } // start enumeration const unsigned nsize = static_cast(feat_set.size()); + #if defined(_OPENMP) const int batch_size = std::max(static_cast(nsize / this->nthread / 32), 1); + #endif #pragma omp parallel for schedule(dynamic, batch_size) for (unsigned i = 0; i < nsize; ++i) { const unsigned fid = feat_set[i]; @@ -301,17 +312,19 @@ class ColMaker: public IUpdater { } // reset position of each data points after split is created in the tree inline void ResetPosition(const std::vector &qexpand, const FMatrix &fmat, const RegTree &tree) { + const std::vector &rowset = fmat.buffered_rowset(); // step 1, set default direct nodes to default, and leaf nodes to -1 - const unsigned ndata = static_cast(position.size()); + const unsigned ndata = static_cast(rowset.size()); #pragma omp parallel for schedule(static) - for (unsigned i = 0; i < ndata; ++i) { - const int nid = position[i]; + for (unsigned i = 0; i < ndata; ++i) { + const bst_uint ridx = rowset[i]; + const int nid = position[ridx]; if (nid >= 0) { if (tree[nid].is_leaf()) { - position[i] = -1; + position[ridx] = -1; } else { // push to default branch, correct latter - position[i] = tree[nid].default_left() ? tree[nid].cleft(): tree[nid].cright(); + position[ridx] = tree[nid].default_left() ? tree[nid].cleft(): tree[nid].cright(); } } } diff --git a/src/tree/updater_refresh-inl.hpp b/src/tree/updater_refresh-inl.hpp index 69f099e1d..12bbcf864 100644 --- a/src/tree/updater_refresh-inl.hpp +++ b/src/tree/updater_refresh-inl.hpp @@ -20,7 +20,6 @@ class TreeRefresher: public IUpdater { // set training parameter virtual void SetParam(const char *name, const char *val) { param.SetParam(name, val); - if (!strcmp(name, "silent")) silent = atoi(val); } // update the tree, do pruning virtual void Update(const std::vector &gpair, @@ -127,8 +126,6 @@ class TreeRefresher: public IUpdater { } // number of thread in the data int nthread; - // shutup - int silent; // training parameter TrainParam param; };