finish mushroom
This commit is contained in:
@@ -89,16 +89,23 @@ namespace xgboost{
|
||||
* \param fi input stream
|
||||
*/
|
||||
inline void LoadModel( utils::IStream &fi ){
|
||||
utils::Assert( fi.Read( &mparam, sizeof(ModelParam) ) != 0 );
|
||||
base_model.LoadModel( fi );
|
||||
utils::Assert( fi.Read( &mparam, sizeof(ModelParam) ) != 0 );
|
||||
}
|
||||
/*!
|
||||
* \brief DumpModel
|
||||
* \param fo text file
|
||||
*/
|
||||
inline void DumpModel( FILE *fo ){
|
||||
base_model.DumpModel( fo );
|
||||
}
|
||||
/*!
|
||||
* \brief save model to stream
|
||||
* \param fo output stream
|
||||
*/
|
||||
inline void SaveModel( utils::IStream &fo ) const{
|
||||
fo.Write( &mparam, sizeof(ModelParam) );
|
||||
base_model.SaveModel( fo );
|
||||
fo.Write( &mparam, sizeof(ModelParam) );
|
||||
}
|
||||
/*!
|
||||
* \brief update the model for one iteration
|
||||
|
||||
@@ -34,6 +34,10 @@ namespace xgboost{
|
||||
}
|
||||
this->InitData();
|
||||
this->InitLearner();
|
||||
if( !strcmp( task.c_str(), "dump") ){
|
||||
this->TaskDump();
|
||||
return 0;
|
||||
}
|
||||
if( !strcmp( task.c_str(), "test") ){
|
||||
this->TaskTest();
|
||||
}else{
|
||||
@@ -68,6 +72,7 @@ namespace xgboost{
|
||||
task = "train";
|
||||
model_in = "NULL";
|
||||
name_pred = "pred.txt";
|
||||
name_dump = "dump.txt";
|
||||
model_dir_path = "./";
|
||||
}
|
||||
~RegBoostTask( void ){
|
||||
@@ -77,6 +82,7 @@ namespace xgboost{
|
||||
}
|
||||
private:
|
||||
inline void InitData( void ){
|
||||
if( !strcmp( task.c_str(), "dump") ) return;
|
||||
if( !strcmp( task.c_str(), "test") ){
|
||||
data.CacheLoad( test_path.c_str() );
|
||||
}else{
|
||||
@@ -126,6 +132,12 @@ namespace xgboost{
|
||||
printf("\nupdating end, %lu sec in all\n", elapsed );
|
||||
}
|
||||
}
|
||||
|
||||
inline void TaskDump( void ){
|
||||
FILE *fo = utils::FopenCheck( name_dump.c_str(), "w" );
|
||||
learner.DumpModel( fo );
|
||||
fclose( fo );
|
||||
}
|
||||
inline void SaveModel( int i ) const{
|
||||
char fname[256];
|
||||
sprintf( fname ,"%s/%04d.model", model_dir_path.c_str(), i+1 );
|
||||
@@ -161,6 +173,8 @@ namespace xgboost{
|
||||
std::string task;
|
||||
/* \brief name of predict file */
|
||||
std::string name_pred;
|
||||
/* \brief name of dump file */
|
||||
std::string name_dump;
|
||||
/* \brief the paths of validation data sets */
|
||||
std::vector<std::string> eval_data_paths;
|
||||
/* \brief the names of the evaluation data used in output log */
|
||||
|
||||
Reference in New Issue
Block a user