#ifndef _XGBOOST_BASE_MODEL_H_ #define _XGBOOST_BASE_MODEL_H_ #include #include "../booster/xgboost.h" #include "../utils/xgboost_config.h" /*! * \file xgboost_base_model.h * \brief a base model class, * that assembles the ensembles of booster together and do model update * this class can be used as base code to create booster variants * \author Tianqi Chen: tianqi.tchen@gmail.com */ namespace xgboost{ /*! \brief namespace for base class library */ namespace gbm_base{ /*! * \brief a base model class, * that assembles the ensembles of booster together and provide single routines to do prediction buffer and update * this class can be used as base code to create booster variants */ class BaseGBMModel{ public: /*! \brief model parameters */ struct Param{ /*! \brief number of boosters */ int num_boosters; /*! \brief type of tree used */ int booster_type; /*! \brief number of root: default 0, means single tree */ int num_root; /*! \brief number of features to be used by boosters */ int num_feature; /*! \brief size of predicton buffer allocated for buffering boosting computation */ int num_pbuffer; /*! * \brief whether we repeatly update a single booster each round: default 0 * set to 1 for linear booster, so that regularization term can be considered */ int do_reboost; /*! \brief reserved parameters */ int reserved[ 32 ]; /*! \brief constructor */ Param( void ){ num_boosters = 0; booster_type = 0; num_root = num_feature = 0; do_reboost = 0; num_pbuffer = 0; memset( reserved, 0, sizeof( reserved ) ); } /*! * \brief set parameters from outside * \param name name of the parameter * \param val value of the parameter */ inline void SetParam( const char *name, const char *val ){ if( !strcmp("booster_type", name ) ) booster_type = atoi( val ); if( !strcmp("num_pbuffer", name ) ) num_pbuffer = atoi( val ); if( !strcmp("do_reboost", name ) ) do_reboost = atoi( val ); if( !strcmp("bst:num_root", name ) ) num_root = atoi( val ); if( !strcmp("bst:num_feature", name ) ) num_feature = atoi( val ); } }; public: /*! \brief destructor */ virtual ~BaseGBMModel( void ){ this->FreeSpace(); } /*! * \brief set parameters from outside * \param name name of the parameter * \param val value of the parameter */ inline void SetParam( const char *name, const char *val ){ if( strncmp( name, "bst:", 4 ) == 0 ){ cfg.PushBack( name + 4, val ); } if( boosters.size() == 0 ) param.SetParam( name, val ); } /*! * \brief load model from stream * \param fi input stream */ inline void LoadModel( utils::IStream &fi ){ if( boosters.size() != 0 ) this->FreeSpace(); utils::Assert( fi.Read( ¶m, sizeof(Param) ) != 0 ); boosters.resize( param.num_boosters ); for( size_t i = 0; i < boosters.size(); i ++ ){ boosters[ i ] = booster::CreateBooster( param.booster_type ); boosters[ i ]->LoadModel( fi ); } {// load info booster_info.resize( param.num_boosters ); if( param.num_boosters != 0 ){ utils::Assert( fi.Read( &booster_info[0], sizeof(int)*param.num_boosters ) != 0 ); } } if( param.num_pbuffer != 0 ){ pred_buffer.resize ( param.num_pbuffer ); pred_counter.resize( param.num_pbuffer ); utils::Assert( fi.Read( &pred_buffer[0] , pred_buffer.size()*sizeof(float) ) != 0 ); utils::Assert( fi.Read( &pred_counter[0], pred_counter.size()*sizeof(unsigned) ) != 0 ); } } /*! * \brief save model to stream * \param fo output stream */ inline void SaveModel( utils::IStream &fo ) const { utils::Assert( param.num_boosters == (int)boosters.size() ); fo.Write( ¶m, sizeof(Param) ); for( size_t i = 0; i < boosters.size(); i ++ ){ boosters[ i ]->SaveModel( fo ); } if( booster_info.size() != 0 ){ fo.Write( &booster_info[0], sizeof(int) * booster_info.size() ); } if( param.num_pbuffer != 0 ){ fo.Write( &pred_buffer[0] , pred_buffer.size()*sizeof(float) ); fo.Write( &pred_counter[0], pred_counter.size()*sizeof(unsigned) ); } } /*! * \brief initialize the current data storage for model, if the model is used first time, call this function */ inline void InitModel( void ){ pred_buffer.clear(); pred_counter.clear(); pred_buffer.resize ( param.num_pbuffer, 0.0 ); pred_counter.resize( param.num_pbuffer, 0 ); utils::Assert( param.num_boosters == 0 ); utils::Assert( boosters.size() == 0 ); } /*! * \brief initialize solver before training, called before training * this function is reserved for solver to allocate necessary space and do other preparation */ inline void InitTrainer( void ){ // make sure all the boosters get the latest parameters for( size_t i = 0; i < this->boosters.size(); i ++ ){ this->ConfigBooster( this->boosters[i] ); } } public: /*! * \brief do gradient boost training for one step, using the information given * \param grad first order gradient of each instance * \param hess second order gradient of each instance * \param feats features of each instance * \param root_index pre-partitioned root index of each instance, * root_index.size() can be 0 which indicates that no pre-partition involved */ inline void DoBoost( std::vector &grad, std::vector &hess, const booster::FMatrixS::Image &feats, const std::vector &root_index ) { booster::IBooster *bst = this->GetUpdateBooster(); bst->DoBoost( grad, hess, feats, root_index ); } /*! * \brief predict values for given sparse feature vector * NOTE: in tree implementation, this is not threadsafe * \param feat vector in sparse format * \param buffer_index the buffer index of the current feature line, default -1 means no buffer assigned * \param rid root id of current instance, default = 0 * \return prediction */ virtual float Predict( const booster::FMatrixS::Line &feat, int buffer_index = -1, unsigned rid = 0 ){ size_t istart = 0; float psum = 0.0f; // load buffered results if any if( param.do_reboost == 0 && buffer_index >= 0 ){ utils::Assert( buffer_index < param.num_pbuffer, "buffer index exceed num_pbuffer" ); istart = this->pred_counter[ buffer_index ]; psum = this->pred_buffer [ buffer_index ]; } for( size_t i = istart; i < this->boosters.size(); i ++ ){ psum += this->boosters[ i ]->Predict( feat, rid ); } // updated the buffered results if( param.do_reboost == 0 && buffer_index >= 0 ){ this->pred_counter[ buffer_index ] = static_cast( boosters.size() ); this->pred_buffer [ buffer_index ] = psum; } return psum; } /*! * \brief predict values for given dense feature vector * \param feat feature vector in dense format * \param funknown indicator that the feature is missing * \param buffer_index the buffer index of the current feature line, default -1 means no buffer assigned * \param rid root id of current instance, default = 0 * \return prediction */ virtual float Predict( const std::vector &feat, const std::vector &funknown, int buffer_index = -1, unsigned rid = 0 ){ size_t istart = 0; float psum = 0.0f; // load buffered results if any if( param.do_reboost == 0 && buffer_index >= 0 ){ utils::Assert( buffer_index < param.num_pbuffer, "buffer index exceed num_pbuffer" ); istart = this->pred_counter[ buffer_index ]; psum = this->pred_buffer [ buffer_index ]; } for( size_t i = istart; i < this->boosters.size(); i ++ ){ psum += this->boosters[ i ]->Predict( feat, funknown, rid ); } // updated the buffered results if( param.do_reboost == 0 && buffer_index >= 0 ){ this->pred_counter[ buffer_index ] = static_cast( boosters.size() ); this->pred_buffer [ buffer_index ] = psum; } return psum; } //-----------non public fields afterwards------------- protected: /*! \brief free space of the model */ inline void FreeSpace( void ){ for( size_t i = 0; i < boosters.size(); i ++ ){ delete boosters[i]; } boosters.clear(); booster_info.clear(); param.num_boosters = 0; } /*! \brief configure a booster */ inline void ConfigBooster( booster::IBooster *bst ){ cfg.BeforeFirst(); while( cfg.Next() ){ bst->SetParam( cfg.name(), cfg.val() ); } } /*! * \brief get a booster to update * \return the booster created */ inline booster::IBooster *GetUpdateBooster( void ){ if( param.do_reboost == 0 || boosters.size() == 0 ){ boosters.push_back( booster::CreateBooster( param.booster_type ) ); booster_info.push_back( 0 ); this->ConfigBooster( boosters.back() ); boosters.back()->InitModel(); }else{ this->ConfigBooster( boosters.back() ); } return boosters.back(); } protected: /*! \brief model parameters */ Param param; /*! \brief component boosters */ std::vector boosters; /*! \brief some information indicator of the booster, reserved */ std::vector booster_info; /*! \brief prediction buffer */ std::vector pred_buffer; /*! \brief prediction buffer counter, record the progress so fart of the buffer */ std::vector pred_counter; /*! \brief configurations saved for each booster */ utils::ConfigSaver cfg; }; }; }; #endif