try interact mode

This commit is contained in:
tqchen 2014-03-05 15:28:53 -08:00
parent 2bdcad9630
commit ef5a389ecf
7 changed files with 142 additions and 24 deletions

View File

@ -95,7 +95,7 @@ namespace xgboost{
if( s.leaf_child_cnt >= 2 && param.need_prune( s.loss_chg, depth - 1 ) ){
this->stat_num_pruned += 2;
// need to be pruned
tree.ChangeToLeaf( pid, param.learning_rate * snode[pid].weight );
tree.ChangeToLeaf( pid, param.learning_rate * s.base_weight );
// tail recursion
this->TryPruneLeaf( pid, depth - 1 );
}

View File

@ -105,9 +105,11 @@ namespace xgboost{
snode[nid].sum_hess = sum_hess;
snode[nid].root_gain = param.CalcRootGain( sum_grad, sum_hess );
if( !tree[nid].is_root() ){
snode[nid].weight = param.CalcWeight( sum_grad, sum_hess, snode[ tree[nid].parent() ].weight );
snode[nid].weight = param.CalcWeight( sum_grad, sum_hess, tree.stat( tree[nid].parent() ).base_weight );
tree.stat(nid).base_weight = snode[nid].weight;
}else{
snode[nid].weight = param.CalcWeight( sum_grad, sum_hess, 0.0f );
tree.stat(nid).base_weight = snode[nid].weight;
}
}
}

View File

@ -30,6 +30,17 @@ namespace xgboost{
utils::Assert( grad.size() == hess.size(), "booster:invalid input" );
utils::Assert( smat.NumRow() == hess.size(), "booster:invalid input" );
utils::Assert( root_index.size() == 0 || root_index.size() == hess.size(), "booster:invalid input" );
{// setup temp space for each thread
if( param.nthread != 0 ){
omp_set_num_threads( param.nthread );
}
#pragma omp parallel
{
this->nthread = omp_get_num_threads();
}
tmp_rptr.resize( this->nthread, std::vector<size_t>() );
snode.reserve( 256 );
}
}
inline void Make( int& stat_max_depth, int& stat_num_pruned ){
this->InitData();
@ -51,7 +62,38 @@ namespace xgboost{
}
// start prunning the tree
stat_num_pruned = this->DoPrune();
}
}
// expand a specific node
inline bool Expand( const std::vector<bst_uint> &valid_index, int nid ){
if( valid_index.size() == 0 ) return false;
this->InitDataExpand( valid_index, nid );
this->InitNewNode( this->qexpand );
this->FindSplit( nid, tmp_rptr[0] );
// update node statistics
for( size_t i = 0; i < qexpand.size(); ++ i ){
const int nid = qexpand[i];
tree.stat( nid ).loss_chg = snode[ nid ].best.loss_chg;
tree.stat( nid ).sum_hess = static_cast<float>( snode[ nid ].sum_hess );
}
// change the leaf
this->UpdateQueueExpand( this->qexpand );
this->InitNewNode( this->qexpand );
// set all the rest expanding nodes to leaf
for( size_t i = 0; i < qexpand.size(); ++ i ){
const int nid = qexpand[i];
tree[ nid ].set_leaf( snode[nid].weight * param.learning_rate );
tree.stat( nid ).loss_chg = 0.0f;
tree.stat( nid ).sum_hess = static_cast<float>( snode[ nid ].sum_hess );
tree.param.max_depth = std::max( tree.param.max_depth, tree.GetDepth( nid ) );
}
if( qexpand.size() != 0 ) {
return true;
}else{
return false;
}
}
private:
// make leaf nodes for all qexpand, update node statistics, mark leaf value
inline void InitNewNode( const std::vector<int> &qexpand ){
@ -71,9 +113,11 @@ namespace xgboost{
snode[nid].root_gain = param.CalcRootGain( sum_grad, sum_hess );
if( !tree[nid].is_root() ){
snode[nid].weight = param.CalcWeight( sum_grad, sum_hess, snode[ tree[nid].parent() ].weight );
snode[nid].weight = param.CalcWeight( sum_grad, sum_hess, tree.stat( tree[nid].parent() ).base_weight );
tree.stat(nid).base_weight = snode[nid].weight;
}else{
snode[nid].weight = param.CalcWeight( sum_grad, sum_hess, 0.0f );
tree.stat(nid).base_weight = snode[nid].weight;
}
}
}
@ -288,18 +332,7 @@ namespace xgboost{
node_bound[i-1] = std::make_pair( rptr[ i - 1 ], rptr[ i ] );
}
}
{// setup temp space for each thread
if( param.nthread != 0 ){
omp_set_num_threads( param.nthread );
}
#pragma omp parallel
{
this->nthread = omp_get_num_threads();
}
tmp_rptr.resize( this->nthread, std::vector<size_t>() );
snode.reserve( 256 );
}
{// expand query
qexpand.reserve( 256 ); qexpand.clear();
for( int i = 0; i < tree.param.num_roots; ++ i ){
@ -307,6 +340,15 @@ namespace xgboost{
}
}
}
// initialize temp data structure
inline void InitDataExpand( const std::vector<bst_uint> &valid_index, int nid ){
row_index_set = valid_index;
node_bound.resize( tree.param.num_nodes );
node_bound[ nid ] = std::make_pair( 0, (bst_uint)row_index_set.size() );
qexpand.clear(); qexpand.push_back( nid );
}
private:
// number of omp thread used during training
int nthread;

