From 0fdda294707b977bbaff0859e9ea7145e2b015a4 Mon Sep 17 00:00:00 2001 From: tqchen Date: Tue, 4 Mar 2014 22:47:39 -0800 Subject: [PATCH] reupdate data --- booster/xgboost_gbmbase.h | 257 +++++++++++++++++--------------- regression/xgboost_reg.h | 84 +++++++++-- regression/xgboost_reg_main.cpp | 35 ++++- 3 files changed, 239 insertions(+), 137 deletions(-) diff --git a/booster/xgboost_gbmbase.h b/booster/xgboost_gbmbase.h index 50746e59e..7c8906edf 100644 --- a/booster/xgboost_gbmbase.h +++ b/booster/xgboost_gbmbase.h @@ -40,69 +40,18 @@ namespace xgboost{ * (4) model.SaveModel to save learned results * * Bufferring: each instance comes with a buffer_index in Predict. - * when param.num_pbuffer != 0, a unique buffer index can be + * when mparam.num_pbuffer != 0, a unique buffer index can be * assigned to each instance to buffer previous results of boosters, * this helps to speedup training, so consider assign buffer_index * for each training instances, if buffer_index = -1, the code * recalculate things from scratch and will still works correctly */ - class GBMBaseModel{ - public: - /*! \brief model parameters */ - struct Param{ - /*! \brief number of boosters */ - int num_boosters; - /*! \brief type of tree used */ - int booster_type; - /*! \brief number of root: default 0, means single tree */ - int num_roots; - /*! \brief number of features to be used by boosters */ - int num_feature; - /*! \brief size of predicton buffer allocated for buffering boosting computation */ - int num_pbuffer; - /*! - * \brief whether we repeatly update a single booster each round: default 0 - * set to 1 for linear booster, so that regularization term can be considered - */ - int do_reboost; - /*! \brief reserved parameters */ - int reserved[ 32 ]; - /*! \brief constructor */ - Param( void ){ - num_boosters = 0; - booster_type = 0; - num_roots = num_feature = 0; - do_reboost = 0; - num_pbuffer = 0; - memset( reserved, 0, sizeof( reserved ) ); - } - /*! - * \brief set parameters from outside - * \param name name of the parameter - * \param val value of the parameter - */ - inline void SetParam( const char *name, const char *val ){ - if( !strcmp("booster_type", name ) ){ - booster_type = atoi( val ); - // linear boost automatically set do reboost - if( booster_type == 1 ) do_reboost = 1; - } - if( !strcmp("num_pbuffer", name ) ) num_pbuffer = atoi( val ); - if( !strcmp("do_reboost", name ) ) do_reboost = atoi( val ); - if( !strcmp("bst:num_roots", name ) ) num_roots = atoi( val ); - if( !strcmp("bst:num_feature", name ) ) num_feature = atoi( val ); - } - }; - public: - /*! \brief model parameters */ - Param param; + class GBMBase{ public: /*! \brief number of thread used */ - GBMBaseModel( void ){ - this->nthread = 1; - } + GBMBase( void ){} /*! \brief destructor */ - virtual ~GBMBaseModel( void ){ + virtual ~GBMBase( void ){ this->FreeSpace(); } /*! @@ -117,8 +66,8 @@ namespace xgboost{ if( !strcmp( name, "silent") ){ cfg.PushBack( name, val ); } - if( !strcmp( name, "nthread") ) nthread = atoi( val ); - if( boosters.size() == 0 ) param.SetParam( name, val ); + tparam.SetParam( name, val ); + if( boosters.size() == 0 ) mparam.SetParam( name, val ); } /*! * \brief load model from stream @@ -126,21 +75,21 @@ namespace xgboost{ */ inline void LoadModel( utils::IStream &fi ){ if( boosters.size() != 0 ) this->FreeSpace(); - utils::Assert( fi.Read( ¶m, sizeof(Param) ) != 0 ); - boosters.resize( param.num_boosters ); + utils::Assert( fi.Read( &mparam, sizeof(ModelParam) ) != 0 ); + boosters.resize( mparam.num_boosters ); for( size_t i = 0; i < boosters.size(); i ++ ){ - boosters[ i ] = booster::CreateBooster( param.booster_type ); + boosters[ i ] = booster::CreateBooster( mparam.booster_type ); boosters[ i ]->LoadModel( fi ); } {// load info - booster_info.resize( param.num_boosters ); - if( param.num_boosters != 0 ){ - utils::Assert( fi.Read( &booster_info[0], sizeof(int)*param.num_boosters ) != 0 ); + booster_info.resize( mparam.num_boosters ); + if( mparam.num_boosters != 0 ){ + utils::Assert( fi.Read( &booster_info[0], sizeof(int)*mparam.num_boosters ) != 0 ); } } - if( param.num_pbuffer != 0 ){ - pred_buffer.resize ( param.num_pbuffer ); - pred_counter.resize( param.num_pbuffer ); + if( mparam.num_pbuffer != 0 ){ + pred_buffer.resize ( mparam.num_pbuffer ); + pred_counter.resize( mparam.num_pbuffer ); utils::Assert( fi.Read( &pred_buffer[0] , pred_buffer.size()*sizeof(float) ) != 0 ); utils::Assert( fi.Read( &pred_counter[0], pred_counter.size()*sizeof(unsigned) ) != 0 ); } @@ -150,15 +99,15 @@ namespace xgboost{ * \param fo output stream */ inline void SaveModel( utils::IStream &fo ) const { - utils::Assert( param.num_boosters == (int)boosters.size() ); - fo.Write( ¶m, sizeof(Param) ); + utils::Assert( mparam.num_boosters == (int)boosters.size() ); + fo.Write( &mparam, sizeof(ModelParam) ); for( size_t i = 0; i < boosters.size(); i ++ ){ boosters[ i ]->SaveModel( fo ); } if( booster_info.size() != 0 ){ fo.Write( &booster_info[0], sizeof(int) * booster_info.size() ); } - if( param.num_pbuffer != 0 ){ + if( mparam.num_pbuffer != 0 ){ fo.Write( &pred_buffer[0] , pred_buffer.size()*sizeof(float) ); fo.Write( &pred_counter[0], pred_counter.size()*sizeof(unsigned) ); } @@ -168,9 +117,9 @@ namespace xgboost{ */ inline void InitModel( void ){ pred_buffer.clear(); pred_counter.clear(); - pred_buffer.resize ( param.num_pbuffer, 0.0 ); - pred_counter.resize( param.num_pbuffer, 0 ); - utils::Assert( param.num_boosters == 0 ); + pred_buffer.resize ( mparam.num_pbuffer, 0.0 ); + pred_counter.resize( mparam.num_pbuffer, 0 ); + utils::Assert( mparam.num_boosters == 0 ); utils::Assert( boosters.size() == 0 ); } /*! @@ -178,8 +127,8 @@ namespace xgboost{ * this function is reserved for solver to allocate necessary space and do other preparation */ inline void InitTrainer( void ){ - if( nthread != 0 ){ - omp_set_num_threads( nthread ); + if( tparam.nthread != 0 ){ + omp_set_num_threads( tparam.nthread ); } // make sure all the boosters get the latest parameters for( size_t i = 0; i < this->boosters.size(); i ++ ){ @@ -233,10 +182,10 @@ namespace xgboost{ const std::vector &root_index ) { booster::IBooster *bst = this->GetUpdateBooster(); bst->DoBoost( grad, hess, feats, root_index ); - } + } /*! * \brief predict values for given sparse feature vector - * NOTE: in tree implementation, this is not threadsafe + * NOTE: in tree implementation, this is only OpenMP threadsafe, but not threadsafe * \param feats feature matrix * \param row_index row index in the feature matrix * \param buffer_index the buffer index of the current feature line, default -1 means no buffer assigned @@ -248,57 +197,51 @@ namespace xgboost{ float psum = 0.0f; // load buffered results if any - if( param.do_reboost == 0 && buffer_index >= 0 ){ - utils::Assert( buffer_index < param.num_pbuffer, "buffer index exceed num_pbuffer" ); + if( mparam.do_reboost == 0 && buffer_index >= 0 ){ + utils::Assert( buffer_index < mparam.num_pbuffer, "buffer index exceed num_pbuffer" ); istart = this->pred_counter[ buffer_index ]; psum = this->pred_buffer [ buffer_index ]; } for( size_t i = istart; i < this->boosters.size(); i ++ ){ psum += this->boosters[ i ]->Predict( feats, row_index, root_index ); - } - + } // updated the buffered results - if( param.do_reboost == 0 && buffer_index >= 0 ){ + if( mparam.do_reboost == 0 && buffer_index >= 0 ){ this->pred_counter[ buffer_index ] = static_cast( boosters.size() ); this->pred_buffer [ buffer_index ] = psum; } return psum; } + public: + //--------trial code for interactive update an existing booster------ + //-------- usually not needed, ignore this region --------- /*! - * \brief predict values for given dense feature vector - * \param feat feature vector in dense format - * \param funknown indicator that the feature is missing - * \param buffer_index the buffer index of the current feature line, default -1 means no buffer assigned - * \param rid root id of current instance, default = 0 - * \return prediction - */ - virtual float Predict( const std::vector &feat, - const std::vector &funknown, - int buffer_index = -1, - unsigned rid = 0 ){ - size_t istart = 0; - float psum = 0.0f; - - // load buffered results if any - if( param.do_reboost == 0 && buffer_index >= 0 ){ - utils::Assert( buffer_index < param.num_pbuffer, - "buffer index exceed num_pbuffer" ); - istart = this->pred_counter[ buffer_index ]; - psum = this->pred_buffer [ buffer_index ]; - } - - for( size_t i = istart; i < this->boosters.size(); i ++ ){ - psum += this->boosters[ i ]->Predict( feat, funknown, rid ); - } - - // updated the buffered results - if( param.do_reboost == 0 && buffer_index >= 0 ){ - this->pred_counter[ buffer_index ] = static_cast( boosters.size() ); - this->pred_buffer [ buffer_index ] = psum; + * \brief same as Predict, but removes the prediction of booster to be updated + * this function must be called once and only once for every data with pbuffer + */ + inline float InteractPredict( const FMatrixS &feats, bst_uint row_index, int buffer_index = -1, unsigned root_index = 0 ){ + float psum = this->Predict( feats, row_index, buffer_index, root_index ); + if( tparam.reupdate_booster != -1 ){ + const int bid = tparam.reupdate_booster; + utils::Assert( bid >= 0 && bid < (int)boosters.size(), "interact:booster_index exceed existing bound" ); + psum -= boosters[ bid ]->Predict( feats, row_index, root_index ); + if( mparam.do_reboost == 0 && buffer_index >= 0 ){ + this->pred_buffer[ buffer_index ] = psum; + } } return psum; } + /*! \brief update the prediction buffer, after booster have been updated */ + inline void InteractRePredict( const FMatrixS &feats, bst_uint row_index, int buffer_index = -1, unsigned root_index = 0 ){ + if( tparam.reupdate_booster != -1 ){ + const int bid = tparam.reupdate_booster; + utils::Assert( bid >= 0 && bid < (int)boosters.size(), "interact:booster_index exceed existing bound" ); + if( mparam.do_reboost == 0 && buffer_index >= 0 ){ + this->pred_buffer[ buffer_index ] += boosters[ bid ]->Predict( feats, row_index, root_index ); + } + } + } //-----------non public fields afterwards------------- protected: /*! \brief free space of the model */ @@ -306,7 +249,7 @@ namespace xgboost{ for( size_t i = 0; i < boosters.size(); i ++ ){ delete boosters[i]; } - boosters.clear(); booster_info.clear(); param.num_boosters = 0; + boosters.clear(); booster_info.clear(); mparam.num_boosters = 0; } /*! \brief configure a booster */ inline void ConfigBooster( booster::IBooster *bst ){ @@ -320,9 +263,16 @@ namespace xgboost{ * \return the booster created */ inline booster::IBooster *GetUpdateBooster( void ){ - if( param.do_reboost == 0 || boosters.size() == 0 ){ - param.num_boosters += 1; - boosters.push_back( booster::CreateBooster( param.booster_type ) ); + if( tparam.reupdate_booster != -1 ){ + const int bid = tparam.reupdate_booster; + utils::Assert( bid >= 0 && bid < (int)boosters.size(), "interact:booster_index exceed existing bound" ); + this->ConfigBooster( boosters[bid] ); + return boosters[ bid ]; + } + + if( mparam.do_reboost == 0 || boosters.size() == 0 ){ + mparam.num_boosters += 1; + boosters.push_back( booster::CreateBooster( mparam.booster_type ) ); booster_info.push_back( 0 ); this->ConfigBooster( boosters.back() ); boosters.back()->InitModel(); @@ -332,8 +282,81 @@ namespace xgboost{ return boosters.back(); } protected: - /*! \brief number of OpenMP threads */ - int nthread; + /*! \brief model parameters */ + struct ModelParam{ + /*! \brief number of boosters */ + int num_boosters; + /*! \brief type of tree used */ + int booster_type; + /*! \brief number of root: default 0, means single tree */ + int num_roots; + /*! \brief number of features to be used by boosters */ + int num_feature; + /*! \brief size of predicton buffer allocated for buffering boosting computation */ + int num_pbuffer; + /*! + * \brief whether we repeatly update a single booster each round: default 0 + * set to 1 for linear booster, so that regularization term can be considered + */ + int do_reboost; + /*! \brief reserved parameters */ + int reserved[ 32 ]; + /*! \brief constructor */ + ModelParam( void ){ + num_boosters = 0; + booster_type = 0; + num_roots = num_feature = 0; + do_reboost = 0; + num_pbuffer = 0; + memset( reserved, 0, sizeof( reserved ) ); + } + /*! + * \brief set parameters from outside + * \param name name of the parameter + * \param val value of the parameter + */ + inline void SetParam( const char *name, const char *val ){ + if( !strcmp("booster_type", name ) ){ + booster_type = atoi( val ); + // linear boost automatically set do reboost + if( booster_type == 1 ) do_reboost = 1; + } + if( !strcmp("num_pbuffer", name ) ) num_pbuffer = atoi( val ); + if( !strcmp("do_reboost", name ) ) do_reboost = atoi( val ); + if( !strcmp("bst:num_roots", name ) ) num_roots = atoi( val ); + if( !strcmp("bst:num_feature", name ) ) num_feature = atoi( val ); + } + }; + /*! \brief training parameters */ + struct TrainParam{ + /*! \brief number of OpenMP threads */ + int nthread; + /*! + * \brief index of specific booster to be re-updated, default = -1: update new booster + * parameter this is part of trial interactive update mode + */ + int reupdate_booster; + /*! \brief constructor */ + TrainParam( void ) { + nthread = 1; + reupdate_booster = -1; + } + /*! + * \brief set parameters from outside + * \param name name of the parameter + * \param val value of the parameter + */ + inline void SetParam( const char *name, const char *val ){ + if( !strcmp("nthread", name ) ) nthread = atoi( val ); + if( !strcmp("interact:booster_index", name ) ) reupdate_booster = atoi( val ); + } + }; + protected: + /*! \brief model parameters */ + ModelParam mparam; + /*! \brief training parameters */ + TrainParam tparam; + protected: /*! \brief component boosters */ std::vector boosters; /*! \brief some information indicator of the booster, reserved */ diff --git a/regression/xgboost_reg.h b/regression/xgboost_reg.h index 7d1a680dc..52d8cbd4e 100644 --- a/regression/xgboost_reg.h +++ b/regression/xgboost_reg.h @@ -60,13 +60,14 @@ namespace xgboost{ } char str_temp[25]; - if( num_feature > base_model.param.num_feature ){ + if( num_feature > mparam.num_feature ){ + mparam.num_feature = num_feature; sprintf( str_temp, "%d", num_feature ); - base_model.SetParam( "bst:num_feature", str_temp ); + base_gbm.SetParam( "bst:num_feature", str_temp ); } sprintf( str_temp, "%u", buffer_size ); - base_model.SetParam( "num_pbuffer", str_temp ); + base_gbm.SetParam( "num_pbuffer", str_temp ); if( !silent ){ printf( "buffer_size=%u\n", buffer_size ); } @@ -81,16 +82,16 @@ namespace xgboost{ */ inline void SetParam( const char *name, const char *val ){ if( !strcmp( name, "silent") ) silent = atoi( val ); - if( !strcmp( name, "eval_metric") ) evaluator_.AddEval( val ); + if( !strcmp( name, "eval_metric") ) evaluator_.AddEval( val ); mparam.SetParam( name, val ); - base_model.SetParam( name, val ); + base_gbm.SetParam( name, val ); } /*! * \brief initialize solver before training, called before training * this function is reserved for solver to allocate necessary space and do other preparation */ inline void InitTrainer( void ){ - base_model.InitTrainer(); + base_gbm.InitTrainer(); if( mparam.loss_type == kLogisticClassify ){ evaluator_.AddEval( "error" ); }else{ @@ -102,7 +103,7 @@ namespace xgboost{ * \brief initialize the current data storage for model, if the model is used first time, call this function */ inline void InitModel( void ){ - base_model.InitModel(); + base_gbm.InitModel(); mparam.AdjustBase(); } /*! @@ -110,7 +111,7 @@ namespace xgboost{ * \param fi input stream */ inline void LoadModel( utils::IStream &fi ){ - base_model.LoadModel( fi ); + base_gbm.LoadModel( fi ); utils::Assert( fi.Read( &mparam, sizeof(ModelParam) ) != 0 ); } /*! @@ -120,7 +121,7 @@ namespace xgboost{ * \param with_stats whether print statistics as well */ inline void DumpModel( FILE *fo, const utils::FeatMap& fmap, bool with_stats ){ - base_model.DumpModel( fo, fmap, with_stats ); + base_gbm.DumpModel( fo, fmap, with_stats ); } /*! * \brief Dump path of all trees @@ -128,14 +129,14 @@ namespace xgboost{ * \param data input data */ inline void DumpPath( FILE *fo, const DMatrix &data ){ - base_model.DumpPath( fo, data.data ); + base_gbm.DumpPath( fo, data.data ); } /*! * \brief save model to stream * \param fo output stream */ inline void SaveModel( utils::IStream &fo ) const{ - base_model.SaveModel( fo ); + base_gbm.SaveModel( fo ); fo.Write( &mparam, sizeof(ModelParam) ); } /*! @@ -146,7 +147,7 @@ namespace xgboost{ this->PredictBuffer( preds_, *train_, 0 ); this->GetGradient( preds_, train_->labels, grad_, hess_ ); std::vector root_index; - base_model.DoBoost( grad_, hess_, train_->data, root_index ); + base_gbm.DoBoost( grad_, hess_, train_->data, root_index ); } /*! * \brief evaluate the model for specific iteration @@ -165,7 +166,6 @@ namespace xgboost{ } fprintf( fo,"\n" ); } - /*! \brief get prediction, without buffering */ inline void Predict( std::vector &preds, const DMatrix &data ){ preds.resize( data.Size() ); @@ -174,7 +174,51 @@ namespace xgboost{ #pragma omp parallel for schedule( static ) for( unsigned j = 0; j < ndata; ++ j ){ preds[j] = mparam.PredTransform - ( mparam.base_score + base_model.Predict( data.data, j, -1 ) ); + ( mparam.base_score + base_gbm.Predict( data.data, j, -1 ) ); + } + } + public: + /*! + * \brief update the model for one iteration + * \param iteration iteration number + */ + inline void UpdateInteract( void ){ + this->InteractPredict( preds_, *train_, 0 ); + int buffer_offset = static_cast( train_->Size() ); + for( size_t i = 0; i < evals_.size(); ++i ){ + std::vector &preds = this->eval_preds_[ i ]; + this->InteractPredict( preds, *evals_[i], buffer_offset ); + buffer_offset += static_cast( evals_[i]->Size() ); + } + + this->GetGradient( preds_, train_->labels, grad_, hess_ ); + std::vector root_index; + base_gbm.DoBoost( grad_, hess_, train_->data, root_index ); + + this->InteractRePredict( *train_, 0 ); + buffer_offset = static_cast( train_->Size() ); + for( size_t i = 0; i < evals_.size(); ++i ){ + this->InteractRePredict( *evals_[i], buffer_offset ); + buffer_offset += static_cast( evals_[i]->Size() ); + } + } + private: + /*! \brief get the transformed predictions, given data */ + inline void InteractPredict( std::vector &preds, const DMatrix &data, unsigned buffer_offset ){ + preds.resize( data.Size() ); + const unsigned ndata = static_cast( data.Size() ); + #pragma omp parallel for schedule( static ) + for( unsigned j = 0; j < ndata; ++ j ){ + preds[j] = mparam.PredTransform + ( mparam.base_score + base_gbm.InteractPredict( data.data, j, buffer_offset + j ) ); + } + } + /*! \brief repredict trial */ + inline void InteractRePredict( const DMatrix &data, unsigned buffer_offset ){ + const unsigned ndata = static_cast( data.Size() ); + #pragma omp parallel for schedule( static ) + for( unsigned j = 0; j < ndata; ++ j ){ + base_gbm.InteractRePredict( data.data, j, buffer_offset + j ); } } private: @@ -186,7 +230,7 @@ namespace xgboost{ #pragma omp parallel for schedule( static ) for( unsigned j = 0; j < ndata; ++ j ){ preds[j] = mparam.PredTransform - ( mparam.base_score + base_model.Predict( data.data, j, buffer_offset + j ) ); + ( mparam.base_score + base_gbm.Predict( data.data, j, buffer_offset + j ) ); } } @@ -218,9 +262,16 @@ namespace xgboost{ float base_score; /* \brief type of loss function */ int loss_type; + /* \brief number of features */ + int num_feature; + /*! \brief reserved field */ + int reserved[ 16 ]; + /*! \brief constructor */ ModelParam( void ){ base_score = 0.5f; loss_type = 0; + num_feature = 0; + memset( reserved, 0, sizeof( reserved ) ); } /*! * \brief set parameters from outside @@ -230,6 +281,7 @@ namespace xgboost{ inline void SetParam( const char *name, const char *val ){ if( !strcmp("base_score", name ) ) base_score = (float)atof( val ); if( !strcmp("loss_type", name ) ) loss_type = atoi( val ); + if( !strcmp("bst:num_feature", name ) ) num_feature = atoi( val ); } /*! * \brief adjust base_score @@ -330,7 +382,7 @@ namespace xgboost{ private: int silent; EvalSet evaluator_; - booster::GBMBaseModel base_model; + booster::GBMBase base_gbm; ModelParam mparam; const DMatrix *train_; std::vector evals_; diff --git a/regression/xgboost_reg_main.cpp b/regression/xgboost_reg_main.cpp index 3b604f30e..7feab8055 100644 --- a/regression/xgboost_reg_main.cpp +++ b/regression/xgboost_reg_main.cpp @@ -39,6 +39,10 @@ namespace xgboost{ this->TaskDump(); return 0; } + if( task == "interactive" ){ + this->TaskInteractive(); + return 0; + } if( task == "dumppath" ){ this->TaskDumpPath(); return 0; @@ -60,6 +64,7 @@ namespace xgboost{ if( !strcmp("data", name ) ) train_path = val; if( !strcmp("test:data", name ) ) test_path = val; if( !strcmp("model_in", name ) ) model_in = val; + if( !strcmp("model_out", name ) ) model_out = val; if( !strcmp("model_dir", name ) ) model_dir_path = val; if( !strcmp("fmap", name ) ) name_fmap = val; if( !strcmp("name_dump", name ) ) name_dump = val; @@ -141,13 +146,30 @@ namespace xgboost{ } // always save final round if( save_period == 0 || num_round % save_period != 0 ){ - this->SaveModel( num_round ); + if( model_out == "NULL" ){ + this->SaveModel( num_round ); + }else{ + this->SaveModel( model_out.c_str() ); + } } if( !silent ){ printf("\nupdating end, %lu sec in all\n", elapsed ); } } + inline void TaskInteractive( void ){ + const time_t start = time( NULL ); + unsigned long elapsed = 0; + learner.UpdateInteract(); + utils::Assert( model_out != "NULL", "interactive mode must specify model_out" ); + this->SaveModel( model_out.c_str() ); + elapsed = (unsigned long)(time(NULL) - start); + + if( !silent ){ + printf("\ninteractive update, %lu sec in all\n", elapsed ); + } + } + inline void TaskDump( void ){ FILE *fo = utils::FopenCheck( name_dump.c_str(), "w" ); learner.DumpModel( fo, fmap, dump_model_stats != 0 ); @@ -158,13 +180,16 @@ namespace xgboost{ learner.DumpPath( fo, data ); fclose( fo ); } - inline void SaveModel( int i ) const{ - char fname[256]; - sprintf( fname ,"%s/%04d.model", model_dir_path.c_str(), i+1 ); + inline void SaveModel( const char *fname ) const{ utils::FileStream fo( utils::FopenCheck( fname, "wb" ) ); learner.SaveModel( fo ); fo.Close(); } + inline void SaveModel( int i ) const{ + char fname[256]; + sprintf( fname ,"%s/%04d.model", model_dir_path.c_str(), i+1 ); + this->SaveModel( fname ); + } inline void TaskPred( void ){ std::vector preds; if( !silent ) printf("start prediction...\n"); @@ -189,6 +214,8 @@ namespace xgboost{ std::string train_path, test_path; /* \brief the path of test model file, or file to restart training */ std::string model_in; + /* \brief the path of final model file, to be saved */ + std::string model_out; /* \brief the path of directory containing the saved models */ std::string model_dir_path; /* \brief task to perform */