add add and remove
This commit is contained in:
parent
f62c5dc3c1
commit
70f3f31206
@ -83,6 +83,7 @@ namespace xgboost{
|
|||||||
// set all the rest expanding nodes to leaf
|
// set all the rest expanding nodes to leaf
|
||||||
for( size_t i = 0; i < qexpand.size(); ++ i ){
|
for( size_t i = 0; i < qexpand.size(); ++ i ){
|
||||||
const int nid = qexpand[i];
|
const int nid = qexpand[i];
|
||||||
|
|
||||||
tree[ nid ].set_leaf( snode[nid].weight * param.learning_rate );
|
tree[ nid ].set_leaf( snode[nid].weight * param.learning_rate );
|
||||||
tree.stat( nid ).loss_chg = 0.0f;
|
tree.stat( nid ).loss_chg = 0.0f;
|
||||||
tree.stat( nid ).sum_hess = static_cast<float>( snode[ nid ].sum_hess );
|
tree.stat( nid ).sum_hess = static_cast<float>( snode[ nid ].sum_hess );
|
||||||
@ -94,6 +95,15 @@ namespace xgboost{
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// collapse specific node
|
||||||
|
inline void Collapse( const std::vector<bst_uint> &valid_index, int nid ){
|
||||||
|
if( valid_index.size() == 0 ) return;
|
||||||
|
this->InitDataExpand( valid_index, nid );
|
||||||
|
this->InitNewNode( this->qexpand );
|
||||||
|
tree.stat( nid ).loss_chg = 0.0f;
|
||||||
|
tree.stat( nid ).sum_hess = static_cast<float>( snode[ nid ].sum_hess );
|
||||||
|
tree.CollapseToLeaf( nid, snode[nid].weight * param.learning_rate );
|
||||||
|
}
|
||||||
private:
|
private:
|
||||||
// make leaf nodes for all qexpand, update node statistics, mark leaf value
|
// make leaf nodes for all qexpand, update node statistics, mark leaf value
|
||||||
inline void InitNewNode( const std::vector<int> &qexpand ){
|
inline void InitNewNode( const std::vector<int> &qexpand ){
|
||||||
|
|||||||
@ -79,7 +79,7 @@ namespace xgboost{
|
|||||||
if( interact_type != 0 ){
|
if( interact_type != 0 ){
|
||||||
switch( interact_type ){
|
switch( interact_type ){
|
||||||
case 1: this->ExpandNode( grad, hess, smat, root_index, interact_node ); return;
|
case 1: this->ExpandNode( grad, hess, smat, root_index, interact_node ); return;
|
||||||
case 2:
|
case 2: this->CollapseNode( grad, hess, smat, root_index, interact_node ); return;
|
||||||
default: utils::Error("unknown interact type");
|
default: utils::Error("unknown interact type");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -108,7 +108,7 @@ namespace xgboost{
|
|||||||
}
|
}
|
||||||
if( !silent ){
|
if( !silent ){
|
||||||
printf( "tree train end, %d roots, %d extra nodes, %d pruned nodes ,max_depth=%d\n",
|
printf( "tree train end, %d roots, %d extra nodes, %d pruned nodes ,max_depth=%d\n",
|
||||||
tree.param.num_roots, tree.num_extra_nodes(), num_pruned, tree.param.max_depth );
|
tree.param.num_roots, tree.num_extra_nodes(), num_pruned, tree.MaxDepth() );
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
virtual float Predict( const FMatrix &fmat, bst_uint ridx, unsigned gid = 0 ){
|
virtual float Predict( const FMatrix &fmat, bst_uint ridx, unsigned gid = 0 ){
|
||||||
@ -158,6 +158,32 @@ namespace xgboost{
|
|||||||
tree.DumpModel( fo, fmap, with_stats );
|
tree.DumpModel( fo, fmap, with_stats );
|
||||||
}
|
}
|
||||||
private:
|
private:
|
||||||
|
inline void CollapseNode( std::vector<float> &grad,
|
||||||
|
std::vector<float> &hess,
|
||||||
|
const FMatrix &fmat,
|
||||||
|
const std::vector<unsigned> &root_index,
|
||||||
|
int nid ){
|
||||||
|
std::vector<bst_uint> valid_index;
|
||||||
|
for( size_t i = 0; i < grad.size(); i ++ ){
|
||||||
|
ThreadEntry &e = this->InitTmp();
|
||||||
|
this->PrepareTmp( fmat.GetRow(i), e );
|
||||||
|
int pid = root_index.size() == 0 ? 0 : (int)root_index[i];
|
||||||
|
// tranverse tree
|
||||||
|
while( !tree[ pid ].is_leaf() ){
|
||||||
|
unsigned split_index = tree[ pid ].split_index();
|
||||||
|
pid = this->GetNext( pid, e.feat[ split_index ], e.funknown[ split_index ] );
|
||||||
|
if( pid == nid ){
|
||||||
|
valid_index.push_back( static_cast<bst_uint>(i) ); break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
this->DropTmp( fmat.GetRow(i), e );
|
||||||
|
}
|
||||||
|
RowTreeMaker<FMatrix> maker( tree, param, grad, hess, fmat, root_index );
|
||||||
|
maker.Collapse( valid_index, nid );
|
||||||
|
if( !silent ){
|
||||||
|
printf( "tree collapse end, max_depth=%d\n", tree.param.max_depth );
|
||||||
|
}
|
||||||
|
}
|
||||||
inline void ExpandNode( std::vector<float> &grad,
|
inline void ExpandNode( std::vector<float> &grad,
|
||||||
std::vector<float> &hess,
|
std::vector<float> &hess,
|
||||||
const FMatrix &fmat,
|
const FMatrix &fmat,
|
||||||
@ -175,7 +201,7 @@ namespace xgboost{
|
|||||||
RowTreeMaker<FMatrix> maker( tree, param, grad, hess, fmat, root_index );
|
RowTreeMaker<FMatrix> maker( tree, param, grad, hess, fmat, root_index );
|
||||||
bool success = maker.Expand( valid_index, nid );
|
bool success = maker.Expand( valid_index, nid );
|
||||||
if( !silent ){
|
if( !silent ){
|
||||||
printf( "tree expand end, success=%d, max_depth=%d\n", (int)success, tree.param.max_depth );
|
printf( "tree expand end, success=%d, max_depth=%d\n", (int)success, tree.MaxDepth() );
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
private:
|
private:
|
||||||
|
|||||||
@ -197,6 +197,21 @@ namespace xgboost{
|
|||||||
this->DeleteNode( nodes[ rid ].cright() );
|
this->DeleteNode( nodes[ rid ].cright() );
|
||||||
nodes[ rid ].set_leaf( value );
|
nodes[ rid ].set_leaf( value );
|
||||||
}
|
}
|
||||||
|
/*!
|
||||||
|
* \brief collapse a non leaf node to a leaf node, delete its children
|
||||||
|
* \param rid node id of the node
|
||||||
|
* \param new leaf value
|
||||||
|
*/
|
||||||
|
inline void CollapseToLeaf( int rid, float value ){
|
||||||
|
if( nodes[rid].is_leaf() ) return;
|
||||||
|
if( !nodes[ nodes[rid].cleft() ].is_leaf() ){
|
||||||
|
CollapseToLeaf( nodes[rid].cleft(), 0.0f );
|
||||||
|
}
|
||||||
|
if( !nodes[ nodes[rid].cright() ].is_leaf() ){
|
||||||
|
CollapseToLeaf( nodes[rid].cright(), 0.0f );
|
||||||
|
}
|
||||||
|
this->ChangeToLeaf( rid, value );
|
||||||
|
}
|
||||||
public:
|
public:
|
||||||
/*! \brief model parameter */
|
/*! \brief model parameter */
|
||||||
Param param;
|
Param param;
|
||||||
@ -287,6 +302,25 @@ namespace xgboost{
|
|||||||
}
|
}
|
||||||
return depth;
|
return depth;
|
||||||
}
|
}
|
||||||
|
/*!
|
||||||
|
* \brief get maximum depth
|
||||||
|
* \param nid node id
|
||||||
|
*/
|
||||||
|
inline int MaxDepth( int nid ) const{
|
||||||
|
if( nodes[nid].is_leaf() ) return 0;
|
||||||
|
return std::max( MaxDepth( nodes[nid].cleft() )+1,
|
||||||
|
MaxDepth( nodes[nid].cright() )+1 );
|
||||||
|
}
|
||||||
|
/*!
|
||||||
|
* \brief get maximum depth
|
||||||
|
*/
|
||||||
|
inline int MaxDepth( void ){
|
||||||
|
int maxd = 0;
|
||||||
|
for( int i = 0; i < param.num_roots; ++ i ){
|
||||||
|
maxd = std::max( maxd, MaxDepth( i ) );
|
||||||
|
}
|
||||||
|
return maxd;
|
||||||
|
}
|
||||||
/*! \brief number of extra nodes besides the root */
|
/*! \brief number of extra nodes besides the root */
|
||||||
inline int num_extra_nodes( void ) const {
|
inline int num_extra_nodes( void ) const {
|
||||||
return param.num_nodes - param.num_roots - param.num_deleted;
|
return param.num_nodes - param.num_roots - param.num_deleted;
|
||||||
|
|||||||
@ -16,10 +16,14 @@ python mknfold.py agaricus.txt 1
|
|||||||
# interaction
|
# interaction
|
||||||
../../xgboost mushroom.conf task=interact model_in=m1.model model_out=m2.model interact:booster_index=0 bst:interact:expand=1
|
../../xgboost mushroom.conf task=interact model_in=m1.model model_out=m2.model interact:booster_index=0 bst:interact:expand=1
|
||||||
../../xgboost mushroom.conf task=interact model_in=m2.model model_out=m3.model interact:booster_index=0 bst:interact:expand=2
|
../../xgboost mushroom.conf task=interact model_in=m2.model model_out=m3.model interact:booster_index=0 bst:interact:expand=2
|
||||||
|
../../xgboost mushroom.conf task=interact model_in=m3.model model_out=m3v.model interact:booster_index=0 bst:interact:remove=2
|
||||||
|
../../xgboost mushroom.conf task=interact model_in=m3v.model model_out=m3p.model interact:booster_index=0 bst:interact:expand=2
|
||||||
|
|
||||||
# this is what dump will looklike with feature map
|
# this is what dump will looklike with feature map
|
||||||
../../xgboost mushroom.conf task=dump model_in=m2.model fmap=featmap.txt name_dump=dump.m2.txt
|
../../xgboost mushroom.conf task=dump model_in=m2.model fmap=featmap.txt name_dump=dump.m2.txt
|
||||||
../../xgboost mushroom.conf task=dump model_in=m3.model fmap=featmap.txt name_dump=dump.m3.txt
|
../../xgboost mushroom.conf task=dump model_in=m3.model fmap=featmap.txt name_dump=dump.m3.txt
|
||||||
|
../../xgboost mushroom.conf task=dump model_in=m3v.model fmap=featmap.txt name_dump=dump.m3v.txt
|
||||||
|
../../xgboost mushroom.conf task=dump model_in=m3p.model fmap=featmap.txt name_dump=dump.m3p.txt
|
||||||
|
|
||||||
echo "========m1======="
|
echo "========m1======="
|
||||||
cat dump.m1.txt
|
cat dump.m1.txt
|
||||||
@ -30,5 +34,11 @@ cat dump.m2.txt
|
|||||||
echo "========m3========"
|
echo "========m3========"
|
||||||
cat dump.m3.txt
|
cat dump.m3.txt
|
||||||
|
|
||||||
|
echo "========m3v========"
|
||||||
|
cat dump.m3v.txt
|
||||||
|
|
||||||
|
echo "========m3p========"
|
||||||
|
cat dump.m3p.txt
|
||||||
|
|
||||||
echo "========full======="
|
echo "========full======="
|
||||||
cat dump.full.txt
|
cat dump.full.txt
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user