From 02c2278f9629020e153ba76040f9e44fdefa863b Mon Sep 17 00:00:00 2001 From: tqchen Date: Sat, 15 Nov 2014 21:18:15 -0800 Subject: [PATCH] ok --- src/tree/updater_histmaker-inl.hpp | 42 +++++++++++++++++++++--------- 1 file changed, 30 insertions(+), 12 deletions(-) diff --git a/src/tree/updater_histmaker-inl.hpp b/src/tree/updater_histmaker-inl.hpp index 72033613d..68ceca371 100644 --- a/src/tree/updater_histmaker-inl.hpp +++ b/src/tree/updater_histmaker-inl.hpp @@ -161,7 +161,11 @@ class HistMaker: public IUpdater { this->UpdateNode2WorkIndex(*p_tree); // if nothing left to be expand, break if (qexpand.size() == 0) break; - } + } + for (size_t i = 0; i < qexpand.size(); ++i) { + const int nid = qexpand[i]; + (*p_tree)[nid].set_leaf(p_tree->stat(nid).base_weight * param.learning_rate); + } } // initialize temp data structure inline void InitData(const std::vector &gpair, @@ -271,7 +275,8 @@ class HistMaker: public IUpdater { inline void EnumerateSplit(const HistUnit &hist, const TStats &node_sum, bst_uint fid, - SplitEntry *best) { + SplitEntry *best, + TStats *left_sum) { if (hist.size == 0) return; double root_gain = node_sum.CalcGain(param); TStats s(param), c(param); @@ -281,7 +286,9 @@ class HistMaker: public IUpdater { c.SetSubstract(node_sum, s); if (c.sum_hess >= param.min_child_weight) { double loss_chg = s.CalcGain(param) + c.CalcGain(param) - root_gain; - best->Update(loss_chg, fid, hist.cut[i], false); + if (best->Update(loss_chg, fid, hist.cut[i], false)) { + *left_sum = s; + } } } } @@ -292,7 +299,9 @@ class HistMaker: public IUpdater { c.SetSubstract(node_sum, s); if (c.sum_hess >= param.min_child_weight) { double loss_chg = s.CalcGain(param) + c.CalcGain(param) - root_gain; - best->Update(loss_chg, fid, hist.cut[i-1], true); + if (best->Update(loss_chg, fid, hist.cut[i-1], true)) { + *left_sum = c; + } } } } @@ -309,17 +318,18 @@ class HistMaker: public IUpdater { this->CreateHist(gpair, p_fmat, info, *p_tree); // get the best split condition for each node std::vector sol(qexpand.size()); + std::vector left_sum(qexpand.size()); bst_omp_uint nexpand = static_cast(qexpand.size()); #pragma omp parallel for schedule(dynamic, 1) for (bst_omp_uint wid = 0; wid < nexpand; ++ wid) { const int nid = qexpand[wid]; utils::Assert(node2workindex[nid] == static_cast(wid), "node2workindex inconsistent"); - SplitEntry &best = sol[wid]; + SplitEntry &best = sol[wid]; TStats &node_sum = wspace.hset[0][num_feature + wid * (num_feature + 1)].data[0]; for (bst_uint fid = 0; fid < num_feature; ++ fid) { EnumerateSplit(wspace.hset[0][fid + wid * (num_feature+1)], - node_sum, fid, &best); + node_sum, fid, &best, &left_sum[wid]); } } // get the best result, we can synchronize the solution @@ -327,25 +337,33 @@ class HistMaker: public IUpdater { const int nid = qexpand[wid]; const SplitEntry &best = sol[wid]; const TStats &node_sum = wspace.hset[0][num_feature + wid * (num_feature + 1)].data[0]; - bst_float weight = node_sum.CalcWeight(param); + this->SetStats(p_tree, nid, node_sum); // set up the values p_tree->stat(nid).loss_chg = best.loss_chg; - p_tree->stat(nid).base_weight = weight; - p_tree->stat(nid).sum_hess = static_cast(node_sum.sum_hess); - node_sum.SetLeafVec(param, p_tree->leafvec(nid)); // now we know the solution in snode[nid], set split if (best.loss_chg > rt_eps) { p_tree->AddChilds(nid); (*p_tree)[nid].set_split(best.split_index(), best.split_value, best.default_left()); // mark right child as 0, to indicate fresh leaf - (*p_tree)[(*p_tree)[nid].cleft()].set_leaf(0.0f, 0); + (*p_tree)[(*p_tree)[nid].cleft()].set_leaf(0.0f, 0); (*p_tree)[(*p_tree)[nid].cright()].set_leaf(0.0f, 0); + // right side sum + TStats right_sum; + right_sum.SetSubstract(node_sum, left_sum[wid]); + this->SetStats(p_tree, (*p_tree)[nid].cleft(), left_sum[wid]); + this->SetStats(p_tree, (*p_tree)[nid].cright(), right_sum); } else { - (*p_tree)[nid].set_leaf(weight * param.learning_rate); + (*p_tree)[nid].set_leaf(p_tree->stat(nid).base_weight * param.learning_rate); } } } + + inline void SetStats(RegTree *p_tree, int nid, const TStats &node_sum) { + p_tree->stat(nid).base_weight = node_sum.CalcWeight(param); + p_tree->stat(nid).sum_hess = static_cast(node_sum.sum_hess); + node_sum.SetLeafVec(param, p_tree->leafvec(nid)); + } }; // hist maker that propose using quantile sketch