diff --git a/src/tree/param.h b/src/tree/param.h index a19eb2d82..5a73c4287 100644 --- a/src/tree/param.h +++ b/src/tree/param.h @@ -11,45 +11,6 @@ namespace xgboost { namespace tree { -/*! \brief core statistics used for tree construction */ -struct GradStats { - /*! \brief sum gradient statistics */ - double sum_grad; - /*! \brief sum hessian statistics */ - double sum_hess; - /*! \brief constructor */ - GradStats(void) { - this->Clear(); - } - /*! \brief clear the statistics */ - inline void Clear(void) { - sum_grad = sum_hess = 0.0f; - } - /*! \brief add statistics to the data */ - inline void Add(double grad, double hess) { - sum_grad += grad; sum_hess += hess; - } - /*! \brief add statistics to the data */ - inline void Add(const bst_gpair& b) { - this->Add(b.grad, b.hess); - } - /*! \brief add statistics to the data */ - inline void Add(const GradStats &b) { - this->Add(b.sum_grad, b.sum_hess); - } - /*! \brief substract the statistics by b */ - inline GradStats Substract(const GradStats &b) const { - GradStats res; - res.sum_grad = this->sum_grad - b.sum_grad; - res.sum_hess = this->sum_hess - b.sum_hess; - return res; - } - /*! \return whether the statistics is not used yet */ - inline bool Empty(void) const { - return sum_hess == 0.0; - } -}; - /*! \brief training parameters for regression tree */ struct TrainParam{ // learning step size for a time @@ -165,13 +126,6 @@ struct TrainParam{ inline bool cannot_split(double sum_hess, int depth) const { return sum_hess < this->min_child_weight * 2.0; } - // code support for template data - inline double CalcWeight(const GradStats &d) const { - return this->CalcWeight(d.sum_grad, d.sum_hess); - } - inline double CalcGain(const GradStats &d) const { - return this->CalcGain(d.sum_grad, d.sum_hess); - } protected: // functions for L1 cost @@ -185,6 +139,61 @@ struct TrainParam{ } }; +/*! \brief core statistics used for tree construction */ +struct GradStats { + /*! \brief sum gradient statistics */ + double sum_grad; + /*! \brief sum hessian statistics */ + double sum_hess; + /*! \brief constructor */ + GradStats(void) { + this->Clear(); + } + /*! \brief clear the statistics */ + inline void Clear(void) { + sum_grad = sum_hess = 0.0f; + } + /*! + * \brief accumulate statistics, + * \param gpair the vector storing the gradient statistics + * \param info the additional information + * \param ridx instance index of this instance + */ + inline void Add(const std::vector &gpair, + const BoosterInfo &info, + bst_uint ridx) { + const bst_gpair &b = gpair[ridx]; + this->Add(b.grad, b.hess); + } + /*! \brief caculate leaf weight */ + inline double CalcWeight(const TrainParam ¶m) const { + return param.CalcWeight(sum_grad, sum_hess); + } + /*!\brief calculate gain of the solution */ + inline double CalcGain(const TrainParam ¶m) const { + return param.CalcGain(sum_grad, sum_hess); + } + /*! \brief add statistics to the data */ + inline void Add(double grad, double hess) { + sum_grad += grad; sum_hess += hess; + } + /*! \brief add statistics to the data */ + inline void Add(const GradStats &b) { + this->Add(b.sum_grad, b.sum_hess); + } + /*! \brief substract the statistics by b */ + inline GradStats Substract(const GradStats &b) const { + GradStats res; + res.sum_grad = this->sum_grad - b.sum_grad; + res.sum_hess = this->sum_hess - b.sum_hess; + return res; + } + /*! \return whether the statistics is not used yet */ + inline bool Empty(void) const { + return sum_hess == 0.0; + } +}; + /*! * \brief statistics that is helpful to store * and represent a split solution for the tree diff --git a/src/tree/updater_colmaker-inl.hpp b/src/tree/updater_colmaker-inl.hpp index afeccb206..fa13e607c 100644 --- a/src/tree/updater_colmaker-inl.hpp +++ b/src/tree/updater_colmaker-inl.hpp @@ -80,13 +80,13 @@ class ColMaker: public IUpdater { const BoosterInfo &info, RegTree *p_tree) { this->InitData(gpair, fmat, info.root_index, *p_tree); - this->InitNewNode(qexpand, gpair, fmat, *p_tree); + this->InitNewNode(qexpand, gpair, fmat, info, *p_tree); for (int depth = 0; depth < param.max_depth; ++depth) { - this->FindSplit(depth, this->qexpand, gpair, fmat, p_tree); + this->FindSplit(depth, this->qexpand, gpair, fmat, info, p_tree); this->ResetPosition(this->qexpand, fmat, *p_tree); this->UpdateQueueExpand(*p_tree, &this->qexpand); - this->InitNewNode(qexpand, gpair, fmat, *p_tree); + this->InitNewNode(qexpand, gpair, fmat, info, *p_tree); // if nothing left to be expand, break if (qexpand.size() == 0) break; } @@ -175,6 +175,7 @@ class ColMaker: public IUpdater { inline void InitNewNode(const std::vector &qexpand, const std::vector &gpair, const FMatrix &fmat, + const BoosterInfo &info, const RegTree &tree) { {// setup statistics space for each tree node for (size_t i = 0; i < stemp.size(); ++i) { @@ -190,7 +191,7 @@ class ColMaker: public IUpdater { const bst_uint ridx = rowset[i]; const int tid = omp_get_thread_num(); if (position[ridx] < 0) continue; - stemp[tid][position[ridx]].stats.Add(gpair[ridx]); + stemp[tid][position[ridx]].stats.Add(gpair, info, ridx); } // sum the per thread statistics together for (size_t j = 0; j < qexpand.size(); ++j) { @@ -201,8 +202,8 @@ class ColMaker: public IUpdater { } // update node statistics snode[nid].stats = stats; - snode[nid].root_gain = param.CalcGain(stats); - snode[nid].weight = param.CalcWeight(stats); + snode[nid].root_gain = stats.CalcGain(param); + snode[nid].weight = stats.CalcWeight(param); } } /*! \brief update queue expand add in new leaves */ @@ -223,6 +224,7 @@ class ColMaker: public IUpdater { template inline void EnumerateSplit(Iter it, unsigned fid, const std::vector &gpair, + const BoosterInfo &info, std::vector &temp, bool is_forward_search) { // clear all the temp statistics @@ -239,19 +241,19 @@ class ColMaker: public IUpdater { ThreadEntry &e = temp[nid]; // test if first hit, this is fine, because we set 0 during init if (e.stats.Empty()) { - e.stats.Add(gpair[ridx]); + e.stats.Add(gpair, info, ridx); e.last_fvalue = fvalue; } else { // try to find a split if (fabsf(fvalue - e.last_fvalue) > rt_2eps && e.stats.sum_hess >= param.min_child_weight) { TStats c = snode[nid].stats.Substract(e.stats); if (c.sum_hess >= param.min_child_weight) { - double loss_chg = param.CalcGain(e.stats) + param.CalcGain(c) - snode[nid].root_gain; + double loss_chg = e.stats.CalcGain(param) + c.CalcGain(param) - snode[nid].root_gain; e.best.Update(loss_chg, fid, (fvalue + e.last_fvalue) * 0.5f, !is_forward_search); } } // update the statistics - e.stats.Add(gpair[ridx]); + e.stats.Add(gpair, info, ridx); e.last_fvalue = fvalue; } } @@ -261,7 +263,7 @@ class ColMaker: public IUpdater { ThreadEntry &e = temp[nid]; TStats c = snode[nid].stats.Substract(e.stats); if (e.stats.sum_hess >= param.min_child_weight && c.sum_hess >= param.min_child_weight) { - const double loss_chg = param.CalcGain(e.stats) + param.CalcGain(c) - snode[nid].root_gain; + const double loss_chg = e.stats.CalcGain(param) + c.CalcGain(param) - snode[nid].root_gain; const float delta = is_forward_search ? rt_eps : -rt_eps; e.best.Update(loss_chg, fid, e.last_fvalue + delta, !is_forward_search); } @@ -269,7 +271,9 @@ class ColMaker: public IUpdater { } // find splits at current level, do split per level inline void FindSplit(int depth, const std::vector &qexpand, - const std::vector &gpair, const FMatrix &fmat, + const std::vector &gpair, + const FMatrix &fmat, + const BoosterInfo &info, RegTree *p_tree) { std::vector feat_set = feat_index; if (param.colsample_bylevel != 1.0f) { @@ -288,10 +292,10 @@ class ColMaker: public IUpdater { const unsigned fid = feat_set[i]; const int tid = omp_get_thread_num(); if (param.need_forward_search(fmat.GetColDensity(fid))) { - this->EnumerateSplit(fmat.GetSortedCol(fid), fid, gpair, stemp[tid], true); + this->EnumerateSplit(fmat.GetSortedCol(fid), fid, gpair, info, stemp[tid], true); } if (param.need_backward_search(fmat.GetColDensity(fid))) { - this->EnumerateSplit(fmat.GetReverseSortedCol(fid), fid, gpair, stemp[tid], false); + this->EnumerateSplit(fmat.GetReverseSortedCol(fid), fid, gpair, info, stemp[tid], false); } } // after this each thread's stemp will get the best candidates, aggregate results diff --git a/src/tree/updater_refresh-inl.hpp b/src/tree/updater_refresh-inl.hpp index e23174e51..3ccf217f6 100644 --- a/src/tree/updater_refresh-inl.hpp +++ b/src/tree/updater_refresh-inl.hpp @@ -65,8 +65,7 @@ class TreeRefresher: public IUpdater { RegTree::FVec &feats = fvec_temp[tid]; feats.Fill(inst); for (size_t j = 0; j < trees.size(); ++j) { - AddStats(*trees[j], feats, gpair[ridx], - info.GetRoot(j), + AddStats(*trees[j], feats, gpair, info, ridx, &stemp[tid * trees.size() + j]); } feats.Drop(inst); @@ -95,31 +94,33 @@ class TreeRefresher: public IUpdater { private: inline static void AddStats(const RegTree &tree, const RegTree::FVec &feat, - const bst_gpair &gpair, unsigned root_id, + const std::vector &gpair, + const BoosterInfo &info, + const bst_uint ridx, std::vector *p_gstats) { std::vector &gstats = *p_gstats; // start from groups that belongs to current data - int pid = static_cast(root_id); - gstats[pid].Add(gpair); + int pid = static_cast(info.GetRoot(ridx)); + gstats[pid].Add(gpair, info, ridx); // tranverse tree while (!tree[pid].is_leaf()) { unsigned split_index = tree[pid].split_index(); pid = tree.GetNext(pid, feat.fvalue(split_index), feat.is_missing(split_index)); - gstats[pid].Add(gpair); + gstats[pid].Add(gpair, info, ridx); } } inline void Refresh(const std::vector &gstats, int nid, RegTree *p_tree) { RegTree &tree = *p_tree; - tree.stat(nid).base_weight = param.CalcWeight(gstats[nid]); + tree.stat(nid).base_weight = gstats[nid].CalcWeight(param); tree.stat(nid).sum_hess = static_cast(gstats[nid].sum_hess); if (tree[nid].is_leaf()) { tree[nid].set_leaf(tree.stat(nid).base_weight * param.learning_rate); } else { tree.stat(nid).loss_chg = - param.CalcGain(gstats[tree[nid].cleft()]) + - param.CalcGain(gstats[tree[nid].cright()]) - - param.CalcGain(gstats[nid]); + gstats[tree[nid].cleft()].CalcGain(param) + + gstats[tree[nid].cright()].CalcGain(param) - + gstats[nid].CalcGain(param); this->Refresh(gstats, tree[nid].cleft(), p_tree); this->Refresh(gstats, tree[nid].cright(), p_tree); }