diff --git a/src/tree/updater_histmaker-inl.hpp b/src/tree/updater_histmaker-inl.hpp index 97e4d0aea..72033613d 100644 --- a/src/tree/updater_histmaker-inl.hpp +++ b/src/tree/updater_histmaker-inl.hpp @@ -166,7 +166,8 @@ class HistMaker: public IUpdater { // initialize temp data structure inline void InitData(const std::vector &gpair, const IFMatrix &fmat, - const std::vector &root_index, const RegTree &tree) { + const std::vector &root_index, + const RegTree &tree) { utils::Assert(tree.param.num_nodes == tree.param.num_roots, "HistMaker: can only grow new tree"); {// setup position @@ -271,6 +272,7 @@ class HistMaker: public IUpdater { const TStats &node_sum, bst_uint fid, SplitEntry *best) { + if (hist.size == 0) return; double root_gain = node_sum.CalcGain(param); TStats s(param), c(param); for (bst_uint i = 0; i < hist.size; ++i) { @@ -319,7 +321,7 @@ class HistMaker: public IUpdater { EnumerateSplit(wspace.hset[0][fid + wid * (num_feature+1)], node_sum, fid, &best); } - } + } // get the best result, we can synchronize the solution for (bst_omp_uint wid = 0; wid < nexpand; ++ wid) { const int nid = qexpand[wid]; @@ -334,7 +336,8 @@ class HistMaker: public IUpdater { // now we know the solution in snode[nid], set split if (best.loss_chg > rt_eps) { p_tree->AddChilds(nid); - (*p_tree)[nid].set_split(best.split_index(), best.split_value, best.default_left()); + (*p_tree)[nid].set_split(best.split_index(), + best.split_value, best.default_left()); // mark right child as 0, to indicate fresh leaf (*p_tree)[(*p_tree)[nid].cleft()].set_leaf(0.0f, 0); (*p_tree)[(*p_tree)[nid].cright()].set_leaf(0.0f, 0); @@ -379,10 +382,12 @@ class QuantileHistMaker: public HistMaker { const bst_uint ridx = static_cast(batch.base_rowid + i); int nid = this->position[ridx]; if (nid >= 0) { - if (tree[nid].is_leaf()) { - this->position[ridx] = ~nid; - } else { + if (!tree[nid].is_leaf()) { this->position[ridx] = nid = HistMaker::NextLevel(inst, tree, nid); + } + if (this->node2workindex[nid] < 0) { + this->position[ridx] = ~nid; + } else{ for (bst_uint j = 0; j < inst.length; ++j) { builder.AddBudget(inst[j].index, omp_get_thread_num()); } @@ -404,7 +409,7 @@ class QuantileHistMaker: public HistMaker { } } // start putting things into sketch - const bst_omp_uint nfeat = tree.param.num_feature; + const bst_omp_uint nfeat = col_ptr.size() - 1; #pragma omp parallel for schedule(dynamic, 1) for (bst_omp_uint k = 0; k < nfeat; ++k) { for (size_t i = col_ptr[k]; i < col_ptr[k+1]; ++i) { @@ -418,15 +423,23 @@ class QuantileHistMaker: public HistMaker { size_t max_size = static_cast(this->param.sketch_ratio / this->param.sketch_eps); // synchronize sketch summary_array.Init(sketchs.size(), max_size); + for (size_t i = 0; i < sketchs.size(); ++i) { + utils::WQuantileSketch::SummaryContainer out; + sketchs[i].GetSummary(&out); + summary_array.Set(i, out); + } size_t n4bytes = (summary_array.MemSize() + 3) / 4; - sreducer.AllReduce(&summary_array, n4bytes); + sreducer.AllReduce(&summary_array, n4bytes); // now we get the final result of sketch, setup the cut - for (size_t wid = 0; wid < this->qexpand.size(); ++wid) { + this->wspace.cut.clear(); + this->wspace.rptr.clear(); + this->wspace.rptr.push_back(0); + for (size_t wid = 0; wid < this->qexpand.size(); ++wid) { for (size_t fid = 0; fid < tree.param.num_feature; ++fid) { - const WXQSketch::Summary a = summary_array[wid * tree.param.num_feature + fid]; + const WXQSketch::Summary a = summary_array[wid * tree.param.num_feature + fid]; for (size_t i = 0; i < a.size; ++i) { bst_float cpt = a.data[i].value + rt_eps; - if (i == 0 || cpt > this->wspace.cut.back()){ + if (i == 0 || cpt > this->wspace.cut.back()) { this->wspace.cut.push_back(cpt); } } @@ -437,7 +450,8 @@ class QuantileHistMaker: public HistMaker { this->wspace.rptr.push_back(this->wspace.cut.size()); } utils::Assert(this->wspace.rptr.size() == - (tree.param.num_feature + 1) * this->qexpand.size(), "cut space inconsistent"); + (tree.param.num_feature + 1) * this->qexpand.size() + 1, + "cut space inconsistent"); } private: diff --git a/src/utils/quantile.h b/src/utils/quantile.h index c27fa9bfe..a3b8c18dd 100644 --- a/src/utils/quantile.h +++ b/src/utils/quantile.h @@ -258,7 +258,7 @@ struct WXQSummary : public WQSummary { return e.rmin_next() > e.rmax_prev() + chunk; } // set prune - inline void SetPrune(const WXQSummary &src, RType maxsize) { + inline void SetPrune(const WQSummary &src, RType maxsize) { if (src.size <= maxsize) { this->CopyFrom(src); return; }