add feature constraint

This commit is contained in:
tqchen 2014-03-19 10:47:56 -07:00
parent 6a91438634
commit d56394d2ef
5 changed files with 90 additions and 51 deletions

View File

@ -11,6 +11,7 @@
#include "xgboost_tree_model.h"
#include "../../utils/xgboost_omp.h"
#include "../../utils/xgboost_random.h"
#include "../../utils/xgboost_fmap.h"
#include "xgboost_base_treemaker.hpp"
namespace xgboost{
@ -23,10 +24,11 @@ namespace xgboost{
const std::vector<float> &grad,
const std::vector<float> &hess,
const FMatrix &smat,
const std::vector<unsigned> &root_index )
const std::vector<unsigned> &root_index,
const utils::FeatConstrain &constrain )
: BaseTreeMaker( tree, param ),
grad(grad), hess(hess),
smat(smat), root_index(root_index) {
smat(smat), root_index(root_index), constrain(constrain) {
utils::Assert( grad.size() == hess.size(), "booster:invalid input" );
utils::Assert( smat.NumRow() == hess.size(), "booster:invalid input" );
utils::Assert( root_index.size() == 0 || root_index.size() == hess.size(), "booster:invalid input" );
@ -282,10 +284,10 @@ namespace xgboost{
{// initialize feature index
int ncol = static_cast<int>( smat.NumCol() );
for( int i = 0; i < ncol; i ++ ){
if( smat.GetSortedCol(i).Next() ){
if( smat.GetSortedCol(i).Next() && constrain.NotBanned(i) ){
feat_index.push_back( i );
}
}
}
random::Shuffle( feat_index );
}
{// setup temp space for each thread
@ -326,6 +328,7 @@ namespace xgboost{
const std::vector<float> &hess;
const FMatrix &smat;
const std::vector<unsigned> &root_index;
const utils::FeatConstrain &constrain;
};
};
};

View File

@ -11,6 +11,7 @@
#include "xgboost_tree_model.h"
#include "../../utils/xgboost_omp.h"
#include "../../utils/xgboost_random.h"
#include "../../utils/xgboost_fmap.h"
#include "xgboost_base_treemaker.hpp"
namespace xgboost{
@ -23,10 +24,11 @@ namespace xgboost{
const std::vector<float> &grad,
const std::vector<float> &hess,
const FMatrix &smat,
const std::vector<unsigned> &root_index )
const std::vector<unsigned> &root_index,
const utils::FeatConstrain &constrain )
: BaseTreeMaker( tree, param ),
grad(grad), hess(hess),
smat(smat), root_index(root_index) {
smat(smat), root_index(root_index), constrain(constrain) {
utils::Assert( grad.size() == hess.size(), "booster:invalid input" );
utils::Assert( smat.NumRow() == hess.size(), "booster:invalid input" );
utils::Assert( root_index.size() == 0 || root_index.size() == hess.size(), "booster:invalid input" );
@ -254,14 +256,18 @@ namespace xgboost{
for( bst_uint i = begin; i < end; ++i ){
const bst_uint ridx = row_index_set[i];
for( typename FMatrix::RowIter it = smat.GetRow(ridx,gid); it.Next(); ){
builder.AddBudget( it.findex() );
const bst_uint findex = it.findex();
if( constrain.NotBanned( findex ) ) builder.AddBudget( findex );
}
}
builder.InitStorage();
for( bst_uint i = begin; i < end; ++i ){
const bst_uint ridx = row_index_set[i];
for( typename FMatrix::RowIter it = smat.GetRow(ridx,gid); it.Next(); ){
builder.PushElem( it.findex(), FMatrixS::REntry( ridx, it.fvalue() ) );
const bst_uint findex = it.findex();
if( constrain.NotBanned( findex ) ) {
builder.PushElem( findex, FMatrixS::REntry( ridx, it.fvalue() ) );
}
}
}
// --- end of building column major matrix ---
@ -373,6 +379,7 @@ namespace xgboost{
const std::vector<float> &hess;
const FMatrix &smat;
const std::vector<unsigned> &root_index;
const utils::FeatConstrain &constrain;
};
};
};

View File

@ -21,7 +21,7 @@ namespace xgboost{
}
};
};
#include "../../utils/xgboost_fmap.h"
#include "xgboost_svdf_tree.hpp"
#include "xgboost_col_treemaker.hpp"
#include "xgboost_row_treemaker.hpp"
@ -57,6 +57,7 @@ namespace xgboost{
}
}
param.SetParam( name, val );
constrain.SetParam( name, val );
tree.param.SetParam( name, val );
}
virtual void LoadModel( utils::IStream &fi ){
@ -90,17 +91,18 @@ namespace xgboost{
int num_pruned;
switch( tree_maker ){
case 0: {
utils::Assert( !constrain.HasConstrain(), "tree maker 0 does not support constrain" );
RTreeUpdater<FMatrix> updater( param, tree, grad, hess, smat, root_index );
tree.param.max_depth = updater.do_boost( num_pruned );
break;
}
case 1:{
ColTreeMaker<FMatrix> maker( tree, param, grad, hess, smat, root_index );
ColTreeMaker<FMatrix> maker( tree, param, grad, hess, smat, root_index, constrain );
maker.Make( tree.param.max_depth, num_pruned );
break;
}
case 2:{
RowTreeMaker<FMatrix> maker( tree, param, grad, hess, smat, root_index );
RowTreeMaker<FMatrix> maker( tree, param, grad, hess, smat, root_index, constrain );
maker.Make( tree.param.max_depth, num_pruned );
break;
}
@ -178,7 +180,7 @@ namespace xgboost{
}
this->DropTmp( fmat.GetRow(i), e );
}
RowTreeMaker<FMatrix> maker( tree, param, grad, hess, fmat, root_index );
RowTreeMaker<FMatrix> maker( tree, param, grad, hess, fmat, root_index, constrain );
maker.Collapse( valid_index, nid );
if( !silent ){
printf( "tree collapse end, max_depth=%d\n", tree.param.max_depth );
@ -198,7 +200,7 @@ namespace xgboost{
this->DropTmp( fmat.GetRow(i), e );
if( pid == nid ) valid_index.push_back( static_cast<bst_uint>(i) );
}
RowTreeMaker<FMatrix> maker( tree, param, grad, hess, fmat, root_index );
RowTreeMaker<FMatrix> maker( tree, param, grad, hess, fmat, root_index, constrain );
bool success = maker.Expand( valid_index, nid );
if( !silent ){
printf( "tree expand end, success=%d, max_depth=%d\n", (int)success, tree.MaxDepth() );
@ -215,7 +217,9 @@ namespace xgboost{
int tree_maker;
// interaction
int interact_type;
int interact_node;
int interact_node;
// feature constrain
utils::FeatConstrain constrain;
private:
struct ThreadEntry{
std::vector<float> feat;

View File

@ -7,46 +7,20 @@ python mknfold.py agaricus.txt 1
../../xgboost mushroom.conf num_round=1 model_out=full.model bst:max_depth=3
../../xgboost mushroom.conf task=dump model_in=full.model fmap=featmap.txt name_dump=dump.full.txt
# training
../../xgboost mushroom.conf num_round=2 model_out=m1.model bst:max_depth=1
# constrain
../../xgboost mushroom.conf num_round=1 model_out=ban.model bst:max_depth=3 bst:fban=22-31
# this is what dump will looklike with feature map
../../xgboost mushroom.conf task=dump model_in=m1.model fmap=featmap.txt name_dump=dump.m1.txt
# interaction
../../xgboost mushroom.conf task=interact model_in=m1.model model_out=m2.model interact:booster_index=0 bst:interact:expand=1
../../xgboost mushroom.conf task=interact model_in=m2.model model_out=m3.model interact:booster_index=0 interact:action=remove
../../xgboost mushroom.conf task=interact model_in=m3.model model_out=m4.model interact:booster_index=0 bst:interact:expand=2
# this is what dump will looklike with feature map
../../xgboost mushroom.conf task=dump model_in=m1.model fmap=featmap.txt name_dump=dump.m2.txt
../../xgboost mushroom.conf task=dump model_in=m2.model fmap=featmap.txt name_dump=dump.m2.txt
../../xgboost mushroom.conf task=dump model_in=m3.model fmap=featmap.txt name_dump=dump.m3.txt
../../xgboost mushroom.conf task=dump model_in=m4.model fmap=featmap.txt name_dump=dump.m4.txt
echo "========m1======="
cat dump.m1.txt
echo "========m2========"
cat dump.m2.txt
echo "========m3========"
cat dump.m3.txt
# statistics are print into stderr
../../xgboost mushroom.conf model_in=m3.model task=eval 2>eval.m3.txt
cat eval.m3.txt
echo "========m4========"
cat dump.m4.txt
../../xgboost mushroom.conf model_in=m4.model task=eval 2>eval.m4.txt
cat eval.m4.txt
# constrain
../../xgboost mushroom.conf num_round=1 model_out=pass.model bst:max_depth=3 bst:fdefault=-1 bst:fpass=22-31
../../xgboost mushroom.conf task=dump model_in=ban.model fmap=featmap.txt name_dump=dump.ban.txt
../../xgboost mushroom.conf task=dump model_in=pass.model fmap=featmap.txt name_dump=dump.pass.txt
echo "========full======="
cat dump.full.txt
../../xgboost mushroom.conf model_in=full.model task=eval 2>eval.full.txt
cat eval.full.txt
echo "========ban======="
cat dump.ban.txt
echo "========pass======="
cat dump.pass.txt

View File

@ -68,5 +68,56 @@ namespace xgboost{
std::vector<Type> types_;
};
}; // namespace utils
namespace utils{
/*! \brief feature constraint, allow or disallow some feature during training */
class FeatConstrain{
public:
FeatConstrain( void ){
default_state_ = +1;
}
/*!\brief set parameters */
inline void SetParam( const char *name, const char *val ){
int a, b;
if( !strcmp( name, "fban") ){
this->ParseRange( val, a, b );
this->SetRange( a, b, -1 );
}
if( !strcmp( name, "fpass") ){
this->ParseRange( val, a, b );
this->SetRange( a, b, +1 );
}
if( !strcmp( name, "fdefault") ){
default_state_ = atoi( val );
}
}
/*! \brief whether constrain is specified */
inline bool HasConstrain( void ) const {
return state_.size() != 0 && default_state_ == 1;
}
/*! \brief whether a feature index is banned or not */
inline bool NotBanned( unsigned index ) const{
int rt = index < state_.size() ? state_[index] : default_state_;
if( rt == 0 ) rt = default_state_;
return rt == 1;
}
private:
inline void SetRange( int a, int b, int st ){
if( b > (int)state_.size() ) state_.resize( b, 0 );
for( int i = a; i < b; ++ i ){
state_[i] = st;
}
}
inline void ParseRange( const char *val, int &a, int &b ){
if( sscanf( val, "%d-%d", &a, &b ) == 2 ) return;
utils::Assert( sscanf( val, "%d", &a ) == 1 );
b = a + 1;
}
/*! \brief default state */
int default_state_;
/*! \brief whether the state here is, +1:pass, -1: ban, 0:default */
std::vector<int> state_;
};
}; // namespace utils
}; // namespace xgboost
#endif // XGBOOST_FMAP_H