diff --git a/README.md b/README.md index bf6b20cca..32ceb2706 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -xgboost +xgboost: A Gradient Boosting Library ======= Creater: Tianqi Chen: tianqi.tchen AT gmail @@ -7,16 +7,16 @@ General Purpose Gradient Boosting Library Goal: A stand-alone efficient library to do learning via boosting in functional space Features: -(1) Sparse feature format, handling of missing features. This allows efficient categorical feature encoding as indicators. The speed of booster only depens on number of existing features. -(2) Layout of gradient boosting algorithm to support generic tasks, see project wiki. +* Sparse feature format, handling of missing features. This allows efficient categorical feature encoding as indicators. The speed of booster only depends on number of existing features. +* Layout of gradient boosting algorithm to support generic tasks, see project wiki. Planned key components: -(1) Gradient boosting models: +* Gradient boosting models: - regression tree (GBRT) - linear model/lasso -(2) Objectives to support tasks: +* Objectives to support tasks: - regression - classification - ranking diff --git a/booster/xgboost_data.h b/booster/xgboost_data.h index 7f5b5f25f..d2b66cebd 100644 --- a/booster/xgboost_data.h +++ b/booster/xgboost_data.h @@ -9,6 +9,7 @@ #include #include "../utils/xgboost_utils.h" +#include "../utils/xgboost_stream.h" namespace xgboost{ namespace booster{ @@ -143,7 +144,7 @@ namespace xgboost{ * the function is not consistent between 64bit and 32bit machine * \param fo output stream */ - inline void SaveBinary( utils::IStream &fo ) const{ + inline void SaveBinary(utils::IStream &fo ) const{ size_t nrow = this->NumRow(); fo.Write( &nrow, sizeof(size_t) ); fo.Write( &row_ptr[0], row_ptr.size() * sizeof(size_t) ); diff --git a/regression/xgboost_reg.h b/regression/xgboost_reg.h index ebc574a85..c015298b7 100644 --- a/regression/xgboost_reg.h +++ b/regression/xgboost_reg.h @@ -1,10 +1,10 @@ #ifndef _XGBOOST_REG_H_ #define _XGBOOST_REG_H_ /*! - * \file xgboost_reg.h - * \brief class for gradient boosted regression - * \author Kailong Chen: chenkl198812@gmail.com, Tianqi Chen: tianqi.tchen@gmail.com - */ +* \file xgboost_reg.h +* \brief class for gradient boosted regression +* \author Kailong Chen: chenkl198812@gmail.com, Tianqi Chen: tianqi.tchen@gmail.com +*/ #include #include "xgboost_regdata.h" #include "../booster/xgboost_gbmbase.h" @@ -12,143 +12,265 @@ #include "../utils/xgboost_stream.h" namespace xgboost{ - namespace regression{ - /*! \brief class for gradient boosted regression */ - class RegBoostLearner{ - public: - /*! - * \brief a regression booter associated with training and evaluating data - * \param train pointer to the training data - * \param evals array of evaluating data - * \param evname name of evaluation data, used print statistics + namespace regression{ + /*! \brief class for gradient boosted regression */ + class RegBoostLearner{ + public: + + RegBoostLearner(bool silent = false){ + this->silent = silent; + } + + /*! + * \brief a regression booter associated with training and evaluating data + * \param train pointer to the training data + * \param evals array of evaluating data + * \param evname name of evaluation data, used print statistics + */ + RegBoostLearner( const DMatrix *train, + std::vector evals, + std::vector evname, bool silent = false ){ + this->silent = silent; + SetData(train,evals,evname); + } + + /*! + * \brief associate regression booster with training and evaluating data + * \param train pointer to the training data + * \param evals array of evaluating data + * \param evname name of evaluation data, used print statistics + */ + inline void SetData(const DMatrix *train, + std::vector evals, + std::vector evname){ + this->train_ = train; + this->evals_ = evals; + this->evname_ = evname; + //assign buffer index + int buffer_size = (*train).size(); + for(int i = 0; i < evals.size(); i++){ + buffer_size += (*evals[i]).size(); + } + char str[25]; + itoa(buffer_size,str,10); + base_model.SetParam("num_pbuffer",str); + } + + /*! + * \brief set parameters from outside + * \param name name of the parameter + * \param val value of the parameter + */ + inline void SetParam( const char *name, const char *val ){ + mparam.SetParam( name, val ); + base_model.SetParam( name, val ); + } + /*! + * \brief initialize solver before training, called before training + * this function is reserved for solver to allocate necessary space and do other preparation + */ + inline void InitTrainer( void ){ + base_model.InitTrainer(); + mparam.AdjustBase(); + } + + /*! + * \brief initialize the current data storage for model, if the model is used first time, call this function */ - RegBoostLearner( const DMatrix *train, - std::vector evals, - std::vector evname ){ - this->train_ = train; - this->evals_ = evals; - this->evname_ = evname; - //TODO: assign buffer index - } - /*! - * \brief set parameters from outside - * \param name name of the parameter - * \param val value of the parameter - */ - inline void SetParam( const char *name, const char *val ){ - mparam.SetParam( name, val ); - base_model.SetParam( name, val ); - } - /*! - * \brief initialize solver before training, called before training - * this function is reserved for solver to allocate necessary space and do other preparation - */ - inline void InitTrainer( void ){ - base_model.InitTrainer(); - mparam.AdjustBase(); - } - /*! - * \brief load model from stream - * \param fi input stream - */ - inline void LoadModel( utils::IStream &fi ){ - utils::Assert( fi.Read( &mparam, sizeof(ModelParam) ) != 0 ); - base_model.LoadModel( fi ); - } - /*! - * \brief save model to stream - * \param fo output stream - */ - inline void SaveModel( utils::IStream &fo ) const{ - fo.Write( &mparam, sizeof(ModelParam) ); - base_model.SaveModel( fo ); - } - /*! - * \brief update the model for one iteration - */ - inline void UpdateOneIter( void ){ - //TODO - } - /*! \brief predict the results, given data */ - inline void Predict( std::vector &preds, const DMatrix &data ){ - //TODO - } - private: - /*! \brief training parameter for regression */ - struct ModelParam{ - /* \brief global bias */ - float base_score; - /* \brief type of loss function */ - int loss_type; - ModelParam( void ){ - base_score = 0.5f; - loss_type = 0; - } - /*! - * \brief set parameters from outside - * \param name name of the parameter - * \param val value of the parameter - */ - inline void SetParam( const char *name, const char *val ){ - if( !strcmp("base_score", name ) ) base_score = (float)atof( val ); - if( !strcmp("loss_type", name ) ) loss_type = atoi( val ); - } - /*! - * \brief adjust base_score - */ - inline void AdjustBase( void ){ - if( loss_type == 1 ){ - utils::Assert( base_score > 0.0f && base_score < 1.0f, "sigmoid range constrain" ); - base_score = - logf( 1.0f / base_score - 1.0f ); - } - } - /*! - * \brief calculate first order gradient of loss, given transformed prediction - * \param predt transformed prediction - * \param label true label - * \return first order gradient - */ - inline float FirstOrderGradient( float predt, float label ) const{ - switch( loss_type ){ - case 0: return predt - label; - case 1: return predt - label; - default: utils::Error("unknown loss_type"); return 0.0f; - } - } - /*! - * \brief calculate second order gradient of loss, given transformed prediction - * \param predt transformed prediction - * \param label true label - * \return second order gradient - */ - inline float SecondOrderGradient( float predt, float label ) const{ - switch( loss_type ){ - case 0: return 1.0f; - case 1: return predt * ( 1 - predt ); - default: utils::Error("unknown loss_type"); return 0.0f; - } - } - /*! - * \brief transform the linear sum to prediction - * \param x linear sum of boosting ensemble - * \return transformed prediction - */ - inline float PredTransform( float x ){ - switch( loss_type ){ - case 0: return x; - case 1: return 1.0f/(1.0f + expf(-x)); - default: utils::Error("unknown loss_type"); return 0.0f; - } - } - }; - private: - booster::GBMBaseModel base_model; - ModelParam mparam; - const DMatrix *train_; - std::vector evals_; - std::vector evname_; - }; - }; + inline void InitModel( void ){ + base_model.InitModel(); + } + + /*! + * \brief load model from stream + * \param fi input stream + */ + inline void LoadModel( utils::IStream &fi ){ + utils::Assert( fi.Read( &mparam, sizeof(ModelParam) ) != 0 ); + base_model.LoadModel( fi ); + } + /*! + * \brief save model to stream + * \param fo output stream + */ + inline void SaveModel( utils::IStream &fo ) const{ + fo.Write( &mparam, sizeof(ModelParam) ); + base_model.SaveModel( fo ); + } + + /*! + * \brief update the model for one iteration + * \param iteration the number of updating iteration + */ + inline void UpdateOneIter( int iteration ){ + std::vector grad,hess,preds; + std::vector root_index; + booster::FMatrixS::Image train_image((*train_).data); + Predict(preds,*train_,0); + Gradient(preds,(*train_).labels,grad,hess); + base_model.DoBoost(grad,hess,train_image,root_index); + int buffer_index_offset = (*train_).size(); + float loss = 0.0; + for(int i = 0; i < evals_.size();i++){ + Predict(preds, *evals_[i], buffer_index_offset); + loss = mparam.Loss(preds,(*evals_[i]).labels); + if(!silent){ + printf("The loss of %s data set in %d the \ + iteration is %f",evname_[i].c_str(),&iteration,&loss); + } + buffer_index_offset += (*evals_[i]).size(); + } + + } + + /*! \brief get the transformed predictions, given data */ + inline void Predict( std::vector &preds, const DMatrix &data,int buffer_index_offset = 0 ){ + int data_size = data.size(); + preds.resize(data_size); + for(int j = 0; j < data_size; j++){ + preds[j] = mparam.PredTransform(mparam.base_score + + base_model.Predict(data.data[j],buffer_index_offset + j)); + } + } + + private: + /*! \brief get the first order and second order gradient, given the transformed predictions and labels*/ + inline void Gradient(const std::vector &preds, const std::vector &labels, std::vector &grad, + std::vector &hess){ + grad.clear(); + hess.clear(); + for(int j = 0; j < preds.size(); j++){ + grad.push_back(mparam.FirstOrderGradient(preds[j],labels[j])); + hess.push_back(mparam.SecondOrderGradient(preds[j],labels[j])); + } + } + + enum LOSS_TYPE_LIST{ + LINEAR_SQUARE, + LOGISTIC_NEGLOGLIKELIHOOD, + }; + + /*! \brief training parameter for regression */ + struct ModelParam{ + /* \brief global bias */ + float base_score; + /* \brief type of loss function */ + int loss_type; + + ModelParam( void ){ + base_score = 0.5f; + loss_type = 0; + } + /*! + * \brief set parameters from outside + * \param name name of the parameter + * \param val value of the parameter + */ + inline void SetParam( const char *name, const char *val ){ + if( !strcmp("base_score", name ) ) base_score = (float)atof( val ); + if( !strcmp("loss_type", name ) ) loss_type = atoi( val ); + } + /*! + * \brief adjust base_score + */ + inline void AdjustBase( void ){ + if( loss_type == 1 ){ + utils::Assert( base_score > 0.0f && base_score < 1.0f, "sigmoid range constrain" ); + base_score = - logf( 1.0f / base_score - 1.0f ); + } + } + /*! + * \brief calculate first order gradient of loss, given transformed prediction + * \param predt transformed prediction + * \param label true label + * \return first order gradient + */ + inline float FirstOrderGradient( float predt, float label ) const{ + switch( loss_type ){ + case LINEAR_SQUARE: return predt - label; + case 1: return predt - label; + default: utils::Error("unknown loss_type"); return 0.0f; + } + } + /*! + * \brief calculate second order gradient of loss, given transformed prediction + * \param predt transformed prediction + * \param label true label + * \return second order gradient + */ + inline float SecondOrderGradient( float predt, float label ) const{ + switch( loss_type ){ + case LINEAR_SQUARE: return 1.0f; + case LOGISTIC_NEGLOGLIKELIHOOD: return predt * ( 1 - predt ); + default: utils::Error("unknown loss_type"); return 0.0f; + } + } + + /*! + * \brief calculating the loss, given the predictions, labels and the loss type + * \param preds the given predictions + * \param labels the given labels + * \return the specified loss + */ + inline float Loss(const std::vector &preds, const std::vector &labels) const{ + switch( loss_type ){ + case LINEAR_SQUARE: return SquareLoss(preds,labels); + case LOGISTIC_NEGLOGLIKELIHOOD: return NegLoglikelihoodLoss(preds,labels); + default: utils::Error("unknown loss_type"); return 0.0f; + } + } + + /*! + * \brief calculating the square loss, given the predictions and labels + * \param preds the given predictions + * \param labels the given labels + * \return the summation of square loss + */ + inline float SquareLoss(const std::vector &preds, const std::vector &labels) const{ + float ans = 0.0; + for(int i = 0; i < preds.size(); i++) + ans += pow(preds[i] - labels[i], 2); + return ans; + } + + /*! + * \brief calculating the square loss, given the predictions and labels + * \param preds the given predictions + * \param labels the given labels + * \return the summation of square loss + */ + inline float NegLoglikelihoodLoss(const std::vector &preds, const std::vector &labels) const{ + float ans = 0.0; + for(int i = 0; i < preds.size(); i++) + ans -= labels[i] * log(preds[i]) + ( 1 - labels[i] ) * log(1 - preds[i]); + return ans; + } + + + /*! + * \brief transform the linear sum to prediction + * \param x linear sum of boosting ensemble + * \return transformed prediction + */ + inline float PredTransform( float x ){ + switch( loss_type ){ + case LINEAR_SQUARE: return x; + case LOGISTIC_NEGLOGLIKELIHOOD: return 1.0f/(1.0f + expf(-x)); + default: utils::Error("unknown loss_type"); return 0.0f; + } + } + + + }; + private: + booster::GBMBaseModel base_model; + ModelParam mparam; + const DMatrix *train_; + std::vector evals_; + std::vector evname_; + bool silent; + }; + } }; #endif diff --git a/regression/xgboost_reg_main.cpp b/regression/xgboost_reg_main.cpp new file mode 100644 index 000000000..6a3fb61a0 --- /dev/null +++ b/regression/xgboost_reg_main.cpp @@ -0,0 +1,14 @@ +#include"xgboost_reg_train.h" +#include"xgboost_reg_test.h" +using namespace xgboost::regression; + +int main(int argc, char *argv[]){ +// char* config_path = argv[1]; +// bool silent = ( atoi(argv[2]) == 1 ); + char* config_path = "c:\\cygwin64\\home\\chen\\github\\gboost\\demo\\regression\\reg.conf"; + bool silent = false; + RegBoostTrain train; + RegBoostTest test; + train.train(config_path,false); + test.test(config_path,false); +} \ No newline at end of file diff --git a/regression/xgboost_reg_test.h b/regression/xgboost_reg_test.h new file mode 100644 index 000000000..b6183d412 --- /dev/null +++ b/regression/xgboost_reg_test.h @@ -0,0 +1,99 @@ +#ifndef _XGBOOST_REG_TEST_H_ +#define _XGBOOST_REG_TEST_H_ + +#include +#include +#include +#include"../utils/xgboost_config.h" +#include"xgboost_reg.h" +#include"xgboost_regdata.h" +#include"../utils/xgboost_string.h" + +using namespace xgboost::utils; +namespace xgboost{ + namespace regression{ + /*! + * \brief wrapping the testing process of the gradient + boosting regression model,given the configuation + * \author Kailong Chen: chenkl198812@gmail.com + */ + class RegBoostTest{ + public: + /*! + * \brief to start the testing process of gradient boosting regression + * model given the configuation, and finally save the prediction + * results to the specified paths. + * \param config_path the location of the configuration + * \param silent whether to print feedback messages + */ + void test(char* config_path,bool silent = false){ + reg_boost_learner = new xgboost::regression::RegBoostLearner(silent); + ConfigIterator config_itr(config_path); + //Get the training data and validation data paths, config the Learner + while (config_itr.Next()){ + reg_boost_learner->SetParam(config_itr.name(),config_itr.val()); + test_param.SetParam(config_itr.name(),config_itr.val()); + } + + Assert(test_param.test_paths.size() == test_param.test_names.size(), + "The number of test data set paths is not the same as the number of test data set data set names"); + + //begin testing + reg_boost_learner->InitModel(); + char model_path[256]; + std::vector preds; + for(int i = 0; i < test_param.test_paths.size(); i++){ + xgboost::regression::DMatrix test_data; + test_data.LoadText(test_param.test_paths[i].c_str()); + sscanf(model_path,"%s/final.model",test_param.model_dir_path); + FileStream fin(fopen(model_path,"r")); + reg_boost_learner->LoadModel(fin); + fin.Close(); + reg_boost_learner->Predict(preds,test_data); + } + } + + private: + struct TestParam{ + /* \brief upperbound of the number of boosters */ + int boost_iterations; + + /* \brief the period to save the model, -1 means only save the final round model */ + int save_period; + + /* \brief the path of directory containing the saved models */ + const char* model_dir_path; + + /* \brief the path of directory containing the output prediction results */ + const char* pred_dir_path; + + /* \brief the paths of test data sets */ + std::vector test_paths; + + /* \brief the names of the test data sets */ + std::vector test_names; + + /*! + * \brief set parameters from outside + * \param name name of the parameter + * \param val value of the parameter + */ + inline void SetParam(const char *name,const char *val ){ + if( !strcmp("model_dir_path", name ) ) model_dir_path = val; + if( !strcmp("pred_dir_path", name ) ) model_dir_path = val; + if( !strcmp("test_paths", name) ) { + test_paths = StringProcessing::split(val,';'); + } + if( !strcmp("test_names", name) ) { + test_names = StringProcessing::split(val,';'); + } + } + }; + + TestParam test_param; + xgboost::regression::RegBoostLearner* reg_boost_learner; + }; + } +} + +#endif diff --git a/regression/xgboost_reg_train.h b/regression/xgboost_reg_train.h new file mode 100644 index 000000000..64c60250d --- /dev/null +++ b/regression/xgboost_reg_train.h @@ -0,0 +1,127 @@ +#ifndef _XGBOOST_REG_TRAIN_H_ +#define _XGBOOST_REG_TRAIN_H_ + +#include +#include +#include +#include"../utils/xgboost_config.h" +#include"xgboost_reg.h" +#include"xgboost_regdata.h" +#include"../utils/xgboost_string.h" + +using namespace xgboost::utils; + +namespace xgboost{ + namespace regression{ + /*! + * \brief wrapping the training process of the gradient + boosting regression model,given the configuation + * \author Kailong Chen: chenkl198812@gmail.com + */ + class RegBoostTrain{ + public: + /*! + * \brief to start the training process of gradient boosting regression + * model given the configuation, and finally saved the models + * to the specified model directory + * \param config_path the location of the configuration + * \param silent whether to print feedback messages + */ + void train(char* config_path,bool silent = false){ + reg_boost_learner = new xgboost::regression::RegBoostLearner(silent); + ConfigIterator config_itr(config_path); + //Get the training data and validation data paths, config the Learner + while (config_itr.Next()){ + reg_boost_learner->SetParam(config_itr.name(),config_itr.val()); + train_param.SetParam(config_itr.name(),config_itr.val()); + } + + Assert(train_param.validation_data_paths.size() == train_param.validation_data_names.size(), + "The number of validation paths is not the same as the number of validation data set names"); + + //Load Data + xgboost::regression::DMatrix train; + train.LoadText(train_param.train_path); + std::vector evals; + for(int i = 0; i < train_param.validation_data_paths.size(); i++){ + xgboost::regression::DMatrix eval; + eval.LoadText(train_param.validation_data_paths[i].c_str()); + evals.push_back(&eval); + } + reg_boost_learner->SetData(&train,evals,train_param.validation_data_names); + + //begin training + reg_boost_learner->InitTrainer(); + char suffix[256]; + for(int i = 1; i <= train_param.boost_iterations; i++){ + reg_boost_learner->UpdateOneIter(i); + if(train_param.save_period != 0 && i % train_param.save_period == 0){ + sscanf(suffix,"%d.model",i); + SaveModel(suffix); + } + } + + //save the final round model + SaveModel("final.model"); + } + + private: + /*! \brief save model in the model directory with specified suffix*/ + void SaveModel(const char* suffix){ + char model_path[256]; + //save the final round model + sscanf(model_path,"%s/%s",train_param.model_dir_path,suffix); + FILE* file = fopen(model_path,"w"); + FileStream fin(file); + reg_boost_learner->SaveModel(fin); + fin.Close(); + } + + struct TrainParam{ + /* \brief upperbound of the number of boosters */ + int boost_iterations; + + /* \brief the period to save the model, -1 means only save the final round model */ + int save_period; + + /* \brief the path of training data set */ + const char* train_path; + + /* \brief the path of directory containing the saved models */ + const char* model_dir_path; + + /* \brief the paths of validation data sets */ + std::vector validation_data_paths; + + /* \brief the names of the validation data sets */ + std::vector validation_data_names; + + /*! + * \brief set parameters from outside + * \param name name of the parameter + * \param val value of the parameter + */ + inline void SetParam(const char *name,const char *val ){ + if( !strcmp("boost_iterations", name ) ) boost_iterations = (float)atof( val ); + if( !strcmp("save_period", name ) ) save_period = atoi( val ); + if( !strcmp("train_path", name ) ) train_path = val; + if( !strcmp("model_dir_path", name ) ) model_dir_path = val; + if( !strcmp("validation_paths", name) ) { + validation_data_paths = StringProcessing::split(val,';'); + } + if( !strcmp("validation_names", name) ) { + validation_data_names = StringProcessing::split(val,';'); + } + } + }; + + /*! \brief the parameters of the training process*/ + TrainParam train_param; + + /*! \brief the gradient boosting regression tree model*/ + xgboost::regression::RegBoostLearner* reg_boost_learner; + }; + } +} + +#endif diff --git a/regression/xgboost_regdata.h b/regression/xgboost_regdata.h index 075a53f23..e7f22d0c9 100644 --- a/regression/xgboost_regdata.h +++ b/regression/xgboost_regdata.h @@ -30,6 +30,13 @@ namespace xgboost{ public: /*! \brief default constructor */ DMatrix( void ){} + + + /*! \brief get the number of instances */ + inline int size() const{ + return labels.size(); + } + /*! * \brief load from text file * \param fname name of text data diff --git a/utils/xgboost_config.h b/utils/xgboost_config.h index 473c38e78..737f6f486 100644 --- a/utils/xgboost_config.h +++ b/utils/xgboost_config.h @@ -10,6 +10,7 @@ #include #include #include "xgboost_utils.h" +#include namespace xgboost{ namespace utils{ diff --git a/utils/xgboost_string.h b/utils/xgboost_string.h new file mode 100644 index 000000000..1ce056d33 --- /dev/null +++ b/utils/xgboost_string.h @@ -0,0 +1,31 @@ +#ifndef _XGBOOST_STRING_H_ +#define _XGBOOST_STRING_H_ +#include +#include + +namespace xgboost{ + namespace utils{ + class StringProcessing{ + + public: + static std::vector &split(const std::string &s, char delim, std::vector &elems) { + std::stringstream ss(s); + std::string item; + while (std::getline(ss, item, delim)) { + elems.push_back(item); + } + return elems; + } + + + static std::vector split(const std::string &s, char delim) { + std::vector elems; + split(s, delim, elems); + return elems; + } + + }; + } +} + +#endif \ No newline at end of file