add pathdump

This commit is contained in:
tqchen 2014-02-26 17:08:23 -08:00
parent 4a612eb3ba
commit 733f8ae393
7 changed files with 91 additions and 19 deletions

View File

@ -57,7 +57,7 @@ namespace xgboost{
virtual int GetLeafIndex( const std::vector<float> &feat, virtual int GetLeafIndex( const std::vector<float> &feat,
const std::vector<bool> &funknown, const std::vector<bool> &funknown,
unsigned gid = 0 ){ unsigned gid = 0 ){
// start from groups that belongs to current data // start from groups that belongs to current data
int pid = (int)gid; int pid = (int)gid;
// tranverse tree // tranverse tree
@ -67,18 +67,28 @@ namespace xgboost{
} }
return pid; return pid;
} }
virtual void PredPath( std::vector<int> &path, const FMatrixS::Line &feat, unsigned gid = 0 ){
path.clear();
this->InitTmp();
this->PrepareTmp( feat );
int pid = (int)gid;
path.push_back( pid );
// tranverse tree
while( !tree[ pid ].is_leaf() ){
unsigned split_index = tree[ pid ].split_index();
pid = this->GetNext( pid, tmp_feat[ split_index ], tmp_funknown[ split_index ] );
path.push_back( pid );
}
this->DropTmp( feat );
}
virtual float Predict( const FMatrixS::Line &feat, unsigned gid = 0 ){ virtual float Predict( const FMatrixS::Line &feat, unsigned gid = 0 ){
this->InitTmp(); this->InitTmp();
for( unsigned i = 0; i < feat.len; i ++ ){ this->PrepareTmp( feat );
utils::Assert( feat[i].findex < (unsigned)tmp_funknown.size() , "input feature execeed bound" );
tmp_funknown[ feat[i].findex ] = false;
tmp_feat[ feat[i].findex ] = feat[i].fvalue;
}
int pid = this->GetLeafIndex( tmp_feat, tmp_funknown, gid ); int pid = this->GetLeafIndex( tmp_feat, tmp_funknown, gid );
// set back this->DropTmp( feat );
for( unsigned i = 0; i < feat.len; i ++ ){
tmp_funknown[ feat[i].findex ] = true;
}
return tree[ pid ].leaf_value(); return tree[ pid ].leaf_value();
} }
virtual float Predict( const std::vector<float> &feat, virtual float Predict( const std::vector<float> &feat,
@ -127,6 +137,18 @@ namespace xgboost{
std::fill( tmp_funknown.begin(), tmp_funknown.end(), true ); std::fill( tmp_funknown.begin(), tmp_funknown.end(), true );
} }
} }
inline void PrepareTmp( const FMatrixS::Line &feat ){
for( unsigned i = 0; i < feat.len; i ++ ){
utils::Assert( feat[i].findex < (unsigned)tmp_funknown.size() , "input feature execeed bound" );
tmp_funknown[ feat[i].findex ] = false;
tmp_feat[ feat[i].findex ] = feat[i].fvalue;
}
}
inline void DropTmp( const FMatrixS::Line &feat ){
for( unsigned i = 0; i < feat.len; i ++ ){
tmp_funknown[ feat[i].findex ] = true;
}
}
inline int GetNext( int pid, float fvalue, bool is_unknown ){ inline int GetNext( int pid, float fvalue, bool is_unknown ){
float split_value = tree[ pid ].split_cond(); float split_value = tree[ pid ].split_cond();

View File

@ -65,6 +65,14 @@ namespace xgboost{
std::vector<float> &hess, std::vector<float> &hess,
const FMatrixS &feats, const FMatrixS &feats,
const std::vector<unsigned> &root_index ) = 0; const std::vector<unsigned> &root_index ) = 0;
/*!
* \brief predict the path ids along a trees, for given sparse feature vector. When booster is a tree
* \param path the result of path
* \param rid root id of current instance, default = 0
*/
virtual void PredPath( std::vector<int> &path, const FMatrixS::Line &feat, unsigned rid = 0 ){
utils::Error( "not implemented" );
}
/*! /*!
* \brief predict values for given sparse feature vector * \brief predict values for given sparse feature vector
* NOTE: in tree implementation, this is not threadsafe, used dense version to ensure threadsafety * NOTE: in tree implementation, this is not threadsafe, used dense version to ensure threadsafety

View File

@ -179,6 +179,25 @@ namespace xgboost{
boosters[i]->DumpModel( fo ); boosters[i]->DumpModel( fo );
} }
} }
/*!
* \brief Dump path of all trees
* \param fo text file
* \param data input data
*/
inline void DumpPath( FILE *fo, const FMatrixS &data ){
for( size_t i = 0; i < data.NumRow(); ++ i ){
for( size_t j = 0; j < boosters.size(); ++ j ){
if( j != 0 ) fprintf( fo, "\t" );
std::vector<int> path;
boosters[j]->PredPath( path, data[i] );
fprintf( fo, "%d", path[0] );
for( size_t k = 1; k < path.size(); ++ k ){
fprintf( fo, ",%d", path[k] );
}
}
fprintf( fo, "\n" );
}
}
public: public:
/*! /*!
* \brief do gradient boost training for one step, using the information given * \brief do gradient boost training for one step, using the information given

View File

@ -4,6 +4,8 @@ save_period=0
data = "agaricus.txt.train" data = "agaricus.txt.train"
eval[test] = "agaricus.txt.test" eval[test] = "agaricus.txt.test"
test:data = "agaricus.txt.test"
booster_type = 0 booster_type = 0
loss_type = 2 loss_type = 2

View File

@ -3,4 +3,5 @@ python mapfeat.py
python mknfold.py agaricus.txt 1 python mknfold.py agaricus.txt 1
../../xgboost mushroom.conf ../../xgboost mushroom.conf
../../xgboost mushroom.conf task=dump model_in=0003.model ../../xgboost mushroom.conf task=dump model_in=0003.model
../../xgboost mushroom.conf task=dumppath model_in=0003.model
python maptree.py python maptree.py

View File

@ -99,6 +99,14 @@ namespace xgboost{
inline void DumpModel( FILE *fo ){ inline void DumpModel( FILE *fo ){
base_model.DumpModel( fo ); base_model.DumpModel( fo );
} }
/*!
* \brief Dump path of all trees
* \param fo text file
* \param data input data
*/
inline void DumpPath( FILE *fo, const DMatrix &data ){
base_model.DumpPath( fo, data.data );
}
/*! /*!
* \brief save model to stream * \brief save model to stream
* \param fo output stream * \param fo output stream

View File

@ -34,11 +34,15 @@ namespace xgboost{
} }
this->InitData(); this->InitData();
this->InitLearner(); this->InitLearner();
if( !strcmp( task.c_str(), "dump") ){ if( task == "dump" ){
this->TaskDump(); this->TaskDump();
return 0; return 0;
} }
if( !strcmp( task.c_str(), "test") ){ if( task == "dumppath" ){
this->TaskDumpPath();
return 0;
}
if( task == "test" ){
this->TaskTest(); this->TaskTest();
}else{ }else{
this->TaskTrain(); this->TaskTrain();
@ -73,6 +77,7 @@ namespace xgboost{
model_in = "NULL"; model_in = "NULL";
name_pred = "pred.txt"; name_pred = "pred.txt";
name_dump = "dump.txt"; name_dump = "dump.txt";
name_dumppath = "dump.path.txt";
model_dir_path = "./"; model_dir_path = "./";
} }
~RegBoostTask( void ){ ~RegBoostTask( void ){
@ -82,8 +87,8 @@ namespace xgboost{
} }
private: private:
inline void InitData( void ){ inline void InitData( void ){
if( !strcmp( task.c_str(), "dump") ) return; if( task == "dump") return;
if( !strcmp( task.c_str(), "test") ){ if( task == "test" || task == "dumppath" ){
data.CacheLoad( test_path.c_str() ); data.CacheLoad( test_path.c_str() );
}else{ }else{
// training // training
@ -101,12 +106,12 @@ namespace xgboost{
while( cfg.Next() ){ while( cfg.Next() ){
learner.SetParam( cfg.name(), cfg.val() ); learner.SetParam( cfg.name(), cfg.val() );
} }
if( strcmp( model_in.c_str(), "NULL" ) != 0 ){ if( model_in != "NULL" ){
utils::FileStream fi( utils::FopenCheck( model_in.c_str(), "rb") ); utils::FileStream fi( utils::FopenCheck( model_in.c_str(), "rb") );
learner.LoadModel( fi ); learner.LoadModel( fi );
fi.Close(); fi.Close();
}else{ }else{
utils::Assert( !strcmp( task.c_str(), "train"), "model_in not specified" ); utils::Assert( task == "train", "model_in not specified" );
learner.InitModel(); learner.InitModel();
} }
learner.InitTrainer(); learner.InitTrainer();
@ -138,6 +143,11 @@ namespace xgboost{
learner.DumpModel( fo ); learner.DumpModel( fo );
fclose( fo ); fclose( fo );
} }
inline void TaskDumpPath( void ){
FILE *fo = utils::FopenCheck( name_dumppath.c_str(), "w" );
learner.DumpPath( fo, data );
fclose( fo );
}
inline void SaveModel( int i ) const{ inline void SaveModel( int i ) const{
char fname[256]; char fname[256];
sprintf( fname ,"%s/%04d.model", model_dir_path.c_str(), i+1 ); sprintf( fname ,"%s/%04d.model", model_dir_path.c_str(), i+1 );
@ -175,6 +185,8 @@ namespace xgboost{
std::string name_pred; std::string name_pred;
/* \brief name of dump file */ /* \brief name of dump file */
std::string name_dump; std::string name_dump;
/* \brief name of dump path file */
std::string name_dumppath;
/* \brief the paths of validation data sets */ /* \brief the paths of validation data sets */
std::vector<std::string> eval_data_paths; std::vector<std::string> eval_data_paths;
/* \brief the names of the evaluation data used in output log */ /* \brief the names of the evaluation data used in output log */