diff --git a/booster/tree/xgboost_col_treemaker.hpp b/booster/tree/xgboost_col_treemaker.hpp index 25ebb8513..865439b57 100644 --- a/booster/tree/xgboost_col_treemaker.hpp +++ b/booster/tree/xgboost_col_treemaker.hpp @@ -11,6 +11,7 @@ #include "xgboost_tree_model.h" #include "../../utils/xgboost_omp.h" #include "../../utils/xgboost_random.h" +#include "../../utils/xgboost_fmap.h" #include "xgboost_base_treemaker.hpp" namespace xgboost{ @@ -23,10 +24,11 @@ namespace xgboost{ const std::vector &grad, const std::vector &hess, const FMatrix &smat, - const std::vector &root_index ) + const std::vector &root_index, + const utils::FeatConstrain &constrain ) : BaseTreeMaker( tree, param ), grad(grad), hess(hess), - smat(smat), root_index(root_index) { + smat(smat), root_index(root_index), constrain(constrain) { utils::Assert( grad.size() == hess.size(), "booster:invalid input" ); utils::Assert( smat.NumRow() == hess.size(), "booster:invalid input" ); utils::Assert( root_index.size() == 0 || root_index.size() == hess.size(), "booster:invalid input" ); @@ -282,10 +284,10 @@ namespace xgboost{ {// initialize feature index int ncol = static_cast( smat.NumCol() ); for( int i = 0; i < ncol; i ++ ){ - if( smat.GetSortedCol(i).Next() ){ + if( smat.GetSortedCol(i).Next() && constrain.NotBanned(i) ){ feat_index.push_back( i ); } - } + } random::Shuffle( feat_index ); } {// setup temp space for each thread @@ -326,6 +328,7 @@ namespace xgboost{ const std::vector &hess; const FMatrix &smat; const std::vector &root_index; + const utils::FeatConstrain &constrain; }; }; }; diff --git a/booster/tree/xgboost_row_treemaker.hpp b/booster/tree/xgboost_row_treemaker.hpp index 4f124b7c4..e9b005b79 100644 --- a/booster/tree/xgboost_row_treemaker.hpp +++ b/booster/tree/xgboost_row_treemaker.hpp @@ -11,6 +11,7 @@ #include "xgboost_tree_model.h" #include "../../utils/xgboost_omp.h" #include "../../utils/xgboost_random.h" +#include "../../utils/xgboost_fmap.h" #include "xgboost_base_treemaker.hpp" namespace xgboost{ @@ -23,10 +24,11 @@ namespace xgboost{ const std::vector &grad, const std::vector &hess, const FMatrix &smat, - const std::vector &root_index ) + const std::vector &root_index, + const utils::FeatConstrain &constrain ) : BaseTreeMaker( tree, param ), grad(grad), hess(hess), - smat(smat), root_index(root_index) { + smat(smat), root_index(root_index), constrain(constrain) { utils::Assert( grad.size() == hess.size(), "booster:invalid input" ); utils::Assert( smat.NumRow() == hess.size(), "booster:invalid input" ); utils::Assert( root_index.size() == 0 || root_index.size() == hess.size(), "booster:invalid input" ); @@ -254,14 +256,18 @@ namespace xgboost{ for( bst_uint i = begin; i < end; ++i ){ const bst_uint ridx = row_index_set[i]; for( typename FMatrix::RowIter it = smat.GetRow(ridx,gid); it.Next(); ){ - builder.AddBudget( it.findex() ); + const bst_uint findex = it.findex(); + if( constrain.NotBanned( findex ) ) builder.AddBudget( findex ); } } builder.InitStorage(); for( bst_uint i = begin; i < end; ++i ){ const bst_uint ridx = row_index_set[i]; for( typename FMatrix::RowIter it = smat.GetRow(ridx,gid); it.Next(); ){ - builder.PushElem( it.findex(), FMatrixS::REntry( ridx, it.fvalue() ) ); + const bst_uint findex = it.findex(); + if( constrain.NotBanned( findex ) ) { + builder.PushElem( findex, FMatrixS::REntry( ridx, it.fvalue() ) ); + } } } // --- end of building column major matrix --- @@ -373,6 +379,7 @@ namespace xgboost{ const std::vector &hess; const FMatrix &smat; const std::vector &root_index; + const utils::FeatConstrain &constrain; }; }; }; diff --git a/booster/tree/xgboost_tree.hpp b/booster/tree/xgboost_tree.hpp index bb2d89acb..7c4f740cc 100644 --- a/booster/tree/xgboost_tree.hpp +++ b/booster/tree/xgboost_tree.hpp @@ -21,7 +21,7 @@ namespace xgboost{ } }; }; - +#include "../../utils/xgboost_fmap.h" #include "xgboost_svdf_tree.hpp" #include "xgboost_col_treemaker.hpp" #include "xgboost_row_treemaker.hpp" @@ -57,6 +57,7 @@ namespace xgboost{ } } param.SetParam( name, val ); + constrain.SetParam( name, val ); tree.param.SetParam( name, val ); } virtual void LoadModel( utils::IStream &fi ){ @@ -90,17 +91,18 @@ namespace xgboost{ int num_pruned; switch( tree_maker ){ case 0: { + utils::Assert( !constrain.HasConstrain(), "tree maker 0 does not support constrain" ); RTreeUpdater updater( param, tree, grad, hess, smat, root_index ); tree.param.max_depth = updater.do_boost( num_pruned ); break; } case 1:{ - ColTreeMaker maker( tree, param, grad, hess, smat, root_index ); + ColTreeMaker maker( tree, param, grad, hess, smat, root_index, constrain ); maker.Make( tree.param.max_depth, num_pruned ); break; } case 2:{ - RowTreeMaker maker( tree, param, grad, hess, smat, root_index ); + RowTreeMaker maker( tree, param, grad, hess, smat, root_index, constrain ); maker.Make( tree.param.max_depth, num_pruned ); break; } @@ -178,7 +180,7 @@ namespace xgboost{ } this->DropTmp( fmat.GetRow(i), e ); } - RowTreeMaker maker( tree, param, grad, hess, fmat, root_index ); + RowTreeMaker maker( tree, param, grad, hess, fmat, root_index, constrain ); maker.Collapse( valid_index, nid ); if( !silent ){ printf( "tree collapse end, max_depth=%d\n", tree.param.max_depth ); @@ -198,7 +200,7 @@ namespace xgboost{ this->DropTmp( fmat.GetRow(i), e ); if( pid == nid ) valid_index.push_back( static_cast(i) ); } - RowTreeMaker maker( tree, param, grad, hess, fmat, root_index ); + RowTreeMaker maker( tree, param, grad, hess, fmat, root_index, constrain ); bool success = maker.Expand( valid_index, nid ); if( !silent ){ printf( "tree expand end, success=%d, max_depth=%d\n", (int)success, tree.MaxDepth() ); @@ -215,7 +217,9 @@ namespace xgboost{ int tree_maker; // interaction int interact_type; - int interact_node; + int interact_node; + // feature constrain + utils::FeatConstrain constrain; private: struct ThreadEntry{ std::vector feat; diff --git a/demo/test/runexp.sh b/demo/test/runexp.sh index 5a5988122..bb5fa2581 100755 --- a/demo/test/runexp.sh +++ b/demo/test/runexp.sh @@ -7,46 +7,20 @@ python mknfold.py agaricus.txt 1 ../../xgboost mushroom.conf num_round=1 model_out=full.model bst:max_depth=3 ../../xgboost mushroom.conf task=dump model_in=full.model fmap=featmap.txt name_dump=dump.full.txt -# training -../../xgboost mushroom.conf num_round=2 model_out=m1.model bst:max_depth=1 +# constrain +../../xgboost mushroom.conf num_round=1 model_out=ban.model bst:max_depth=3 bst:fban=22-31 -# this is what dump will looklike with feature map -../../xgboost mushroom.conf task=dump model_in=m1.model fmap=featmap.txt name_dump=dump.m1.txt - -# interaction -../../xgboost mushroom.conf task=interact model_in=m1.model model_out=m2.model interact:booster_index=0 bst:interact:expand=1 -../../xgboost mushroom.conf task=interact model_in=m2.model model_out=m3.model interact:booster_index=0 interact:action=remove -../../xgboost mushroom.conf task=interact model_in=m3.model model_out=m4.model interact:booster_index=0 bst:interact:expand=2 - -# this is what dump will looklike with feature map -../../xgboost mushroom.conf task=dump model_in=m1.model fmap=featmap.txt name_dump=dump.m2.txt -../../xgboost mushroom.conf task=dump model_in=m2.model fmap=featmap.txt name_dump=dump.m2.txt -../../xgboost mushroom.conf task=dump model_in=m3.model fmap=featmap.txt name_dump=dump.m3.txt -../../xgboost mushroom.conf task=dump model_in=m4.model fmap=featmap.txt name_dump=dump.m4.txt - - -echo "========m1=======" -cat dump.m1.txt - -echo "========m2========" -cat dump.m2.txt - -echo "========m3========" -cat dump.m3.txt - -# statistics are print into stderr -../../xgboost mushroom.conf model_in=m3.model task=eval 2>eval.m3.txt -cat eval.m3.txt - -echo "========m4========" -cat dump.m4.txt - -../../xgboost mushroom.conf model_in=m4.model task=eval 2>eval.m4.txt -cat eval.m4.txt +# constrain +../../xgboost mushroom.conf num_round=1 model_out=pass.model bst:max_depth=3 bst:fdefault=-1 bst:fpass=22-31 +../../xgboost mushroom.conf task=dump model_in=ban.model fmap=featmap.txt name_dump=dump.ban.txt +../../xgboost mushroom.conf task=dump model_in=pass.model fmap=featmap.txt name_dump=dump.pass.txt echo "========full=======" cat dump.full.txt -../../xgboost mushroom.conf model_in=full.model task=eval 2>eval.full.txt -cat eval.full.txt \ No newline at end of file +echo "========ban=======" +cat dump.ban.txt + +echo "========pass=======" +cat dump.pass.txt diff --git a/utils/xgboost_fmap.h b/utils/xgboost_fmap.h index f288d027a..4ab7e3909 100644 --- a/utils/xgboost_fmap.h +++ b/utils/xgboost_fmap.h @@ -68,5 +68,56 @@ namespace xgboost{ std::vector types_; }; }; // namespace utils + + namespace utils{ + /*! \brief feature constraint, allow or disallow some feature during training */ + class FeatConstrain{ + public: + FeatConstrain( void ){ + default_state_ = +1; + } + /*!\brief set parameters */ + inline void SetParam( const char *name, const char *val ){ + int a, b; + if( !strcmp( name, "fban") ){ + this->ParseRange( val, a, b ); + this->SetRange( a, b, -1 ); + } + if( !strcmp( name, "fpass") ){ + this->ParseRange( val, a, b ); + this->SetRange( a, b, +1 ); + } + if( !strcmp( name, "fdefault") ){ + default_state_ = atoi( val ); + } + } + /*! \brief whether constrain is specified */ + inline bool HasConstrain( void ) const { + return state_.size() != 0 && default_state_ == 1; + } + /*! \brief whether a feature index is banned or not */ + inline bool NotBanned( unsigned index ) const{ + int rt = index < state_.size() ? state_[index] : default_state_; + if( rt == 0 ) rt = default_state_; + return rt == 1; + } + private: + inline void SetRange( int a, int b, int st ){ + if( b > (int)state_.size() ) state_.resize( b, 0 ); + for( int i = a; i < b; ++ i ){ + state_[i] = st; + } + } + inline void ParseRange( const char *val, int &a, int &b ){ + if( sscanf( val, "%d-%d", &a, &b ) == 2 ) return; + utils::Assert( sscanf( val, "%d", &a ) == 1 ); + b = a + 1; + } + /*! \brief default state */ + int default_state_; + /*! \brief whether the state here is, +1:pass, -1: ban, 0:default */ + std::vector state_; + }; + }; // namespace utils }; // namespace xgboost #endif // XGBOOST_FMAP_H