diff --git a/booster/tree/xgboost_row_treemaker.hpp b/booster/tree/xgboost_row_treemaker.hpp index fdc46eb03..4f124b7c4 100644 --- a/booster/tree/xgboost_row_treemaker.hpp +++ b/booster/tree/xgboost_row_treemaker.hpp @@ -83,6 +83,7 @@ namespace xgboost{ // set all the rest expanding nodes to leaf for( size_t i = 0; i < qexpand.size(); ++ i ){ const int nid = qexpand[i]; + tree[ nid ].set_leaf( snode[nid].weight * param.learning_rate ); tree.stat( nid ).loss_chg = 0.0f; tree.stat( nid ).sum_hess = static_cast( snode[ nid ].sum_hess ); @@ -94,6 +95,15 @@ namespace xgboost{ return false; } } + // collapse specific node + inline void Collapse( const std::vector &valid_index, int nid ){ + if( valid_index.size() == 0 ) return; + this->InitDataExpand( valid_index, nid ); + this->InitNewNode( this->qexpand ); + tree.stat( nid ).loss_chg = 0.0f; + tree.stat( nid ).sum_hess = static_cast( snode[ nid ].sum_hess ); + tree.CollapseToLeaf( nid, snode[nid].weight * param.learning_rate ); + } private: // make leaf nodes for all qexpand, update node statistics, mark leaf value inline void InitNewNode( const std::vector &qexpand ){ diff --git a/booster/tree/xgboost_tree.hpp b/booster/tree/xgboost_tree.hpp index fab2071bc..bb2d89acb 100644 --- a/booster/tree/xgboost_tree.hpp +++ b/booster/tree/xgboost_tree.hpp @@ -79,7 +79,7 @@ namespace xgboost{ if( interact_type != 0 ){ switch( interact_type ){ case 1: this->ExpandNode( grad, hess, smat, root_index, interact_node ); return; - case 2: + case 2: this->CollapseNode( grad, hess, smat, root_index, interact_node ); return; default: utils::Error("unknown interact type"); } } @@ -108,7 +108,7 @@ namespace xgboost{ } if( !silent ){ printf( "tree train end, %d roots, %d extra nodes, %d pruned nodes ,max_depth=%d\n", - tree.param.num_roots, tree.num_extra_nodes(), num_pruned, tree.param.max_depth ); + tree.param.num_roots, tree.num_extra_nodes(), num_pruned, tree.MaxDepth() ); } } virtual float Predict( const FMatrix &fmat, bst_uint ridx, unsigned gid = 0 ){ @@ -158,6 +158,32 @@ namespace xgboost{ tree.DumpModel( fo, fmap, with_stats ); } private: + inline void CollapseNode( std::vector &grad, + std::vector &hess, + const FMatrix &fmat, + const std::vector &root_index, + int nid ){ + std::vector valid_index; + for( size_t i = 0; i < grad.size(); i ++ ){ + ThreadEntry &e = this->InitTmp(); + this->PrepareTmp( fmat.GetRow(i), e ); + int pid = root_index.size() == 0 ? 0 : (int)root_index[i]; + // tranverse tree + while( !tree[ pid ].is_leaf() ){ + unsigned split_index = tree[ pid ].split_index(); + pid = this->GetNext( pid, e.feat[ split_index ], e.funknown[ split_index ] ); + if( pid == nid ){ + valid_index.push_back( static_cast(i) ); break; + } + } + this->DropTmp( fmat.GetRow(i), e ); + } + RowTreeMaker maker( tree, param, grad, hess, fmat, root_index ); + maker.Collapse( valid_index, nid ); + if( !silent ){ + printf( "tree collapse end, max_depth=%d\n", tree.param.max_depth ); + } + } inline void ExpandNode( std::vector &grad, std::vector &hess, const FMatrix &fmat, @@ -175,7 +201,7 @@ namespace xgboost{ RowTreeMaker maker( tree, param, grad, hess, fmat, root_index ); bool success = maker.Expand( valid_index, nid ); if( !silent ){ - printf( "tree expand end, success=%d, max_depth=%d\n", (int)success, tree.param.max_depth ); + printf( "tree expand end, success=%d, max_depth=%d\n", (int)success, tree.MaxDepth() ); } } private: diff --git a/booster/tree/xgboost_tree_model.h b/booster/tree/xgboost_tree_model.h index 262616461..fe3a62f15 100644 --- a/booster/tree/xgboost_tree_model.h +++ b/booster/tree/xgboost_tree_model.h @@ -197,6 +197,21 @@ namespace xgboost{ this->DeleteNode( nodes[ rid ].cright() ); nodes[ rid ].set_leaf( value ); } + /*! + * \brief collapse a non leaf node to a leaf node, delete its children + * \param rid node id of the node + * \param new leaf value + */ + inline void CollapseToLeaf( int rid, float value ){ + if( nodes[rid].is_leaf() ) return; + if( !nodes[ nodes[rid].cleft() ].is_leaf() ){ + CollapseToLeaf( nodes[rid].cleft(), 0.0f ); + } + if( !nodes[ nodes[rid].cright() ].is_leaf() ){ + CollapseToLeaf( nodes[rid].cright(), 0.0f ); + } + this->ChangeToLeaf( rid, value ); + } public: /*! \brief model parameter */ Param param; @@ -287,6 +302,25 @@ namespace xgboost{ } return depth; } + /*! + * \brief get maximum depth + * \param nid node id + */ + inline int MaxDepth( int nid ) const{ + if( nodes[nid].is_leaf() ) return 0; + return std::max( MaxDepth( nodes[nid].cleft() )+1, + MaxDepth( nodes[nid].cright() )+1 ); + } + /*! + * \brief get maximum depth + */ + inline int MaxDepth( void ){ + int maxd = 0; + for( int i = 0; i < param.num_roots; ++ i ){ + maxd = std::max( maxd, MaxDepth( i ) ); + } + return maxd; + } /*! \brief number of extra nodes besides the root */ inline int num_extra_nodes( void ) const { return param.num_nodes - param.num_roots - param.num_deleted; diff --git a/demo/test/runexp.sh b/demo/test/runexp.sh index 0c273ff7f..626df2fc8 100755 --- a/demo/test/runexp.sh +++ b/demo/test/runexp.sh @@ -16,10 +16,14 @@ python mknfold.py agaricus.txt 1 # interaction ../../xgboost mushroom.conf task=interact model_in=m1.model model_out=m2.model interact:booster_index=0 bst:interact:expand=1 ../../xgboost mushroom.conf task=interact model_in=m2.model model_out=m3.model interact:booster_index=0 bst:interact:expand=2 +../../xgboost mushroom.conf task=interact model_in=m3.model model_out=m3v.model interact:booster_index=0 bst:interact:remove=2 +../../xgboost mushroom.conf task=interact model_in=m3v.model model_out=m3p.model interact:booster_index=0 bst:interact:expand=2 # this is what dump will looklike with feature map ../../xgboost mushroom.conf task=dump model_in=m2.model fmap=featmap.txt name_dump=dump.m2.txt ../../xgboost mushroom.conf task=dump model_in=m3.model fmap=featmap.txt name_dump=dump.m3.txt +../../xgboost mushroom.conf task=dump model_in=m3v.model fmap=featmap.txt name_dump=dump.m3v.txt +../../xgboost mushroom.conf task=dump model_in=m3p.model fmap=featmap.txt name_dump=dump.m3p.txt echo "========m1=======" cat dump.m1.txt @@ -30,5 +34,11 @@ cat dump.m2.txt echo "========m3========" cat dump.m3.txt +echo "========m3v========" +cat dump.m3v.txt + +echo "========m3p========" +cat dump.m3p.txt + echo "========full=======" cat dump.full.txt