add add and remove
This commit is contained in:
parent
ef5a389ecf
commit
d960550933
@ -83,6 +83,7 @@ namespace xgboost{
|
||||
// set all the rest expanding nodes to leaf
|
||||
for( size_t i = 0; i < qexpand.size(); ++ i ){
|
||||
const int nid = qexpand[i];
|
||||
|
||||
tree[ nid ].set_leaf( snode[nid].weight * param.learning_rate );
|
||||
tree.stat( nid ).loss_chg = 0.0f;
|
||||
tree.stat( nid ).sum_hess = static_cast<float>( snode[ nid ].sum_hess );
|
||||
@ -94,6 +95,15 @@ namespace xgboost{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
// collapse specific node
|
||||
inline void Collapse( const std::vector<bst_uint> &valid_index, int nid ){
|
||||
if( valid_index.size() == 0 ) return;
|
||||
this->InitDataExpand( valid_index, nid );
|
||||
this->InitNewNode( this->qexpand );
|
||||
tree.stat( nid ).loss_chg = 0.0f;
|
||||
tree.stat( nid ).sum_hess = static_cast<float>( snode[ nid ].sum_hess );
|
||||
tree.CollapseToLeaf( nid, snode[nid].weight * param.learning_rate );
|
||||
}
|
||||
private:
|
||||
// make leaf nodes for all qexpand, update node statistics, mark leaf value
|
||||
inline void InitNewNode( const std::vector<int> &qexpand ){
|
||||
|
||||
@ -79,7 +79,7 @@ namespace xgboost{
|
||||
if( interact_type != 0 ){
|
||||
switch( interact_type ){
|
||||
case 1: this->ExpandNode( grad, hess, smat, root_index, interact_node ); return;
|
||||
case 2:
|
||||
case 2: this->CollapseNode( grad, hess, smat, root_index, interact_node ); return;
|
||||
default: utils::Error("unknown interact type");
|
||||
}
|
||||
}
|
||||
@ -108,7 +108,7 @@ namespace xgboost{
|
||||
}
|
||||
if( !silent ){
|
||||
printf( "tree train end, %d roots, %d extra nodes, %d pruned nodes ,max_depth=%d\n",
|
||||
tree.param.num_roots, tree.num_extra_nodes(), num_pruned, tree.param.max_depth );
|
||||
tree.param.num_roots, tree.num_extra_nodes(), num_pruned, tree.MaxDepth() );
|
||||
}
|
||||
}
|
||||
virtual float Predict( const FMatrix &fmat, bst_uint ridx, unsigned gid = 0 ){
|
||||
@ -158,6 +158,32 @@ namespace xgboost{
|
||||
tree.DumpModel( fo, fmap, with_stats );
|
||||
}
|
||||
private:
|
||||
inline void CollapseNode( std::vector<float> &grad,
|
||||
std::vector<float> &hess,
|
||||
const FMatrix &fmat,
|
||||
const std::vector<unsigned> &root_index,
|
||||
int nid ){
|
||||
std::vector<bst_uint> valid_index;
|
||||
for( size_t i = 0; i < grad.size(); i ++ ){
|
||||
ThreadEntry &e = this->InitTmp();
|
||||
this->PrepareTmp( fmat.GetRow(i), e );
|
||||
int pid = root_index.size() == 0 ? 0 : (int)root_index[i];
|
||||
// tranverse tree
|
||||
while( !tree[ pid ].is_leaf() ){
|
||||
unsigned split_index = tree[ pid ].split_index();
|
||||
pid = this->GetNext( pid, e.feat[ split_index ], e.funknown[ split_index ] );
|
||||
if( pid == nid ){
|
||||
valid_index.push_back( static_cast<bst_uint>(i) ); break;
|
||||
}
|
||||
}
|
||||
this->DropTmp( fmat.GetRow(i), e );
|
||||
}
|
||||
RowTreeMaker<FMatrix> maker( tree, param, grad, hess, fmat, root_index );
|
||||
maker.Collapse( valid_index, nid );
|
||||
if( !silent ){
|
||||
printf( "tree collapse end, max_depth=%d\n", tree.param.max_depth );
|
||||
}
|
||||
}
|
||||
inline void ExpandNode( std::vector<float> &grad,
|
||||
std::vector<float> &hess,
|
||||
const FMatrix &fmat,
|
||||
@ -175,7 +201,7 @@ namespace xgboost{
|
||||
RowTreeMaker<FMatrix> maker( tree, param, grad, hess, fmat, root_index );
|
||||
bool success = maker.Expand( valid_index, nid );
|
||||
if( !silent ){
|
||||
printf( "tree expand end, success=%d, max_depth=%d\n", (int)success, tree.param.max_depth );
|
||||
printf( "tree expand end, success=%d, max_depth=%d\n", (int)success, tree.MaxDepth() );
|
||||
}
|
||||
}
|
||||
private:
|
||||
|
||||
@ -197,6 +197,21 @@ namespace xgboost{
|
||||
this->DeleteNode( nodes[ rid ].cright() );
|
||||
nodes[ rid ].set_leaf( value );
|
||||
}
|
||||
/*!
|
||||
* \brief collapse a non leaf node to a leaf node, delete its children
|
||||
* \param rid node id of the node
|
||||
* \param new leaf value
|
||||
*/
|
||||
inline void CollapseToLeaf( int rid, float value ){
|
||||
if( nodes[rid].is_leaf() ) return;
|
||||
if( !nodes[ nodes[rid].cleft() ].is_leaf() ){
|
||||
CollapseToLeaf( nodes[rid].cleft(), 0.0f );
|
||||
}
|
||||
if( !nodes[ nodes[rid].cright() ].is_leaf() ){
|
||||
CollapseToLeaf( nodes[rid].cright(), 0.0f );
|
||||
}
|
||||
this->ChangeToLeaf( rid, value );
|
||||
}
|
||||
public:
|
||||
/*! \brief model parameter */
|
||||
Param param;
|
||||
@ -287,6 +302,25 @@ namespace xgboost{
|
||||
}
|
||||
return depth;
|
||||
}
|
||||
/*!
|
||||
* \brief get maximum depth
|
||||
* \param nid node id
|
||||
*/
|
||||
inline int MaxDepth( int nid ) const{
|
||||
if( nodes[nid].is_leaf() ) return 0;
|
||||
return std::max( MaxDepth( nodes[nid].cleft() )+1,
|
||||
MaxDepth( nodes[nid].cright() )+1 );
|
||||
}
|
||||
/*!
|
||||
* \brief get maximum depth
|
||||
*/
|
||||
inline int MaxDepth( void ){
|
||||
int maxd = 0;
|
||||
for( int i = 0; i < param.num_roots; ++ i ){
|
||||
maxd = std::max( maxd, MaxDepth( i ) );
|
||||
}
|
||||
return maxd;
|
||||
}
|
||||
/*! \brief number of extra nodes besides the root */
|
||||
inline int num_extra_nodes( void ) const {
|
||||
return param.num_nodes - param.num_roots - param.num_deleted;
|
||||
|
||||
@ -16,10 +16,14 @@ python mknfold.py agaricus.txt 1
|
||||
# interaction
|
||||
../../xgboost mushroom.conf task=interact model_in=m1.model model_out=m2.model interact:booster_index=0 bst:interact:expand=1
|
||||
../../xgboost mushroom.conf task=interact model_in=m2.model model_out=m3.model interact:booster_index=0 bst:interact:expand=2
|
||||
../../xgboost mushroom.conf task=interact model_in=m3.model model_out=m3v.model interact:booster_index=0 bst:interact:remove=2
|
||||
../../xgboost mushroom.conf task=interact model_in=m3v.model model_out=m3p.model interact:booster_index=0 bst:interact:expand=2
|
||||
|
||||
# this is what dump will looklike with feature map
|
||||
../../xgboost mushroom.conf task=dump model_in=m2.model fmap=featmap.txt name_dump=dump.m2.txt
|
||||
../../xgboost mushroom.conf task=dump model_in=m3.model fmap=featmap.txt name_dump=dump.m3.txt
|
||||
../../xgboost mushroom.conf task=dump model_in=m3v.model fmap=featmap.txt name_dump=dump.m3v.txt
|
||||
../../xgboost mushroom.conf task=dump model_in=m3p.model fmap=featmap.txt name_dump=dump.m3p.txt
|
||||
|
||||
echo "========m1======="
|
||||
cat dump.m1.txt
|
||||
@ -30,5 +34,11 @@ cat dump.m2.txt
|
||||
echo "========m3========"
|
||||
cat dump.m3.txt
|
||||
|
||||
echo "========m3v========"
|
||||
cat dump.m3v.txt
|
||||
|
||||
echo "========m3p========"
|
||||
cat dump.m3p.txt
|
||||
|
||||
echo "========full======="
|
||||
cat dump.full.txt
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user