diff --git a/src/tree/updater_basemaker-inl.h b/src/tree/updater_basemaker-inl.h index cad3ec811..271aa4ae2 100644 --- a/src/tree/updater_basemaker-inl.h +++ b/src/tree/updater_basemaker-inl.h @@ -206,6 +206,16 @@ class BaseMaker: public TreeUpdater { const RegTree &tree) { // set the positions in the nondefault this->SetNonDefaultPositionCol(nodes, p_fmat, tree); + this->SetDefaultPostion(p_fmat, tree); + } + /*! + * \brief helper function to set the non-leaf positions to default direction. + * This function can be applied multiple times and will get the same result. + * \param p_fmat feature matrix needed for tree construction + * \param tree the regression tree structure + */ + inline void SetDefaultPostion(DMatrix *p_fmat, + const RegTree &tree) { // set rest of instances to default position const RowSet &rowset = p_fmat->buffered_rowset(); // set default direct nodes to default @@ -222,7 +232,7 @@ class BaseMaker: public TreeUpdater { if (tree[nid].cright() == -1) { position[ridx] = ~nid; } - } else { + } else { // push to default branch if (tree[nid].default_left()) { this->SetEncodePosition(ridx, tree[nid].cleft()); @@ -234,16 +244,16 @@ class BaseMaker: public TreeUpdater { } /*! * \brief this is helper function uses column based data structure, - * update all positions into nondefault branch, if any, ignore the default branch * \param nodes the set of nodes that contains the split to be used - * \param p_fmat feature matrix needed for tree construction * \param tree the regression tree structure + * \param out_split_set The split index set */ - virtual void SetNonDefaultPositionCol(const std::vector &nodes, - DMatrix *p_fmat, - const RegTree &tree) { + inline void GetSplitSet(const std::vector &nodes, + const RegTree &tree, + std::vector* out_split_set) { + std::vector& fsplits = *out_split_set; + fsplits.clear(); // step 1, classify the non-default data into right places - std::vector fsplits; for (size_t i = 0; i < nodes.size(); ++i) { const int nid = nodes[i]; if (!tree[nid].is_leaf()) { @@ -252,7 +262,19 @@ class BaseMaker: public TreeUpdater { } std::sort(fsplits.begin(), fsplits.end()); fsplits.resize(std::unique(fsplits.begin(), fsplits.end()) - fsplits.begin()); - + } + /*! + * \brief this is helper function uses column based data structure, + * update all positions into nondefault branch, if any, ignore the default branch + * \param nodes the set of nodes that contains the split to be used + * \param p_fmat feature matrix needed for tree construction + * \param tree the regression tree structure + */ + virtual void SetNonDefaultPositionCol(const std::vector &nodes, + DMatrix *p_fmat, + const RegTree &tree) { + std::vector fsplits; + this->GetSplitSet(nodes, tree, &fsplits); dmlc::DataIter *iter = p_fmat->ColIterator(fsplits); while (iter->Next()) { const ColBatch &batch = iter->Value(); diff --git a/src/tree/updater_histmaker.cc b/src/tree/updater_histmaker.cc index c6d53b270..e7254307b 100644 --- a/src/tree/updater_histmaker.cc +++ b/src/tree/updater_histmaker.cc @@ -355,8 +355,9 @@ class CQHistMaker: public HistMaker { #endif } void ResetPositionAfterSplit(DMatrix *p_fmat, - const RegTree &tree) override { + const RegTree &tree) override { this->ResetPositionCol(this->qexpand, p_fmat, tree); + this->GetSplitSet(this->qexpand, tree, &fsplit_set); } void ResetPosAndPropose(const std::vector &gpair, DMatrix *p_fmat, @@ -388,14 +389,10 @@ class CQHistMaker: public HistMaker { for (size_t i = 0; i < sketchs.size(); ++i) { summary_array[i].Reserve(max_size); } - // if it is C++11, use lazy evaluation for Allreduce -#if __cplusplus >= 201103L - auto lazy_get_summary = [&]() -#endif - { + { // get smmary thread_sketch.resize(this->get_nthread()); - // number of rows in + // number of rows in data const size_t nrows = p_fmat->buffered_rowset().size(); // start accumulating statistics dmlc::DataIter *iter = p_fmat->ColIterator(freal_set); @@ -422,15 +419,10 @@ class CQHistMaker: public HistMaker { summary_array[i].SetPrune(out, max_size); } CHECK_EQ(summary_array.size(), sketchs.size()); - }; + } if (summary_array.size() != 0) { size_t nbytes = WXQSketch::SummaryContainer::CalcMemCost(max_size); -#if __cplusplus >= 201103L - sreducer.Allreduce(dmlc::BeginPtr(summary_array), - nbytes, summary_array.size(), lazy_get_summary); -#else sreducer.Allreduce(dmlc::BeginPtr(summary_array), nbytes, summary_array.size()); -#endif } // now we get the final result of sketch, setup the cut this->wspace.cut.clear(); @@ -617,6 +609,8 @@ class CQHistMaker: public HistMaker { std::vector feat2workindex; // set of index from fset that are real std::vector freal_set; + // set of index from that are split candidates. + std::vector fsplit_set; // thread temp data std::vector > thread_sketch; // used to hold statistics @@ -633,6 +627,7 @@ class CQHistMaker: public HistMaker { std::vector > sketchs; }; + template class QuantileHistMaker: public HistMaker { protected: