From ce4d59ed69a57ee4a0c93f9a80b5a985595e3e06 Mon Sep 17 00:00:00 2001 From: tqchen Date: Thu, 28 Jan 2016 17:16:38 -0800 Subject: [PATCH] [TREE] Enable global proposal for faster speed --- src/data/data.cc | 28 +++---- src/tree/updater_histmaker.cc | 135 ++++++++++++++++++++++++++++++++-- 2 files changed, 141 insertions(+), 22 deletions(-) diff --git a/src/data/data.cc b/src/data/data.cc index 11a36ffdf..39fa260d1 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -138,23 +138,23 @@ DMatrix* DMatrix::Load(const std::string& uri, cache_file = uri.substr(dlm_pos + 1, uri.length()); fname = uri.substr(0, dlm_pos); CHECK_EQ(cache_file.find('#'), std::string::npos) - << "Only one `#` is allowed in file path for cache file specification."; + << "Only one `#` is allowed in file path for cache file specification."; if (load_row_split) { std::ostringstream os; std::vector cache_shards = common::Split(cache_file, ':'); for (size_t i = 0; i < cache_shards.size(); ++i) { - size_t pos = cache_shards[i].rfind('.'); - if (pos == std::string::npos) { - os << cache_shards[i] - << ".r" << rabit::GetRank() - << "-" << rabit::GetWorldSize(); - } else { - os << cache_shards[i].substr(0, pos) - << ".r" << rabit::GetRank() - << "-" << rabit::GetWorldSize() - << cache_shards[i].substr(pos, cache_shards[i].length()); - } - if (i + 1 != cache_shards.size()) os << ':'; + size_t pos = cache_shards[i].rfind('.'); + if (pos == std::string::npos) { + os << cache_shards[i] + << ".r" << rabit::GetRank() + << "-" << rabit::GetWorldSize(); + } else { + os << cache_shards[i].substr(0, pos) + << ".r" << rabit::GetRank() + << "-" << rabit::GetWorldSize() + << cache_shards[i].substr(pos, cache_shards[i].length()); + } + if (i + 1 != cache_shards.size()) os << ':'; } cache_file = os.str(); } @@ -172,7 +172,7 @@ DMatrix* DMatrix::Load(const std::string& uri, if (npart != 1) { LOG(CONSOLE) << "Load part of data " << partid - << " of " << npart << " parts"; + << " of " << npart << " parts"; } // legacy handling of binary data loading if (file_format == "auto" && !load_row_split) { diff --git a/src/tree/updater_histmaker.cc b/src/tree/updater_histmaker.cc index 0f4b93c3b..6af7c8117 100644 --- a/src/tree/updater_histmaker.cc +++ b/src/tree/updater_histmaker.cc @@ -268,6 +268,10 @@ class HistMaker: public BaseMaker { template class CQHistMaker: public HistMaker { + public: + CQHistMaker() : cache_dmatrix_(nullptr) { + } + protected: struct HistEntry { typename HistMaker::HistUnit hist; @@ -290,9 +294,13 @@ class CQHistMaker: public HistMaker { */ inline void Add(bst_float fv, bst_gpair gstats) { - while (istart < hist.size && !(fv < hist.cut[istart])) ++istart; - CHECK_NE(istart, hist.size); - hist.data[istart].Add(gstats); + if (fv < hist.cut[istart]) { + hist.data[istart].Add(gstats); + } else { + while (istart < hist.size && !(fv < hist.cut[istart])) ++istart; + CHECK_NE(istart, hist.size); + hist.data[istart].Add(gstats); + } } }; // sketch type used for this @@ -301,7 +309,10 @@ class CQHistMaker: public HistMaker { void InitWorkSet(DMatrix *p_fmat, const RegTree &tree, std::vector *p_fset) override { - feat_helper.InitByCol(p_fmat, tree); + if (p_fmat != cache_dmatrix_) { + feat_helper.InitByCol(p_fmat, tree); + cache_dmatrix_ = p_fmat; + } feat_helper.SampleCol(this->param.colsample_bytree, p_fset); } // code to create histogram @@ -342,6 +353,9 @@ class CQHistMaker: public HistMaker { } } } + // update node statistics. + this->GetNodeStats(gpair, *p_fmat, tree, + &thread_stats, &node_stats); for (size_t i = 0; i < this->qexpand.size(); ++i) { const int nid = this->qexpand[i]; const int wid = this->node2workindex[nid]; @@ -434,9 +448,6 @@ class CQHistMaker: public HistMaker { size_t nbytes = WXQSketch::SummaryContainer::CalcMemCost(max_size); sreducer.Allreduce(dmlc::BeginPtr(summary_array), nbytes, summary_array.size()); } - // update node statistics. - this->GetNodeStats(gpair, *p_fmat, tree, - &thread_stats, &node_stats); // now we get the final result of sketch, setup the cut this->wspace.cut.clear(); this->wspace.rptr.clear(); @@ -475,7 +486,6 @@ class CQHistMaker: public HistMaker { (fset.size() + 1) * this->qexpand.size() + 1); } - private: inline void UpdateHistCol(const std::vector &gpair, const ColBatch::Inst &c, const MetaInfo &info, @@ -607,6 +617,8 @@ class CQHistMaker: public HistMaker { sbuilder[nid].Finalize(max_size); } } + // cached dmatrix where we initialized the feature on. + const DMatrix* cache_dmatrix_; // feature helper BaseMaker::FMetaHelper feat_helper; // temp space to map feature id to working index @@ -631,6 +643,107 @@ class CQHistMaker: public HistMaker { std::vector > sketchs; }; +// global proposal +template +class GlobalProposalHistMaker: public CQHistMaker { + protected: + void ResetPosAndPropose(const std::vector &gpair, + DMatrix *p_fmat, + const std::vector &fset, + const RegTree &tree) override { + if (this->qexpand.size() == 1 && !this->param.cache_global_proposal) { + cached_rptr_.clear(); + cached_cut_.clear(); + } + if (cached_rptr_.size() == 0) { + CHECK_EQ(this->qexpand.size(), 1); + CQHistMaker::ResetPosAndPropose(gpair, p_fmat, fset, tree); + cached_rptr_ = this->wspace.rptr; + cached_cut_ = this->wspace.cut; + } else { + this->wspace.cut.clear(); + this->wspace.rptr.clear(); + this->wspace.rptr.push_back(0); + for (size_t i = 0; i < this->qexpand.size(); ++i) { + for (size_t j = 0; j < cached_rptr_.size() - 1; ++j) { + this->wspace.rptr.push_back( + this->wspace.rptr.back() + cached_rptr_[j + 1] - cached_rptr_[j]); + } + this->wspace.cut.insert(this->wspace.cut.end(), cached_cut_.begin(), cached_cut_.end()); + } + CHECK_EQ(this->wspace.rptr.size(), + (fset.size() + 1) * this->qexpand.size() + 1); + CHECK_EQ(this->wspace.rptr.back(), this->wspace.cut.size()); + } + } + + // code to create histogram + void CreateHist(const std::vector &gpair, + DMatrix *p_fmat, + const std::vector &fset, + const RegTree &tree) override { + const MetaInfo &info = p_fmat->info(); + // fill in reverse map + this->feat2workindex.resize(tree.param.num_feature); + this->work_set = fset; + std::fill(this->feat2workindex.begin(), this->feat2workindex.end(), -1); + for (size_t i = 0; i < fset.size(); ++i) { + this->feat2workindex[fset[i]] = static_cast(i); + } + // start to work + this->wspace.Init(this->param, 1); + // to gain speedup in recovery + { + this->thread_hist.resize(this->get_nthread()); + + // TWOPASS: use the real set + split set in the column iteration. + this->SetDefaultPostion(p_fmat, tree); + this->work_set.insert(this->work_set.end(), this->fsplit_set.begin(), this->fsplit_set.end()); + std::sort(this->work_set.begin(), this->work_set.end()); + this->work_set.resize( + std::unique(this->work_set.begin(), this->work_set.end()) - this->work_set.begin()); + + // start accumulating statistics + dmlc::DataIter *iter = p_fmat->ColIterator(this->work_set); + iter->BeforeFirst(); + while (iter->Next()) { + const ColBatch &batch = iter->Value(); + // TWOPASS: use the real set + split set in the column iteration. + this->CorrectNonDefaultPositionByBatch(batch, this->fsplit_set, tree); + + // start enumeration + const bst_omp_uint nsize = static_cast(batch.size); + #pragma omp parallel for schedule(dynamic, 1) + for (bst_omp_uint i = 0; i < nsize; ++i) { + int offset = this->feat2workindex[batch.col_index[i]]; + if (offset >= 0) { + this->UpdateHistCol(gpair, batch[i], info, tree, + fset, offset, + &this->thread_hist[omp_get_thread_num()]); + } + } + } + + // update node statistics. + this->GetNodeStats(gpair, *p_fmat, tree, + &(this->thread_stats), &(this->node_stats)); + for (size_t i = 0; i < this->qexpand.size(); ++i) { + const int nid = this->qexpand[i]; + const int wid = this->node2workindex[nid]; + this->wspace.hset[0][fset.size() + wid * (fset.size()+1)] + .data[0] = this->node_stats[nid]; + } + } + this->histred.Allreduce(dmlc::BeginPtr(this->wspace.hset[0].data), + this->wspace.hset[0].data.size()); + } + + // cached unit pointer + std::vector cached_rptr_; + // cached cut value. + std::vector cached_cut_; +}; + template class QuantileHistMaker: public HistMaker { @@ -763,5 +876,11 @@ XGBOOST_REGISTER_TREE_UPDATER(HistMaker, "grow_histmaker") .set_body([]() { return new CQHistMaker(); }); + +XGBOOST_REGISTER_TREE_UPDATER(GlobalHistMaker, "grow_global_histmaker") +.describe("Tree constructor that uses approximate global proposal of histogram construction.") +.set_body([]() { + return new GlobalProposalHistMaker(); + }); } // namespace tree } // namespace xgboost