From 2910bdedf47125b2d23c8f7a10e032dc2db68739 Mon Sep 17 00:00:00 2001 From: tqchen Date: Wed, 5 Mar 2014 10:20:36 -0800 Subject: [PATCH] split new base treemaker, not very good abstraction, but ok --- booster/tree/xgboost_base_treemaker.hpp | 161 ++++++++++++++++++++++++ booster/tree/xgboost_col_treemaker.hpp | 148 ++-------------------- 2 files changed, 168 insertions(+), 141 deletions(-) create mode 100644 booster/tree/xgboost_base_treemaker.hpp diff --git a/booster/tree/xgboost_base_treemaker.hpp b/booster/tree/xgboost_base_treemaker.hpp new file mode 100644 index 000000000..dcc8c06fd --- /dev/null +++ b/booster/tree/xgboost_base_treemaker.hpp @@ -0,0 +1,161 @@ +#ifndef XGBOOST_BASE_TREEMAKER_HPP +#define XGBOOST_BASE_TREEMAKER_HPP +/*! + * \file xgboost_base_treemaker.hpp + * \brief implementation of base data structure for regression tree maker, + * gives common operations of tree construction steps template + * + * \author Tianqi Chen: tianqi.tchen@gmail.com + */ +#include +#include "xgboost_tree_model.h" + +namespace xgboost{ + namespace booster{ + class BaseTreeMaker{ + protected: + BaseTreeMaker( RegTree &tree, + const TreeParamTrain ¶m ) + : tree( tree ), param( param ){} + protected: + // statistics that is helpful to decide a split + struct SplitEntry{ + /*! \brief loss change after split this node */ + float loss_chg; + /*! \brief split index */ + unsigned sindex; + /*! \brief split value */ + float split_value; + /*! \brief constructor */ + SplitEntry( void ){ + loss_chg = 0.0f; + split_value = 0.0f; sindex = 0; + } + // This function gives better priority to lower index when loss_chg equals + // not the best way, but helps to give consistent result during multi-thread execution + inline bool NeedReplace( float loss_chg, unsigned split_index ) const{ + if( this->split_index() <= split_index ){ + return loss_chg > this->loss_chg; + }else{ + return !(this->loss_chg > loss_chg); + } + } + inline void Update( const SplitEntry &e ){ + if( this->NeedReplace( e.loss_chg, e.split_index() ) ){ + this->loss_chg = e.loss_chg; + this->sindex = e.sindex; + this->split_value = e.split_value; + } + } + inline void Update( float loss_chg, unsigned split_index, float split_value, bool default_left ){ + if( this->NeedReplace( loss_chg, split_index ) ){ + this->loss_chg = loss_chg; + if( default_left ) split_index |= (1U << 31); + this->sindex = split_index; + this->split_value = split_value; + } + } + inline unsigned split_index( void ) const{ + return sindex & ( (1U<<31) - 1U ); + } + inline bool default_left( void ) const{ + return (sindex >> 31) != 0; + } + }; + struct NodeEntry{ + /*! \brief sum gradient statistics */ + double sum_grad; + /*! \brief sum hessian statistics */ + double sum_hess; + /*! \brief loss of this node, without split */ + float root_gain; + /*! \brief weight calculated related to current data */ + float weight; + /*! \brief current best solution */ + SplitEntry best; + NodeEntry( void ){ + sum_grad = sum_hess = 0.0; + weight = root_gain = 0.0f; + } + }; + + /*! \brief per thread x per node entry to store tmp data */ + struct ThreadEntry{ + /*! \brief sum gradient statistics */ + double sum_grad; + /*! \brief sum hessian statistics */ + double sum_hess; + /*! \brief last feature value scanned */ + float last_fvalue; + /*! \brief current best solution */ + SplitEntry best; + /*! \brief constructor */ + ThreadEntry( void ){ + this->ClearStats(); + } + /*! \brief clear statistics */ + inline void ClearStats( void ){ + sum_grad = sum_hess = 0.0; + } + }; + private: + // try to prune off current leaf, return true if successful + inline void TryPruneLeaf( int nid, int depth ){ + if( tree[ nid ].is_root() ) return; + int pid = tree[ nid ].parent(); + RegTree::NodeStat &s = tree.stat( pid ); + ++ s.leaf_child_cnt; + + if( s.leaf_child_cnt >= 2 && param.need_prune( s.loss_chg, depth - 1 ) ){ + this->stat_num_pruned += 2; + // need to be pruned + tree.ChangeToLeaf( pid, param.learning_rate * snode[pid].weight ); + // tail recursion + this->TryPruneLeaf( pid, depth - 1 ); + } + } + protected: + /*! \brief do prunning of a tree */ + inline int DoPrune( void ){ + this->stat_num_pruned = 0; + // initialize auxiliary statistics + for( int nid = 0; nid < tree.param.num_nodes; ++ nid ){ + tree.stat( nid ).leaf_child_cnt = 0; + tree.stat( nid ).loss_chg = snode[ nid ].best.loss_chg; + tree.stat( nid ).sum_hess = static_cast( snode[ nid ].sum_hess ); + } + for( int nid = 0; nid < tree.param.num_nodes; ++ nid ){ + if( tree[ nid ].is_leaf() ) this->TryPruneLeaf( nid, tree.GetDepth(nid) ); + } + return this->stat_num_pruned; + } + protected: + /*! \brief update queue expand add in new leaves */ + inline void UpdateQueueExpand( std::vector &qexpand ){ + std::vector newnodes; + for( size_t i = 0; i < qexpand.size(); ++ i ){ + const int nid = qexpand[i]; + if( !tree[ nid ].is_leaf() ){ + newnodes.push_back( tree[nid].cleft() ); + newnodes.push_back( tree[nid].cright() ); + } + } + // use new nodes for qexpand + qexpand = newnodes; + } + protected: + // local helper tmp data structure + // statistics + int stat_num_pruned; + /*! \brief queue of nodes to be expanded */ + std::vector qexpand; + /*! \brief TreeNode Data: statistics for each constructed node, the derived class must maintain this */ + std::vector snode; + protected: + // original data that supports tree construction + RegTree &tree; + const TreeParamTrain ¶m; + }; + }; // namespace booster +}; // namespace xgboost +#endif // XGBOOST_BASE_TREEMAKER_HPP diff --git a/booster/tree/xgboost_col_treemaker.hpp b/booster/tree/xgboost_col_treemaker.hpp index 7bd2a3c77..2782c67a9 100644 --- a/booster/tree/xgboost_col_treemaker.hpp +++ b/booster/tree/xgboost_col_treemaker.hpp @@ -11,23 +11,25 @@ #include "xgboost_tree_model.h" #include "../../utils/xgboost_omp.h" #include "../../utils/xgboost_random.h" +#include "xgboost_base_treemaker.hpp" namespace xgboost{ namespace booster{ template - class ColTreeMaker{ + class ColTreeMaker : public BaseTreeMaker{ public: ColTreeMaker( RegTree &tree, const TreeParamTrain ¶m, const std::vector &grad, const std::vector &hess, const FMatrix &smat, - const std::vector &root_index ): - tree( tree ), param( param ), grad( grad ), hess( hess ), - smat( smat ), root_index( root_index ){ + const std::vector &root_index ) + : BaseTreeMaker( tree, param ), + grad(grad), hess(hess), + smat(smat), root_index(root_index) { utils::Assert( grad.size() == hess.size(), "booster:invalid input" ); utils::Assert( smat.NumRow() == hess.size(), "booster:invalid input" ); - utils::Assert( root_index.size() == 0 || root_index.size() == hess.size(), "booster:invalid input" ); + utils::Assert( root_index.size() == 0 || root_index.size() == hess.size(), "booster:invalid input" ); utils::Assert( smat.HaveColAccess(), "ColTreeMaker: need column access matrix" ); } inline void Make( int& stat_max_depth, int& stat_num_pruned ){ @@ -52,131 +54,6 @@ namespace xgboost{ stat_num_pruned = this->DoPrune(); } private: - // statistics that is helpful to decide a split - struct SplitEntry{ - /*! \brief loss change after split this node */ - float loss_chg; - /*! \brief split index */ - unsigned sindex; - /*! \brief split value */ - float split_value; - /*! \brief constructor */ - SplitEntry( void ){ - loss_chg = 0.0f; - split_value = 0.0f; sindex = 0; - } - // This function gives better priority to lower index when loss_chg equals - // not the best way, but helps to give consistent result during multi-thread execution - inline bool NeedReplace( float loss_chg, unsigned split_index ) const{ - if( this->split_index() <= split_index ){ - return loss_chg > this->loss_chg; - }else{ - return !(this->loss_chg > loss_chg); - } - } - inline void Update( const SplitEntry &e ){ - if( this->NeedReplace( e.loss_chg, e.split_index() ) ){ - this->loss_chg = e.loss_chg; - this->sindex = e.sindex; - this->split_value = e.split_value; - } - } - inline void Update( float loss_chg, unsigned split_index, float split_value, bool default_left ){ - if( this->NeedReplace( loss_chg, split_index ) ){ - this->loss_chg = loss_chg; - if( default_left ) split_index |= (1U << 31); - this->sindex = split_index; - this->split_value = split_value; - } - } - inline unsigned split_index( void ) const{ - return sindex & ( (1U<<31) - 1U ); - } - inline bool default_left( void ) const{ - return (sindex >> 31) != 0; - } - }; - struct NodeEntry{ - /*! \brief sum gradient statistics */ - double sum_grad; - /*! \brief sum hessian statistics */ - double sum_hess; - /*! \brief loss of this node, without split */ - float root_gain; - /*! \brief weight calculated related to current data */ - float weight; - /*! \brief current best solution */ - SplitEntry best; - NodeEntry( void ){ - sum_grad = sum_hess = 0.0; - weight = root_gain = 0.0f; - } - }; - - /*! \brief per thread x per node entry to store tmp data */ - struct ThreadEntry{ - /*! \brief sum gradient statistics */ - double sum_grad; - /*! \brief sum hessian statistics */ - double sum_hess; - /*! \brief last feature value scanned */ - float last_fvalue; - /*! \brief current best solution */ - SplitEntry best; - /*! \brief constructor */ - ThreadEntry( void ){ - this->ClearStats(); - } - /*! \brief clear statistics */ - inline void ClearStats( void ){ - sum_grad = sum_hess = 0.0; - } - }; - private: - // try to prune off current leaf, return true if successful - inline void TryPruneLeaf( int nid, int depth ){ - if( tree[ nid ].is_root() ) return; - int pid = tree[ nid ].parent(); - RegTree::NodeStat &s = tree.stat( pid ); - ++ s.leaf_child_cnt; - - if( s.leaf_child_cnt >= 2 && param.need_prune( s.loss_chg, depth - 1 ) ){ - this->stat_num_pruned += 2; - // need to be pruned - tree.ChangeToLeaf( pid, param.learning_rate * snode[pid].weight ); - // tail recursion - this->TryPruneLeaf( pid, depth - 1 ); - } - } - // prune tree - inline int DoPrune( void ){ - this->stat_num_pruned = 0; - // initialize auxiliary statistics - for( int nid = 0; nid < tree.param.num_nodes; ++ nid ){ - tree.stat( nid ).leaf_child_cnt = 0; - tree.stat( nid ).loss_chg = snode[ nid ].best.loss_chg; - tree.stat( nid ).sum_hess = static_cast( snode[ nid ].sum_hess ); - } - for( int nid = 0; nid < tree.param.num_nodes; ++ nid ){ - if( tree[ nid ].is_leaf() ) this->TryPruneLeaf( nid, tree.GetDepth(nid) ); - } - return this->stat_num_pruned; - } - private: - // update queue expand - inline void UpdateQueueExpand( std::vector &qexpand ){ - std::vector newnodes; - for( size_t i = 0; i < qexpand.size(); ++ i ){ - const int nid = qexpand[i]; - if( !tree[ nid ].is_leaf() ){ - newnodes.push_back( tree[nid].cleft() ); - newnodes.push_back( tree[nid].cright() ); - } - } - // use new nodes for qexpand - qexpand = newnodes; - } - // make leaf nodes for all qexpand, update node statistics, mark leaf value inline void InitNewNode( const std::vector &qexpand ){ {// setup statistics space for each tree node @@ -414,30 +291,19 @@ namespace xgboost{ } } private: - // local helper tmp data structure - // statistics - int stat_num_pruned; // number of omp thread used during training int nthread; - // queue of nodes to be expanded - std::vector qexpand; // Per feature: shuffle index of each feature index std::vector feat_index; // Instance Data: current node position in the tree of each instance std::vector position; - // TreeNode Data: statistics for each constructed node - std::vector snode; // PerThread x PerTreeNode: statistics for per thread construction std::vector< std::vector > stemp; private: - // original data that supports tree construction - RegTree &tree; - const TreeParamTrain ¶m; const std::vector &grad; const std::vector &hess; const FMatrix &smat; const std::vector &root_index; - }; }; };