diff --git a/dev/README.md b/dev/README.md deleted file mode 100644 index 406cf66ff..000000000 --- a/dev/README.md +++ /dev/null @@ -1 +0,0 @@ -this folder contains codes under development diff --git a/dev/base/xgboost_boost_task.h b/dev/base/xgboost_boost_task.h new file mode 100644 index 000000000..1234eee8f --- /dev/null +++ b/dev/base/xgboost_boost_task.h @@ -0,0 +1,324 @@ +#define _CRT_SECURE_NO_WARNINGS +#define _CRT_SECURE_NO_DEPRECATE + +#include +#include +#include +#include "xgboost_data_instance.h" +#include "xgboost_learner.h" +#include "../utils/xgboost_fmap.h" +#include "../utils/xgboost_random.h" +#include "../utils/xgboost_config.h" + +namespace xgboost{ + namespace base{ + /*! + * \brief wrapping the training process of the gradient boosting model, + * given the configuation + * \author Kailong Chen: chenkl198812@gmail.com, Tianqi Chen: tianqi.chen@gmail.com + */ + class BoostTask{ + public: + inline int Run(int argc, char *argv[]){ + if (argc < 2){ + printf("Usage: \n"); + return 0; + } + utils::ConfigIterator itr(argv[1]); + while (itr.Next()){ + this->SetParam(itr.name(), itr.val()); + } + for (int i = 2; i < argc; i++){ + char name[256], val[256]; + if (sscanf(argv[i], "%[^=]=%s", name, val) == 2){ + this->SetParam(name, val); + } + } + this->InitData(); + this->InitLearner(); + if (task == "dump"){ + this->TaskDump(); + return 0; + } + if (task == "interact"){ + this->TaskInteractive(); return 0; + } + if (task == "dumppath"){ + this->TaskDumpPath(); return 0; + } + if (task == "eval"){ + this->TaskEval(); return 0; + } + if (task == "pred"){ + this->TaskPred(); + } + else{ + this->TaskTrain(); + } + return 0; + } + + enum learning_tasks{ + REGRESSION = 0, + BINARY_CLASSIFICATION = 1, + RANKING = 2 + }; + + /* \brief set learner + * \param learner the passed in learner + */ + inline void SetLearner(BoostLearner* learner){ + learner_ = learner; + } + + inline void SetParam(const char *name, const char *val){ + if (!strcmp("learning_task", name)) learning_task = atoi(val); + if (!strcmp("silent", name)) silent = atoi(val); + if (!strcmp("use_buffer", name)) use_buffer = atoi(val); + if (!strcmp("seed", name)) random::Seed(atoi(val)); + if (!strcmp("num_round", name)) num_round = atoi(val); + if (!strcmp("save_period", name)) save_period = atoi(val); + if (!strcmp("task", name)) task = val; + if (!strcmp("data", name)) train_path = val; + if (!strcmp("test:data", name)) test_path = val; + if (!strcmp("model_in", name)) model_in = val; + if (!strcmp("model_out", name)) model_out = val; + if (!strcmp("model_dir", name)) model_dir_path = val; + if (!strcmp("fmap", name)) name_fmap = val; + if (!strcmp("name_dump", name)) name_dump = val; + if (!strcmp("name_dumppath", name)) name_dumppath = val; + if (!strcmp("name_pred", name)) name_pred = val; + if (!strcmp("dump_stats", name)) dump_model_stats = atoi(val); + if (!strcmp("interact:action", name)) interact_action = val; + if (!strncmp("batch:", name, 6)){ + cfg_batch.PushBack(name + 6, val); + } + if (!strncmp("eval[", name, 5)) { + char evname[256]; + utils::Assert(sscanf(name, "eval[%[^]]", evname) == 1, "must specify evaluation name for display"); + eval_data_names.push_back(std::string(evname)); + eval_data_paths.push_back(std::string(val)); + } + cfg.PushBack(name, val); + } + public: + BoostTask(void){ + // default parameters + silent = 0; + use_buffer = 1; + num_round = 10; + save_period = 0; + dump_model_stats = 0; + task = "train"; + model_in = "NULL"; + model_out = "NULL"; + name_fmap = "NULL"; + name_pred = "pred.txt"; + name_dump = "dump.txt"; + name_dumppath = "dump.path.txt"; + model_dir_path = "./"; + interact_action = "update"; + } + ~BoostTask(void){ + for (size_t i = 0; i < deval.size(); i++){ + delete deval[i]; + } + } + private: + + + inline void InitData(void){ + if (name_fmap != "NULL") fmap.LoadText(name_fmap.c_str()); + if (task == "dump") return; + if (learning_task == RANKING){ + char instance_path[256], group_path[256]; + if (task == "pred" || task == "dumppath"){ + sscanf(test_path.c_str(), "%[^;];%s", instance_path, group_path); + data.CacheLoad(instance_path, group_path, silent != 0, use_buffer != 0); + } + else{ + // training + sscanf(train_path.c_str(), "%[^;];%s", instance_path, group_path); + data.CacheLoad(instance_path, group_path, silent != 0, use_buffer != 0); + utils::Assert(eval_data_names.size() == eval_data_paths.size()); + for (size_t i = 0; i < eval_data_names.size(); ++i){ + deval.push_back(new DMatrix()); + sscanf(eval_data_paths[i].c_str(), "%[^;];%s", instance_path, group_path); + deval.back()->CacheLoad(instance_path, group_path, silent != 0, use_buffer != 0); + } + } + + + } + else{ + if (task == "pred" || task == "dumppath"){ + data.CacheLoad(test_path.c_str(), "", silent != 0, use_buffer != 0); + } + else{ + // training + data.CacheLoad(train_path.c_str(), "", silent != 0, use_buffer != 0); + utils::Assert(eval_data_names.size() == eval_data_paths.size()); + for (size_t i = 0; i < eval_data_names.size(); ++i){ + deval.push_back(new DMatrix()); + deval.back()->CacheLoad(eval_data_paths[i].c_str(), "", silent != 0, use_buffer != 0); + } + } + } + + learner_->SetData(&data, deval, eval_data_names); + } + inline void InitLearner(void){ + cfg.BeforeFirst(); + while (cfg.Next()){ + learner_->SetParam(cfg.name(), cfg.val()); + } + if (model_in != "NULL"){ + utils::FileStream fi(utils::FopenCheck(model_in.c_str(), "rb")); + learner_->LoadModel(fi); + fi.Close(); + } + else{ + utils::Assert(task == "train", "model_in not specified"); + learner_->InitModel(); + } + learner_->InitTrainer(); + } + + inline void TaskTrain(void){ + const time_t start = time(NULL); + unsigned long elapsed = 0; + for (int i = 0; i < num_round; ++i){ + elapsed = (unsigned long)(time(NULL) - start); + if (!silent) printf("boosting round %d, %lu sec elapsed\n", i, elapsed); + learner_->UpdateOneIter(i); + learner_->EvalOneIter(i); + if (save_period != 0 && (i + 1) % save_period == 0){ + this->SaveModel(i); + } + elapsed = (unsigned long)(time(NULL) - start); + } + // always save final round + if (save_period == 0 || num_round % save_period != 0){ + if (model_out == "NULL"){ + this->SaveModel(num_round - 1); + } + else{ + this->SaveModel(model_out.c_str()); + } + } + if (!silent){ + printf("\nupdating end, %lu sec in all\n", elapsed); + } + } + inline void TaskEval(void){ + learner_->EvalOneIter(0); + } + inline void TaskInteractive(void){ + const time_t start = time(NULL); + unsigned long elapsed = 0; + int batch_action = 0; + + cfg_batch.BeforeFirst(); + while (cfg_batch.Next()){ + if (!strcmp(cfg_batch.name(), "run")){ + learner_->UpdateInteract(interact_action); + batch_action += 1; + } + else{ + learner_->SetParam(cfg_batch.name(), cfg_batch.val()); + } + } + + if (batch_action == 0){ + learner_->UpdateInteract(interact_action); + } + utils::Assert(model_out != "NULL", "interactive mode must specify model_out"); + this->SaveModel(model_out.c_str()); + elapsed = (unsigned long)(time(NULL) - start); + + if (!silent){ + printf("\ninteractive update, %d batch actions, %lu sec in all\n", batch_action, elapsed); + } + } + + inline void TaskDump(void){ + FILE *fo = utils::FopenCheck(name_dump.c_str(), "w"); + learner_->DumpModel(fo, fmap, dump_model_stats != 0); + fclose(fo); + } + inline void TaskDumpPath(void){ + FILE *fo = utils::FopenCheck(name_dumppath.c_str(), "w"); + learner_->DumpPath(fo, data); + fclose(fo); + } + inline void SaveModel(const char *fname) const{ + utils::FileStream fo(utils::FopenCheck(fname, "wb")); + learner_->SaveModel(fo); + fo.Close(); + } + inline void SaveModel(int i) const{ + char fname[256]; + sprintf(fname, "%s/%04d.model", model_dir_path.c_str(), i + 1); + this->SaveModel(fname); + } + inline void TaskPred(void){ + std::vector preds; + if (!silent) printf("start prediction...\n"); + learner_->Predict(preds, data); + if (!silent) printf("writing prediction to %s\n", name_pred.c_str()); + FILE *fo = utils::FopenCheck(name_pred.c_str(), "w"); + for (size_t i = 0; i < preds.size(); i++){ + fprintf(fo, "%f\n", preds[i]); + } + fclose(fo); + } + private: + /* \brief specify the learning task*/ + int learning_task; + /* \brief whether silent */ + int silent; + /* \brief whether use auto binary buffer */ + int use_buffer; + /* \brief number of boosting iterations */ + int num_round; + /* \brief the period to save the model, 0 means only save the final round model */ + int save_period; + /*! \brief interfact action */ + std::string interact_action; + /* \brief the path of training/test data set */ + std::string train_path, test_path; + /* \brief the path of test model file, or file to restart training */ + std::string model_in; + /* \brief the path of final model file, to be saved */ + std::string model_out; + /* \brief the path of directory containing the saved models */ + std::string model_dir_path; + /* \brief task to perform, choosing training or testing */ + std::string task; + /* \brief name of predict file */ + std::string name_pred; + /* \brief whether dump statistics along with model */ + int dump_model_stats; + /* \brief name of feature map */ + std::string name_fmap; + /* \brief name of dump file */ + std::string name_dump; + /* \brief name of dump path file */ + std::string name_dumppath; + /* \brief the paths of validation data sets */ + std::vector eval_data_paths; + /* \brief the names of the evaluation data used in output log */ + std::vector eval_data_names; + /*! \brief saves configurations */ + utils::ConfigSaver cfg; + /*! \brief batch configurations */ + utils::ConfigSaver cfg_batch; + private: + DMatrix data; + std::vector deval; + utils::FeatMap fmap; + BoostLearner* learner_; + + }; + }; +}; diff --git a/dev/base/xgboost_data_instance.h b/dev/base/xgboost_data_instance.h new file mode 100644 index 000000000..e33ca687a --- /dev/null +++ b/dev/base/xgboost_data_instance.h @@ -0,0 +1,191 @@ +#ifndef XGBOOST_DATA_INSTANCE_H +#define XGBOOST_DATA_INSTANCE_H + +#include +#include +#include "../booster/xgboost_data.h" +#include "../utils/xgboost_utils.h" +#include "../utils/xgboost_stream.h" + + +namespace xgboost{ + namespace base{ + /*! \brief data matrix for regression,classification,rank content */ + struct DMatrix{ + public: + /*! \brief maximum feature dimension */ + unsigned num_feature; + /*! \brief feature data content */ + booster::FMatrixS data; + /*! \brief label of each instance */ + std::vector labels; + /*! \brief the index of begin and end of a group, + * needed when the learning task is ranking*/ + std::vector group_index; + public: + /*! \brief default constructor */ + DMatrix(void){} + + /*! \brief get the number of instances */ + inline size_t Size() const{ + return labels.size(); + } + /*! + * \brief load from text file + * \param fname file of instances data + * \param fgroup file of the group data + * \param silent whether print information or not + */ + inline void LoadText(const char* fname, const char* fgroup, bool silent = false){ + data.Clear(); + FILE* file = utils::FopenCheck(fname, "r"); + float label; bool init = true; + char tmp[1024]; + std::vector findex; + std::vector fvalue; + + while (fscanf(file, "%s", tmp) == 1){ + unsigned index; float value; + if (sscanf(tmp, "%u:%f", &index, &value) == 2){ + findex.push_back(index); fvalue.push_back(value); + } + else{ + if (!init){ + labels.push_back(label); + data.AddRow(findex, fvalue); + } + findex.clear(); fvalue.clear(); + utils::Assert(sscanf(tmp, "%f", &label) == 1, "invalid format"); + init = false; + } + } + + labels.push_back(label); + data.AddRow(findex, fvalue); + // initialize column support as well + data.InitData(); + + if (!silent){ + printf("%ux%u matrix with %lu entries is loaded from %s\n", + (unsigned)data.NumRow(), (unsigned)data.NumCol(), (unsigned long)data.NumEntry(), fname); + } + fclose(file); + + //if exists group data load it in + FILE *file_group = fopen64(fgroup, "r"); + if (file_group != NULL){ + group_index.push_back(0); + int tmp = 0, acc = 0; + while (fscanf(file_group, "%d", tmp) == 1){ + acc += tmp; + group_index.push_back(acc); + } + } + } + /*! + * \brief load from binary file + * \param fname name of binary data + * \param silent whether print information or not + * \return whether loading is success + */ + inline bool LoadBinary(const char* fname, const char* fgroup, bool silent = false){ + FILE *fp = fopen64(fname, "rb"); + if (fp == NULL) return false; + utils::FileStream fs(fp); + data.LoadBinary(fs); + labels.resize(data.NumRow()); + utils::Assert(fs.Read(&labels[0], sizeof(float)* data.NumRow()) != 0, "DMatrix LoadBinary"); + fs.Close(); + // initialize column support as well + data.InitData(); + + if (!silent){ + printf("%ux%u matrix with %lu entries is loaded from %s\n", + (unsigned)data.NumRow(), (unsigned)data.NumCol(), (unsigned long)data.NumEntry(), fname); + } + + //if group data exists load it in + FILE *file_group = fopen64(fgroup, "r"); + if (file_group != NULL){ + int group_index_size = 0; + utils::FileStream group_stream(file_group); + utils::Assert(group_stream.Read(&group_index_size, sizeof(int)) != 0, "Load group indice size"); + group_index.resize(group_index_size); + utils::Assert(group_stream.Read(&group_index, sizeof(int)* group_index_size) != 0, "Load group indice"); + + if (!silent){ + printf("the group index of %d groups is loaded from %s\n", + group_index_size - 1, fgroup); + } + } + return true; + } + /*! + * \brief save to binary file + * \param fname name of binary data + * \param silent whether print information or not + */ + inline void SaveBinary(const char* fname, const char* fgroup, bool silent = false){ + // initialize column support as well + data.InitData(); + + utils::FileStream fs(utils::FopenCheck(fname, "wb")); + data.SaveBinary(fs); + fs.Write(&labels[0], sizeof(float)* data.NumRow()); + fs.Close(); + if (!silent){ + printf("%ux%u matrix with %lu entries is saved to %s\n", + (unsigned)data.NumRow(), (unsigned)data.NumCol(), (unsigned long)data.NumEntry(), fname); + } + + //save group data + if (group_index.size() > 0){ + utils::FileStream file_group(utils::FopenCheck(fgroup, "wb")); + int group_index_size = group_index.size(); + file_group.Write(&(group_index_size), sizeof(int)); + file_group.Write(&group_index[0], sizeof(int) * group_index_size); + } + + } + /*! + * \brief cache load data given a file name, if filename ends with .buffer, direct load binary + * otherwise the function will first check if fname + '.buffer' exists, + * if binary buffer exists, it will reads from binary buffer, otherwise, it will load from text file, + * and try to create a buffer file + * \param fname name of binary data + * \param silent whether print information or not + * \param savebuffer whether do save binary buffer if it is text + */ + inline void CacheLoad(const char *fname, const char *fgroup, bool silent = false, bool savebuffer = true){ + int len = strlen(fname); + if (len > 8 && !strcmp(fname + len - 7, ".buffer")){ + this->LoadBinary(fname, fgroup, silent); return; + } + char bname[1024]; + sprintf(bname, "%s.buffer", fname); + if (!this->LoadBinary(bname, fgroup, silent)){ + this->LoadText(fname, fgroup, silent); + if (savebuffer) this->SaveBinary(bname, fgroup, silent); + } + } + private: + /*! \brief update num_feature info */ + inline void UpdateInfo(void){ + this->num_feature = 0; + for (size_t i = 0; i < data.NumRow(); i++){ + booster::FMatrixS::Line sp = data[i]; + for (unsigned j = 0; j < sp.len; j++){ + if (num_feature <= sp[j].findex){ + num_feature = sp[j].findex + 1; + } + } + } + } + }; + + + + } +}; + +#endif \ No newline at end of file diff --git a/dev/base/xgboost_learner.h b/dev/base/xgboost_learner.h new file mode 100644 index 000000000..0b02030f8 --- /dev/null +++ b/dev/base/xgboost_learner.h @@ -0,0 +1,275 @@ +#ifndef XGBOOST_LEARNER_H +#define XGBOOST_LEARNER_H +/*! +* \file xgboost_learner.h +* \brief class for gradient boosting learner +* \author Kailong Chen: chenkl198812@gmail.com, Tianqi Chen: tianqi.tchen@gmail.com +*/ +#include +#include +#include +#include "xgboost_data_instance.h" +#include "../utils/xgboost_omp.h" +#include "../booster/xgboost_gbmbase.h" +#include "../utils/xgboost_utils.h" +#include "../utils/xgboost_stream.h" + +namespace xgboost { + namespace base { + /*! \brief class for gradient boosting learner */ + class BoostLearner { + public: + /*! \brief constructor */ + BoostLearner(void) { + silent = 0; + } + /*! + * \brief booster associated with training and evaluating data + * \param train pointer to the training data + * \param evals array of evaluating data + * \param evname name of evaluation data, used print statistics + */ + BoostLearner(const DMatrix *train, + const std::vector &evals, + const std::vector &evname) { + silent = 0; + this->SetData(train, evals, evname); + } + + /*! + * \brief associate booster with training and evaluating data + * \param train pointer to the training data + * \param evals array of evaluating data + * \param evname name of evaluation data, used print statistics + */ + inline void SetData(const DMatrix *train, + const std::vector &evals, + const std::vector &evname) { + this->train_ = train; + this->evals_ = evals; + this->evname_ = evname; + // estimate feature bound + int num_feature = (int)(train->data.NumCol()); + // assign buffer index + unsigned buffer_size = static_cast(train->Size()); + + for (size_t i = 0; i < evals.size(); ++i) { + buffer_size += static_cast(evals[i]->Size()); + num_feature = std::max(num_feature, (int)(evals[i]->data.NumCol())); + } + + char str_temp[25]; + if (num_feature > mparam.num_feature) { + mparam.num_feature = num_feature; + sprintf(str_temp, "%d", num_feature); + base_gbm.SetParam("bst:num_feature", str_temp); + } + + sprintf(str_temp, "%u", buffer_size); + base_gbm.SetParam("num_pbuffer", str_temp); + if (!silent) { + printf("buffer_size=%u\n", buffer_size); + } + + // set eval_preds tmp sapce + this->eval_preds_.resize(evals.size(), std::vector()); + } + /*! + * \brief set parameters from outside + * \param name name of the parameter + * \param val value of the parameter + */ + virtual inline void SetParam(const char *name, const char *val) { + if (!strcmp(name, "silent")) silent = atoi(val); + mparam.SetParam(name, val); + base_gbm.SetParam(name, val); + } + /*! + * \brief initialize solver before training, called before training + * this function is reserved for solver to allocate necessary space and do other preparation + */ + inline void InitTrainer(void) { + base_gbm.InitTrainer(); + } + /*! + * \brief initialize the current data storage for model, if the model is used first time, call this function + */ + inline void InitModel(void) { + base_gbm.InitModel(); + } + /*! + * \brief load model from stream + * \param fi input stream + */ + inline void LoadModel(utils::IStream &fi) { + base_gbm.LoadModel(fi); + utils::Assert(fi.Read(&mparam, sizeof(ModelParam)) != 0); + } + /*! + * \brief DumpModel + * \param fo text file + * \param fmap feature map that may help give interpretations of feature + * \param with_stats whether print statistics as well + */ + inline void DumpModel(FILE *fo, const utils::FeatMap& fmap, bool with_stats) { + base_gbm.DumpModel(fo, fmap, with_stats); + } + /*! + * \brief Dump path of all trees + * \param fo text file + * \param data input data + */ + inline void DumpPath(FILE *fo, const DMatrix &data) { + base_gbm.DumpPath(fo, data.data); + } + + /*! + * \brief save model to stream + * \param fo output stream + */ + inline void SaveModel(utils::IStream &fo) const { + base_gbm.SaveModel(fo); + fo.Write(&mparam, sizeof(ModelParam)); + } + + virtual void EvalOneIter(int iter, FILE *fo = stderr) {} + + /*! + * \brief update the model for one iteration + * \param iteration iteration number + */ + inline void UpdateOneIter(int iter) { + this->PredictBuffer(preds_, *train_, 0); + this->GetGradient(preds_, train_->labels, train_->group_index, grad_, hess_); + std::vector root_index; + base_gbm.DoBoost(grad_, hess_, train_->data, root_index); + } + + /*! \brief get intransformed prediction, without buffering */ + inline void Predict(std::vector &preds, const DMatrix &data) { + preds.resize(data.Size()); + + const unsigned ndata = static_cast(data.Size()); +#pragma omp parallel for schedule( static ) + for (unsigned j = 0; j < ndata; ++j) { + preds[j] = base_gbm.Predict(data.data, j, -1); + } + } + + public: + /*! + * \brief update the model for one iteration + * \param iteration iteration number + */ + virtual inline void UpdateInteract(std::string action){ + this->InteractPredict(preds_, *train_, 0); + + int buffer_offset = static_cast(train_->Size()); + for (size_t i = 0; i < evals_.size(); ++i) { + std::vector &preds = this->eval_preds_[i]; + this->InteractPredict(preds, *evals_[i], buffer_offset); + buffer_offset += static_cast(evals_[i]->Size()); + } + + if (action == "remove") { + base_gbm.DelteBooster(); + return; + } + + this->GetGradient(preds_, train_->labels, train_->group_index, grad_, hess_); + std::vector root_index; + base_gbm.DoBoost(grad_, hess_, train_->data, root_index); + + this->InteractRePredict(*train_, 0); + buffer_offset = static_cast(train_->Size()); + for (size_t i = 0; i < evals_.size(); ++i) { + this->InteractRePredict(*evals_[i], buffer_offset); + buffer_offset += static_cast(evals_[i]->Size()); + } + }; + + protected: + /*! \brief get the intransformed predictions, given data */ + inline void InteractPredict(std::vector &preds, const DMatrix &data, unsigned buffer_offset) { + preds.resize(data.Size()); + const unsigned ndata = static_cast(data.Size()); +#pragma omp parallel for schedule( static ) + for (unsigned j = 0; j < ndata; ++j) { + preds[j] = base_gbm.InteractPredict(data.data, j, buffer_offset + j); + } + } + /*! \brief repredict trial */ + inline void InteractRePredict(const xgboost::base::DMatrix &data, unsigned buffer_offset) { + const unsigned ndata = static_cast(data.Size()); +#pragma omp parallel for schedule( static ) + for (unsigned j = 0; j < ndata; ++j) { + base_gbm.InteractRePredict(data.data, j, buffer_offset + j); + } + } + + /*! \brief get intransformed predictions, given data */ + virtual inline void PredictBuffer(std::vector &preds, const DMatrix &data, unsigned buffer_offset) { + preds.resize(data.Size()); + + const unsigned ndata = static_cast(data.Size()); +#pragma omp parallel for schedule( static ) + for (unsigned j = 0; j < ndata; ++j) { + preds[j] = base_gbm.Predict(data.data, j, buffer_offset + j); + } + } + + /*! \brief get the first order and second order gradient, given the transformed predictions and labels */ + virtual inline void GetGradient(const std::vector &preds, + const std::vector &labels, + const std::vector &group_index, + std::vector &grad, + std::vector &hess) {}; + + + protected: + + /*! \brief training parameter for regression */ + struct ModelParam { + /* \brief type of loss function */ + int loss_type; + /* \brief number of features */ + int num_feature; + /*! \brief reserved field */ + int reserved[16]; + /*! \brief constructor */ + ModelParam(void) { + loss_type = 0; + num_feature = 0; + memset(reserved, 0, sizeof(reserved)); + } + /*! + * \brief set parameters from outside + * \param name name of the parameter + * \param val value of the parameter + */ + inline void SetParam(const char *name, const char *val) { + if (!strcmp("loss_type", name)) loss_type = atoi(val); + if (!strcmp("bst:num_feature", name)) num_feature = atoi(val); + } + + }; + + int silent; + booster::GBMBase base_gbm; + ModelParam mparam; + const DMatrix *train_; + std::vector evals_; + std::vector evname_; + std::vector buffer_index_; + std::vector grad_, hess_, preds_; + std::vector< std::vector > eval_preds_; + }; + } +}; + +#endif + + + + + diff --git a/dev/rank/xgboost_rank.h b/dev/rank/xgboost_rank.h index 3f0986e25..82320f2a9 100644 --- a/dev/rank/xgboost_rank.h +++ b/dev/rank/xgboost_rank.h @@ -9,347 +9,142 @@ #include #include #include "xgboost_sample.h" -#include "xgboost_rank_data.h" #include "xgboost_rank_eval.h" +#include "../base/xgboost_data_instance.h" #include "../utils/xgboost_omp.h" #include "../booster/xgboost_gbmbase.h" #include "../utils/xgboost_utils.h" #include "../utils/xgboost_stream.h" +#include "../base/xgboost_learner.h" namespace xgboost { -namespace rank { -/*! \brief class for gradient boosted regression */ -class RankBoostLearner { -public: - /*! \brief constructor */ - RegBoostLearner( void ) { - silent = 0; - } - /*! - * \brief a rank booster associated with training and evaluating data - * \param train pointer to the training data - * \param evals array of evaluating data - * \param evname name of evaluation data, used print statistics - */ - RankBoostLearner( const RMatrix *train, - const std::vector &evals, - const std::vector &evname ) { - silent = 0; - this->SetData(train,evals,evname); - } + namespace rank { + /*! \brief class for gradient boosted regression */ + class RankBoostLearner :public base::BoostLearner{ + public: + /*! \brief constructor */ + RankBoostLearner(void) { + BoostLearner(); + } + /*! + * \brief a rank booster associated with training and evaluating data + * \param train pointer to the training data + * \param evals array of evaluating data + * \param evname name of evaluation data, used print statistics + */ + RankBoostLearner(const base::DMatrix *train, + const std::vector &evals, + const std::vector &evname) { - /*! - * \brief associate rank booster with training and evaluating data - * \param train pointer to the training data - * \param evals array of evaluating data - * \param evname name of evaluation data, used print statistics - */ - inline void SetData( const RMatrix *train, - const std::vector &evals, - const std::vector &evname ) { - this->train_ = train; - this->evals_ = evals; - this->evname_ = evname; - // estimate feature bound - int num_feature = (int)(train->data.NumCol()); - // assign buffer index - unsigned buffer_size = static_cast( train->Size() ); + BoostLearner(train, evals, evname); + } - for( size_t i = 0; i < evals.size(); ++ i ) { - buffer_size += static_cast( evals[i]->Size() ); - num_feature = std::max( num_feature, (int)(evals[i]->data.NumCol()) ); - } + /*! + * \brief initialize solver before training, called before training + * this function is reserved for solver to allocate necessary space + * and do other preparation + */ + inline void InitTrainer(void) { + BoostLearner::InitTrainer(); + if (mparam.loss_type == PAIRWISE) { + evaluator_.AddEval("PAIR"); + } + else if (mparam.loss_type == MAP) { + evaluator_.AddEval("MAP"); + } + else { + evaluator_.AddEval("NDCG"); + } + evaluator_.Init(); + } - char str_temp[25]; - if( num_feature > mparam.num_feature ) { - mparam.num_feature = num_feature; - sprintf( str_temp, "%d", num_feature ); - base_gbm.SetParam( "bst:num_feature", str_temp ); - } + void EvalOneIter(int iter, FILE *fo = stderr) { + fprintf(fo, "[%d]", iter); + int buffer_offset = static_cast(train_->Size()); - sprintf( str_temp, "%u", buffer_size ); - base_gbm.SetParam( "num_pbuffer", str_temp ); - if( !silent ) { - printf( "buffer_size=%u\n", buffer_size ); - } + for (size_t i = 0; i < evals_.size(); ++i) { + std::vector &preds = this->eval_preds_[i]; + this->PredictBuffer(preds, *evals_[i], buffer_offset); + evaluator_.Eval(fo, evname_[i].c_str(), preds, (*evals_[i]).labels, (*evals_[i]).group_index); + buffer_offset += static_cast(evals_[i]->Size()); + } + fprintf(fo, "\n"); + } - // set eval_preds tmp sapce - this->eval_preds_.resize( evals.size(), std::vector() ); - } - /*! - * \brief set parameters from outside - * \param name name of the parameter - * \param val value of the parameter - */ - inline void SetParam( const char *name, const char *val ) { - if( !strcmp( name, "silent") ) silent = atoi( val ); - if( !strcmp( name, "eval_metric") ) evaluator_.AddEval( val ); - mparam.SetParam( name, val ); - base_gbm.SetParam( name, val ); - } - /*! - * \brief initialize solver before training, called before training - * this function is reserved for solver to allocate necessary space and do other preparation - */ - inline void InitTrainer( void ) { - base_gbm.InitTrainer(); - if( mparam.loss_type == PAIRWISE) { - evaluator_.AddEval( "PAIR" ); - } else if( mparam.loss_type == MAP) { - evaluator_.AddEval( "MAP" ); - } else { - evaluator_.AddEval( "NDCG" ); - } - evaluator_.Init(); - sampler.AssignSampler(mparam.sampler_type); - } - /*! - * \brief initialize the current data storage for model, if the model is used first time, call this function - */ - inline void InitModel( void ) { - base_gbm.InitModel(); - } - /*! - * \brief load model from stream - * \param fi input stream - */ - inline void LoadModel( utils::IStream &fi ) { - base_gbm.LoadModel( fi ); - utils::Assert( fi.Read( &mparam, sizeof(ModelParam) ) != 0 ); - } - /*! - * \brief DumpModel - * \param fo text file - * \param fmap feature map that may help give interpretations of feature - * \param with_stats whether print statistics as well - */ - inline void DumpModel( FILE *fo, const utils::FeatMap& fmap, bool with_stats ) { - base_gbm.DumpModel( fo, fmap, with_stats ); - } - /*! - * \brief Dump path of all trees - * \param fo text file - * \param data input data - */ - inline void DumpPath( FILE *fo, const RMatrix &data ) { - base_gbm.DumpPath( fo, data.data ); - } - - /*! - * \brief save model to stream - * \param fo output stream - */ - inline void SaveModel( utils::IStream &fo ) const { - base_gbm.SaveModel( fo ); - fo.Write( &mparam, sizeof(ModelParam) ); - } - - /*! - * \brief update the model for one iteration - * \param iteration iteration number - */ - inline void UpdateOneIter( int iter ) { - this->PredictBuffer( preds_, *train_, 0 ); - this->GetGradient( preds_, train_->labels,train_->group_index, grad_, hess_ ); - std::vector root_index; - base_gbm.DoBoost( grad_, hess_, train_->data, root_index ); - } - /*! - * \brief evaluate the model for specific iteration - * \param iter iteration number - * \param fo file to output log - */ - inline void EvalOneIter( int iter, FILE *fo = stderr ) { - fprintf( fo, "[%d]", iter ); - int buffer_offset = static_cast( train_->Size() ); + inline void SetParam(const char *name, const char *val){ + if (!strcmp(name, "eval_metric")) evaluator_.AddEval(val); + if (!strcmp(name, "rank:sampler")) sampler.AssignSampler(atoi(val)); + } + /*! \brief get the first order and second order gradient, given the transformed predictions and labels */ + inline void GetGradient(const std::vector &preds, + const std::vector &labels, + const std::vector &group_index, + std::vector &grad, + std::vector &hess) { + grad.resize(preds.size()); + hess.resize(preds.size()); + bool j_better; + float pred_diff, pred_diff_exp, first_order_gradient, second_order_gradient; + for (int i = 0; i < group_index.size() - 1; i++){ + sample::Pairs pairs = sampler.GenPairs(preds, labels, group_index[i], group_index[i + 1]); + for (int j = group_index[i]; j < group_index[i + 1]; j++){ + std::vector pair_instance = pairs.GetPairs(j); + for (int k = 0; k < pair_instance.size(); k++){ + j_better = labels[j] > labels[pair_instance[k]]; + if (j_better){ + pred_diff = preds[preds[j] - pair_instance[k]]; + pred_diff_exp = j_better ? expf(-pred_diff) : expf(pred_diff); + first_order_gradient = FirstOrderGradient(pred_diff_exp); + second_order_gradient = 2 * SecondOrderGradient(pred_diff_exp); + hess[j] += second_order_gradient; + grad[j] += first_order_gradient; + hess[pair_instance[k]] += second_order_gradient; + grad[pair_instance[k]] += -first_order_gradient; + } + } + } + } + } - for( size_t i = 0; i < evals_.size(); ++i ) { - std::vector &preds = this->eval_preds_[ i ]; - this->PredictBuffer( preds, *evals_[i], buffer_offset); - evaluator_.Eval( fo, evname_[i].c_str(), preds, (*evals_[i]).labels ); - buffer_offset += static_cast( evals_[i]->Size() ); - } - fprintf( fo,"\n" ); - } - - /*! \brief get intransformed prediction, without buffering */ - inline void Predict( std::vector &preds, const DMatrix &data ) { - preds.resize( data.Size() ); - - const unsigned ndata = static_cast( data.Size() ); - #pragma omp parallel for schedule( static ) - for( unsigned j = 0; j < ndata; ++ j ) { - preds[j] = base_gbm.Predict( data.data, j, -1 ); - } - } - -public: - /*! - * \brief update the model for one iteration - * \param iteration iteration number - */ - inline void UpdateInteract( std::string action ) { - this->InteractPredict( preds_, *train_, 0 ); - - int buffer_offset = static_cast( train_->Size() ); - for( size_t i = 0; i < evals_.size(); ++i ) { - std::vector &preds = this->eval_preds_[ i ]; - this->InteractPredict( preds, *evals_[i], buffer_offset ); - buffer_offset += static_cast( evals_[i]->Size() ); - } - - if( action == "remove" ) { - base_gbm.DelteBooster(); - return; - } - - this->GetGradient( preds_, train_->labels, grad_, hess_ ); - std::vector root_index; - base_gbm.DoBoost( grad_, hess_, train_->data, root_index ); - - this->InteractRePredict( *train_, 0 ); - buffer_offset = static_cast( train_->Size() ); - for( size_t i = 0; i < evals_.size(); ++i ) { - this->InteractRePredict( *evals_[i], buffer_offset ); - buffer_offset += static_cast( evals_[i]->Size() ); - } - } -private: - /*! \brief get the transformed predictions, given data */ - inline void InteractPredict( std::vector &preds, const DMatrix &data, unsigned buffer_offset ) { - preds.resize( data.Size() ); - const unsigned ndata = static_cast( data.Size() ); - #pragma omp parallel for schedule( static ) - for( unsigned j = 0; j < ndata; ++ j ) { - preds[j] = base_gbm.InteractPredict( data.data, j, buffer_offset + j ); - } - } - /*! \brief repredict trial */ - inline void InteractRePredict( const DMatrix &data, unsigned buffer_offset ) { - const unsigned ndata = static_cast( data.Size() ); - #pragma omp parallel for schedule( static ) - for( unsigned j = 0; j < ndata; ++ j ) { - base_gbm.InteractRePredict( data.data, j, buffer_offset + j ); - } - } -private: - /*! \brief get intransformed predictions, given data */ - inline void PredictBuffer( std::vector &preds, const RMatrix &data, unsigned buffer_offset ) { - preds.resize( data.Size() ); - - const unsigned ndata = static_cast( data.Size() ); - #pragma omp parallel for schedule( static ) - for( unsigned j = 0; j < ndata; ++ j ) { - preds[j] = base_gbm.Predict( data.data, j, buffer_offset + j ); - - } - } - - /*! \brief get the first order and second order gradient, given the transformed predictions and labels */ - inline void GetGradient( const std::vector &preds, - const std::vector &labels, - const std::vector &group_index, - std::vector &grad, - std::vector &hess ) { - grad.resize( preds.size() ); - hess.resize( preds.size() ); - bool j_better; - float pred_diff,pred_diff_exp,first_order_gradient,second_order_gradient; - for(int i = 0; i < group_index.size() - 1; i++){ - - sample::Pairs pairs = sampler.GenPairs(preds,labels,group_index[i],group_index[i+1]); - for(int j = group_index[i]; j < group_index[i + 1]; j++){ - std::vector pair_instance = pairs.GetPairs(j); - for(int k = 0; k < pair_instance.size(); k++){ - j_better = labels[j] > labels[pair_instance[k]]; - if(j_better){ - pred_diff = preds[preds[j] - pair_instance[k]]; - pred_diff_exp = j_better? expf(-pred_diff):expf(pred_diff); - first_order_gradient = mparam.FirstOrderGradient(pred_diff_exp); - second_order_gradient = 2 * mparam.SecondOrderGradient(pred_diff_exp); - hess[j] += second_order_gradient; - grad[j] += first_order_gradient; - hess[pair_instance[k]] += second_order_gradient; - grad[pair_instance[k]] += -first_order_gradient; - } - } - } - } - - } - -private: - enum LossType { - PAIRWISE = 0, - MAP = 1, - NDCG = 2 - }; - - /*! \brief training parameter for regression */ - struct ModelParam { - /* \brief type of loss function */ - int loss_type; - /* \brief number of features */ - int num_feature; - /*! \brief reserved field */ - int reserved[ 16 ]; - /*! \brief sampler type */ - int sampler_type; - /*! \brief constructor */ - ModelParam( void ) { - loss_type = 0; - num_feature = 0; - memset( reserved, 0, sizeof( reserved ) ); - } - /*! - * \brief set parameters from outside - * \param name name of the parameter - * \param val value of the parameter - */ - inline void SetParam( const char *name, const char *val ) { - if( !strcmp("loss_type", name ) ) loss_type = atoi( val ); - if( !strcmp("bst:num_feature", name ) ) num_feature = atoi( val ); - if( !strcmp("rank:sampler",name)) sampler = atoi( val ); - } + inline void UpdateInteract(std::string action) { + + } + private: + enum LossType { + PAIRWISE = 0, + MAP = 1, + NDCG = 2 + }; - /*! - * \brief calculate first order gradient of pairwise loss function(f(x) = ln(1+exp(-x)), - * given the exponential of the difference of intransformed pair predictions - * \param the intransformed prediction of positive instance - * \param the intransformed prediction of negative instance - * \return first order gradient - */ - inline float FirstOrderGradient( float pred_diff_exp) const { - return -pred_diff_exp/(1 + pred_diff_exp); - } - - /*! - * \brief calculate second order gradient of pairwise loss function(f(x) = ln(1+exp(-x)), - * given the exponential of the difference of intransformed pair predictions - * \param the intransformed prediction of positive instance - * \param the intransformed prediction of negative instance - * \return second order gradient - */ - inline float SecondOrderGradient( float pred_diff_exp ) const { - return pred_diff_exp/pow(1 + pred_diff_exp,2); - } - }; -private: - int silent; - RankEvalSet evaluator_; - sample::PairSamplerWrapper sampler; - booster::GBMBase base_gbm; - ModelParam mparam; - const RMatrix *train_; - std::vector evals_; - std::vector evname_; - std::vector buffer_index_; -private: - std::vector grad_, hess_, preds_; - std::vector< std::vector > eval_preds_; -}; -} + + /*! + * \brief calculate first order gradient of pairwise loss function(f(x) = ln(1+exp(-x)), + * given the exponential of the difference of intransformed pair predictions + * \param the intransformed prediction of positive instance + * \param the intransformed prediction of negative instance + * \return first order gradient + */ + inline float FirstOrderGradient(float pred_diff_exp) const { + return -pred_diff_exp / (1 + pred_diff_exp); + } + + /*! + * \brief calculate second order gradient of pairwise loss function(f(x) = ln(1+exp(-x)), + * given the exponential of the difference of intransformed pair predictions + * \param the intransformed prediction of positive instance + * \param the intransformed prediction of negative instance + * \return second order gradient + */ + inline float SecondOrderGradient(float pred_diff_exp) const { + return pred_diff_exp / pow(1 + pred_diff_exp, 2); + } + + private: + RankEvalSet evaluator_; + sample::PairSamplerWrapper sampler; + }; + }; }; #endif diff --git a/dev/rank/xgboost_rank_data.h b/dev/rank/xgboost_rank_data.h deleted file mode 100644 index 40eeb2ac7..000000000 --- a/dev/rank/xgboost_rank_data.h +++ /dev/null @@ -1,179 +0,0 @@ -#ifndef XGBOOST_RANK_DATA_H -#define XGBOOST_RANK_DATA_H - -/*! - * \file xgboost_rank_data.h - * \brief input data structure for rank task. - * Format: - * The data should contains groups of rank data, a group here may refer to - * the rank list of a query, or the browsing history of a user, etc. - * Each group first contains the size of the group in a single line, - * then following is the line data with the same format with the regression data: - * label [feature index:feature value]+ - * \author Kailong Chen: chenkl198812@gmail.com, Tianqi Chen: tianqi.tchen@gmail.com - */ -#include -#include -#include "../booster/xgboost_data.h" -#include "../utils/xgboost_utils.h" -#include "../utils/xgboost_stream.h" - -namespace xgboost { -namespace rank { -/*! \brief data matrix for regression content */ -struct RMatrix { -public: - /*! \brief maximum feature dimension */ - unsigned num_feature; - /*! \brief feature data content */ - booster::FMatrixS data; - /*! \brief label of each instance */ - std::vector labels; - /*! \brief The index of begin and end of each group */ - std::vector group_index; -public: - /*! \brief default constructor */ - RMatrix( void ) {} - - /*! \brief get the number of instances */ - inline size_t Size() const { - return labels.size(); - } - - /*! - * \brief load from text file - * \param fname name of text data - * \param silent whether print information or not - */ - inline void LoadText( const char* fname, bool silent = false ) { - data.Clear(); - FILE* file = utils::FopenCheck( fname, "r" ); - float label; - bool init = true; - char tmp[ 1024 ]; - int group_size,group_size_acc = 0; - std::vector findex; - std::vector fvalue; - group_index.push_back(0); - while(fscanf(file, "%d",group_size) == 1) { - group_size_acc += group_size; - group_index.push_back(group_size_acc); - unsigned index; - float value; - if( sscanf( tmp, "%u:%f", &index, &value ) == 2 ) { - findex.push_back( index ); - fvalue.push_back( value ); - } else { - if( !init ) { - labels.push_back( label ); - data.AddRow( findex, fvalue ); - } - findex.clear(); - fvalue.clear(); - utils::Assert( sscanf( tmp, "%f", &label ) == 1, "invalid format" ); - init = false; - } - } - - - labels.push_back( label ); - data.AddRow( findex, fvalue ); - // initialize column support as well - data.InitData(); - - if( !silent ) { - printf("%ux%u matrix with %lu entries is loaded from %s\n", - (unsigned)data.NumRow(), (unsigned)data.NumCol(), (unsigned long)data.NumEntry(), fname ); - } - fclose(file); - } - /*! - * \brief load from binary file - * \param fname name of binary data - * \param silent whether print information or not - * \return whether loading is success - */ - inline bool LoadBinary( const char* fname, bool silent = false ) { - FILE *fp = fopen64( fname, "rb" ); - int group_index_size = 0; - if( fp == NULL ) return false; - utils::FileStream fs( fp ); - data.LoadBinary( fs ); - labels.resize( data.NumRow() ); - utils::Assert( fs.Read( &labels[0], sizeof(float) * data.NumRow() ) != 0, "DMatrix LoadBinary" ); - - utils::Assert( fs.Read( &group_index_size, sizeof(int) ) != 0, "Load group indice size" ); - group_index.resize(group_index_size); - utils::Assert( fs.Read( &group_index, sizeof(int) * group_index_size) != 0, "Load group indice" d); - - fs.Close(); - // initialize column support as well - data.InitData(); - - if( !silent ) { - printf("%ux%u matrix with %lu entries is loaded from %s\n", - (unsigned)data.NumRow(), (unsigned)data.NumCol(), (unsigned long)data.NumEntry(), fname ); - } - return true; - } - /*! - * \brief save to binary file - * \param fname name of binary data - * \param silent whether print information or not - */ - inline void SaveBinary( const char* fname, bool silent = false ) { - // initialize column support as well - data.InitData(); - - utils::FileStream fs( utils::FopenCheck( fname, "wb" ) ); - data.SaveBinary( fs ); - fs.Write( &labels[0], sizeof(float) * data.NumRow() ); - - fs.Write( &(group_index.size()), sizeof(int)); - fs.Write( &group_index[0], sizeof(int) * group_index.size() ); - - fs.Close(); - if( !silent ) { - printf("%ux%u matrix with %lu entries is saved to %s\n", - (unsigned)data.NumRow(), (unsigned)data.NumCol(), (unsigned long)data.NumEntry(), fname ); - } - } - /*! - * \brief cache load data given a file name, if filename ends with .buffer, direct load binary - * otherwise the function will first check if fname + '.buffer' exists, - * if binary buffer exists, it will reads from binary buffer, otherwise, it will load from text file, - * and try to create a buffer file - * \param fname name of binary data - * \param silent whether print information or not - * \param savebuffer whether do save binary buffer if it is text - */ - inline void CacheLoad( const char *fname, bool silent = false, bool savebuffer = true ) { - int len = strlen( fname ); - if( len > 8 && !strcmp( fname + len - 7, ".buffer") ) { - this->LoadBinary( fname, silent ); - return; - } - char bname[ 1024 ]; - sprintf( bname, "%s.buffer", fname ); - if( !this->LoadBinary( bname, silent ) ) { - this->LoadText( fname, silent ); - if( savebuffer ) this->SaveBinary( bname, silent ); - } - } -private: - /*! \brief update num_feature info */ - inline void UpdateInfo( void ) { - this->num_feature = 0; - for( size_t i = 0; i < data.NumRow(); i ++ ) { - booster::FMatrixS::Line sp = data[i]; - for( unsigned j = 0; j < sp.len; j ++ ) { - if( num_feature <= sp[j].findex ) { - num_feature = sp[j].findex + 1; - } - } - } - } -}; -}; -}; -#endif diff --git a/dev/rank/xgboost_rank_eval.h b/dev/rank/xgboost_rank_eval.h index ab36cbafd..73a7664ca 100644 --- a/dev/rank/xgboost_rank_eval.h +++ b/dev/rank/xgboost_rank_eval.h @@ -13,157 +13,170 @@ #include "../utils/xgboost_omp.h" namespace xgboost { -namespace rank { -/*! \brief evaluator that evaluates the loss metrics */ -struct IRankEvaluator { - /*! - * \brief evaluate a specific metric - * \param preds prediction - * \param labels label - */ - virtual float Eval( const std::vector &preds, - const std::vector &labels, - const std::vector &group_index) const= 0; - /*! \return name of metric */ - virtual const char *Name( void ) const= 0; -}; + namespace rank { + /*! \brief evaluator that evaluates the loss metrics */ + class IRankEvaluator { + public: + /*! + * \brief evaluate a specific metric + * \param preds prediction + * \param labels label + */ + virtual float Eval(const std::vector &preds, + const std::vector &labels, + const std::vector &group_index) const = 0; + /*! \return name of metric */ + virtual const char *Name(void) const = 0; + }; -struct Pair{ - float key_; - float value_; - - Pair(float key,float value){ - key_ = key; - value_ = value_; - } -}; + class Pair{ + public: + float key_; + float value_; -bool PairKeyComparer(const Pair &a, const Pair &b){ - return a.key_ < b.key_; -} + Pair(float key, float value){ + key_ = key; + value_ = value_; + } + }; -bool PairValueComparer(const Pair &a, const Pair &b){ - return a.value_ < b.value_; -} + bool PairKeyComparer(const Pair &a, const Pair &b){ + return a.key_ < b.key_; + } -struct EvalPair : public IRankEvaluator{ - virtual float Eval( const std::vector &preds, - const std::vector &labels, - const std::vector &group_index ) const { - return 0; - } -}; - -/*! \brief Mean Average Precision */ -struct EvalMAP : public IRankEvaluator { - virtual float Eval( const std::vector &preds, - const std::vector &labels, - const std::vector &group_index ) const { - float acc = 0; - std::vector pairs_sort; - for(int i = 0; i < group_index.size() - 1; i++){ - for(int j = group_index[i]; j < group_index[i+1];j++){ - Pair pair(preds[j],labels[j]); - pairs_sort.push_back(pair); - } - acc += average_precision(pairs_sort); - } - return acc / (group_index.size() - 1); - } - - float float average_precision(std::vector pairs_sort){ - std::sort(pairs_sort.begin(),pairs_sort.end(),PairKeyComparer); - float hits = 0; - float average_precision = 0; - for(int j = 0; j < pairs_sort.size(); j++){ - if(pairs_sort[j].value_ == 1){ - hits++; - average_precision += hits/(j+1); - } - } - if(hits != 0) average_precision /= hits; - return average_precision; - } - - virtual const char *Name( void ) const { - return "MAP"; - } - -}; + bool PairValueComparer(const Pair &a, const Pair &b){ + return a.value_ < b.value_; + } -/*! \brief Normalized DCG */ -struct EvalNDCG : public IRankEvaluator { - virtual float Eval( const std::vector &preds, - const std::vector &labels, - const std::vector &group_index ) const { - float acc = 0; - std::vector pairs_sort; - for(int i = 0; i < group_index.size() - 1; i++){ - for(int j = group_index[i]; j < group_index[i+1];j++){ - Pair pair(preds[j],labels[j]); - pairs_sort.push_back(pair); - } - acc += NDCG(pairs_sort); - } - } - - float NDCG(std::vector pairs_sort){ - std::sort(pairs_sort.begin(),pairs_sort.end(),PairKeyComparer); - float DCG = DCG(pairs_sort); - std::sort(pairs_sort.begin(),pairs_sort.end(),PairValueComparer); - float IDCG = DCG(pairs_sort); - if(IDCG == 0) return 0; - return DCG/IDCG; - } - - float DCG(std::vector pairs_sort){ - float ans = 0.0; - ans += pairs_sort[0].value_; - for(int i = 1; i < pairs_sort.size(); i++){ - ans += pairs_sort[i].value_/log(i + 1); - } - return ans; - } - - virtual const char *Name( void ) const { - return "NDCG"; - } -}; + /*! \brief Mean Average Precision */ + class EvalMAP : public IRankEvaluator { + public: + float Eval(const std::vector &preds, + const std::vector &labels, + const std::vector &group_index) const { + float acc = 0; + std::vector pairs_sort; + for (int i = 0; i < group_index.size() - 1; i++){ + for (int j = group_index[i]; j < group_index[i + 1]; j++){ + Pair pair(preds[j], labels[j]); + pairs_sort.push_back(pair); + } + acc += average_precision(pairs_sort); + } + return acc / (group_index.size() - 1); + } -}; -namespace rank { -/*! \brief a set of evaluators */ -struct RankEvalSet { -public: - inline void AddEval( const char *name ) { - if( !strcmp( name, "PAIR" )) evals_.push_back( &pair_); - if( !strcmp( name, "MAP") ) evals_.push_back( &map_ ); - if( !strcmp( name, "NDCG") ) evals_.push_back( &ndcg_ ); - } - - inline void Init( void ) { - std::sort( evals_.begin(), evals_.end() ); - evals_.resize( std::unique( evals_.begin(), evals_.end() ) - evals_.begin() ); - } - - inline void Eval( FILE *fo, const char *evname, - const std::vector &preds, - const std::vector &labels, - const std::vector &group_index ) const { - for( size_t i = 0; i < evals_.size(); ++ i ) { - float res = evals_[i]->Eval( preds, labels,group_index ); - fprintf( fo, "\t%s-%s:%f", evname, evals_[i]->Name(), res ); - } - } - -private: - EvalPair pair_; - EvalMAP map_; - EvalNDCG ndcg_; - std::vector evals_; -}; -}; + virtual const char *Name(void) const { + return "MAP"; + } + + float average_precision(std::vector pairs_sort) const{ + + std::sort(pairs_sort.begin(), pairs_sort.end(), PairKeyComparer); + float hits = 0; + float average_precision = 0; + for (int j = 0; j < pairs_sort.size(); j++){ + if (pairs_sort[j].value_ == 1){ + hits++; + average_precision += hits / (j + 1); + } + } + if (hits != 0) average_precision /= hits; + return average_precision; + } + }; + + + class EvalPair : public IRankEvaluator{ + public: + float Eval(const std::vector &preds, + const std::vector &labels, + const std::vector &group_index) const { + return 0; + } + + const char *Name(void) const { + return "PAIR"; + } + }; + + /*! \brief Normalized DCG */ + class EvalNDCG : public IRankEvaluator { + public: + float Eval(const std::vector &preds, + const std::vector &labels, + const std::vector &group_index) const { + if (group_index.size() <= 1) return 0; + float acc = 0; + std::vector pairs_sort; + for (int i = 0; i < group_index.size() - 1; i++){ + for (int j = group_index[i]; j < group_index[i + 1]; j++){ + Pair pair(preds[j], labels[j]); + pairs_sort.push_back(pair); + } + acc += NDCG(pairs_sort); + } + return acc / (group_index.size() - 1); + } + + float NDCG(std::vector pairs_sort) const{ + std::sort(pairs_sort.begin(), pairs_sort.end(), PairKeyComparer); + float dcg = DCG(pairs_sort); + std::sort(pairs_sort.begin(), pairs_sort.end(), PairValueComparer); + float IDCG = DCG(pairs_sort); + if (IDCG == 0) return 0; + return dcg / IDCG; + } + + float DCG(std::vector pairs_sort) const{ + float ans = 0.0; + ans += pairs_sort[0].value_; + for (int i = 1; i < pairs_sort.size(); i++){ + ans += pairs_sort[i].value_ / log(i + 1); + } + return ans; + } + + virtual const char *Name(void) const { + return "NDCG"; + } + }; + + }; + + namespace rank { + /*! \brief a set of evaluators */ + class RankEvalSet { + public: + inline void AddEval(const char *name) { + if (!strcmp(name, "PAIR")) evals_.push_back(&pair_); + if (!strcmp(name, "MAP")) evals_.push_back(&map_); + if (!strcmp(name, "NDCG")) evals_.push_back(&ndcg_); + } + + inline void Init(void) { + std::sort(evals_.begin(), evals_.end()); + evals_.resize(std::unique(evals_.begin(), evals_.end()) - evals_.begin()); + } + + inline void Eval(FILE *fo, const char *evname, + const std::vector &preds, + const std::vector &labels, + const std::vector &group_index) const { + for (size_t i = 0; i < evals_.size(); ++i) { + float res = evals_[i]->Eval(preds, labels, group_index); + fprintf(fo, "\t%s-%s:%f", evname, evals_[i]->Name(), res); + } + } + + private: + EvalPair pair_; + EvalMAP map_; + EvalNDCG ndcg_; + std::vector evals_; + }; + }; }; #endif diff --git a/dev/rank/xgboost_rank_main.cpp b/dev/rank/xgboost_rank_main.cpp index 2652c7ff0..05580e280 100644 --- a/dev/rank/xgboost_rank_main.cpp +++ b/dev/rank/xgboost_rank_main.cpp @@ -1,283 +1,30 @@ #define _CRT_SECURE_NO_WARNINGS #define _CRT_SECURE_NO_DEPRECATE - #include #include #include -#include "xgboost_rank.h" +#include "../base/xgboost_learner.h" #include "../utils/xgboost_fmap.h" #include "../utils/xgboost_random.h" #include "../utils/xgboost_config.h" +#include "../base/xgboost_learner.h" +#include "../base/xgboost_boost_task.h" +#include "xgboost_rank.h" +#include "../regression/xgboost_reg.h" -namespace xgboost { -namespace rank { -/*! -* \brief wrapping the training process of the gradient boosting regression model, -* given the configuation -* \author Kailong Chen: chenkl198812@gmail.com, Tianqi Chen: tianqi.chen@gmail.com -*/ -class RankBoostTask { -public: - inline int Run( int argc, char *argv[] ) { - if( argc < 2 ) { - printf("Usage: \n"); - return 0; - } - utils::ConfigIterator itr( argv[1] ); - while( itr.Next() ) { - this->SetParam( itr.name(), itr.val() ); - } - for( int i = 2; i < argc; i ++ ) { - char name[256], val[256]; - if( sscanf( argv[i], "%[^=]=%s", name, val ) == 2 ) { - this->SetParam( name, val ); - } - } - this->InitData(); - this->InitLearner(); - if( task == "dump" ) { - this->TaskDump(); - return 0; - } - if( task == "interact" ) { - this->TaskInteractive(); - return 0; - } - if( task == "dumppath" ) { - this->TaskDumpPath(); - return 0; - } - if( task == "eval" ) { - this->TaskEval(); - return 0; - } - if( task == "pred" ) { - this->TaskPred(); - } else { - this->TaskTrain(); - } - return 0; - } - inline void SetParam( const char *name, const char *val ) { - if( !strcmp("silent", name ) ) silent = atoi( val ); - if( !strcmp("use_buffer", name ) ) use_buffer = atoi( val ); - if( !strcmp("seed", name ) ) random::Seed( atoi(val) ); - if( !strcmp("num_round", name ) ) num_round = atoi( val ); - if( !strcmp("save_period", name ) ) save_period = atoi( val ); - if( !strcmp("task", name ) ) task = val; - if( !strcmp("data", name ) ) train_path = val; - if( !strcmp("test:data", name ) ) test_path = val; - if( !strcmp("model_in", name ) ) model_in = val; - if( !strcmp("model_out", name ) ) model_out = val; - if( !strcmp("model_dir", name ) ) model_dir_path = val; - if( !strcmp("fmap", name ) ) name_fmap = val; - if( !strcmp("name_dump", name ) ) name_dump = val; - if( !strcmp("name_dumppath", name ) ) name_dumppath = val; - if( !strcmp("name_pred", name ) ) name_pred = val; - if( !strcmp("dump_stats", name ) ) dump_model_stats = atoi( val ); - if( !strcmp("interact:action", name ) ) interact_action = val; - if( !strncmp("batch:", name, 6 ) ) { - cfg_batch.PushBack( name + 6, val ); - } - if( !strncmp("eval[", name, 5 ) ) { - char evname[ 256 ]; - utils::Assert( sscanf( name, "eval[%[^]]", evname ) == 1, "must specify evaluation name for display"); - eval_data_names.push_back( std::string( evname ) ); - eval_data_paths.push_back( std::string( val ) ); - } - cfg.PushBack( name, val ); - } -public: - RankBoostTask( void ) { - // default parameters - silent = 0; - use_buffer = 1; - num_round = 10; - save_period = 0; - dump_model_stats = 0; - task = "train"; - model_in = "NULL"; - model_out = "NULL"; - name_fmap = "NULL"; - name_pred = "pred.txt"; - name_dump = "dump.txt"; - name_dumppath = "dump.path.txt"; - model_dir_path = "./"; - interact_action = "update"; - } - ~RankBoostTask( void ) { - for( size_t i = 0; i < deval.size(); i ++ ) { - delete deval[i]; - } - } -private: - inline void InitData( void ) { - if( name_fmap != "NULL" ) fmap.LoadText( name_fmap.c_str() ); - if( task == "dump" ) return; - if( task == "pred" || task == "dumppath" ) { - data.CacheLoad( test_path.c_str(), silent!=0, use_buffer!=0 ); - } else { - // training - data.CacheLoad( train_path.c_str(), silent!=0, use_buffer!=0 ); - utils::Assert( eval_data_names.size() == eval_data_paths.size() ); - for( size_t i = 0; i < eval_data_names.size(); ++ i ) { - deval.push_back( new RMatrix() ); - deval.back()->CacheLoad( eval_data_paths[i].c_str(), silent!=0, use_buffer!=0 ); - } - } - learner.SetData( &data, deval, eval_data_names ); - } - inline void InitLearner( void ) { - cfg.BeforeFirst(); - while( cfg.Next() ) { - learner.SetParam( cfg.name(), cfg.val() ); - } - if( model_in != "NULL" ) { - utils::FileStream fi( utils::FopenCheck( model_in.c_str(), "rb") ); - learner.LoadModel( fi ); - fi.Close(); - } else { - utils::Assert( task == "train", "model_in not specified" ); - learner.InitModel(); - } - learner.InitTrainer(); - } - inline void TaskTrain( void ) { - const time_t start = time( NULL ); - unsigned long elapsed = 0; - for( int i = 0; i < num_round; ++ i ) { - elapsed = (unsigned long)(time(NULL) - start); - if( !silent ) printf("boosting round %d, %lu sec elapsed\n", i , elapsed ); - learner.UpdateOneIter( i ); - learner.EvalOneIter( i ); - if( save_period != 0 && (i+1) % save_period == 0 ) { - this->SaveModel( i ); - } - elapsed = (unsigned long)(time(NULL) - start); - } - // always save final round - if( save_period == 0 || num_round % save_period != 0 ) { - if( model_out == "NULL" ) { - this->SaveModel( num_round - 1 ); - } else { - this->SaveModel( model_out.c_str() ); - } - } - if( !silent ) { - printf("\nupdating end, %lu sec in all\n", elapsed ); - } - } - inline void TaskEval( void ) { - learner.EvalOneIter( 0 ); - } - inline void TaskInteractive( void ) { - const time_t start = time( NULL ); - unsigned long elapsed = 0; - int batch_action = 0; - - cfg_batch.BeforeFirst(); - while( cfg_batch.Next() ) { - if( !strcmp( cfg_batch.name(), "run" ) ) { - learner.UpdateInteract( interact_action ); - batch_action += 1; - } else { - learner.SetParam( cfg_batch.name(), cfg_batch.val() ); - } - } - - if( batch_action == 0 ) { - learner.UpdateInteract( interact_action ); - } - utils::Assert( model_out != "NULL", "interactive mode must specify model_out" ); - this->SaveModel( model_out.c_str() ); - elapsed = (unsigned long)(time(NULL) - start); - - if( !silent ) { - printf("\ninteractive update, %d batch actions, %lu sec in all\n", batch_action, elapsed ); - } - } - - inline void TaskDump( void ) { - FILE *fo = utils::FopenCheck( name_dump.c_str(), "w" ); - learner.DumpModel( fo, fmap, dump_model_stats != 0 ); - fclose( fo ); - } - inline void TaskDumpPath( void ) { - FILE *fo = utils::FopenCheck( name_dumppath.c_str(), "w" ); - learner.DumpPath( fo, data ); - fclose( fo ); - } - inline void SaveModel( const char *fname ) const { - utils::FileStream fo( utils::FopenCheck( fname, "wb" ) ); - learner.SaveModel( fo ); - fo.Close(); - } - inline void SaveModel( int i ) const { - char fname[256]; - sprintf( fname ,"%s/%04d.model", model_dir_path.c_str(), i+1 ); - this->SaveModel( fname ); - } - inline void TaskPred( void ) { - std::vector preds; - if( !silent ) printf("start prediction...\n"); - learner.Predict( preds, data ); - if( !silent ) printf("writing prediction to %s\n", name_pred.c_str() ); - FILE *fo = utils::FopenCheck( name_pred.c_str(), "w" ); - for( size_t i = 0; i < preds.size(); i ++ ) { - fprintf( fo, "%f\n", preds[i] ); - } - fclose( fo ); - } -private: - /* \brief whether silent */ - int silent; - /* \brief whether use auto binary buffer */ - int use_buffer; - /* \brief number of boosting iterations */ - int num_round; - /* \brief the period to save the model, 0 means only save the final round model */ - int save_period; - /*! \brief interfact action */ - std::string interact_action; - /* \brief the path of training/test data set */ - std::string train_path, test_path; - /* \brief the path of test model file, or file to restart training */ - std::string model_in; - /* \brief the path of final model file, to be saved */ - std::string model_out; - /* \brief the path of directory containing the saved models */ - std::string model_dir_path; - /* \brief task to perform */ - std::string task; - /* \brief name of predict file */ - std::string name_pred; - /* \brief whether dump statistics along with model */ - int dump_model_stats; - /* \brief name of feature map */ - std::string name_fmap; - /* \brief name of dump file */ - std::string name_dump; - /* \brief name of dump path file */ - std::string name_dumppath; - /* \brief the paths of validation data sets */ - std::vector eval_data_paths; - /* \brief the names of the evaluation data used in output log */ - std::vector eval_data_names; - /*! \brief saves configurations */ - utils::ConfigSaver cfg; - /*! \brief batch configurations */ - utils::ConfigSaver cfg_batch; -private: - RMatrix data; - std::vector deval; - utils::FeatMap fmap; - RankBoostLearner learner; -}; -}; -}; - -int main( int argc, char *argv[] ) { - xgboost::random::Seed( 0 ); - xgboost::rank::RankBoostTask tsk; - return tsk.Run( argc, argv ); +int main(int argc, char *argv[]) { + + xgboost::random::Seed(0); + xgboost::base::BoostTask tsk; + xgboost::utils::ConfigIterator itr(argv[1]); + int learner_index = 0; + while (itr.Next()){ + if (!strcmp(itr.name(), "learning_task")){ + learner_index = atoi(itr.val()); + } + } + xgboost::rank::RankBoostLearner* rank_learner = new xgboost::rank::RankBoostLearner; + xgboost::base::BoostLearner *parent = static_cast(rank_learner); + tsk.SetLearner(parent); + return tsk.Run(argc, argv); } diff --git a/dev/rank/xgboost_sample.h b/dev/rank/xgboost_sample.h index 81f2fd454..85f429d56 100644 --- a/dev/rank/xgboost_sample.h +++ b/dev/rank/xgboost_sample.h @@ -1,7 +1,7 @@ #ifndef _XGBOOST_SAMPLE_H_ #define _XGBOOST_SAMPLE_H_ -#include +#include #include"../utils/xgboost_utils.h" namespace xgboost { @@ -21,7 +21,7 @@ namespace xgboost { */ Pairs(int start,int end):start_(start),end_(end_){ for(int i = start; i < end; i++){ - vector v; + std::vector v; pairs_.push_back(v); } } @@ -31,7 +31,7 @@ namespace xgboost { * \return the index of instances paired */ std::vector GetPairs(int index) { - utils::assert(index >= start_ && index < end_,"The query index out of sampling bound"); + utils::Assert(index >= start_ && index < end_,"The query index out of sampling bound"); return pairs_[index-start_]; } @@ -44,7 +44,7 @@ namespace xgboost { pairs_[index - start_].push_back(paired_index); } - std::vector> pairs_; + std::vector< std::vector > pairs_; int start_; int end_; }; @@ -115,7 +115,7 @@ namespace xgboost { Pairs GenPairs(const std::vector &preds, const std::vector &labels, int start,int end){ - return sampler_.GenPairs(preds,labels,start,end); + return sampler_->GenPairs(preds,labels,start,end); } private: BinaryLinearSampler binary_linear_sampler; @@ -124,4 +124,4 @@ namespace xgboost { } } } - +#endif \ No newline at end of file diff --git a/regression/xgboost_reg_main.cpp b/regression/xgboost_reg_main.cpp index 9d43c22fb..8da40328f 100644 --- a/regression/xgboost_reg_main.cpp +++ b/regression/xgboost_reg_main.cpp @@ -273,8 +273,3 @@ namespace xgboost{ }; }; -int main( int argc, char *argv[] ){ - xgboost::random::Seed( 0 ); - xgboost::regression::RegBoostTask tsk; - return tsk.Run( argc, argv ); -} diff --git a/utils/xgboost_omp.h b/utils/xgboost_omp.h index 34a0e041a..8fb80d302 100644 --- a/utils/xgboost_omp.h +++ b/utils/xgboost_omp.h @@ -10,7 +10,7 @@ #if defined(_OPENMP) #include #else -#warning "OpenMP is not available, compile to single thread code" +//#warning "OpenMP is not available, compile to single thread code" inline int omp_get_thread_num() { return 0; } inline int omp_get_num_threads() { return 1; } inline void omp_set_num_threads( int nthread ) {} diff --git a/utils/xgboost_utils.h b/utils/xgboost_utils.h index d50b9201f..9373e1076 100644 --- a/utils/xgboost_utils.h +++ b/utils/xgboost_utils.h @@ -61,7 +61,17 @@ namespace xgboost{ exit( -1 ); } return fp; - } + } + + /*! \brief replace fopen, */ + inline FILE *FopenTry( const char *fname , const char *flag ){ + FILE *fp = fopen64( fname , flag ); + if( fp == NULL ){ + fprintf( stderr, "can not open file \"%s\"\n",fname ); + exit( -1 ); + } + return fp; + } }; };