From 523afcbcd25b02437c4b77f36cd503dfe8c0b3ae Mon Sep 17 00:00:00 2001 From: tqchen Date: Tue, 19 Jan 2016 21:53:52 -0800 Subject: [PATCH] [TREE] Cleanup some functions, add utility function for two pass --- src/tree/updater_basemaker-inl.h | 39 +++++++++++++++++++++ src/tree/updater_histmaker.cc | 58 +++++++++++++++----------------- 2 files changed, 66 insertions(+), 31 deletions(-) diff --git a/src/tree/updater_basemaker-inl.h b/src/tree/updater_basemaker-inl.h index 271aa4ae2..b6dbacd6c 100644 --- a/src/tree/updater_basemaker-inl.h +++ b/src/tree/updater_basemaker-inl.h @@ -242,6 +242,45 @@ class BaseMaker: public TreeUpdater { } } } + /*! + * \brief this is helper function uses column based data structure, + * to CORRECT the positions of non-default directions that WAS set to default + * before calling this function. + * \param batch The column batch + * \param sorted_split_set The set of index that contains split solutions. + * \param tree the regression tree structure + */ + inline void CorrectNonDefaultPositionByBatch( + const ColBatch& batch, + const std::vector &sorted_split_set, + const RegTree &tree) { + for (size_t i = 0; i < batch.size; ++i) { + ColBatch::Inst col = batch[i]; + const bst_uint fid = batch.col_index[i]; + auto it = std::lower_bound(sorted_split_set.begin(), sorted_split_set.end(), fid); + + if (it != sorted_split_set.end() && *it == fid) { + const bst_omp_uint ndata = static_cast(col.length); + #pragma omp parallel for schedule(static) + for (bst_omp_uint j = 0; j < ndata; ++j) { + const bst_uint ridx = col[j].index; + const float fvalue = col[j].fvalue; + const int nid = this->DecodePosition(ridx); + CHECK(tree[nid].is_leaf()); + int pid = tree[nid].parent(); + + // go back to parent, correct those who are not default + if (!tree[nid].is_root() && tree[pid].split_index() == fid) { + if (fvalue < tree[pid].split_cond()) { + this->SetEncodePosition(ridx, tree[pid].cleft()); + } else { + this->SetEncodePosition(ridx, tree[pid].cright()); + } + } + } + } + } + } /*! * \brief this is helper function uses column based data structure, * \param nodes the set of nodes that contains the split to be used diff --git a/src/tree/updater_histmaker.cc b/src/tree/updater_histmaker.cc index e7254307b..40089c26d 100644 --- a/src/tree/updater_histmaker.cc +++ b/src/tree/updater_histmaker.cc @@ -127,6 +127,11 @@ class HistMaker: public BaseMaker { RegTree *p_tree) { this->InitData(gpair, *p_fmat, *p_tree); this->InitWorkSet(p_fmat, *p_tree, &fwork_set); + // mark root node as fresh. + for (int i = 0; i < p_tree->param.num_roots; ++i) { + (*p_tree)[i].set_leaf(0.0f, 0); + } + for (int depth = 0; depth < param.max_depth; ++depth) { // reset and propose candidate split this->ResetPosAndPropose(gpair, p_fmat, fwork_set, *p_tree); @@ -356,8 +361,8 @@ class CQHistMaker: public HistMaker { } void ResetPositionAfterSplit(DMatrix *p_fmat, const RegTree &tree) override { + // remove this reset and do two pass reset on ResetPosAndPropose this->ResetPositionCol(this->qexpand, p_fmat, tree); - this->GetSplitSet(this->qexpand, tree, &fsplit_set); } void ResetPosAndPropose(const std::vector &gpair, DMatrix *p_fmat, @@ -367,18 +372,18 @@ class CQHistMaker: public HistMaker { // fill in reverse map feat2workindex.resize(tree.param.num_feature); std::fill(feat2workindex.begin(), feat2workindex.end(), -1); - freal_set.clear(); + work_set.clear(); for (size_t i = 0; i < fset.size(); ++i) { if (feat_helper.Type(fset[i]) == 2) { - feat2workindex[fset[i]] = static_cast(freal_set.size()); - freal_set.push_back(fset[i]); + feat2workindex[fset[i]] = static_cast(work_set.size()); + work_set.push_back(fset[i]); } else { feat2workindex[fset[i]] = -2; } } - this->GetNodeStats(gpair, *p_fmat, tree, - &thread_stats, &node_stats); - sketchs.resize(this->qexpand.size() * freal_set.size()); + const size_t work_set_size = work_set.size(); + + sketchs.resize(this->qexpand.size() * work_set_size); for (size_t i = 0; i < sketchs.size(); ++i) { sketchs[i].Init(info.num_row, this->param.sketch_eps); } @@ -392,10 +397,9 @@ class CQHistMaker: public HistMaker { { // get smmary thread_sketch.resize(this->get_nthread()); - // number of rows in data - const size_t nrows = p_fmat->buffered_rowset().size(); + // start accumulating statistics - dmlc::DataIter *iter = p_fmat->ColIterator(freal_set); + dmlc::DataIter *iter = p_fmat->ColIterator(work_set); iter->BeforeFirst(); while (iter->Next()) { const ColBatch &batch = iter->Value(); @@ -406,9 +410,7 @@ class CQHistMaker: public HistMaker { int offset = feat2workindex[batch.col_index[i]]; if (offset >= 0) { this->UpdateSketchCol(gpair, batch[i], tree, - node_stats, - freal_set, offset, - batch[i].length == nrows, + work_set_size, offset, &thread_sketch[omp_get_thread_num()]); } } @@ -419,11 +421,14 @@ class CQHistMaker: public HistMaker { summary_array[i].SetPrune(out, max_size); } CHECK_EQ(summary_array.size(), sketchs.size()); - } + } if (summary_array.size() != 0) { size_t nbytes = WXQSketch::SummaryContainer::CalcMemCost(max_size); sreducer.Allreduce(dmlc::BeginPtr(summary_array), nbytes, summary_array.size()); } + // update node statistics. + this->GetNodeStats(gpair, *p_fmat, tree, + &thread_stats, &node_stats); // now we get the final result of sketch, setup the cut this->wspace.cut.clear(); this->wspace.rptr.clear(); @@ -432,7 +437,7 @@ class CQHistMaker: public HistMaker { for (size_t i = 0; i < fset.size(); ++i) { int offset = feat2workindex[fset[i]]; if (offset >= 0) { - const WXQSketch::Summary &a = summary_array[wid * freal_set.size() + offset]; + const WXQSketch::Summary &a = summary_array[wid * work_set_size + offset]; for (size_t i = 1; i < a.size; ++i) { bst_float cpt = a.data[i].value - rt_eps; if (i == 1 || cpt > this->wspace.cut.back()) { @@ -518,10 +523,8 @@ class CQHistMaker: public HistMaker { inline void UpdateSketchCol(const std::vector &gpair, const ColBatch::Inst &c, const RegTree &tree, - const std::vector &nstats, - const std::vector &frealset, + size_t work_set_size, bst_uint offset, - bool col_full, std::vector *p_temp) { if (c.length == 0) return; // initialize sbuilder for use @@ -531,22 +534,15 @@ class CQHistMaker: public HistMaker { const unsigned nid = this->qexpand[i]; const unsigned wid = this->node2workindex[nid]; sbuilder[nid].sum_total = 0.0f; - sbuilder[nid].sketch = &sketchs[wid * frealset.size() + offset]; + sbuilder[nid].sketch = &sketchs[wid * work_set_size + offset]; } - if (!col_full) { - // first pass, get sum of weight, TODO, optimization to skip first pass - for (bst_uint j = 0; j < c.length; ++j) { + // first pass, get sum of weight, TODO, optimization to skip first pass + for (bst_uint j = 0; j < c.length; ++j) { const bst_uint ridx = c[j].index; const int nid = this->position[ridx]; if (nid >= 0) { - sbuilder[nid].sum_total += gpair[ridx].hess; - } - } - } else { - for (size_t i = 0; i < this->qexpand.size(); ++i) { - const unsigned nid = this->qexpand[i]; - sbuilder[nid].sum_total = static_cast(nstats[nid].sum_hess); + sbuilder[nid].sum_total += gpair[ridx].hess; } } // if only one value, no need to do second pass @@ -607,8 +603,8 @@ class CQHistMaker: public HistMaker { BaseMaker::FMetaHelper feat_helper; // temp space to map feature id to working index std::vector feat2workindex; - // set of index from fset that are real - std::vector freal_set; + // set of index from fset that are current work set + std::vector work_set; // set of index from that are split candidates. std::vector fsplit_set; // thread temp data