From 1748e4517aa1810c863b294ebf890b80651a694a Mon Sep 17 00:00:00 2001 From: tqchen Date: Sat, 1 Mar 2014 20:56:25 -0800 Subject: [PATCH] full omp support for regression --- Makefile | 2 +- booster/tree/xgboost_tree.hpp | 61 ++++++++------- booster/xgboost.h | 4 +- booster/xgboost_gbmbase.h | 11 +++ regression/xgboost_reg.h | 123 +++++++++++++++--------------- regression/xgboost_reg_test.h | 100 ------------------------- regression/xgboost_reg_train.h | 132 --------------------------------- regression/xgboost_regeval.h | 96 ++++++++++++++++++++++++ utils/xgboost_string.h | 31 -------- 9 files changed, 206 insertions(+), 354 deletions(-) delete mode 100644 regression/xgboost_reg_test.h delete mode 100644 regression/xgboost_reg_train.h create mode 100644 regression/xgboost_regeval.h delete mode 100644 utils/xgboost_string.h diff --git a/Makefile b/Makefile index 819355989..8db7e81c8 100644 --- a/Makefile +++ b/Makefile @@ -11,7 +11,7 @@ all: $(BIN) $(OBJ) export LDFLAGS= -pthread -lm xgboost.o: booster/xgboost.h booster/xgboost_data.h booster/xgboost.cpp booster/*/*.hpp booster/*/*.h -xgboost: regression/xgboost_reg_main.cpp xgboost.o +xgboost: regression/xgboost_reg_main.cpp regression/*.h xgboost.o $(BIN) : $(CXX) $(CFLAGS) $(LDFLAGS) -o $@ $(filter %.cpp %.o %.c, $^) diff --git a/booster/tree/xgboost_tree.hpp b/booster/tree/xgboost_tree.hpp index 0e13c78a4..e6caa51ce 100644 --- a/booster/tree/xgboost_tree.hpp +++ b/booster/tree/xgboost_tree.hpp @@ -32,7 +32,9 @@ namespace xgboost{ class RegTreeTrainer : public IBooster{ public: RegTreeTrainer( void ){ - silent = 0; tree_maker = 1; + silent = 0; tree_maker = 1; + // normally we won't have more than 64 OpenMP threads + threadtemp.resize( 64, ThreadEntry() ); } virtual ~RegTreeTrainer( void ){} public: @@ -74,25 +76,25 @@ namespace xgboost{ virtual void PredPath( std::vector &path, const FMatrixS::Line &feat, unsigned gid = 0 ){ path.clear(); - this->InitTmp(); - this->PrepareTmp( feat ); + ThreadEntry &e = this->InitTmp(); + this->PrepareTmp( feat, e ); int pid = (int)gid; path.push_back( pid ); // tranverse tree while( !tree[ pid ].is_leaf() ){ unsigned split_index = tree[ pid ].split_index(); - pid = this->GetNext( pid, tmp_feat[ split_index ], tmp_funknown[ split_index ] ); + pid = this->GetNext( pid, e.feat[ split_index ], e.funknown[ split_index ] ); path.push_back( pid ); } - this->DropTmp( feat ); + this->DropTmp( feat, e ); } - + // make it OpenMP thread safe, but not thread safe in general virtual float Predict( const FMatrixS::Line &feat, unsigned gid = 0 ){ - this->InitTmp(); - this->PrepareTmp( feat ); - int pid = this->GetLeafIndex( tmp_feat, tmp_funknown, gid ); - this->DropTmp( feat ); + ThreadEntry &e = this->InitTmp(); + this->PrepareTmp( feat, e ); + int pid = this->GetLeafIndex( e.feat, e.funknown, gid ); + this->DropTmp( feat, e ); return tree[ pid ].leaf_value(); } virtual float Predict( const std::vector &feat, @@ -102,8 +104,7 @@ namespace xgboost{ "input data smaller than num feature" ); int pid = this->GetLeafIndex( feat, funknown, gid ); return tree[ pid ].leaf_value(); - } - + } virtual void DumpModel( FILE *fo ){ tree.DumpModel( fo ); } @@ -137,25 +138,34 @@ namespace xgboost{ RegTree tree; TreeParamTrain param; private: - std::vector tmp_feat; - std::vector tmp_funknown; - inline void InitTmp( void ){ - if( tmp_feat.size() != (size_t)tree.param.num_feature ){ - tmp_feat.resize( tree.param.num_feature ); - tmp_funknown.resize( tree.param.num_feature ); - std::fill( tmp_funknown.begin(), tmp_funknown.end(), true ); + struct ThreadEntry{ + std::vector feat; + std::vector funknown; + }; + std::vector threadtemp; + private: + + inline ThreadEntry& InitTmp( void ){ + const int tid = omp_get_thread_num(); + utils::Assert( tid < (int)threadtemp.size(), "RTreeUpdater: threadtemp pool is too small" ); + ThreadEntry &e = threadtemp[ tid ]; + if( e.feat.size() != (size_t)tree.param.num_feature ){ + e.feat.resize( tree.param.num_feature ); + e.funknown.resize( tree.param.num_feature ); + std::fill( e.funknown.begin(), e.funknown.end(), true ); } + return e; } - inline void PrepareTmp( const FMatrixS::Line &feat ){ + inline void PrepareTmp( const FMatrixS::Line &feat, ThreadEntry &e ){ for( unsigned i = 0; i < feat.len; i ++ ){ - utils::Assert( feat[i].findex < (unsigned)tmp_funknown.size() , "input feature execeed bound" ); - tmp_funknown[ feat[i].findex ] = false; - tmp_feat[ feat[i].findex ] = feat[i].fvalue; + utils::Assert( feat[i].findex < (unsigned)tree.param.num_feature , "input feature execeed bound" ); + e.funknown[ feat[i].findex ] = false; + e.feat[ feat[i].findex ] = feat[i].fvalue; } } - inline void DropTmp( const FMatrixS::Line &feat ){ + inline void DropTmp( const FMatrixS::Line &feat, ThreadEntry &e ){ for( unsigned i = 0; i < feat.len; i ++ ){ - tmp_funknown[ feat[i].findex ] = true; + e.funknown[ feat[i].findex ] = true; } } @@ -174,4 +184,3 @@ namespace xgboost{ }; #endif - diff --git a/booster/xgboost.h b/booster/xgboost.h index 731b869bb..0cf2b9e19 100644 --- a/booster/xgboost.h +++ b/booster/xgboost.h @@ -75,7 +75,9 @@ namespace xgboost{ } /*! * \brief predict values for given sparse feature vector - * NOTE: in tree implementation, this is not threadsafe, used dense version to ensure threadsafety + * + * NOTE: in tree implementation, Sparse Predict is OpenMP threadsafe, but not threadsafe in general, + * dense version of Predict to ensures threadsafety * \param feat vector in sparse format * \param rid root id of current instance, default = 0 * \return prediction diff --git a/booster/xgboost_gbmbase.h b/booster/xgboost_gbmbase.h index a846061dd..8e30f8069 100644 --- a/booster/xgboost_gbmbase.h +++ b/booster/xgboost_gbmbase.h @@ -1,6 +1,7 @@ #ifndef _XGBOOST_GBMBASE_H_ #define _XGBOOST_GBMBASE_H_ +#include #include #include "xgboost.h" #include "../utils/xgboost_config.h" @@ -88,6 +89,10 @@ namespace xgboost{ } }; public: + /*! \brief number of thread used */ + GBMBaseModel( void ){ + this->nthread = 1; + } /*! \brief destructor */ virtual ~GBMBaseModel( void ){ this->FreeSpace(); @@ -104,6 +109,7 @@ namespace xgboost{ if( !strcmp( name, "silent") ){ cfg.PushBack( name, val ); } + if( !strcmp( name, "nthread") ) nthread = atoi( val ); if( boosters.size() == 0 ) param.SetParam( name, val ); } /*! @@ -164,6 +170,9 @@ namespace xgboost{ * this function is reserved for solver to allocate necessary space and do other preparation */ inline void InitTrainer( void ){ + if( nthread != 0 ){ + omp_set_num_threads( nthread ); + } // make sure all the boosters get the latest parameters for( size_t i = 0; i < this->boosters.size(); i ++ ){ this->ConfigBooster( this->boosters[i] ); @@ -312,6 +321,8 @@ namespace xgboost{ return boosters.back(); } protected: + /*! \brief number of OpenMP threads */ + int nthread; /*! \brief model parameters */ Param param; /*! \brief component boosters */ diff --git a/regression/xgboost_reg.h b/regression/xgboost_reg.h index 542185a77..ed2e0a89e 100644 --- a/regression/xgboost_reg.h +++ b/regression/xgboost_reg.h @@ -8,7 +8,9 @@ #include #include #include +#include #include "xgboost_regdata.h" +#include "xgboost_regeval.h" #include "../booster/xgboost_gbmbase.h" #include "../utils/xgboost_utils.h" #include "../utils/xgboost_stream.h" @@ -16,11 +18,11 @@ namespace xgboost{ namespace regression{ /*! \brief class for gradient boosted regression */ - class RegBoostLearner{ + class RegBoostLearner{ public: /*! \brief constructor */ RegBoostLearner( void ){ - silent = 0; + silent = 0; } /*! * \brief a regression booter associated with training and evaluating data @@ -59,6 +61,8 @@ namespace xgboost{ printf( "buffer_size=%u\n", buffer_size ); } base_model.SetParam( "num_pbuffer",snum_pbuffer ); + + this->eval_preds_.resize( evals.size(), std::vector() ); } /*! * \brief set parameters from outside @@ -66,7 +70,8 @@ namespace xgboost{ * \param val value of the parameter */ inline void SetParam( const char *name, const char *val ){ - if( !strcmp( name, "silent") ) silent = atoi( val ); + if( !strcmp( name, "silent") ) silent = atoi( val ); + if( !strcmp( name, "eval_metric") ) evaluator_.AddEval( val ); mparam.SetParam( name, val ); base_model.SetParam( name, val ); } @@ -76,6 +81,12 @@ namespace xgboost{ */ inline void InitTrainer( void ){ base_model.InitTrainer(); + if( mparam.loss_type == kLogisticClassify ){ + evaluator_.AddEval( "error" ); + }else{ + evaluator_.AddEval( "rmse" ); + } + evaluator_.Init(); } /*! * \brief initialize the current data storage for model, if the model is used first time, call this function @@ -120,12 +131,10 @@ namespace xgboost{ * \param iteration iteration number */ inline void UpdateOneIter( int iter ){ - std::vector grad, hess, preds; - this->Predict( preds, *train_, 0 ); - this->GetGradient( preds, train_->labels, grad, hess ); - + this->PredictBuffer( preds_, *train_, 0 ); + this->GetGradient( preds_, train_->labels, grad_, hess_ ); std::vector root_index; - base_model.DoBoost(grad,hess,train_->data,root_index); + base_model.DoBoost( grad_, hess_, train_->data, root_index ); } /*! * \brief evaluate the model for specific iteration @@ -133,13 +142,13 @@ namespace xgboost{ * \param fo file to output log */ inline void EvalOneIter( int iter, FILE *fo = stderr ){ - std::vector preds; fprintf( fo, "[%d]", iter ); int buffer_offset = static_cast( train_->Size() ); - - for(size_t i = 0; i < evals_.size();i++){ - this->Predict(preds, *evals_[i], buffer_offset); - this->Eval( fo, evname_[i].c_str(), preds, (*evals_[i]).labels ); + + for( size_t i = 0; i < evals_.size(); ++i ){ + std::vector &preds = this->eval_preds_[ i ]; + this->PredictBuffer( preds, *evals_[i], buffer_offset); + evaluator_.Eval( fo, evname_[i].c_str(), preds, (*evals_[i]).labels ); buffer_offset += static_cast( evals_[i]->Size() ); } fprintf( fo,"\n" ); @@ -148,23 +157,22 @@ namespace xgboost{ /*! \brief get prediction, without buffering */ inline void Predict( std::vector &preds, const DMatrix &data ){ preds.resize( data.Size() ); - for( size_t j = 0; j < data.Size(); j++ ){ + + const unsigned ndata = static_cast( data.Size() ); + #pragma omp parallel for schedule( static ) + for( unsigned j = 0; j < ndata; ++ j ){ preds[j] = mparam.PredTransform ( mparam.base_score + base_model.Predict( data.data[j], -1 ) ); } } private: - /*! \brief print evaluation results */ - inline void Eval( FILE *fo, const char *evname, - const std::vector &preds, - const std::vector &labels ){ - const float loss = mparam.Loss( preds, labels ); - fprintf( fo, "\t%s:%f", evname, loss ); - } /*! \brief get the transformed predictions, given data */ - inline void Predict( std::vector &preds, const DMatrix &data, unsigned buffer_offset ){ + inline void PredictBuffer( std::vector &preds, const DMatrix &data, unsigned buffer_offset ){ preds.resize( data.Size() ); - for( size_t j = 0; j < data.Size(); j++ ){ + + const unsigned ndata = static_cast( data.Size() ); + #pragma omp parallel for schedule( static ) + for( unsigned j = 0; j < ndata; ++ j ){ preds[j] = mparam.PredTransform ( mparam.base_score + base_model.Predict( data.data[j], buffer_offset + j ) ); } @@ -175,13 +183,16 @@ namespace xgboost{ const std::vector &labels, std::vector &grad, std::vector &hess ){ - grad.clear(); hess.clear(); - for( size_t j = 0; j < preds.size(); j++ ){ - grad.push_back( mparam.FirstOrderGradient (preds[j],labels[j]) ); - hess.push_back( mparam.SecondOrderGradient(preds[j],labels[j]) ); + grad.resize( preds.size() ); hess.resize( preds.size() ); + + const unsigned ndata = static_cast( preds.size() ); + #pragma omp parallel for schedule( static ) + for( unsigned j = 0; j < ndata; ++ j ){ + grad[j] = mparam.FirstOrderGradient( preds[j], labels[j] ); + hess[j] = mparam.SecondOrderGradient( preds[j], labels[j] ); } } - + private: enum LossType{ kLinearSquare = 0, @@ -262,26 +273,26 @@ namespace xgboost{ } /*! - * \brief calculating the loss, given the predictions, labels and the loss type - * \param preds the given predictions - * \param labels the given labels - * \return the specified loss - */ + * \brief calculating the loss, given the predictions, labels and the loss type + * \param preds the given predictions + * \param labels the given labels + * \return the specified loss + */ inline float Loss(const std::vector &preds, const std::vector &labels) const{ switch( loss_type ){ case kLinearSquare: return SquareLoss(preds,labels); - case kLogisticNeglik: return NegLoglikelihoodLoss(preds,labels); - case kLogisticClassify: return ClassificationError(preds, labels); + case kLogisticNeglik: + case kLogisticClassify: return NegLoglikelihoodLoss(preds,labels); default: utils::Error("unknown loss_type"); return 0.0f; } } /*! - * \brief calculating the square loss, given the predictions and labels - * \param preds the given predictions - * \param labels the given labels - * \return the summation of square loss - */ + * \brief calculating the square loss, given the predictions and labels + * \param preds the given predictions + * \param labels the given labels + * \return the summation of square loss + */ inline float SquareLoss(const std::vector &preds, const std::vector &labels) const{ float ans = 0.0; for(size_t i = 0; i < preds.size(); i++){ @@ -292,44 +303,30 @@ namespace xgboost{ } /*! - * \brief calculating the square loss, given the predictions and labels - * \param preds the given predictions - * \param labels the given labels - * \return the summation of square loss - */ + * \brief calculating the square loss, given the predictions and labels + * \param preds the given predictions + * \param labels the given labels + * \return the summation of square loss + */ inline float NegLoglikelihoodLoss(const std::vector &preds, const std::vector &labels) const{ float ans = 0.0; for(size_t i = 0; i < preds.size(); i++) ans -= labels[i] * logf(preds[i]) + ( 1 - labels[i] ) * logf(1 - preds[i]); return ans; } - - /*! - * \brief calculating the ClassificationError loss, given the predictions and labels - * \param preds the given predictions - * \param labels the given labels - * \return the summation of square loss - */ - inline float ClassificationError(const std::vector &preds, const std::vector &labels) const{ - int nerr = 0; - for(size_t i = 0; i < preds.size(); i++){ - if( preds[i] > 0.5f ){ - if( labels[i] < 0.5f ) nerr ++; - }else{ - if( labels[i] > 0.5f ) nerr ++; - } - } - return (float)nerr/preds.size(); - } }; private: int silent; + EvalSet evaluator_; booster::GBMBaseModel base_model; ModelParam mparam; const DMatrix *train_; std::vector evals_; std::vector evname_; std::vector buffer_index_; + private: + std::vector grad_, hess_, preds_; + std::vector< std::vector > eval_preds_; }; } }; diff --git a/regression/xgboost_reg_test.h b/regression/xgboost_reg_test.h deleted file mode 100644 index ea831856e..000000000 --- a/regression/xgboost_reg_test.h +++ /dev/null @@ -1,100 +0,0 @@ -#ifndef _XGBOOST_REG_TEST_H_ -#define _XGBOOST_REG_TEST_H_ - -#include -#include -#include -#include"../utils/xgboost_config.h" -#include"xgboost_reg.h" -#include"xgboost_regdata.h" -#include"../utils/xgboost_string.h" - -using namespace xgboost::utils; -namespace xgboost{ - namespace regression{ - /*! - * \brief wrapping the testing process of the gradient - boosting regression model,given the configuation - * \author Kailong Chen: chenkl198812@gmail.com - */ - class RegBoostTest{ - public: - /*! - * \brief to start the testing process of gradient boosting regression - * model given the configuation, and finally save the prediction - * results to the specified paths. - * \param config_path the location of the configuration - * \param silent whether to print feedback messages - */ - void test(char* config_path,bool silent = false){ - reg_boost_learner = new xgboost::regression::RegBoostLearner(); - ConfigIterator config_itr(config_path); - //Get the training data and validation data paths, config the Learner - while (config_itr.Next()){ - reg_boost_learner->SetParam(config_itr.name(),config_itr.val()); - test_param.SetParam(config_itr.name(),config_itr.val()); - } - - Assert(test_param.test_paths.size() == test_param.test_names.size(), - "The number of test data set paths is not the same as the number of test data set data set names"); - - //begin testing - reg_boost_learner->InitModel(); - char model_path[256]; - std::vector preds; - for(size_t i = 0; i < test_param.test_paths.size(); i++){ - xgboost::regression::DMatrix test_data; - test_data.LoadText(test_param.test_paths[i].c_str()); - sprintf(model_path,"%s/final.model",test_param.model_dir_path); - // BUG: model need to be rb - FileStream fin(fopen(model_path,"r")); - reg_boost_learner->LoadModel(fin); - fin.Close(); - reg_boost_learner->Predict(preds,test_data); - } - } - - private: - struct TestParam{ - /* \brief upperbound of the number of boosters */ - int boost_iterations; - - /* \brief the period to save the model, -1 means only save the final round model */ - int save_period; - - /* \brief the path of directory containing the saved models */ - char model_dir_path[256]; - - /* \brief the path of directory containing the output prediction results */ - char pred_dir_path[256]; - - /* \brief the paths of test data sets */ - std::vector test_paths; - - /* \brief the names of the test data sets */ - std::vector test_names; - - /*! - * \brief set parameters from outside - * \param name name of the parameter - * \param val value of the parameter - */ - inline void SetParam(const char *name,const char *val ){ - if( !strcmp("model_dir_path", name ) ) strcpy(model_dir_path,val); - if( !strcmp("pred_dir_path", name ) ) strcpy(pred_dir_path,val); - if( !strcmp("test_paths", name) ) { - test_paths = StringProcessing::split(val,';'); - } - if( !strcmp("test_names", name) ) { - test_names = StringProcessing::split(val,';'); - } - } - }; - - TestParam test_param; - xgboost::regression::RegBoostLearner* reg_boost_learner; - }; - } -} - -#endif diff --git a/regression/xgboost_reg_train.h b/regression/xgboost_reg_train.h deleted file mode 100644 index 05f7f42ee..000000000 --- a/regression/xgboost_reg_train.h +++ /dev/null @@ -1,132 +0,0 @@ -#ifndef _XGBOOST_REG_TRAIN_H_ -#define _XGBOOST_REG_TRAIN_H_ - -#include -#include -#include -#include "../utils/xgboost_config.h" -#include "xgboost_reg.h" -#include "xgboost_regdata.h" -#include "../utils/xgboost_string.h" - -using namespace xgboost::utils; - -namespace xgboost{ - namespace regression{ - /*! - * \brief wrapping the training process of the gradient - boosting regression model,given the configuation - * \author Kailong Chen: chenkl198812@gmail.com - */ - class RegBoostTrain{ - public: - /*! - * \brief to start the training process of gradient boosting regression - * model given the configuation, and finally saved the models - * to the specified model directory - * \param config_path the location of the configuration - * \param silent whether to print feedback messages - */ - void train(char* config_path,bool silent = false){ - reg_boost_learner = new xgboost::regression::RegBoostLearner(); - - ConfigIterator config_itr(config_path); - //Get the training data and validation data paths, config the Learner - while (config_itr.Next()){ - printf("%s %s\n",config_itr.name(),config_itr.val()); - reg_boost_learner->SetParam(config_itr.name(),config_itr.val()); - train_param.SetParam(config_itr.name(),config_itr.val()); - } - - Assert(train_param.validation_data_paths.size() == train_param.validation_data_names.size(), - "The number of validation paths is not the same as the number of validation data set names"); - - //Load Data - xgboost::regression::DMatrix train; - printf("%s",train_param.train_path); - train.LoadText(train_param.train_path); - std::vector evals; - for(size_t i = 0; i < train_param.validation_data_paths.size(); i++){ - xgboost::regression::DMatrix eval; - eval.LoadText(train_param.validation_data_paths[i].c_str()); - evals.push_back(&eval); - } - reg_boost_learner->SetData(&train,evals,train_param.validation_data_names); - - //begin training - reg_boost_learner->InitTrainer(); - char suffix[256]; - for(int i = 1; i <= train_param.boost_iterations; i++){ - reg_boost_learner->UpdateOneIter(i); - if(train_param.save_period != 0 && i % train_param.save_period == 0){ - sprintf(suffix,"%d.model",i); - SaveModel(suffix); - } - } - - //save the final round model - SaveModel("final.model"); - } - - private: - /*! \brief save model in the model directory with specified suffix*/ - void SaveModel(const char* suffix){ - char model_path[256]; - //save the final round model - sprintf(model_path,"%s/%s",train_param.model_dir_path,suffix); - FILE* file = fopen(model_path,"w"); - FileStream fin(file); - reg_boost_learner->SaveModel(fin); - fin.Close(); - } - - struct TrainParam{ - /* \brief upperbound of the number of boosters */ - int boost_iterations; - - /* \brief the period to save the model, -1 means only save the final round model */ - int save_period; - - /* \brief the path of training data set */ - char train_path[256]; - - /* \brief the path of directory containing the saved models */ - char model_dir_path[256]; - - /* \brief the paths of validation data sets */ - std::vector validation_data_paths; - - /* \brief the names of the validation data sets */ - std::vector validation_data_names; - - /*! - * \brief set parameters from outside - * \param name name of the parameter - * \param val value of the parameter - */ - inline void SetParam(const char *name,const char *val ){ - if( !strcmp("boost_iterations", name ) ) boost_iterations = atoi( val ); - if( !strcmp("save_period", name ) ) save_period = atoi( val ); - if( !strcmp("train_path", name ) ) strcpy(train_path,val); - if( !strcmp("model_dir_path", name ) ) { - strcpy(model_dir_path,val); - } - if( !strcmp("validation_paths", name) ) { - validation_data_paths = StringProcessing::split(val,';'); - } - if( !strcmp("validation_names", name) ) { - validation_data_names = StringProcessing::split(val,';'); - } - } - }; - - /*! \brief the parameters of the training process*/ - TrainParam train_param; - - /*! \brief the gradient boosting regression tree model*/ - xgboost::regression::RegBoostLearner* reg_boost_learner; - }; - } -} - -#endif diff --git a/regression/xgboost_regeval.h b/regression/xgboost_regeval.h new file mode 100644 index 000000000..e5b0ce791 --- /dev/null +++ b/regression/xgboost_regeval.h @@ -0,0 +1,96 @@ +#ifndef _XGBOOST_REGEVAL_H_ +#define _XGBOOST_REGEVAL_H_ +/*! +* \file xgboost_regeval.h +* \brief evaluation metrics for regression and classification +* \author Kailong Chen: chenkl198812@gmail.com, Tianqi Chen: tianqi.tchen@gmail.com +*/ +#include +#include +#include +#include +#include "../utils/xgboost_utils.h" + +namespace xgboost{ + namespace regression{ + /*! \brief evaluator that evaluates the loss metrics */ + struct IEvaluator{ + /*! + * \brief evaluate a specific metric + * \param preds prediction + * \param labels label + */ + virtual float Eval( const std::vector &preds, + const std::vector &labels ) const= 0; + /*! \return name of metric */ + virtual const char *Name( void ) const= 0; + }; + + /*! \brief RMSE */ + struct EvalRMSE : public IEvaluator{ + virtual float Eval( const std::vector &preds, + const std::vector &labels ) const{ + const unsigned ndata = static_cast( preds.size() ); + float sum = 0.0; + #pragma omp parallel for reduction(+:sum) schedule( static ) + for( unsigned i = 0; i < ndata; ++ i ){ + float diff = preds[i] - labels[i]; + sum += diff * diff; + } + return sqrtf( sum / ndata ); + } + virtual const char *Name( void ) const{ + return "rmse"; + } + }; + + /*! \brief Error */ + struct EvalError : public IEvaluator{ + virtual float Eval( const std::vector &preds, + const std::vector &labels ) const{ + const unsigned ndata = static_cast( preds.size() ); + unsigned nerr = 0; + #pragma omp parallel for reduction(+:nerr) schedule( static ) + for( unsigned i = 0; i < ndata; ++ i ){ + if( preds[i] > 0.5f ){ + if( labels[i] < 0.5f ) nerr += 1; + }else{ + if( labels[i] > 0.5f ) nerr += 1; + } + } + return static_cast(nerr) / ndata; + } + virtual const char *Name( void ) const{ + return "error"; + } + }; + }; + + namespace regression{ + /*! \brief a set of evaluators */ + struct EvalSet{ + public: + inline void AddEval( const char *name ){ + if( !strcmp( name, "rmse") ) evals_.push_back( &rmse_ ); + if( !strcmp( name, "error") ) evals_.push_back( &error_ ); + } + inline void Init( void ){ + std::sort( evals_.begin(), evals_.end() ); + evals_.resize( std::unique( evals_.begin(), evals_.end() ) - evals_.begin() ); + } + inline void Eval( FILE *fo, const char *evname, + const std::vector &preds, + const std::vector &labels ) const{ + for( size_t i = 0; i < evals_.size(); ++ i ){ + float res = evals_[i]->Eval( preds, labels ); + fprintf( fo, "\t%s-%s:%f", evname, evals_[i]->Name(), res ); + } + } + private: + EvalRMSE rmse_; + EvalError error_; + std::vector evals_; + }; + }; +}; +#endif diff --git a/utils/xgboost_string.h b/utils/xgboost_string.h deleted file mode 100644 index 1ce056d33..000000000 --- a/utils/xgboost_string.h +++ /dev/null @@ -1,31 +0,0 @@ -#ifndef _XGBOOST_STRING_H_ -#define _XGBOOST_STRING_H_ -#include -#include - -namespace xgboost{ - namespace utils{ - class StringProcessing{ - - public: - static std::vector &split(const std::string &s, char delim, std::vector &elems) { - std::stringstream ss(s); - std::string item; - while (std::getline(ss, item, delim)) { - elems.push_back(item); - } - return elems; - } - - - static std::vector split(const std::string &s, char delim) { - std::vector elems; - split(s, delim, elems); - return elems; - } - - }; - } -} - -#endif \ No newline at end of file