add add and remove
This commit is contained in:
@@ -79,7 +79,7 @@ namespace xgboost{
|
||||
if( interact_type != 0 ){
|
||||
switch( interact_type ){
|
||||
case 1: this->ExpandNode( grad, hess, smat, root_index, interact_node ); return;
|
||||
case 2:
|
||||
case 2: this->CollapseNode( grad, hess, smat, root_index, interact_node ); return;
|
||||
default: utils::Error("unknown interact type");
|
||||
}
|
||||
}
|
||||
@@ -108,7 +108,7 @@ namespace xgboost{
|
||||
}
|
||||
if( !silent ){
|
||||
printf( "tree train end, %d roots, %d extra nodes, %d pruned nodes ,max_depth=%d\n",
|
||||
tree.param.num_roots, tree.num_extra_nodes(), num_pruned, tree.param.max_depth );
|
||||
tree.param.num_roots, tree.num_extra_nodes(), num_pruned, tree.MaxDepth() );
|
||||
}
|
||||
}
|
||||
virtual float Predict( const FMatrix &fmat, bst_uint ridx, unsigned gid = 0 ){
|
||||
@@ -158,6 +158,32 @@ namespace xgboost{
|
||||
tree.DumpModel( fo, fmap, with_stats );
|
||||
}
|
||||
private:
|
||||
inline void CollapseNode( std::vector<float> &grad,
|
||||
std::vector<float> &hess,
|
||||
const FMatrix &fmat,
|
||||
const std::vector<unsigned> &root_index,
|
||||
int nid ){
|
||||
std::vector<bst_uint> valid_index;
|
||||
for( size_t i = 0; i < grad.size(); i ++ ){
|
||||
ThreadEntry &e = this->InitTmp();
|
||||
this->PrepareTmp( fmat.GetRow(i), e );
|
||||
int pid = root_index.size() == 0 ? 0 : (int)root_index[i];
|
||||
// tranverse tree
|
||||
while( !tree[ pid ].is_leaf() ){
|
||||
unsigned split_index = tree[ pid ].split_index();
|
||||
pid = this->GetNext( pid, e.feat[ split_index ], e.funknown[ split_index ] );
|
||||
if( pid == nid ){
|
||||
valid_index.push_back( static_cast<bst_uint>(i) ); break;
|
||||
}
|
||||
}
|
||||
this->DropTmp( fmat.GetRow(i), e );
|
||||
}
|
||||
RowTreeMaker<FMatrix> maker( tree, param, grad, hess, fmat, root_index );
|
||||
maker.Collapse( valid_index, nid );
|
||||
if( !silent ){
|
||||
printf( "tree collapse end, max_depth=%d\n", tree.param.max_depth );
|
||||
}
|
||||
}
|
||||
inline void ExpandNode( std::vector<float> &grad,
|
||||
std::vector<float> &hess,
|
||||
const FMatrix &fmat,
|
||||
@@ -175,7 +201,7 @@ namespace xgboost{
|
||||
RowTreeMaker<FMatrix> maker( tree, param, grad, hess, fmat, root_index );
|
||||
bool success = maker.Expand( valid_index, nid );
|
||||
if( !silent ){
|
||||
printf( "tree expand end, success=%d, max_depth=%d\n", (int)success, tree.param.max_depth );
|
||||
printf( "tree expand end, success=%d, max_depth=%d\n", (int)success, tree.MaxDepth() );
|
||||
}
|
||||
}
|
||||
private:
|
||||
|
||||
Reference in New Issue
Block a user