View File

@ -34,7 +34,10 @@ namespace xgboost{
class RegTreeTrainer : public InterfaceBooster<FMatrix>{
public:
RegTreeTrainer( void ){
silent = 0; tree_maker = 1;
silent = 0; tree_maker = 1;
// interact mode
interact_type = 0;
interact_node = 0;
// normally we won't have more than 64 OpenMP threads
threadtemp.resize( 64, ThreadEntry() );
}
@ -43,6 +46,16 @@ namespace xgboost{
virtual void SetParam( const char *name, const char *val ){
if( !strcmp( name, "silent") ) silent = atoi( val );
if( !strcmp( name, "tree_maker") ) tree_maker = atoi( val );
if( !strncmp( name, "interact:", 9) ){
const char *ename = name + 9;
interact_node = atoi( val );
if( !strcmp( ename, "expand") ) {
interact_type = 1;
}
if( !strcmp( ename, "remove") ) {
interact_type = 2;
}
}
param.SetParam( name, val );
tree.param.SetParam( name, val );
}
@ -61,6 +74,16 @@ namespace xgboost{
const FMatrix &smat,
const std::vector<unsigned> &root_index ){
utils::Assert( grad.size() < UINT_MAX, "number of instance exceed what we can handle" );
// interactive update
if( interact_type != 0 ){
switch( interact_type ){
case 1: this->ExpandNode( grad, hess, smat, root_index, interact_node ); return;
case 2:
default: utils::Error("unknown interact type");
}
}
if( !silent ){
printf( "\nbuild GBRT with %u instances\n", (unsigned)grad.size() );
}
@ -135,10 +158,38 @@ namespace xgboost{
tree.DumpModel( fo, fmap, with_stats );
}
private:
inline void ExpandNode( std::vector<float> &grad,
std::vector<float> &hess,
const FMatrix &fmat,
const std::vector<unsigned> &root_index,
int nid ){
std::vector<bst_uint> valid_index;
for( size_t i = 0; i < grad.size(); i ++ ){
ThreadEntry &e = this->InitTmp();
this->PrepareTmp( fmat.GetRow(i), e );
unsigned rtidx = root_index.size() == 0 ? 0 : root_index[i];
int pid = this->GetLeafIndex( e.feat, e.funknown, rtidx );
this->DropTmp( fmat.GetRow(i), e );
if( pid == nid ) valid_index.push_back( static_cast<bst_uint>(i) );
}
RowTreeMaker<FMatrix> maker( tree, param, grad, hess, fmat, root_index );
bool success = maker.Expand( valid_index, nid );
if( !silent ){
printf( "tree expand end, success=%d, max_depth=%d\n", (int)success, tree.param.max_depth );
}
}
private:
// silent
int silent;
int tree_maker;
RegTree tree;
TreeParamTrain param;
private:
// some training parameters
// tree maker
int tree_maker;
// interaction
int interact_type;
int interact_node;
private:
struct ThreadEntry{
std::vector<float> feat;

1
demo/test/README Normal file
View File

@ -0,0 +1 @@
test folder to test new functions

View File

@ -4,9 +4,31 @@ python mapfeat.py
# split train and test
python mknfold.py agaricus.txt 1
# training
../../xgboost mushroom.conf
# this is what dump will looklike without feature map
../../xgboost mushroom.conf task=dump model_in=0003.model name_dump=dump.raw.txt
../../xgboost mushroom.conf num_round=1 model_out=full.model bst:max_depth=3
../../xgboost mushroom.conf task=dump model_in=full.model fmap=featmap.txt name_dump=dump.full.txt
# training
../../xgboost mushroom.conf num_round=1 model_out=m1.model bst:max_depth=1
# this is what dump will looklike with feature map
../../xgboost mushroom.conf task=dump model_in=0003.model fmap=featmap.txt name_dump=dump.nice.txt
cat dump.nice.txt
../../xgboost mushroom.conf task=dump model_in=m1.model fmap=featmap.txt name_dump=dump.m1.txt
# interaction
../../xgboost mushroom.conf task=interact model_in=m1.model model_out=m2.model interact:booster_index=0 bst:interact:expand=1
../../xgboost mushroom.conf task=interact model_in=m2.model model_out=m3.model interact:booster_index=0 bst:interact:expand=2
# this is what dump will looklike with feature map
../../xgboost mushroom.conf task=dump model_in=m2.model fmap=featmap.txt name_dump=dump.m2.txt
../../xgboost mushroom.conf task=dump model_in=m3.model fmap=featmap.txt name_dump=dump.m3.txt
echo "========m1======="
cat dump.m1.txt
echo "========m2========"
cat dump.m2.txt
echo "========m3========"
cat dump.m3.txt
echo "========full======="
cat dump.full.txt

View File

@ -39,7 +39,7 @@ namespace xgboost{
this->TaskDump();
return 0;
}
if( task == "interactive" ){
if( task == "interact" ){
this->TaskInteractive();
return 0;
}