From e52720976c1767e252f4c1c8e1c31e91985df396 Mon Sep 17 00:00:00 2001 From: tqchen Date: Thu, 20 Feb 2014 22:08:31 -0800 Subject: [PATCH] changes to reg booster --- Makefile | 3 +- booster/xgboost_data.h | 1 + regression/xgboost_reg.h | 215 +++++++++++++++++--------------- regression/xgboost_reg_main.cpp | 191 ++++++++++++++++++++++++++-- regression/xgboost_reg_test.h | 5 +- regression/xgboost_reg_train.h | 23 ++-- regression/xgboost_regdata.h | 4 +- 7 files changed, 312 insertions(+), 130 deletions(-) diff --git a/Makefile b/Makefile index 668c4f9c8..cb142d3d8 100644 --- a/Makefile +++ b/Makefile @@ -3,7 +3,7 @@ export CXX = g++ export CFLAGS = -Wall -O3 -msse2 # specify tensor path -BIN = +BIN = xgboost OBJ = xgboost.o .PHONY: clean all @@ -11,6 +11,7 @@ all: $(BIN) $(OBJ) export LDFLAGS= -pthread -lm xgboost.o: booster/xgboost.h booster/xgboost_data.h booster/xgboost.cpp booster/*/*.hpp booster/*/*.h +xgboost: regression/xgboost_reg_main.cpp xgboost.o $(BIN) : $(CXX) $(CFLAGS) $(LDFLAGS) -o $@ $(filter %.cpp %.o %.c, $^) diff --git a/booster/xgboost_data.h b/booster/xgboost_data.h index d2b66cebd..285b770b1 100644 --- a/booster/xgboost_data.h +++ b/booster/xgboost_data.h @@ -8,6 +8,7 @@ */ #include +#include #include "../utils/xgboost_utils.h" #include "../utils/xgboost_stream.h" diff --git a/regression/xgboost_reg.h b/regression/xgboost_reg.h index 3fe576213..60500f615 100644 --- a/regression/xgboost_reg.h +++ b/regression/xgboost_reg.h @@ -6,6 +6,8 @@ * \author Kailong Chen: chenkl198812@gmail.com, Tianqi Chen: tianqi.tchen@gmail.com */ #include +#include +#include #include "xgboost_regdata.h" #include "../booster/xgboost_gbmbase.h" #include "../utils/xgboost_utils.h" @@ -16,11 +18,8 @@ namespace xgboost{ /*! \brief class for gradient boosted regression */ class RegBoostLearner{ public: - - RegBoostLearner(bool silent = false){ - this->silent = silent; - } - + /*! \brief constructor */ + RegBoostLearner( void ){} /*! * \brief a regression booter associated with training and evaluating data * \param train pointer to the training data @@ -28,10 +27,9 @@ namespace xgboost{ * \param evname name of evaluation data, used print statistics */ RegBoostLearner( const DMatrix *train, - std::vector evals, - std::vector evname, bool silent = false ){ - this->silent = silent; - SetData(train,evals,evname); + const std::vector &evals, + const std::vector &evname ){ + this->SetData(train,evals,evname); } /*! @@ -40,23 +38,22 @@ namespace xgboost{ * \param evals array of evaluating data * \param evname name of evaluation data, used print statistics */ - inline void SetData(const DMatrix *train, - std::vector evals, - std::vector evname){ - this->train_ = train; - this->evals_ = evals; - this->evname_ = evname; - //assign buffer index - int buffer_size = (*train).size(); - for(int i = 0; i < evals.size(); i++){ - buffer_size += (*evals[i]).size(); - } - char str[25]; - _itoa(buffer_size,str,10); - base_model.SetParam("num_pbuffer",str); - base_model.SetParam("num_pbuffer",str); + inline void SetData( const DMatrix *train, + const std::vector &evals, + const std::vector &evname ){ + this->train_ = train; + this->evals_ = evals; + this->evname_ = evname; + //assign buffer index + unsigned buffer_size = static_cast( train->Size() ); + + for( size_t i = 0; i < evals.size(); ++ i ){ + buffer_size += static_cast( evals[i]->Size() ); + } + char snum_pbuffer[25]; + printf( snum_pbuffer, "%u", buffer_size ); + base_model.SetParam( "num_pbuffer",snum_pbuffer ); } - /*! * \brief set parameters from outside * \param name name of the parameter @@ -72,17 +69,14 @@ namespace xgboost{ */ inline void InitTrainer( void ){ base_model.InitTrainer(); - InitModel(); - mparam.AdjustBase(); } - /*! * \brief initialize the current data storage for model, if the model is used first time, call this function */ inline void InitModel( void ){ base_model.InitModel(); - } - + mparam.AdjustBase(); + } /*! * \brief load model from stream * \param fi input stream @@ -99,57 +93,78 @@ namespace xgboost{ fo.Write( &mparam, sizeof(ModelParam) ); base_model.SaveModel( fo ); } - /*! - * \brief update the model for one iteration - * \param iteration the number of updating iteration - */ - inline void UpdateOneIter( int iteration ){ - std::vector grad,hess,preds; - std::vector root_index; - booster::FMatrixS::Image train_image((*train_).data); - Predict(preds,*train_,0); - Gradient(preds,(*train_).labels,grad,hess); - base_model.DoBoost(grad,hess,train_image,root_index); - int buffer_index_offset = (*train_).size(); - float loss = 0.0; - for(int i = 0; i < evals_.size();i++){ - Predict(preds, *evals_[i], buffer_index_offset); - loss = mparam.Loss(preds,(*evals_[i]).labels); - if(!silent){ - printf("The loss of %s data set in %d the \ - iteration is %f",evname_[i].c_str(),&iteration,&loss); - } - buffer_index_offset += (*evals_[i]).size(); - } + * \brief update the model for one iteration + * \param iteration iteration number + */ + inline void UpdateOneIter( int iter ){ + std::vector grad, hess, preds; + this->Predict( preds, *train_, 0 ); + this->GetGradient( preds, train_->labels, grad, hess ); + std::vector root_index; + booster::FMatrixS::Image train_image( train_->data ); + base_model.DoBoost(grad,hess,train_image,root_index); + } + /*! + * \brief evaluate the model for specific iteration + * \param iter iteration number + * \param fo file to output log + */ + inline void EvalOneIter( int iter, FILE *fo = stderr ){ + std::vector preds; + fprintf( fo, "[%d]", iter ); + int buffer_offset = static_cast( train_->Size() ); + + for(size_t i = 0; i < evals_.size();i++){ + this->Predict(preds, *evals_[i], buffer_offset); + this->Eval( fo, evname_[i].c_str(), preds, (*evals_[i]).labels ); + buffer_offset += static_cast( evals_[i]->Size() ); + } + fprintf( fo,"\n" ); } + /*! \brief get prediction, without buffering */ + inline void Predict( std::vector &preds, const DMatrix &data ){ + preds.resize( data.Size() ); + for( size_t j = 0; j < data.Size(); j++ ){ + preds[j] = mparam.PredTransform + ( mparam.base_score + base_model.Predict( data.data[j], -1 ) ); + } + } + private: + /*! \brief print evaluation results */ + inline void Eval( FILE *fo, const char *evname, + const std::vector &preds, + const std::vector &labels ){ + const float loss = mparam.Loss( preds, labels ); + fprintf( fo, "\t%s:%f", evname, loss ); + } /*! \brief get the transformed predictions, given data */ - inline void Predict( std::vector &preds, const DMatrix &data,int buffer_index_offset = 0 ){ - int data_size = data.size(); - preds.resize(data_size); - for(int j = 0; j < data_size; j++){ - preds[j] = mparam.PredTransform(mparam.base_score + - base_model.Predict(data.data[j],buffer_index_offset + j)); + inline void Predict( std::vector &preds, const DMatrix &data, unsigned buffer_offset ){ + preds.resize( data.Size() ); + for( size_t j = 0; j < data.Size(); j++ ){ + preds[j] = mparam.PredTransform + ( mparam.base_score + base_model.Predict( data.data[j], buffer_offset + j ) ); + } + } + + /*! \brief get the first order and second order gradient, given the transformed predictions and labels */ + inline void GetGradient( const std::vector &preds, + const std::vector &labels, + std::vector &grad, + std::vector &hess ){ + grad.clear(); hess.clear(); + for( size_t j = 0; j < preds.size(); j++ ){ + grad.push_back( mparam.FirstOrderGradient (preds[j],labels[j]) ); + hess.push_back( mparam.SecondOrderGradient(preds[j],labels[j]) ); } } private: - /*! \brief get the first order and second order gradient, given the transformed predictions and labels*/ - inline void Gradient(const std::vector &preds, const std::vector &labels, std::vector &grad, - std::vector &hess){ - grad.clear(); - hess.clear(); - for(int j = 0; j < preds.size(); j++){ - grad.push_back(mparam.FirstOrderGradient(preds[j],labels[j])); - hess.push_back(mparam.SecondOrderGradient(preds[j],labels[j])); - } - } - - enum LOSS_TYPE_LIST{ - LINEAR_SQUARE, - LOGISTIC_NEGLOGLIKELIHOOD, + enum LossType{ + kLinearSquare = 0, + kLogisticNeglik = 1, }; /*! \brief training parameter for regression */ @@ -181,6 +196,20 @@ namespace xgboost{ base_score = - logf( 1.0f / base_score - 1.0f ); } } + + /*! + * \brief transform the linear sum to prediction + * \param x linear sum of boosting ensemble + * \return transformed prediction + */ + inline float PredTransform( float x ){ + switch( loss_type ){ + case kLinearSquare: return x; + case kLogisticNeglik: return 1.0f/(1.0f + expf(-x)); + default: utils::Error("unknown loss_type"); return 0.0f; + } + } + /*! * \brief calculate first order gradient of loss, given transformed prediction * \param predt transformed prediction @@ -189,7 +218,7 @@ namespace xgboost{ */ inline float FirstOrderGradient( float predt, float label ) const{ switch( loss_type ){ - case LINEAR_SQUARE: return predt - label; + case kLinearSquare: return predt - label; case 1: return predt - label; default: utils::Error("unknown loss_type"); return 0.0f; } @@ -202,8 +231,8 @@ namespace xgboost{ */ inline float SecondOrderGradient( float predt, float label ) const{ switch( loss_type ){ - case LINEAR_SQUARE: return 1.0f; - case LOGISTIC_NEGLOGLIKELIHOOD: return predt * ( 1 - predt ); + case kLinearSquare: return 1.0f; + case kLogisticNeglik: return predt * ( 1 - predt ); default: utils::Error("unknown loss_type"); return 0.0f; } } @@ -216,8 +245,8 @@ namespace xgboost{ */ inline float Loss(const std::vector &preds, const std::vector &labels) const{ switch( loss_type ){ - case LINEAR_SQUARE: return SquareLoss(preds,labels); - case LOGISTIC_NEGLOGLIKELIHOOD: return NegLoglikelihoodLoss(preds,labels); + case kLinearSquare: return SquareLoss(preds,labels); + case kLogisticNeglik: return NegLoglikelihoodLoss(preds,labels); default: utils::Error("unknown loss_type"); return 0.0f; } } @@ -230,8 +259,10 @@ namespace xgboost{ */ inline float SquareLoss(const std::vector &preds, const std::vector &labels) const{ float ans = 0.0; - for(int i = 0; i < preds.size(); i++) - ans += pow(preds[i] - labels[i], 2); + for(size_t i = 0; i < preds.size(); i++){ + float dif = preds[i] - labels[i]; + ans += dif * dif; + } return ans; } @@ -243,34 +274,18 @@ namespace xgboost{ */ inline float NegLoglikelihoodLoss(const std::vector &preds, const std::vector &labels) const{ float ans = 0.0; - for(int i = 0; i < preds.size(); i++) - ans -= labels[i] * log(preds[i]) + ( 1 - labels[i] ) * log(1 - preds[i]); + for(size_t i = 0; i < preds.size(); i++) + ans -= labels[i] * logf(preds[i]) + ( 1 - labels[i] ) * logf(1 - preds[i]); return ans; } - - - /*! - * \brief transform the linear sum to prediction - * \param x linear sum of boosting ensemble - * \return transformed prediction - */ - inline float PredTransform( float x ){ - switch( loss_type ){ - case LINEAR_SQUARE: return x; - case LOGISTIC_NEGLOGLIKELIHOOD: return 1.0f/(1.0f + expf(-x)); - default: utils::Error("unknown loss_type"); return 0.0f; - } - } - - - }; + }; private: booster::GBMBaseModel base_model; ModelParam mparam; const DMatrix *train_; - std::vector evals_; + std::vector evals_; std::vector evname_; - bool silent; + std::vector buffer_index_; }; } }; diff --git a/regression/xgboost_reg_main.cpp b/regression/xgboost_reg_main.cpp index 16ef5f486..eb06cdabf 100644 --- a/regression/xgboost_reg_main.cpp +++ b/regression/xgboost_reg_main.cpp @@ -1,15 +1,180 @@ -#include"xgboost_reg_train.h" -#include"xgboost_reg_test.h" -using namespace xgboost::regression; +#define _CRT_SECURE_NO_WARNINGS +#define _CRT_SECURE_NO_DEPRECATE -int main(int argc, char *argv[]){ - //char* config_path = argv[1]; - //bool silent = ( atoi(argv[2]) == 1 ); - char* config_path = "c:\\cygwin64\\home\\chen\\github\\xgboost\\demo\\regression\\reg.conf"; - bool silent = false; - RegBoostTrain train; - train.train(config_path,false); +#include +#include +#include +#include "xgboost_reg.h" +#include "../utils/xgboost_random.h" +#include "../utils/xgboost_config.h" - RegBoostTest test; - test.test(config_path,false); -} \ No newline at end of file +namespace xgboost{ + namespace regression{ + /*! + * \brief wrapping the training process of the gradient boosting regression model, + * given the configuation + * \author Kailong Chen: chenkl198812@gmail.com, Tianqi Chen: tianqi.chen@gmail.com + */ + class RegBoostTask{ + public: + inline int Run( int argc, char *argv[] ){ + if( argc < 2 ){ + printf("Usage: \n"); + return 0; + } + utils::ConfigIterator itr( argv[1] ); + while( itr.Next() ){ + this->SetParam( itr.name(), itr.val() ); + } + for( int i = 2; i < argc; i ++ ){ + char name[256], val[256]; + if( sscanf( argv[i], "%[^=]=%s", name, val ) == 2 ){ + this->SetParam( name, val ); + } + } + this->InitData(); + this->InitLearner(); + if( !strcmp( task.c_str(), "test") ){ + this->TaskTest(); + }else{ + this->TaskTrain(); + } + return 0; + } + inline void SetParam( const char *name, const char *val ){ + if( !strcmp("silent", name ) ) silent = atoi( val ); + if( !strcmp("seed", name ) ) random::Seed( atoi(val) ); + if( !strcmp("num_round", name ) ) num_round = atoi( val ); + if( !strcmp("save_period", name ) ) save_period = atoi( val ); + if( !strcmp("task", name ) ) task = val; + if( !strcmp("data", name ) ) train_path = val; + if( !strcmp("test:data", name ) ) test_path = val; + if( !strcmp("model_in", name ) ) model_in = val; + if( !strcmp("model_dir", name ) ) model_dir_path = val; + if( !strncmp("eval[", name, 5 ) ) { + char evname[ 256 ]; + utils::Assert( sscanf( name, "eval[%[^]]", evname ) == 1, "must specify evaluation name for display"); + eval_data_names.push_back( std::string( evname ) ); + eval_data_paths.push_back( std::string( val ) ); + } + cfg.PushBack( name, val ); + } + public: + RegBoostTask( void ){ + // default parameters + silent = 0; + num_round = 10; + save_period = 0; + task = "train"; + model_in = "NULL"; + name_pred = "pred.txt"; + model_dir_path = "./"; + } + ~RegBoostTask( void ){ + for( size_t i = 0; i < deval.size(); i ++ ){ + delete deval[i]; + } + } + private: + inline void InitData( void ){ + if( !strcmp( task.c_str(), "test") ){ + data.CacheLoad( test_path.c_str() ); + }else{ + // training + data.CacheLoad( train_path.c_str() ); + utils::Assert( eval_data_names.size() == eval_data_paths.size() ); + for( size_t i = 0; i < eval_data_names.size(); ++ i ){ + deval.push_back( new DMatrix() ); + deval.back()->CacheLoad( eval_data_paths[i].c_str() ); + } + } + learner.SetData( &data, deval, eval_data_names ); + } + inline void InitLearner( void ){ + cfg.BeforeFirst(); + while( cfg.Next() ){ + learner.SetParam( cfg.name(), cfg.val() ); + } + if( strcmp( model_in.c_str(), "NULL" ) != 0 ){ + utils::Assert( !strcmp( task.c_str(), "train"), "model_in not specified" ); + utils::FileStream fi( utils::FopenCheck( model_in.c_str(), "rb") ); + learner.LoadModel( fi ); + fi.Close(); + }else{ + learner.InitModel(); + } + learner.InitTrainer(); + } + inline void TaskTrain( void ){ + const time_t start = time( NULL ); + unsigned long elapsed = 0; + for( int i = 0; i < num_round; ++ i ){ + elapsed = (unsigned long)(time(NULL) - start); + if( !silent ) printf("boosting round %d, %lu sec elapsed\n", i , elapsed ); + learner.UpdateOneIter( i ); + learner.EvalOneIter( i ); + if( save_period != 0 && (i+1) % save_period == 0 ){ + SaveModel( i ); + } + elapsed = (unsigned long)(time(NULL) - start); + } + // always save final round + if( num_round % save_period != 0 ){ + SaveModel( num_round ); + } + if( !silent ){ + printf("\nupdating end, %lu sec in all\n", elapsed ); + } + } + inline void SaveModel( int i ){ + char fname[256]; + sprintf( fname ,"%s/%04d.model", model_dir_path.c_str(), i+1 ); + utils::FileStream fo( utils::FopenCheck( fname, "wb" ) ); + learner.SaveModel( fo ); + fo.Close(); + } + inline void TaskTest( void ){ + std::vector preds; + learner.Predict( preds, data ); + FILE *fo = utils::FopenCheck( name_pred.c_str(), "w" ); + for( size_t i = 0; i < preds.size(); i ++ ){ + fprintf( fo, "%f\n", preds[i] ); + } + fclose( fo ); + } + private: + /* \brief whether silent */ + int silent; + /* \brief number of boosting iterations */ + int num_round; + /* \brief the period to save the model, 0 means only save the final round model */ + int save_period; + /* \brief the path of training/test data set */ + std::string train_path, test_path; + /* \brief the path of test model file, or file to restart training */ + std::string model_in; + /* \brief the path of directory containing the saved models */ + std::string model_dir_path; + /* \brief task to perform */ + std::string task; + /* \brief name of predict file */ + std::string name_pred; + /* \brief the paths of validation data sets */ + std::vector eval_data_paths; + /* \brief the names of the evaluation data used in output log */ + std::vector eval_data_names; + /*! \brief saves configurations */ + utils::ConfigSaver cfg; + private: + DMatrix data; + std::vector deval; + RegBoostLearner learner; + }; + }; +}; + +int main( int argc, char *argv[] ){ + xgboost::random::Seed( 0 ); + xgboost::regression::RegBoostTask tsk; + return tsk.Run( argc, argv ); +} diff --git a/regression/xgboost_reg_test.h b/regression/xgboost_reg_test.h index cccdca5fa..ea831856e 100644 --- a/regression/xgboost_reg_test.h +++ b/regression/xgboost_reg_test.h @@ -27,7 +27,7 @@ namespace xgboost{ * \param silent whether to print feedback messages */ void test(char* config_path,bool silent = false){ - reg_boost_learner = new xgboost::regression::RegBoostLearner(silent); + reg_boost_learner = new xgboost::regression::RegBoostLearner(); ConfigIterator config_itr(config_path); //Get the training data and validation data paths, config the Learner while (config_itr.Next()){ @@ -42,10 +42,11 @@ namespace xgboost{ reg_boost_learner->InitModel(); char model_path[256]; std::vector preds; - for(int i = 0; i < test_param.test_paths.size(); i++){ + for(size_t i = 0; i < test_param.test_paths.size(); i++){ xgboost::regression::DMatrix test_data; test_data.LoadText(test_param.test_paths[i].c_str()); sprintf(model_path,"%s/final.model",test_param.model_dir_path); + // BUG: model need to be rb FileStream fin(fopen(model_path,"r")); reg_boost_learner->LoadModel(fin); fin.Close(); diff --git a/regression/xgboost_reg_train.h b/regression/xgboost_reg_train.h index c994df566..05f7f42ee 100644 --- a/regression/xgboost_reg_train.h +++ b/regression/xgboost_reg_train.h @@ -1,13 +1,13 @@ #ifndef _XGBOOST_REG_TRAIN_H_ #define _XGBOOST_REG_TRAIN_H_ -#include -#include -#include -#include"../utils/xgboost_config.h" -#include"xgboost_reg.h" -#include"xgboost_regdata.h" -#include"../utils/xgboost_string.h" +#include +#include +#include +#include "../utils/xgboost_config.h" +#include "xgboost_reg.h" +#include "xgboost_regdata.h" +#include "../utils/xgboost_string.h" using namespace xgboost::utils; @@ -28,7 +28,8 @@ namespace xgboost{ * \param silent whether to print feedback messages */ void train(char* config_path,bool silent = false){ - reg_boost_learner = new xgboost::regression::RegBoostLearner(silent); + reg_boost_learner = new xgboost::regression::RegBoostLearner(); + ConfigIterator config_itr(config_path); //Get the training data and validation data paths, config the Learner while (config_itr.Next()){ @@ -38,14 +39,14 @@ namespace xgboost{ } Assert(train_param.validation_data_paths.size() == train_param.validation_data_names.size(), - "The number of validation paths is not the same as the number of validation data set names"); + "The number of validation paths is not the same as the number of validation data set names"); //Load Data xgboost::regression::DMatrix train; printf("%s",train_param.train_path); train.LoadText(train_param.train_path); std::vector evals; - for(int i = 0; i < train_param.validation_data_paths.size(); i++){ + for(size_t i = 0; i < train_param.validation_data_paths.size(); i++){ xgboost::regression::DMatrix eval; eval.LoadText(train_param.validation_data_paths[i].c_str()); evals.push_back(&eval); @@ -58,7 +59,7 @@ namespace xgboost{ for(int i = 1; i <= train_param.boost_iterations; i++){ reg_boost_learner->UpdateOneIter(i); if(train_param.save_period != 0 && i % train_param.save_period == 0){ - sscanf(suffix,"%d.model",i); + sprintf(suffix,"%d.model",i); SaveModel(suffix); } } diff --git a/regression/xgboost_regdata.h b/regression/xgboost_regdata.h index d400b74d7..275a785ca 100644 --- a/regression/xgboost_regdata.h +++ b/regression/xgboost_regdata.h @@ -31,12 +31,10 @@ namespace xgboost{ /*! \brief default constructor */ DMatrix( void ){} - /*! \brief get the number of instances */ - inline int size() const{ + inline size_t Size() const{ return labels.size(); } - /*! * \brief load from text file * \param fname name of text data