finish mushroom

This commit is contained in:
tqchen
2014-02-24 23:06:57 -08:00
parent 9d6ef11eb5
commit c4949c0937
8 changed files with 88 additions and 3 deletions

View File

@@ -89,16 +89,23 @@ namespace xgboost{
* \param fi input stream
*/
inline void LoadModel( utils::IStream &fi ){
utils::Assert( fi.Read( &mparam, sizeof(ModelParam) ) != 0 );
base_model.LoadModel( fi );
utils::Assert( fi.Read( &mparam, sizeof(ModelParam) ) != 0 );
}
/*!
* \brief DumpModel
* \param fo text file
*/
inline void DumpModel( FILE *fo ){
base_model.DumpModel( fo );
}
/*!
* \brief save model to stream
* \param fo output stream
*/
inline void SaveModel( utils::IStream &fo ) const{
fo.Write( &mparam, sizeof(ModelParam) );
base_model.SaveModel( fo );
fo.Write( &mparam, sizeof(ModelParam) );
}
/*!
* \brief update the model for one iteration

View File

@@ -34,6 +34,10 @@ namespace xgboost{
}
this->InitData();
this->InitLearner();
if( !strcmp( task.c_str(), "dump") ){
this->TaskDump();
return 0;
}
if( !strcmp( task.c_str(), "test") ){
this->TaskTest();
}else{
@@ -68,6 +72,7 @@ namespace xgboost{
task = "train";
model_in = "NULL";
name_pred = "pred.txt";
name_dump = "dump.txt";
model_dir_path = "./";
}
~RegBoostTask( void ){
@@ -77,6 +82,7 @@ namespace xgboost{
}
private:
inline void InitData( void ){
if( !strcmp( task.c_str(), "dump") ) return;
if( !strcmp( task.c_str(), "test") ){
data.CacheLoad( test_path.c_str() );
}else{
@@ -126,6 +132,12 @@ namespace xgboost{
printf("\nupdating end, %lu sec in all\n", elapsed );
}
}
inline void TaskDump( void ){
FILE *fo = utils::FopenCheck( name_dump.c_str(), "w" );
learner.DumpModel( fo );
fclose( fo );
}
inline void SaveModel( int i ) const{
char fname[256];
sprintf( fname ,"%s/%04d.model", model_dir_path.c_str(), i+1 );
@@ -161,6 +173,8 @@ namespace xgboost{
std::string task;
/* \brief name of predict file */
std::string name_pred;
/* \brief name of dump file */
std::string name_dump;
/* \brief the paths of validation data sets */
std::vector<std::string> eval_data_paths;
/* \brief the names of the evaluation data used in output log */