diff --git a/src/tree/model.h b/src/tree/model.h index 081f71ffb..af99a5145 100644 --- a/src/tree/model.h +++ b/src/tree/model.h @@ -42,11 +42,17 @@ class TreeModel { int max_depth; /*! \brief number of features used for tree construction */ int num_feature; + /*! + * \brief leaf vector size, used for vector tree + * used to store more than one dimensional information in tree + */ + int size_leaf_vector; /*! \brief reserved part */ - int reserved[32]; + int reserved[31]; /*! \brief constructor */ Param(void) { max_depth = 0; + size_leaf_vector = 0; memset(reserved, 0, sizeof(reserved)); } /*! @@ -57,6 +63,7 @@ class TreeModel { inline void SetParam(const char *name, const char *val) { if (!strcmp("num_roots", name)) num_roots = atoi(val); if (!strcmp("num_feature", name)) num_feature = atoi(val); + if (!strcmp("size_leaf_vector", name)) size_leaf_vector = atoi(val); } }; /*! \brief tree node */ @@ -166,10 +173,12 @@ class TreeModel { protected: // vector of nodes std::vector nodes; - // stats of nodes - std::vector stats; // free node space, used during training process std::vector deleted_nodes; + // stats of nodes + std::vector stats; + // leaf vector, that is used to store additional information + std::vector leaf_vector; // allocate a new node, // !!!!!! NOTE: may cause BUG here, nodes.resize inline int AllocNode(void) { @@ -184,6 +193,7 @@ class TreeModel { "number of nodes in the tree exceed 2^31"); nodes.resize(param.num_nodes); stats.resize(param.num_nodes); + leaf_vector.resize(param.num_nodes * param.size_leaf_vector); return nd; } // delete a tree node @@ -247,6 +257,14 @@ class TreeModel { inline NodeStat &stat(int nid) { return stats[nid]; } + /*! \brief get leaf vector given nid */ + inline bst_float* leafvec(int nid) { + return &leaf_vector[nid * param.size_leaf_vector]; + } + /*! \brief get leaf vector given nid */ + inline const bst_float* leafvec(int nid) const{ + return &leaf_vector[nid * param.size_leaf_vector]; + } /*! \brief initialize the model */ inline void InitModel(void) { param.num_nodes = param.num_roots; diff --git a/src/tree/param.h b/src/tree/param.h index 5a73c4287..6e02db00a 100644 --- a/src/tree/param.h +++ b/src/tree/param.h @@ -145,8 +145,8 @@ struct GradStats { double sum_grad; /*! \brief sum hessian statistics */ double sum_hess; - /*! \brief constructor */ - GradStats(void) { + /*! \brief constructor, the object must be cleared during construction */ + explicit GradStats(const TrainParam ¶m) { this->Clear(); } /*! \brief clear the statistics */ @@ -169,29 +169,31 @@ struct GradStats { inline double CalcWeight(const TrainParam ¶m) const { return param.CalcWeight(sum_grad, sum_hess); } - /*!\brief calculate gain of the solution */ + /*! \brief calculate gain of the solution */ inline double CalcGain(const TrainParam ¶m) const { return param.CalcGain(sum_grad, sum_hess); } /*! \brief add statistics to the data */ - inline void Add(double grad, double hess) { - sum_grad += grad; sum_hess += hess; - } - /*! \brief add statistics to the data */ inline void Add(const GradStats &b) { this->Add(b.sum_grad, b.sum_hess); } - /*! \brief substract the statistics by b */ - inline GradStats Substract(const GradStats &b) const { - GradStats res; - res.sum_grad = this->sum_grad - b.sum_grad; - res.sum_hess = this->sum_hess - b.sum_hess; - return res; + /*! \brief set current value to a - b */ + inline void SetSubstract(const GradStats &a, const GradStats &b) { + sum_grad = a.sum_grad - b.sum_grad; + sum_hess = a.sum_hess - b.sum_hess; } /*! \return whether the statistics is not used yet */ inline bool Empty(void) const { return sum_hess == 0.0; } + /*! \brief set leaf vector value based on statistics */ + inline void SetLeafVec(const TrainParam ¶m, bst_float *vec) const{ + } + protected: + /*! \brief add statistics to the data */ + inline void Add(double grad, double hess) { + sum_grad += grad; sum_hess += hess; + } }; /*! diff --git a/src/tree/updater_colmaker-inl.hpp b/src/tree/updater_colmaker-inl.hpp index fa13e607c..0a494fb5c 100644 --- a/src/tree/updater_colmaker-inl.hpp +++ b/src/tree/updater_colmaker-inl.hpp @@ -51,8 +51,8 @@ class ColMaker: public IUpdater { /*! \brief current best solution */ SplitEntry best; // constructor - ThreadEntry(void) { - stats.Clear(); + explicit ThreadEntry(const TrainParam ¶m) + : stats(param) { } }; struct NodeEntry { @@ -65,8 +65,8 @@ class ColMaker: public IUpdater { /*! \brief current best solution */ SplitEntry best; // constructor - NodeEntry(void) : root_gain(0.0f), weight(0.0f){ - stats.Clear(); + explicit NodeEntry(const TrainParam ¶m) + : stats(param), root_gain(0.0f), weight(0.0f){ } }; // actual builder that runs the algorithm @@ -100,6 +100,7 @@ class ColMaker: public IUpdater { p_tree->stat(nid).loss_chg = snode[nid].best.loss_chg; p_tree->stat(nid).base_weight = snode[nid].weight; p_tree->stat(nid).sum_hess = static_cast(snode[nid].stats.sum_hess); + snode[nid].stats.SetLeafVec(param, p_tree->leafvec(nid)); } } @@ -179,9 +180,9 @@ class ColMaker: public IUpdater { const RegTree &tree) { {// setup statistics space for each tree node for (size_t i = 0; i < stemp.size(); ++i) { - stemp[i].resize(tree.param.num_nodes, ThreadEntry()); + stemp[i].resize(tree.param.num_nodes, ThreadEntry(param)); } - snode.resize(tree.param.num_nodes, NodeEntry()); + snode.resize(tree.param.num_nodes, NodeEntry(param)); } const std::vector &rowset = fmat.buffered_rowset(); // setup position @@ -196,7 +197,7 @@ class ColMaker: public IUpdater { // sum the per thread statistics together for (size_t j = 0; j < qexpand.size(); ++j) { const int nid = qexpand[j]; - TStats stats; stats.Clear(); + TStats stats(param); for (size_t tid = 0; tid < stemp.size(); ++tid) { stats.Add(stemp[tid][nid].stats); } @@ -231,6 +232,8 @@ class ColMaker: public IUpdater { for (size_t j = 0; j < qexpand.size(); ++j) { temp[qexpand[j]].stats.Clear(); } + // left statistics + TStats c(param); while (it.Next()) { const bst_uint ridx = it.rindex(); const int nid = position[ridx]; @@ -246,7 +249,7 @@ class ColMaker: public IUpdater { } else { // try to find a split if (fabsf(fvalue - e.last_fvalue) > rt_2eps && e.stats.sum_hess >= param.min_child_weight) { - TStats c = snode[nid].stats.Substract(e.stats); + c.SetSubstract(snode[nid].stats, e.stats); if (c.sum_hess >= param.min_child_weight) { double loss_chg = e.stats.CalcGain(param) + c.CalcGain(param) - snode[nid].root_gain; e.best.Update(loss_chg, fid, (fvalue + e.last_fvalue) * 0.5f, !is_forward_search); @@ -261,7 +264,7 @@ class ColMaker: public IUpdater { for (size_t i = 0; i < qexpand.size(); ++i) { const int nid = qexpand[i]; ThreadEntry &e = temp[nid]; - TStats c = snode[nid].stats.Substract(e.stats); + c.SetSubstract(snode[nid].stats, e.stats); if (e.stats.sum_hess >= param.min_child_weight && c.sum_hess >= param.min_child_weight) { const double loss_chg = e.stats.CalcGain(param) + c.CalcGain(param) - snode[nid].root_gain; const float delta = is_forward_search ? rt_eps : -rt_eps; diff --git a/src/tree/updater_refresh-inl.hpp b/src/tree/updater_refresh-inl.hpp index e0e7ab520..d76936791 100644 --- a/src/tree/updater_refresh-inl.hpp +++ b/src/tree/updater_refresh-inl.hpp @@ -44,8 +44,8 @@ class TreeRefresher: public IUpdater { int tid = omp_get_thread_num(); for (size_t i = 0; i < trees.size(); ++i) { std::vector &vec = stemp[tid * trees.size() + i]; - vec.resize(trees[i]->param.num_nodes); - std::fill(vec.begin(), vec.end(), TStats()); + vec.resize(trees[i]->param.num_nodes, TStats(param)); + std::fill(vec.begin(), vec.end(), TStats(param)); } fvec_temp[tid].Init(trees[0]->param.num_feature); } @@ -114,6 +114,7 @@ class TreeRefresher: public IUpdater { RegTree &tree = *p_tree; tree.stat(nid).base_weight = gstats[nid].CalcWeight(param); tree.stat(nid).sum_hess = static_cast(gstats[nid].sum_hess); + gstats[nid].SetLeafVec(param, tree.leafvec(nid)); if (tree[nid].is_leaf()) { tree[nid].set_leaf(tree.stat(nid).base_weight * param.learning_rate); } else {