finish mushroom

This commit is contained in:
tqchen 2014-02-24 23:06:57 -08:00
parent 9d6ef11eb5
commit c4949c0937
8 changed files with 88 additions and 3 deletions

View File

@ -544,7 +544,11 @@ namespace xgboost{
"input data smaller than num feature" );
int pid = this->GetLeafIndex( feat, funknown, gid );
return tree[ pid ].leaf_value();
}
}
virtual void DumpModel( FILE *fo ){
tree.DumpModel( fo );
}
public:
RTreeTrainer( void ){ silent = 0; }
virtual ~RTreeTrainer( void ){}

View File

@ -284,6 +284,26 @@ namespace xgboost{
inline int num_extra_nodes( void ) const {
return param.num_nodes - param.num_roots - param.num_deleted;
}
/*! \brief dump model to text file */
inline void DumpModel( FILE *fo ){
this->Dump( 0, fo, 0 );
}
private:
void Dump( int nid, FILE *fo, int depth ){
for( int i = 0; i < depth; ++ i ){
fprintf( fo, "\t" );
}
if( nodes[ nid ].is_leaf() ){
fprintf( fo, "%d:leaf=%f\n", nid, nodes[ nid ].leaf_value() );
}else{
// right then left,
TSplitCond cond = nodes[ nid ].split_cond();
fprintf( fo, "%d:[f%u>%f] yes=%d,no=%d\n", nid,
nodes[ nid ].split_index(), float(cond), nodes[ nid ].cright(), nodes[ nid ].cleft() );
this->Dump( nodes[ nid ].cright(), fo, depth+1 );
this->Dump( nodes[ nid ].cleft(), fo, depth+1 );
}
}
};
};

View File

@ -92,6 +92,13 @@ namespace xgboost{
* \param fo output stream
*/
virtual void PrintInfo( FILE *fo ){}
/*!
* \brief dump model into text file
* \param fo output stream
*/
virtual void DumpModel( FILE *fo ){
utils::Error( "not implemented" );
}
public:
/*! \brief virtual destructor */
virtual ~IBooster( void ){}

View File

@ -169,6 +169,16 @@ namespace xgboost{
this->ConfigBooster( this->boosters[i] );
}
}
/*!
* \brief DumpModel
* \param fo text file
*/
inline void DumpModel( FILE *fo ){
for( size_t i = 0; i < boosters.size(); i ++ ){
fprintf( fo, "booster[%d]\n", (int)i );
boosters[i]->DumpModel( fo );
}
}
public:
/*!
* \brief do gradient boost training for one step, using the information given

21
demo/mushroom/maptree.py Executable file
View File

@ -0,0 +1,21 @@
#!/usr/bin/python
import sys
def loadnmap( fname ):
nmap = {}
for l in open(fname):
arr = l.split()
nmap[int(arr[0])] = arr[1].strip()
return nmap
fo = sys.stdout
nmap = loadnmap( 'featname.txt' )
for l in open( 'dump.txt'):
idx = l.find('[f')
if idx == -1:
fo.write(l)
else:
fid = int( l[idx+2:len(l)].split('>')[0])
rl = l[0:idx]+'['+nmap[fid]+']' + l.split()[1].strip()+'\n'
fo.write(rl)

View File

@ -2,3 +2,5 @@
python mapfeat.py
python mknfold.py agaricus.txt 1
../../xgboost mushroom.conf
../../xgboost mushroom.conf task=dump model_in=0003.model
python maptree.py

View File

@ -89,16 +89,23 @@ namespace xgboost{
* \param fi input stream
*/
inline void LoadModel( utils::IStream &fi ){
utils::Assert( fi.Read( &mparam, sizeof(ModelParam) ) != 0 );
base_model.LoadModel( fi );
utils::Assert( fi.Read( &mparam, sizeof(ModelParam) ) != 0 );
}
/*!
* \brief DumpModel
* \param fo text file
*/
inline void DumpModel( FILE *fo ){
base_model.DumpModel( fo );
}
/*!
* \brief save model to stream
* \param fo output stream
*/
inline void SaveModel( utils::IStream &fo ) const{
fo.Write( &mparam, sizeof(ModelParam) );
base_model.SaveModel( fo );
fo.Write( &mparam, sizeof(ModelParam) );
}
/*!
* \brief update the model for one iteration

View File

@ -34,6 +34,10 @@ namespace xgboost{
}
this->InitData();
this->InitLearner();
if( !strcmp( task.c_str(), "dump") ){
this->TaskDump();
return 0;
}
if( !strcmp( task.c_str(), "test") ){
this->TaskTest();
}else{
@ -68,6 +72,7 @@ namespace xgboost{
task = "train";
model_in = "NULL";
name_pred = "pred.txt";
name_dump = "dump.txt";
model_dir_path = "./";
}
~RegBoostTask( void ){
@ -77,6 +82,7 @@ namespace xgboost{
}
private:
inline void InitData( void ){
if( !strcmp( task.c_str(), "dump") ) return;
if( !strcmp( task.c_str(), "test") ){
data.CacheLoad( test_path.c_str() );
}else{
@ -126,6 +132,12 @@ namespace xgboost{
printf("\nupdating end, %lu sec in all\n", elapsed );
}
}
inline void TaskDump( void ){
FILE *fo = utils::FopenCheck( name_dump.c_str(), "w" );
learner.DumpModel( fo );
fclose( fo );
}
inline void SaveModel( int i ) const{
char fname[256];
sprintf( fname ,"%s/%04d.model", model_dir_path.c_str(), i+1 );
@ -161,6 +173,8 @@ namespace xgboost{
std::string task;
/* \brief name of predict file */
std::string name_pred;
/* \brief name of dump file */
std::string name_dump;
/* \brief the paths of validation data sets */
std::vector<std::string> eval_data_paths;
/* \brief the names of the evaluation data used in output log */