try interact mode
This commit is contained in:
parent
2bdcad9630
commit
ef5a389ecf
@ -95,7 +95,7 @@ namespace xgboost{
|
|||||||
if( s.leaf_child_cnt >= 2 && param.need_prune( s.loss_chg, depth - 1 ) ){
|
if( s.leaf_child_cnt >= 2 && param.need_prune( s.loss_chg, depth - 1 ) ){
|
||||||
this->stat_num_pruned += 2;
|
this->stat_num_pruned += 2;
|
||||||
// need to be pruned
|
// need to be pruned
|
||||||
tree.ChangeToLeaf( pid, param.learning_rate * snode[pid].weight );
|
tree.ChangeToLeaf( pid, param.learning_rate * s.base_weight );
|
||||||
// tail recursion
|
// tail recursion
|
||||||
this->TryPruneLeaf( pid, depth - 1 );
|
this->TryPruneLeaf( pid, depth - 1 );
|
||||||
}
|
}
|
||||||
|
|||||||
@ -105,9 +105,11 @@ namespace xgboost{
|
|||||||
snode[nid].sum_hess = sum_hess;
|
snode[nid].sum_hess = sum_hess;
|
||||||
snode[nid].root_gain = param.CalcRootGain( sum_grad, sum_hess );
|
snode[nid].root_gain = param.CalcRootGain( sum_grad, sum_hess );
|
||||||
if( !tree[nid].is_root() ){
|
if( !tree[nid].is_root() ){
|
||||||
snode[nid].weight = param.CalcWeight( sum_grad, sum_hess, snode[ tree[nid].parent() ].weight );
|
snode[nid].weight = param.CalcWeight( sum_grad, sum_hess, tree.stat( tree[nid].parent() ).base_weight );
|
||||||
|
tree.stat(nid).base_weight = snode[nid].weight;
|
||||||
}else{
|
}else{
|
||||||
snode[nid].weight = param.CalcWeight( sum_grad, sum_hess, 0.0f );
|
snode[nid].weight = param.CalcWeight( sum_grad, sum_hess, 0.0f );
|
||||||
|
tree.stat(nid).base_weight = snode[nid].weight;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -30,6 +30,17 @@ namespace xgboost{
|
|||||||
utils::Assert( grad.size() == hess.size(), "booster:invalid input" );
|
utils::Assert( grad.size() == hess.size(), "booster:invalid input" );
|
||||||
utils::Assert( smat.NumRow() == hess.size(), "booster:invalid input" );
|
utils::Assert( smat.NumRow() == hess.size(), "booster:invalid input" );
|
||||||
utils::Assert( root_index.size() == 0 || root_index.size() == hess.size(), "booster:invalid input" );
|
utils::Assert( root_index.size() == 0 || root_index.size() == hess.size(), "booster:invalid input" );
|
||||||
|
{// setup temp space for each thread
|
||||||
|
if( param.nthread != 0 ){
|
||||||
|
omp_set_num_threads( param.nthread );
|
||||||
|
}
|
||||||
|
#pragma omp parallel
|
||||||
|
{
|
||||||
|
this->nthread = omp_get_num_threads();
|
||||||
|
}
|
||||||
|
tmp_rptr.resize( this->nthread, std::vector<size_t>() );
|
||||||
|
snode.reserve( 256 );
|
||||||
|
}
|
||||||
}
|
}
|
||||||
inline void Make( int& stat_max_depth, int& stat_num_pruned ){
|
inline void Make( int& stat_max_depth, int& stat_num_pruned ){
|
||||||
this->InitData();
|
this->InitData();
|
||||||
@ -51,7 +62,38 @@ namespace xgboost{
|
|||||||
}
|
}
|
||||||
// start prunning the tree
|
// start prunning the tree
|
||||||
stat_num_pruned = this->DoPrune();
|
stat_num_pruned = this->DoPrune();
|
||||||
}
|
}
|
||||||
|
// expand a specific node
|
||||||
|
inline bool Expand( const std::vector<bst_uint> &valid_index, int nid ){
|
||||||
|
if( valid_index.size() == 0 ) return false;
|
||||||
|
this->InitDataExpand( valid_index, nid );
|
||||||
|
this->InitNewNode( this->qexpand );
|
||||||
|
this->FindSplit( nid, tmp_rptr[0] );
|
||||||
|
|
||||||
|
// update node statistics
|
||||||
|
for( size_t i = 0; i < qexpand.size(); ++ i ){
|
||||||
|
const int nid = qexpand[i];
|
||||||
|
tree.stat( nid ).loss_chg = snode[ nid ].best.loss_chg;
|
||||||
|
tree.stat( nid ).sum_hess = static_cast<float>( snode[ nid ].sum_hess );
|
||||||
|
}
|
||||||
|
// change the leaf
|
||||||
|
this->UpdateQueueExpand( this->qexpand );
|
||||||
|
this->InitNewNode( this->qexpand );
|
||||||
|
|
||||||
|
// set all the rest expanding nodes to leaf
|
||||||
|
for( size_t i = 0; i < qexpand.size(); ++ i ){
|
||||||
|
const int nid = qexpand[i];
|
||||||
|
tree[ nid ].set_leaf( snode[nid].weight * param.learning_rate );
|
||||||
|
tree.stat( nid ).loss_chg = 0.0f;
|
||||||
|
tree.stat( nid ).sum_hess = static_cast<float>( snode[ nid ].sum_hess );
|
||||||
|
tree.param.max_depth = std::max( tree.param.max_depth, tree.GetDepth( nid ) );
|
||||||
|
}
|
||||||
|
if( qexpand.size() != 0 ) {
|
||||||
|
return true;
|
||||||
|
}else{
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
private:
|
private:
|
||||||
// make leaf nodes for all qexpand, update node statistics, mark leaf value
|
// make leaf nodes for all qexpand, update node statistics, mark leaf value
|
||||||
inline void InitNewNode( const std::vector<int> &qexpand ){
|
inline void InitNewNode( const std::vector<int> &qexpand ){
|
||||||
@ -71,9 +113,11 @@ namespace xgboost{
|
|||||||
|
|
||||||
snode[nid].root_gain = param.CalcRootGain( sum_grad, sum_hess );
|
snode[nid].root_gain = param.CalcRootGain( sum_grad, sum_hess );
|
||||||
if( !tree[nid].is_root() ){
|
if( !tree[nid].is_root() ){
|
||||||
snode[nid].weight = param.CalcWeight( sum_grad, sum_hess, snode[ tree[nid].parent() ].weight );
|
snode[nid].weight = param.CalcWeight( sum_grad, sum_hess, tree.stat( tree[nid].parent() ).base_weight );
|
||||||
|
tree.stat(nid).base_weight = snode[nid].weight;
|
||||||
}else{
|
}else{
|
||||||
snode[nid].weight = param.CalcWeight( sum_grad, sum_hess, 0.0f );
|
snode[nid].weight = param.CalcWeight( sum_grad, sum_hess, 0.0f );
|
||||||
|
tree.stat(nid).base_weight = snode[nid].weight;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -288,18 +332,7 @@ namespace xgboost{
|
|||||||
node_bound[i-1] = std::make_pair( rptr[ i - 1 ], rptr[ i ] );
|
node_bound[i-1] = std::make_pair( rptr[ i - 1 ], rptr[ i ] );
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
{// setup temp space for each thread
|
|
||||||
if( param.nthread != 0 ){
|
|
||||||
omp_set_num_threads( param.nthread );
|
|
||||||
}
|
|
||||||
#pragma omp parallel
|
|
||||||
{
|
|
||||||
this->nthread = omp_get_num_threads();
|
|
||||||
}
|
|
||||||
tmp_rptr.resize( this->nthread, std::vector<size_t>() );
|
|
||||||
snode.reserve( 256 );
|
|
||||||
}
|
|
||||||
{// expand query
|
{// expand query
|
||||||
qexpand.reserve( 256 ); qexpand.clear();
|
qexpand.reserve( 256 ); qexpand.clear();
|
||||||
for( int i = 0; i < tree.param.num_roots; ++ i ){
|
for( int i = 0; i < tree.param.num_roots; ++ i ){
|
||||||
@ -307,6 +340,15 @@ namespace xgboost{
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// initialize temp data structure
|
||||||
|
inline void InitDataExpand( const std::vector<bst_uint> &valid_index, int nid ){
|
||||||
|
row_index_set = valid_index;
|
||||||
|
node_bound.resize( tree.param.num_nodes );
|
||||||
|
node_bound[ nid ] = std::make_pair( 0, (bst_uint)row_index_set.size() );
|
||||||
|
|
||||||
|
qexpand.clear(); qexpand.push_back( nid );
|
||||||
|
}
|
||||||
private:
|
private:
|
||||||
// number of omp thread used during training
|
// number of omp thread used during training
|
||||||
int nthread;
|
int nthread;
|
||||||
|
|||||||
@ -34,7 +34,10 @@ namespace xgboost{
|
|||||||
class RegTreeTrainer : public InterfaceBooster<FMatrix>{
|
class RegTreeTrainer : public InterfaceBooster<FMatrix>{
|
||||||
public:
|
public:
|
||||||
RegTreeTrainer( void ){
|
RegTreeTrainer( void ){
|
||||||
silent = 0; tree_maker = 1;
|
silent = 0; tree_maker = 1;
|
||||||
|
// interact mode
|
||||||
|
interact_type = 0;
|
||||||
|
interact_node = 0;
|
||||||
// normally we won't have more than 64 OpenMP threads
|
// normally we won't have more than 64 OpenMP threads
|
||||||
threadtemp.resize( 64, ThreadEntry() );
|
threadtemp.resize( 64, ThreadEntry() );
|
||||||
}
|
}
|
||||||
@ -43,6 +46,16 @@ namespace xgboost{
|
|||||||
virtual void SetParam( const char *name, const char *val ){
|
virtual void SetParam( const char *name, const char *val ){
|
||||||
if( !strcmp( name, "silent") ) silent = atoi( val );
|
if( !strcmp( name, "silent") ) silent = atoi( val );
|
||||||
if( !strcmp( name, "tree_maker") ) tree_maker = atoi( val );
|
if( !strcmp( name, "tree_maker") ) tree_maker = atoi( val );
|
||||||
|
if( !strncmp( name, "interact:", 9) ){
|
||||||
|
const char *ename = name + 9;
|
||||||
|
interact_node = atoi( val );
|
||||||
|
if( !strcmp( ename, "expand") ) {
|
||||||
|
interact_type = 1;
|
||||||
|
}
|
||||||
|
if( !strcmp( ename, "remove") ) {
|
||||||
|
interact_type = 2;
|
||||||
|
}
|
||||||
|
}
|
||||||
param.SetParam( name, val );
|
param.SetParam( name, val );
|
||||||
tree.param.SetParam( name, val );
|
tree.param.SetParam( name, val );
|
||||||
}
|
}
|
||||||
@ -61,6 +74,16 @@ namespace xgboost{
|
|||||||
const FMatrix &smat,
|
const FMatrix &smat,
|
||||||
const std::vector<unsigned> &root_index ){
|
const std::vector<unsigned> &root_index ){
|
||||||
utils::Assert( grad.size() < UINT_MAX, "number of instance exceed what we can handle" );
|
utils::Assert( grad.size() < UINT_MAX, "number of instance exceed what we can handle" );
|
||||||
|
|
||||||
|
// interactive update
|
||||||
|
if( interact_type != 0 ){
|
||||||
|
switch( interact_type ){
|
||||||
|
case 1: this->ExpandNode( grad, hess, smat, root_index, interact_node ); return;
|
||||||
|
case 2:
|
||||||
|
default: utils::Error("unknown interact type");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if( !silent ){
|
if( !silent ){
|
||||||
printf( "\nbuild GBRT with %u instances\n", (unsigned)grad.size() );
|
printf( "\nbuild GBRT with %u instances\n", (unsigned)grad.size() );
|
||||||
}
|
}
|
||||||
@ -135,10 +158,38 @@ namespace xgboost{
|
|||||||
tree.DumpModel( fo, fmap, with_stats );
|
tree.DumpModel( fo, fmap, with_stats );
|
||||||
}
|
}
|
||||||
private:
|
private:
|
||||||
|
inline void ExpandNode( std::vector<float> &grad,
|
||||||
|
std::vector<float> &hess,
|
||||||
|
const FMatrix &fmat,
|
||||||
|
const std::vector<unsigned> &root_index,
|
||||||
|
int nid ){
|
||||||
|
std::vector<bst_uint> valid_index;
|
||||||
|
for( size_t i = 0; i < grad.size(); i ++ ){
|
||||||
|
ThreadEntry &e = this->InitTmp();
|
||||||
|
this->PrepareTmp( fmat.GetRow(i), e );
|
||||||
|
unsigned rtidx = root_index.size() == 0 ? 0 : root_index[i];
|
||||||
|
int pid = this->GetLeafIndex( e.feat, e.funknown, rtidx );
|
||||||
|
this->DropTmp( fmat.GetRow(i), e );
|
||||||
|
if( pid == nid ) valid_index.push_back( static_cast<bst_uint>(i) );
|
||||||
|
}
|
||||||
|
RowTreeMaker<FMatrix> maker( tree, param, grad, hess, fmat, root_index );
|
||||||
|
bool success = maker.Expand( valid_index, nid );
|
||||||
|
if( !silent ){
|
||||||
|
printf( "tree expand end, success=%d, max_depth=%d\n", (int)success, tree.param.max_depth );
|
||||||
|
}
|
||||||
|
}
|
||||||
|
private:
|
||||||
|
// silent
|
||||||
int silent;
|
int silent;
|
||||||
int tree_maker;
|
|
||||||
RegTree tree;
|
RegTree tree;
|
||||||
TreeParamTrain param;
|
TreeParamTrain param;
|
||||||
|
private:
|
||||||
|
// some training parameters
|
||||||
|
// tree maker
|
||||||
|
int tree_maker;
|
||||||
|
// interaction
|
||||||
|
int interact_type;
|
||||||
|
int interact_node;
|
||||||
private:
|
private:
|
||||||
struct ThreadEntry{
|
struct ThreadEntry{
|
||||||
std::vector<float> feat;
|
std::vector<float> feat;
|
||||||
|
|||||||
1
demo/test/README
Normal file
1
demo/test/README
Normal file
@ -0,0 +1 @@
|
|||||||
|
test folder to test new functions
|
||||||
@ -4,9 +4,31 @@ python mapfeat.py
|
|||||||
# split train and test
|
# split train and test
|
||||||
python mknfold.py agaricus.txt 1
|
python mknfold.py agaricus.txt 1
|
||||||
# training
|
# training
|
||||||
../../xgboost mushroom.conf
|
../../xgboost mushroom.conf num_round=1 model_out=full.model bst:max_depth=3
|
||||||
# this is what dump will looklike without feature map
|
../../xgboost mushroom.conf task=dump model_in=full.model fmap=featmap.txt name_dump=dump.full.txt
|
||||||
../../xgboost mushroom.conf task=dump model_in=0003.model name_dump=dump.raw.txt
|
|
||||||
|
# training
|
||||||
|
../../xgboost mushroom.conf num_round=1 model_out=m1.model bst:max_depth=1
|
||||||
|
|
||||||
# this is what dump will looklike with feature map
|
# this is what dump will looklike with feature map
|
||||||
../../xgboost mushroom.conf task=dump model_in=0003.model fmap=featmap.txt name_dump=dump.nice.txt
|
../../xgboost mushroom.conf task=dump model_in=m1.model fmap=featmap.txt name_dump=dump.m1.txt
|
||||||
cat dump.nice.txt
|
|
||||||
|
# interaction
|
||||||
|
../../xgboost mushroom.conf task=interact model_in=m1.model model_out=m2.model interact:booster_index=0 bst:interact:expand=1
|
||||||
|
../../xgboost mushroom.conf task=interact model_in=m2.model model_out=m3.model interact:booster_index=0 bst:interact:expand=2
|
||||||
|
|
||||||
|
# this is what dump will looklike with feature map
|
||||||
|
../../xgboost mushroom.conf task=dump model_in=m2.model fmap=featmap.txt name_dump=dump.m2.txt
|
||||||
|
../../xgboost mushroom.conf task=dump model_in=m3.model fmap=featmap.txt name_dump=dump.m3.txt
|
||||||
|
|
||||||
|
echo "========m1======="
|
||||||
|
cat dump.m1.txt
|
||||||
|
|
||||||
|
echo "========m2========"
|
||||||
|
cat dump.m2.txt
|
||||||
|
|
||||||
|
echo "========m3========"
|
||||||
|
cat dump.m3.txt
|
||||||
|
|
||||||
|
echo "========full======="
|
||||||
|
cat dump.full.txt
|
||||||
|
|||||||
@ -39,7 +39,7 @@ namespace xgboost{
|
|||||||
this->TaskDump();
|
this->TaskDump();
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
if( task == "interactive" ){
|
if( task == "interact" ){
|
||||||
this->TaskInteractive();
|
this->TaskInteractive();
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user