add feature constraint
This commit is contained in:
parent
6a91438634
commit
d56394d2ef
@ -11,6 +11,7 @@
|
|||||||
#include "xgboost_tree_model.h"
|
#include "xgboost_tree_model.h"
|
||||||
#include "../../utils/xgboost_omp.h"
|
#include "../../utils/xgboost_omp.h"
|
||||||
#include "../../utils/xgboost_random.h"
|
#include "../../utils/xgboost_random.h"
|
||||||
|
#include "../../utils/xgboost_fmap.h"
|
||||||
#include "xgboost_base_treemaker.hpp"
|
#include "xgboost_base_treemaker.hpp"
|
||||||
|
|
||||||
namespace xgboost{
|
namespace xgboost{
|
||||||
@ -23,10 +24,11 @@ namespace xgboost{
|
|||||||
const std::vector<float> &grad,
|
const std::vector<float> &grad,
|
||||||
const std::vector<float> &hess,
|
const std::vector<float> &hess,
|
||||||
const FMatrix &smat,
|
const FMatrix &smat,
|
||||||
const std::vector<unsigned> &root_index )
|
const std::vector<unsigned> &root_index,
|
||||||
|
const utils::FeatConstrain &constrain )
|
||||||
: BaseTreeMaker( tree, param ),
|
: BaseTreeMaker( tree, param ),
|
||||||
grad(grad), hess(hess),
|
grad(grad), hess(hess),
|
||||||
smat(smat), root_index(root_index) {
|
smat(smat), root_index(root_index), constrain(constrain) {
|
||||||
utils::Assert( grad.size() == hess.size(), "booster:invalid input" );
|
utils::Assert( grad.size() == hess.size(), "booster:invalid input" );
|
||||||
utils::Assert( smat.NumRow() == hess.size(), "booster:invalid input" );
|
utils::Assert( smat.NumRow() == hess.size(), "booster:invalid input" );
|
||||||
utils::Assert( root_index.size() == 0 || root_index.size() == hess.size(), "booster:invalid input" );
|
utils::Assert( root_index.size() == 0 || root_index.size() == hess.size(), "booster:invalid input" );
|
||||||
@ -282,10 +284,10 @@ namespace xgboost{
|
|||||||
{// initialize feature index
|
{// initialize feature index
|
||||||
int ncol = static_cast<int>( smat.NumCol() );
|
int ncol = static_cast<int>( smat.NumCol() );
|
||||||
for( int i = 0; i < ncol; i ++ ){
|
for( int i = 0; i < ncol; i ++ ){
|
||||||
if( smat.GetSortedCol(i).Next() ){
|
if( smat.GetSortedCol(i).Next() && constrain.NotBanned(i) ){
|
||||||
feat_index.push_back( i );
|
feat_index.push_back( i );
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
random::Shuffle( feat_index );
|
random::Shuffle( feat_index );
|
||||||
}
|
}
|
||||||
{// setup temp space for each thread
|
{// setup temp space for each thread
|
||||||
@ -326,6 +328,7 @@ namespace xgboost{
|
|||||||
const std::vector<float> &hess;
|
const std::vector<float> &hess;
|
||||||
const FMatrix &smat;
|
const FMatrix &smat;
|
||||||
const std::vector<unsigned> &root_index;
|
const std::vector<unsigned> &root_index;
|
||||||
|
const utils::FeatConstrain &constrain;
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|||||||
@ -11,6 +11,7 @@
|
|||||||
#include "xgboost_tree_model.h"
|
#include "xgboost_tree_model.h"
|
||||||
#include "../../utils/xgboost_omp.h"
|
#include "../../utils/xgboost_omp.h"
|
||||||
#include "../../utils/xgboost_random.h"
|
#include "../../utils/xgboost_random.h"
|
||||||
|
#include "../../utils/xgboost_fmap.h"
|
||||||
#include "xgboost_base_treemaker.hpp"
|
#include "xgboost_base_treemaker.hpp"
|
||||||
|
|
||||||
namespace xgboost{
|
namespace xgboost{
|
||||||
@ -23,10 +24,11 @@ namespace xgboost{
|
|||||||
const std::vector<float> &grad,
|
const std::vector<float> &grad,
|
||||||
const std::vector<float> &hess,
|
const std::vector<float> &hess,
|
||||||
const FMatrix &smat,
|
const FMatrix &smat,
|
||||||
const std::vector<unsigned> &root_index )
|
const std::vector<unsigned> &root_index,
|
||||||
|
const utils::FeatConstrain &constrain )
|
||||||
: BaseTreeMaker( tree, param ),
|
: BaseTreeMaker( tree, param ),
|
||||||
grad(grad), hess(hess),
|
grad(grad), hess(hess),
|
||||||
smat(smat), root_index(root_index) {
|
smat(smat), root_index(root_index), constrain(constrain) {
|
||||||
utils::Assert( grad.size() == hess.size(), "booster:invalid input" );
|
utils::Assert( grad.size() == hess.size(), "booster:invalid input" );
|
||||||
utils::Assert( smat.NumRow() == hess.size(), "booster:invalid input" );
|
utils::Assert( smat.NumRow() == hess.size(), "booster:invalid input" );
|
||||||
utils::Assert( root_index.size() == 0 || root_index.size() == hess.size(), "booster:invalid input" );
|
utils::Assert( root_index.size() == 0 || root_index.size() == hess.size(), "booster:invalid input" );
|
||||||
@ -254,14 +256,18 @@ namespace xgboost{
|
|||||||
for( bst_uint i = begin; i < end; ++i ){
|
for( bst_uint i = begin; i < end; ++i ){
|
||||||
const bst_uint ridx = row_index_set[i];
|
const bst_uint ridx = row_index_set[i];
|
||||||
for( typename FMatrix::RowIter it = smat.GetRow(ridx,gid); it.Next(); ){
|
for( typename FMatrix::RowIter it = smat.GetRow(ridx,gid); it.Next(); ){
|
||||||
builder.AddBudget( it.findex() );
|
const bst_uint findex = it.findex();
|
||||||
|
if( constrain.NotBanned( findex ) ) builder.AddBudget( findex );
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
builder.InitStorage();
|
builder.InitStorage();
|
||||||
for( bst_uint i = begin; i < end; ++i ){
|
for( bst_uint i = begin; i < end; ++i ){
|
||||||
const bst_uint ridx = row_index_set[i];
|
const bst_uint ridx = row_index_set[i];
|
||||||
for( typename FMatrix::RowIter it = smat.GetRow(ridx,gid); it.Next(); ){
|
for( typename FMatrix::RowIter it = smat.GetRow(ridx,gid); it.Next(); ){
|
||||||
builder.PushElem( it.findex(), FMatrixS::REntry( ridx, it.fvalue() ) );
|
const bst_uint findex = it.findex();
|
||||||
|
if( constrain.NotBanned( findex ) ) {
|
||||||
|
builder.PushElem( findex, FMatrixS::REntry( ridx, it.fvalue() ) );
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// --- end of building column major matrix ---
|
// --- end of building column major matrix ---
|
||||||
@ -373,6 +379,7 @@ namespace xgboost{
|
|||||||
const std::vector<float> &hess;
|
const std::vector<float> &hess;
|
||||||
const FMatrix &smat;
|
const FMatrix &smat;
|
||||||
const std::vector<unsigned> &root_index;
|
const std::vector<unsigned> &root_index;
|
||||||
|
const utils::FeatConstrain &constrain;
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|||||||
@ -21,7 +21,7 @@ namespace xgboost{
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
#include "../../utils/xgboost_fmap.h"
|
||||||
#include "xgboost_svdf_tree.hpp"
|
#include "xgboost_svdf_tree.hpp"
|
||||||
#include "xgboost_col_treemaker.hpp"
|
#include "xgboost_col_treemaker.hpp"
|
||||||
#include "xgboost_row_treemaker.hpp"
|
#include "xgboost_row_treemaker.hpp"
|
||||||
@ -57,6 +57,7 @@ namespace xgboost{
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
param.SetParam( name, val );
|
param.SetParam( name, val );
|
||||||
|
constrain.SetParam( name, val );
|
||||||
tree.param.SetParam( name, val );
|
tree.param.SetParam( name, val );
|
||||||
}
|
}
|
||||||
virtual void LoadModel( utils::IStream &fi ){
|
virtual void LoadModel( utils::IStream &fi ){
|
||||||
@ -90,17 +91,18 @@ namespace xgboost{
|
|||||||
int num_pruned;
|
int num_pruned;
|
||||||
switch( tree_maker ){
|
switch( tree_maker ){
|
||||||
case 0: {
|
case 0: {
|
||||||
|
utils::Assert( !constrain.HasConstrain(), "tree maker 0 does not support constrain" );
|
||||||
RTreeUpdater<FMatrix> updater( param, tree, grad, hess, smat, root_index );
|
RTreeUpdater<FMatrix> updater( param, tree, grad, hess, smat, root_index );
|
||||||
tree.param.max_depth = updater.do_boost( num_pruned );
|
tree.param.max_depth = updater.do_boost( num_pruned );
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case 1:{
|
case 1:{
|
||||||
ColTreeMaker<FMatrix> maker( tree, param, grad, hess, smat, root_index );
|
ColTreeMaker<FMatrix> maker( tree, param, grad, hess, smat, root_index, constrain );
|
||||||
maker.Make( tree.param.max_depth, num_pruned );
|
maker.Make( tree.param.max_depth, num_pruned );
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case 2:{
|
case 2:{
|
||||||
RowTreeMaker<FMatrix> maker( tree, param, grad, hess, smat, root_index );
|
RowTreeMaker<FMatrix> maker( tree, param, grad, hess, smat, root_index, constrain );
|
||||||
maker.Make( tree.param.max_depth, num_pruned );
|
maker.Make( tree.param.max_depth, num_pruned );
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@ -178,7 +180,7 @@ namespace xgboost{
|
|||||||
}
|
}
|
||||||
this->DropTmp( fmat.GetRow(i), e );
|
this->DropTmp( fmat.GetRow(i), e );
|
||||||
}
|
}
|
||||||
RowTreeMaker<FMatrix> maker( tree, param, grad, hess, fmat, root_index );
|
RowTreeMaker<FMatrix> maker( tree, param, grad, hess, fmat, root_index, constrain );
|
||||||
maker.Collapse( valid_index, nid );
|
maker.Collapse( valid_index, nid );
|
||||||
if( !silent ){
|
if( !silent ){
|
||||||
printf( "tree collapse end, max_depth=%d\n", tree.param.max_depth );
|
printf( "tree collapse end, max_depth=%d\n", tree.param.max_depth );
|
||||||
@ -198,7 +200,7 @@ namespace xgboost{
|
|||||||
this->DropTmp( fmat.GetRow(i), e );
|
this->DropTmp( fmat.GetRow(i), e );
|
||||||
if( pid == nid ) valid_index.push_back( static_cast<bst_uint>(i) );
|
if( pid == nid ) valid_index.push_back( static_cast<bst_uint>(i) );
|
||||||
}
|
}
|
||||||
RowTreeMaker<FMatrix> maker( tree, param, grad, hess, fmat, root_index );
|
RowTreeMaker<FMatrix> maker( tree, param, grad, hess, fmat, root_index, constrain );
|
||||||
bool success = maker.Expand( valid_index, nid );
|
bool success = maker.Expand( valid_index, nid );
|
||||||
if( !silent ){
|
if( !silent ){
|
||||||
printf( "tree expand end, success=%d, max_depth=%d\n", (int)success, tree.MaxDepth() );
|
printf( "tree expand end, success=%d, max_depth=%d\n", (int)success, tree.MaxDepth() );
|
||||||
@ -215,7 +217,9 @@ namespace xgboost{
|
|||||||
int tree_maker;
|
int tree_maker;
|
||||||
// interaction
|
// interaction
|
||||||
int interact_type;
|
int interact_type;
|
||||||
int interact_node;
|
int interact_node;
|
||||||
|
// feature constrain
|
||||||
|
utils::FeatConstrain constrain;
|
||||||
private:
|
private:
|
||||||
struct ThreadEntry{
|
struct ThreadEntry{
|
||||||
std::vector<float> feat;
|
std::vector<float> feat;
|
||||||
|
|||||||
@ -7,46 +7,20 @@ python mknfold.py agaricus.txt 1
|
|||||||
../../xgboost mushroom.conf num_round=1 model_out=full.model bst:max_depth=3
|
../../xgboost mushroom.conf num_round=1 model_out=full.model bst:max_depth=3
|
||||||
../../xgboost mushroom.conf task=dump model_in=full.model fmap=featmap.txt name_dump=dump.full.txt
|
../../xgboost mushroom.conf task=dump model_in=full.model fmap=featmap.txt name_dump=dump.full.txt
|
||||||
|
|
||||||
# training
|
# constrain
|
||||||
../../xgboost mushroom.conf num_round=2 model_out=m1.model bst:max_depth=1
|
../../xgboost mushroom.conf num_round=1 model_out=ban.model bst:max_depth=3 bst:fban=22-31
|
||||||
|
|
||||||
# this is what dump will looklike with feature map
|
# constrain
|
||||||
../../xgboost mushroom.conf task=dump model_in=m1.model fmap=featmap.txt name_dump=dump.m1.txt
|
../../xgboost mushroom.conf num_round=1 model_out=pass.model bst:max_depth=3 bst:fdefault=-1 bst:fpass=22-31
|
||||||
|
|
||||||
# interaction
|
|
||||||
../../xgboost mushroom.conf task=interact model_in=m1.model model_out=m2.model interact:booster_index=0 bst:interact:expand=1
|
|
||||||
../../xgboost mushroom.conf task=interact model_in=m2.model model_out=m3.model interact:booster_index=0 interact:action=remove
|
|
||||||
../../xgboost mushroom.conf task=interact model_in=m3.model model_out=m4.model interact:booster_index=0 bst:interact:expand=2
|
|
||||||
|
|
||||||
# this is what dump will looklike with feature map
|
|
||||||
../../xgboost mushroom.conf task=dump model_in=m1.model fmap=featmap.txt name_dump=dump.m2.txt
|
|
||||||
../../xgboost mushroom.conf task=dump model_in=m2.model fmap=featmap.txt name_dump=dump.m2.txt
|
|
||||||
../../xgboost mushroom.conf task=dump model_in=m3.model fmap=featmap.txt name_dump=dump.m3.txt
|
|
||||||
../../xgboost mushroom.conf task=dump model_in=m4.model fmap=featmap.txt name_dump=dump.m4.txt
|
|
||||||
|
|
||||||
|
|
||||||
echo "========m1======="
|
|
||||||
cat dump.m1.txt
|
|
||||||
|
|
||||||
echo "========m2========"
|
|
||||||
cat dump.m2.txt
|
|
||||||
|
|
||||||
echo "========m3========"
|
|
||||||
cat dump.m3.txt
|
|
||||||
|
|
||||||
# statistics are print into stderr
|
|
||||||
../../xgboost mushroom.conf model_in=m3.model task=eval 2>eval.m3.txt
|
|
||||||
cat eval.m3.txt
|
|
||||||
|
|
||||||
echo "========m4========"
|
|
||||||
cat dump.m4.txt
|
|
||||||
|
|
||||||
../../xgboost mushroom.conf model_in=m4.model task=eval 2>eval.m4.txt
|
|
||||||
cat eval.m4.txt
|
|
||||||
|
|
||||||
|
../../xgboost mushroom.conf task=dump model_in=ban.model fmap=featmap.txt name_dump=dump.ban.txt
|
||||||
|
../../xgboost mushroom.conf task=dump model_in=pass.model fmap=featmap.txt name_dump=dump.pass.txt
|
||||||
|
|
||||||
echo "========full======="
|
echo "========full======="
|
||||||
cat dump.full.txt
|
cat dump.full.txt
|
||||||
|
|
||||||
../../xgboost mushroom.conf model_in=full.model task=eval 2>eval.full.txt
|
echo "========ban======="
|
||||||
cat eval.full.txt
|
cat dump.ban.txt
|
||||||
|
|
||||||
|
echo "========pass======="
|
||||||
|
cat dump.pass.txt
|
||||||
|
|||||||
@ -68,5 +68,56 @@ namespace xgboost{
|
|||||||
std::vector<Type> types_;
|
std::vector<Type> types_;
|
||||||
};
|
};
|
||||||
}; // namespace utils
|
}; // namespace utils
|
||||||
|
|
||||||
|
namespace utils{
|
||||||
|
/*! \brief feature constraint, allow or disallow some feature during training */
|
||||||
|
class FeatConstrain{
|
||||||
|
public:
|
||||||
|
FeatConstrain( void ){
|
||||||
|
default_state_ = +1;
|
||||||
|
}
|
||||||
|
/*!\brief set parameters */
|
||||||
|
inline void SetParam( const char *name, const char *val ){
|
||||||
|
int a, b;
|
||||||
|
if( !strcmp( name, "fban") ){
|
||||||
|
this->ParseRange( val, a, b );
|
||||||
|
this->SetRange( a, b, -1 );
|
||||||
|
}
|
||||||
|
if( !strcmp( name, "fpass") ){
|
||||||
|
this->ParseRange( val, a, b );
|
||||||
|
this->SetRange( a, b, +1 );
|
||||||
|
}
|
||||||
|
if( !strcmp( name, "fdefault") ){
|
||||||
|
default_state_ = atoi( val );
|
||||||
|
}
|
||||||
|
}
|
||||||
|
/*! \brief whether constrain is specified */
|
||||||
|
inline bool HasConstrain( void ) const {
|
||||||
|
return state_.size() != 0 && default_state_ == 1;
|
||||||
|
}
|
||||||
|
/*! \brief whether a feature index is banned or not */
|
||||||
|
inline bool NotBanned( unsigned index ) const{
|
||||||
|
int rt = index < state_.size() ? state_[index] : default_state_;
|
||||||
|
if( rt == 0 ) rt = default_state_;
|
||||||
|
return rt == 1;
|
||||||
|
}
|
||||||
|
private:
|
||||||
|
inline void SetRange( int a, int b, int st ){
|
||||||
|
if( b > (int)state_.size() ) state_.resize( b, 0 );
|
||||||
|
for( int i = a; i < b; ++ i ){
|
||||||
|
state_[i] = st;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
inline void ParseRange( const char *val, int &a, int &b ){
|
||||||
|
if( sscanf( val, "%d-%d", &a, &b ) == 2 ) return;
|
||||||
|
utils::Assert( sscanf( val, "%d", &a ) == 1 );
|
||||||
|
b = a + 1;
|
||||||
|
}
|
||||||
|
/*! \brief default state */
|
||||||
|
int default_state_;
|
||||||
|
/*! \brief whether the state here is, +1:pass, -1: ban, 0:default */
|
||||||
|
std::vector<int> state_;
|
||||||
|
};
|
||||||
|
}; // namespace utils
|
||||||
}; // namespace xgboost
|
}; // namespace xgboost
|
||||||
#endif // XGBOOST_FMAP_H
|
#endif // XGBOOST_FMAP_H
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user