try interact mode
This commit is contained in:
@@ -30,6 +30,17 @@ namespace xgboost{
|
||||
utils::Assert( grad.size() == hess.size(), "booster:invalid input" );
|
||||
utils::Assert( smat.NumRow() == hess.size(), "booster:invalid input" );
|
||||
utils::Assert( root_index.size() == 0 || root_index.size() == hess.size(), "booster:invalid input" );
|
||||
{// setup temp space for each thread
|
||||
if( param.nthread != 0 ){
|
||||
omp_set_num_threads( param.nthread );
|
||||
}
|
||||
#pragma omp parallel
|
||||
{
|
||||
this->nthread = omp_get_num_threads();
|
||||
}
|
||||
tmp_rptr.resize( this->nthread, std::vector<size_t>() );
|
||||
snode.reserve( 256 );
|
||||
}
|
||||
}
|
||||
inline void Make( int& stat_max_depth, int& stat_num_pruned ){
|
||||
this->InitData();
|
||||
@@ -51,7 +62,38 @@ namespace xgboost{
|
||||
}
|
||||
// start prunning the tree
|
||||
stat_num_pruned = this->DoPrune();
|
||||
}
|
||||
}
|
||||
// expand a specific node
|
||||
inline bool Expand( const std::vector<bst_uint> &valid_index, int nid ){
|
||||
if( valid_index.size() == 0 ) return false;
|
||||
this->InitDataExpand( valid_index, nid );
|
||||
this->InitNewNode( this->qexpand );
|
||||
this->FindSplit( nid, tmp_rptr[0] );
|
||||
|
||||
// update node statistics
|
||||
for( size_t i = 0; i < qexpand.size(); ++ i ){
|
||||
const int nid = qexpand[i];
|
||||
tree.stat( nid ).loss_chg = snode[ nid ].best.loss_chg;
|
||||
tree.stat( nid ).sum_hess = static_cast<float>( snode[ nid ].sum_hess );
|
||||
}
|
||||
// change the leaf
|
||||
this->UpdateQueueExpand( this->qexpand );
|
||||
this->InitNewNode( this->qexpand );
|
||||
|
||||
// set all the rest expanding nodes to leaf
|
||||
for( size_t i = 0; i < qexpand.size(); ++ i ){
|
||||
const int nid = qexpand[i];
|
||||
tree[ nid ].set_leaf( snode[nid].weight * param.learning_rate );
|
||||
tree.stat( nid ).loss_chg = 0.0f;
|
||||
tree.stat( nid ).sum_hess = static_cast<float>( snode[ nid ].sum_hess );
|
||||
tree.param.max_depth = std::max( tree.param.max_depth, tree.GetDepth( nid ) );
|
||||
}
|
||||
if( qexpand.size() != 0 ) {
|
||||
return true;
|
||||
}else{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
private:
|
||||
// make leaf nodes for all qexpand, update node statistics, mark leaf value
|
||||
inline void InitNewNode( const std::vector<int> &qexpand ){
|
||||
@@ -71,9 +113,11 @@ namespace xgboost{
|
||||
|
||||
snode[nid].root_gain = param.CalcRootGain( sum_grad, sum_hess );
|
||||
if( !tree[nid].is_root() ){
|
||||
snode[nid].weight = param.CalcWeight( sum_grad, sum_hess, snode[ tree[nid].parent() ].weight );
|
||||
snode[nid].weight = param.CalcWeight( sum_grad, sum_hess, tree.stat( tree[nid].parent() ).base_weight );
|
||||
tree.stat(nid).base_weight = snode[nid].weight;
|
||||
}else{
|
||||
snode[nid].weight = param.CalcWeight( sum_grad, sum_hess, 0.0f );
|
||||
tree.stat(nid).base_weight = snode[nid].weight;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -288,18 +332,7 @@ namespace xgboost{
|
||||
node_bound[i-1] = std::make_pair( rptr[ i - 1 ], rptr[ i ] );
|
||||
}
|
||||
}
|
||||
|
||||
{// setup temp space for each thread
|
||||
if( param.nthread != 0 ){
|
||||
omp_set_num_threads( param.nthread );
|
||||
}
|
||||
#pragma omp parallel
|
||||
{
|
||||
this->nthread = omp_get_num_threads();
|
||||
}
|
||||
tmp_rptr.resize( this->nthread, std::vector<size_t>() );
|
||||
snode.reserve( 256 );
|
||||
}
|
||||
|
||||
{// expand query
|
||||
qexpand.reserve( 256 ); qexpand.clear();
|
||||
for( int i = 0; i < tree.param.num_roots; ++ i ){
|
||||
@@ -307,6 +340,15 @@ namespace xgboost{
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// initialize temp data structure
|
||||
inline void InitDataExpand( const std::vector<bst_uint> &valid_index, int nid ){
|
||||
row_index_set = valid_index;
|
||||
node_bound.resize( tree.param.num_nodes );
|
||||
node_bound[ nid ] = std::make_pair( 0, (bst_uint)row_index_set.size() );
|
||||
|
||||
qexpand.clear(); qexpand.push_back( nid );
|
||||
}
|
||||
private:
|
||||
// number of omp thread used during training
|
||||
int nthread;
|
||||
|
||||
Reference in New Issue
Block a user