finish mushroom
This commit is contained in:
parent
9d6ef11eb5
commit
c4949c0937
@ -544,7 +544,11 @@ namespace xgboost{
|
||||
"input data smaller than num feature" );
|
||||
int pid = this->GetLeafIndex( feat, funknown, gid );
|
||||
return tree[ pid ].leaf_value();
|
||||
}
|
||||
}
|
||||
|
||||
virtual void DumpModel( FILE *fo ){
|
||||
tree.DumpModel( fo );
|
||||
}
|
||||
public:
|
||||
RTreeTrainer( void ){ silent = 0; }
|
||||
virtual ~RTreeTrainer( void ){}
|
||||
|
||||
@ -284,6 +284,26 @@ namespace xgboost{
|
||||
inline int num_extra_nodes( void ) const {
|
||||
return param.num_nodes - param.num_roots - param.num_deleted;
|
||||
}
|
||||
/*! \brief dump model to text file */
|
||||
inline void DumpModel( FILE *fo ){
|
||||
this->Dump( 0, fo, 0 );
|
||||
}
|
||||
private:
|
||||
void Dump( int nid, FILE *fo, int depth ){
|
||||
for( int i = 0; i < depth; ++ i ){
|
||||
fprintf( fo, "\t" );
|
||||
}
|
||||
if( nodes[ nid ].is_leaf() ){
|
||||
fprintf( fo, "%d:leaf=%f\n", nid, nodes[ nid ].leaf_value() );
|
||||
}else{
|
||||
// right then left,
|
||||
TSplitCond cond = nodes[ nid ].split_cond();
|
||||
fprintf( fo, "%d:[f%u>%f] yes=%d,no=%d\n", nid,
|
||||
nodes[ nid ].split_index(), float(cond), nodes[ nid ].cright(), nodes[ nid ].cleft() );
|
||||
this->Dump( nodes[ nid ].cright(), fo, depth+1 );
|
||||
this->Dump( nodes[ nid ].cleft(), fo, depth+1 );
|
||||
}
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
|
||||
@ -92,6 +92,13 @@ namespace xgboost{
|
||||
* \param fo output stream
|
||||
*/
|
||||
virtual void PrintInfo( FILE *fo ){}
|
||||
/*!
|
||||
* \brief dump model into text file
|
||||
* \param fo output stream
|
||||
*/
|
||||
virtual void DumpModel( FILE *fo ){
|
||||
utils::Error( "not implemented" );
|
||||
}
|
||||
public:
|
||||
/*! \brief virtual destructor */
|
||||
virtual ~IBooster( void ){}
|
||||
|
||||
@ -169,6 +169,16 @@ namespace xgboost{
|
||||
this->ConfigBooster( this->boosters[i] );
|
||||
}
|
||||
}
|
||||
/*!
|
||||
* \brief DumpModel
|
||||
* \param fo text file
|
||||
*/
|
||||
inline void DumpModel( FILE *fo ){
|
||||
for( size_t i = 0; i < boosters.size(); i ++ ){
|
||||
fprintf( fo, "booster[%d]\n", (int)i );
|
||||
boosters[i]->DumpModel( fo );
|
||||
}
|
||||
}
|
||||
public:
|
||||
/*!
|
||||
* \brief do gradient boost training for one step, using the information given
|
||||
|
||||
21
demo/mushroom/maptree.py
Executable file
21
demo/mushroom/maptree.py
Executable file
@ -0,0 +1,21 @@
|
||||
#!/usr/bin/python
|
||||
import sys
|
||||
|
||||
def loadnmap( fname ):
|
||||
nmap = {}
|
||||
for l in open(fname):
|
||||
arr = l.split()
|
||||
nmap[int(arr[0])] = arr[1].strip()
|
||||
return nmap
|
||||
|
||||
fo = sys.stdout
|
||||
nmap = loadnmap( 'featname.txt' )
|
||||
for l in open( 'dump.txt'):
|
||||
idx = l.find('[f')
|
||||
if idx == -1:
|
||||
fo.write(l)
|
||||
else:
|
||||
fid = int( l[idx+2:len(l)].split('>')[0])
|
||||
rl = l[0:idx]+'['+nmap[fid]+']' + l.split()[1].strip()+'\n'
|
||||
fo.write(rl)
|
||||
|
||||
@ -2,3 +2,5 @@
|
||||
python mapfeat.py
|
||||
python mknfold.py agaricus.txt 1
|
||||
../../xgboost mushroom.conf
|
||||
../../xgboost mushroom.conf task=dump model_in=0003.model
|
||||
python maptree.py
|
||||
|
||||
@ -89,16 +89,23 @@ namespace xgboost{
|
||||
* \param fi input stream
|
||||
*/
|
||||
inline void LoadModel( utils::IStream &fi ){
|
||||
utils::Assert( fi.Read( &mparam, sizeof(ModelParam) ) != 0 );
|
||||
base_model.LoadModel( fi );
|
||||
utils::Assert( fi.Read( &mparam, sizeof(ModelParam) ) != 0 );
|
||||
}
|
||||
/*!
|
||||
* \brief DumpModel
|
||||
* \param fo text file
|
||||
*/
|
||||
inline void DumpModel( FILE *fo ){
|
||||
base_model.DumpModel( fo );
|
||||
}
|
||||
/*!
|
||||
* \brief save model to stream
|
||||
* \param fo output stream
|
||||
*/
|
||||
inline void SaveModel( utils::IStream &fo ) const{
|
||||
fo.Write( &mparam, sizeof(ModelParam) );
|
||||
base_model.SaveModel( fo );
|
||||
fo.Write( &mparam, sizeof(ModelParam) );
|
||||
}
|
||||
/*!
|
||||
* \brief update the model for one iteration
|
||||
|
||||
@ -34,6 +34,10 @@ namespace xgboost{
|
||||
}
|
||||
this->InitData();
|
||||
this->InitLearner();
|
||||
if( !strcmp( task.c_str(), "dump") ){
|
||||
this->TaskDump();
|
||||
return 0;
|
||||
}
|
||||
if( !strcmp( task.c_str(), "test") ){
|
||||
this->TaskTest();
|
||||
}else{
|
||||
@ -68,6 +72,7 @@ namespace xgboost{
|
||||
task = "train";
|
||||
model_in = "NULL";
|
||||
name_pred = "pred.txt";
|
||||
name_dump = "dump.txt";
|
||||
model_dir_path = "./";
|
||||
}
|
||||
~RegBoostTask( void ){
|
||||
@ -77,6 +82,7 @@ namespace xgboost{
|
||||
}
|
||||
private:
|
||||
inline void InitData( void ){
|
||||
if( !strcmp( task.c_str(), "dump") ) return;
|
||||
if( !strcmp( task.c_str(), "test") ){
|
||||
data.CacheLoad( test_path.c_str() );
|
||||
}else{
|
||||
@ -126,6 +132,12 @@ namespace xgboost{
|
||||
printf("\nupdating end, %lu sec in all\n", elapsed );
|
||||
}
|
||||
}
|
||||
|
||||
inline void TaskDump( void ){
|
||||
FILE *fo = utils::FopenCheck( name_dump.c_str(), "w" );
|
||||
learner.DumpModel( fo );
|
||||
fclose( fo );
|
||||
}
|
||||
inline void SaveModel( int i ) const{
|
||||
char fname[256];
|
||||
sprintf( fname ,"%s/%04d.model", model_dir_path.c_str(), i+1 );
|
||||
@ -161,6 +173,8 @@ namespace xgboost{
|
||||
std::string task;
|
||||
/* \brief name of predict file */
|
||||
std::string name_pred;
|
||||
/* \brief name of dump file */
|
||||
std::string name_dump;
|
||||
/* \brief the paths of validation data sets */
|
||||
std::vector<std::string> eval_data_paths;
|
||||
/* \brief the names of the evaluation data used in output log */
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user