From c8b2f46b893327a87ad535a0c6ee197e43f99b17 Mon Sep 17 00:00:00 2001 From: kalenhaha Date: Thu, 10 Apr 2014 22:09:19 +0800 Subject: [PATCH] lambda rank added --- Makefile | 4 +- base/xgboost_boost_task.h | 329 +++++++++++++++++++++++++++++++ base/xgboost_data_instance.h | 214 ++++++++++++++++++++ base/xgboost_learner.h | 283 ++++++++++++++++++++++++++ demo/rank/toy.train | 16 +- demo/rank/toy.train.group | 4 +- dev/base/xgboost_boost_task.h | 2 +- dev/base/xgboost_data_instance.h | 18 +- dev/base/xgboost_learner.h | 26 ++- dev/rank/xgboost_rank.h | 184 ++++++++++++++--- dev/rank/xgboost_rank_eval.h | 95 +++++++-- dev/rank/xgboost_rank_main.cpp | 3 +- dev/rank/xgboost_sample.h | 4 +- rank/xgboost_rank.h | 295 +++++++++++++++++++++++++++ rank/xgboost_rank_eval.h | 237 ++++++++++++++++++++++ rank/xgboost_rank_main.cpp | 22 +++ rank/xgboost_sample.h | 128 ++++++++++++ regression/xgboost_reg.h | 4 +- 18 files changed, 1792 insertions(+), 76 deletions(-) create mode 100644 base/xgboost_boost_task.h create mode 100644 base/xgboost_data_instance.h create mode 100644 base/xgboost_learner.h create mode 100644 rank/xgboost_rank.h create mode 100644 rank/xgboost_rank_eval.h create mode 100644 rank/xgboost_rank_main.cpp create mode 100644 rank/xgboost_sample.h diff --git a/Makefile b/Makefile index 75baf5662..845c50e6f 100644 --- a/Makefile +++ b/Makefile @@ -10,9 +10,9 @@ OBJ = all: $(BIN) $(OBJ) export LDFLAGS= -pthread -lm -xgboost: regression/xgboost_reg_main.cpp regression/*.h booster/*.h booster/*/*.hpp booster/*.hpp +#xgboost: regression/xgboost_reg_main.cpp regression/*.h booster/*.h booster/*/*.hpp booster/*.hpp -#xgboost: rank/xgboost_rank_main.cpp base/*.h rank/*.h booster/*.h booster/*/*.hpp booster/*.hpp +xgboost: rank/xgboost_rank_main.cpp base/*.h rank/*.h booster/*.h booster/*/*.hpp booster/*.hpp $(BIN) : $(CXX) $(CFLAGS) $(LDFLAGS) -o $@ $(filter %.cpp %.o %.c, $^) diff --git a/base/xgboost_boost_task.h b/base/xgboost_boost_task.h new file mode 100644 index 000000000..b79af31e4 --- /dev/null +++ b/base/xgboost_boost_task.h @@ -0,0 +1,329 @@ +#define _CRT_SECURE_NO_WARNINGS +#define _CRT_SECURE_NO_DEPRECATE + +#include +#include +#include +#include "xgboost_data_instance.h" +#include "xgboost_learner.h" +#include "../utils/xgboost_fmap.h" +#include "../utils/xgboost_random.h" +#include "../utils/xgboost_config.h" + +namespace xgboost{ + namespace base{ + /*! + * \brief wrapping the training process of the gradient boosting model, + * given the configuation + * \author Kailong Chen: chenkl198812@gmail.com, Tianqi Chen: tianqi.chen@gmail.com + */ + class BoostTask{ + public: + inline int Run(int argc, char *argv[]){ + + if (argc < 2){ + printf("Usage: \n"); + return 0; + } + utils::ConfigIterator itr(argv[1]); + while (itr.Next()){ + this->SetParam(itr.name(), itr.val()); + } + for (int i = 2; i < argc; i++){ + char name[256], val[256]; + if (sscanf(argv[i], "%[^=]=%s", name, val) == 2){ + this->SetParam(name, val); + } + } + + this->InitData(); + this->InitLearner(); + if (task == "dump"){ + this->TaskDump(); + return 0; + } + if (task == "interact"){ + this->TaskInteractive(); return 0; + } + if (task == "dumppath"){ + this->TaskDumpPath(); return 0; + } + if (task == "eval"){ + this->TaskEval(); return 0; + } + if (task == "pred"){ + this->TaskPred(); + } + else{ + this->TaskTrain(); + } + return 0; + } + + enum learning_tasks{ + REGRESSION = 0, + BINARY_CLASSIFICATION = 1, + RANKING = 2 + }; + + /* \brief set learner + * \param learner the passed in learner + */ + inline void SetLearner(BoostLearner* learner){ + learner_ = learner; + } + + inline void SetParam(const char *name, const char *val){ + if (!strcmp("learning_task", name)) learning_task = atoi(val); + if (!strcmp("silent", name)) silent = atoi(val); + if (!strcmp("use_buffer", name)) use_buffer = atoi(val); + if (!strcmp("seed", name)) random::Seed(atoi(val)); + if (!strcmp("num_round", name)) num_round = atoi(val); + if (!strcmp("save_period", name)) save_period = atoi(val); + if (!strcmp("task", name)) task = val; + if (!strcmp("data", name)) train_path = val; + if (!strcmp("test:data", name)) test_path = val; + if (!strcmp("model_in", name)) model_in = val; + if (!strcmp("model_out", name)) model_out = val; + if (!strcmp("model_dir", name)) model_dir_path = val; + if (!strcmp("fmap", name)) name_fmap = val; + if (!strcmp("name_dump", name)) name_dump = val; + if (!strcmp("name_dumppath", name)) name_dumppath = val; + if (!strcmp("name_pred", name)) name_pred = val; + if (!strcmp("dump_stats", name)) dump_model_stats = atoi(val); + if (!strcmp("interact:action", name)) interact_action = val; + if (!strncmp("batch:", name, 6)){ + cfg_batch.PushBack(name + 6, val); + } + if (!strncmp("eval[", name, 5)) { + char evname[256]; + utils::Assert(sscanf(name, "eval[%[^]]", evname) == 1, "must specify evaluation name for display"); + eval_data_names.push_back(std::string(evname)); + eval_data_paths.push_back(std::string(val)); + } + cfg.PushBack(name, val); + } + public: + BoostTask(void){ + // default parameters + silent = 0; + use_buffer = 1; + num_round = 10; + save_period = 0; + dump_model_stats = 0; + task = "train"; + model_in = "NULL"; + model_out = "NULL"; + name_fmap = "NULL"; + name_pred = "pred.txt"; + name_dump = "dump.txt"; + name_dumppath = "dump.path.txt"; + model_dir_path = "./"; + interact_action = "update"; + } + ~BoostTask(void){ + for (size_t i = 0; i < deval.size(); i++){ + delete deval[i]; + } + } + private: + + + inline void InitData(void){ + + if (name_fmap != "NULL") fmap.LoadText(name_fmap.c_str()); + if (task == "dump") return; + if (learning_task == RANKING){ + char instance_path[256], group_path[256]; + if (task == "pred" || task == "dumppath"){ + sscanf(test_path.c_str(), "%[^;];%s", instance_path, group_path); + data.CacheLoad(instance_path, group_path, silent != 0, use_buffer != 0); + } + else{ + // training + sscanf(train_path.c_str(), "%[^;];%s", instance_path, group_path); + data.CacheLoad(instance_path, group_path, silent != 0, use_buffer != 0); + + utils::Assert(eval_data_names.size() == eval_data_paths.size()); + for (size_t i = 0; i < eval_data_names.size(); ++i){ + deval.push_back(new DMatrix()); + sscanf(eval_data_paths[i].c_str(), "%[^;];%s", instance_path, group_path); + deval.back()->CacheLoad(instance_path, group_path, silent != 0, use_buffer != 0); + } + } + } + else{ + if (task == "pred" || task == "dumppath"){ + data.CacheLoad(test_path.c_str(), "", silent != 0, use_buffer != 0); + } + else{ + // training + data.CacheLoad(train_path.c_str(), "", silent != 0, use_buffer != 0); + utils::Assert(eval_data_names.size() == eval_data_paths.size()); + for (size_t i = 0; i < eval_data_names.size(); ++i){ + deval.push_back(new DMatrix()); + deval.back()->CacheLoad(eval_data_paths[i].c_str(), "", silent != 0, use_buffer != 0); + } + } + } + + learner_->SetData(&data, deval, eval_data_names); + if(!silent) printf("BoostTask:Data Initiation Done!\n"); + } + + inline void InitLearner(void){ + cfg.BeforeFirst(); + while (cfg.Next()){ + learner_->SetParam(cfg.name(), cfg.val()); + } + if (model_in != "NULL"){ + utils::FileStream fi(utils::FopenCheck(model_in.c_str(), "rb")); + learner_->LoadModel(fi); + fi.Close(); + } + else{ + utils::Assert(task == "train", "model_in not specified"); + learner_->InitModel(); + } + learner_->InitTrainer(); + if(!silent) printf("BoostTask:InitLearner Done!\n"); + } + + inline void TaskTrain(void){ + const time_t start = time(NULL); + unsigned long elapsed = 0; + for (int i = 0; i < num_round; ++i){ + elapsed = (unsigned long)(time(NULL) - start); + if (!silent) printf("boosting round %d, %lu sec elapsed\n", i, elapsed); + learner_->UpdateOneIter(i); + learner_->EvalOneIter(i); + if (save_period != 0 && (i + 1) % save_period == 0){ + this->SaveModel(i); + } + elapsed = (unsigned long)(time(NULL) - start); + } + // always save final round + if (save_period == 0 || num_round % save_period != 0){ + if (model_out == "NULL"){ + this->SaveModel(num_round - 1); + } + else{ + this->SaveModel(model_out.c_str()); + } + } + if (!silent){ + printf("\nupdating end, %lu sec in all\n", elapsed); + } + } + inline void TaskEval(void){ + learner_->EvalOneIter(0); + } + inline void TaskInteractive(void){ + const time_t start = time(NULL); + unsigned long elapsed = 0; + int batch_action = 0; + + cfg_batch.BeforeFirst(); + while (cfg_batch.Next()){ + if (!strcmp(cfg_batch.name(), "run")){ + learner_->UpdateInteract(interact_action); + batch_action += 1; + } + else{ + learner_->SetParam(cfg_batch.name(), cfg_batch.val()); + } + } + + if (batch_action == 0){ + learner_->UpdateInteract(interact_action); + } + utils::Assert(model_out != "NULL", "interactive mode must specify model_out"); + this->SaveModel(model_out.c_str()); + elapsed = (unsigned long)(time(NULL) - start); + + if (!silent){ + printf("\ninteractive update, %d batch actions, %lu sec in all\n", batch_action, elapsed); + } + } + + inline void TaskDump(void){ + FILE *fo = utils::FopenCheck(name_dump.c_str(), "w"); + learner_->DumpModel(fo, fmap, dump_model_stats != 0); + fclose(fo); + } + inline void TaskDumpPath(void){ + FILE *fo = utils::FopenCheck(name_dumppath.c_str(), "w"); + learner_->DumpPath(fo, data); + fclose(fo); + } + inline void SaveModel(const char *fname) const{ + utils::FileStream fo(utils::FopenCheck(fname, "wb")); + learner_->SaveModel(fo); + fo.Close(); + } + inline void SaveModel(int i) const{ + char fname[256]; + sprintf(fname, "%s/%04d.model", model_dir_path.c_str(), i + 1); + this->SaveModel(fname); + } + inline void TaskPred(void){ + std::vector preds; + if (!silent) printf("start prediction...\n"); + learner_->Predict(preds, data); + if (!silent) printf("writing prediction to %s\n", name_pred.c_str()); + FILE *fo = utils::FopenCheck(name_pred.c_str(), "w"); + for (size_t i = 0; i < preds.size(); i++){ + fprintf(fo, "%f\n", preds[i]); + } + fclose(fo); + } + private: + /* \brief specify the learning task*/ + int learning_task; + /* \brief whether silent */ + int silent; + /* \brief whether use auto binary buffer */ + int use_buffer; + /* \brief number of boosting iterations */ + int num_round; + /* \brief the period to save the model, 0 means only save the final round model */ + int save_period; + /*! \brief interfact action */ + std::string interact_action; + /* \brief the path of training/test data set */ + std::string train_path, test_path; + /* \brief the path of test model file, or file to restart training */ + std::string model_in; + /* \brief the path of final model file, to be saved */ + std::string model_out; + /* \brief the path of directory containing the saved models */ + std::string model_dir_path; + /* \brief task to perform, choosing training or testing */ + std::string task; + /* \brief name of predict file */ + std::string name_pred; + /* \brief whether dump statistics along with model */ + int dump_model_stats; + /* \brief name of feature map */ + std::string name_fmap; + /* \brief name of dump file */ + std::string name_dump; + /* \brief name of dump path file */ + std::string name_dumppath; + /* \brief the paths of validation data sets */ + std::vector eval_data_paths; + /* \brief the names of the evaluation data used in output log */ + std::vector eval_data_names; + /*! \brief saves configurations */ + utils::ConfigSaver cfg; + /*! \brief batch configurations */ + utils::ConfigSaver cfg_batch; + private: + DMatrix data; + std::vector deval; + utils::FeatMap fmap; + BoostLearner* learner_; + + }; + }; +}; diff --git a/base/xgboost_data_instance.h b/base/xgboost_data_instance.h new file mode 100644 index 000000000..6ac5c5d13 --- /dev/null +++ b/base/xgboost_data_instance.h @@ -0,0 +1,214 @@ +#ifndef XGBOOST_DATA_INSTANCE_H +#define XGBOOST_DATA_INSTANCE_H + +#include +#include +#include "../booster/xgboost_data.h" +#include "../utils/xgboost_utils.h" +#include "../utils/xgboost_stream.h" + + +namespace xgboost{ + namespace base{ + /*! \brief data matrix for regression, classification, rank content */ + struct DMatrix{ + public: + /*! \brief maximum feature dimension */ + unsigned num_feature; + /*! \brief feature data content */ + booster::FMatrixS data; + /*! \brief label of each instance */ + std::vector labels; + /*! \brief the index of begin and end of a group, + * needed when the learning task is ranking*/ + std::vector group_index; + public: + /*! \brief default constructor */ + DMatrix(void){} + + /*! \brief get the number of instances */ + inline size_t Size() const{ + return labels.size(); + } + /*! + * \brief load from text file + * \param fname file of instances data + * \param fgroup file of the group data + * \param silent whether print information or not + */ + inline void LoadText(const char* fname, const char* fgroup, bool silent = false){ + data.Clear(); + FILE* file = utils::FopenCheck(fname, "r"); + float label; bool init = true; + char tmp[1024]; + std::vector findex; + std::vector fvalue; + + while (fscanf(file, "%s", tmp) == 1){ + unsigned index; float value; + if (sscanf(tmp, "%u:%f", &index, &value) == 2){ + findex.push_back(index); fvalue.push_back(value); + } + else{ + if (!init){ + labels.push_back(label); + data.AddRow(findex, fvalue); + } + findex.clear(); fvalue.clear(); + utils::Assert(sscanf(tmp, "%f", &label) == 1, "invalid format"); + init = false; + } + } + + labels.push_back(label); + data.AddRow(findex, fvalue); + // initialize column support as well + data.InitData(); + + if (!silent){ + printf("%ux%u matrix with %lu entries is loaded from %s\n", + (unsigned)data.NumRow(), (unsigned)data.NumCol(), (unsigned long)data.NumEntry(), fname); + } + fclose(file); + LoadGroup(fgroup,silent); + } + + inline void LoadGroup(const char* fgroup, bool silent = false){ + //if exists group data load it in + FILE *file_group = fopen64(fgroup, "r"); + + if (file_group != NULL){ + group_index.push_back(0); + int tmp = 0, acc = 0,cnt = 0; + while (fscanf(file_group, "%d", &tmp) == 1){ + acc += tmp; + group_index.push_back(acc); + cnt++; + } + if(!silent) printf("%d groups are loaded from %s\n",cnt,fgroup); + fclose(file_group); + }else{ + if(!silent) printf("There is no group file\n"); + } + + } + /*! + * \brief load from binary file + * \param fname name of binary data + * \param silent whether print information or not + * \return whether loading is success + */ + inline bool LoadBinary(const char* fname, const char* fgroup, bool silent = false){ + FILE *fp = fopen64(fname, "rb"); + if (fp == NULL) return false; + utils::FileStream fs(fp); + data.LoadBinary(fs); + labels.resize(data.NumRow()); + utils::Assert(fs.Read(&labels[0], sizeof(float) * data.NumRow()) != 0, "DMatrix LoadBinary"); + fs.Close(); + // initialize column support as well + data.InitData(); + + if (!silent){ + printf("%ux%u matrix with %lu entries is loaded from %s as binary\n", + (unsigned)data.NumRow(), (unsigned)data.NumCol(), (unsigned long)data.NumEntry(), fname); + } + + LoadGroupBinary(fgroup,silent); + return true; + } + + /*! + * \brief save to binary file + * \param fname name of binary data + * \param silent whether print information or not + */ + inline void SaveBinary(const char* fname, const char* fgroup, bool silent = false){ + // initialize column support as well + data.InitData(); + + utils::FileStream fs(utils::FopenCheck(fname, "wb")); + data.SaveBinary(fs); + fs.Write(&labels[0], sizeof(float)* data.NumRow()); + fs.Close(); + if (!silent){ + printf("%ux%u matrix with %lu entries is saved to %s as binary\n", + (unsigned)data.NumRow(), (unsigned)data.NumCol(), (unsigned long)data.NumEntry(), fname); + } + + SaveGroupBinary(fgroup,silent); + } + + inline void SaveGroupBinary(const char* fgroup, bool silent = false){ + //save group data + if (group_index.size() > 0){ + utils::FileStream file_group(utils::FopenCheck(fgroup, "wb")); + int group_index_size = group_index.size(); + file_group.Write(&(group_index_size), sizeof(int)); + file_group.Write(&group_index[0], sizeof(int) * group_index_size); + file_group.Close(); + if(!silent){printf("Index info of %d groups is saved to %s as binary\n",group_index_size-1,fgroup);} + } + } + + inline void LoadGroupBinary(const char* fgroup, bool silent = false){ + //if group data exists load it in + FILE *file_group = fopen64(fgroup, "r"); + if (file_group != NULL){ + int group_index_size = 0; + utils::FileStream group_stream(file_group); + utils::Assert(group_stream.Read(&group_index_size, sizeof(int)) != 0, "Load group indice size"); + group_index.resize(group_index_size); + utils::Assert(group_stream.Read(&group_index[0], sizeof(int) * group_index_size) != 0, "Load group indice"); + + if (!silent){ + printf("Index info of %d groups is loaded from %s as binary\n", + group_index.size() - 1, fgroup); + } + fclose(file_group); + }else{ + if(!silent){printf("The binary file of group info not exists");} + } + } + + /*! + * \brief cache load data given a file name, if filename ends with .buffer, direct load binary + * otherwise the function will first check if fname + '.buffer' exists, + * if binary buffer exists, it will reads from binary buffer, otherwise, it will load from text file, + * and try to create a buffer file + * \param fname name of binary data + * \param silent whether print information or not + * \param savebuffer whether do save binary buffer if it is text + */ + inline void CacheLoad(const char *fname, const char *fgroup, bool silent = false, bool savebuffer = true){ + int len = strlen(fname); + if (len > 8 && !strcmp(fname + len - 7, ".buffer")){ + this->LoadBinary(fname, fgroup, silent); return; + } + char bname[1024],bgroup[1024]; + sprintf(bname, "%s.buffer", fname); + sprintf(bgroup, "%s.buffer", fgroup); + if (!this->LoadBinary(bname, bgroup, silent)) + { + this->LoadText(fname, fgroup, silent); + if (savebuffer) this->SaveBinary(bname, bgroup, silent); + } + } + private: + /*! \brief update num_feature info */ + inline void UpdateInfo(void){ + this->num_feature = 0; + for (size_t i = 0; i < data.NumRow(); i++){ + booster::FMatrixS::Line sp = data[i]; + for (unsigned j = 0; j < sp.len; j++){ + if (num_feature <= sp[j].findex){ + num_feature = sp[j].findex + 1; + } + } + } + } + }; + } +}; + +#endif \ No newline at end of file diff --git a/base/xgboost_learner.h b/base/xgboost_learner.h new file mode 100644 index 000000000..0fc15e7f8 --- /dev/null +++ b/base/xgboost_learner.h @@ -0,0 +1,283 @@ +#ifndef XGBOOST_LEARNER_H +#define XGBOOST_LEARNER_H +/*! +* \file xgboost_learner.h +* \brief class for gradient boosting learner +* \author Kailong Chen: chenkl198812@gmail.com, Tianqi Chen: tianqi.tchen@gmail.com +*/ +#include +#include +#include +#include "xgboost_data_instance.h" +#include "../utils/xgboost_omp.h" +#include "../booster/xgboost_gbmbase.h" +#include "../utils/xgboost_utils.h" +#include "../utils/xgboost_stream.h" + +namespace xgboost { + namespace base { + /*! \brief class for gradient boosting learner */ + class BoostLearner { + public: + /*! \brief constructor */ + BoostLearner(void) { + silent = 0; + } + /*! + * \brief booster associated with training and evaluating data + * \param train pointer to the training data + * \param evals array of evaluating data + * \param evname name of evaluation data, used print statistics + */ + BoostLearner(const DMatrix *train, + const std::vector &evals, + const std::vector &evname) { + silent = 0; + this->SetData(train, evals, evname); + } + + /*! + * \brief associate booster with training and evaluating data + * \param train pointer to the training data + * \param evals array of evaluating data + * \param evname name of evaluation data, used print statistics + */ + inline void SetData(const DMatrix *train, + const std::vector &evals, + const std::vector &evname) { + this->train_ = train; + this->evals_ = evals; + this->evname_ = evname; + // estimate feature bound + int num_feature = (int)(train->data.NumCol()); + // assign buffer index + unsigned buffer_size = static_cast(train->Size()); + + for (size_t i = 0; i < evals.size(); ++i) { + buffer_size += static_cast(evals[i]->Size()); + num_feature = std::max(num_feature, (int)(evals[i]->data.NumCol())); + } + + char str_temp[25]; + if (num_feature > mparam.num_feature) { + mparam.num_feature = num_feature; + sprintf(str_temp, "%d", num_feature); + base_gbm.SetParam("bst:num_feature", str_temp); + } + + sprintf(str_temp, "%u", buffer_size); + base_gbm.SetParam("num_pbuffer", str_temp); + if (!silent) { + printf("buffer_size=%u\n", buffer_size); + } + + // set eval_preds tmp sapce + this->eval_preds_.resize(evals.size(), std::vector()); + } + /*! + * \brief set parameters from outside + * \param name name of the parameter + * \param val value of the parameter + */ + virtual inline void SetParam(const char *name, const char *val) { + if (!strcmp(name, "silent")) silent = atoi(val); + mparam.SetParam(name, val); + base_gbm.SetParam(name, val); + } + /*! + * \brief initialize solver before training, called before training + * this function is reserved for solver to allocate necessary space and do other preparation + */ + inline void InitTrainer(void) { + base_gbm.InitTrainer(); + } + /*! + * \brief initialize the current data storage for model, if the model is used first time, call this function + */ + inline void InitModel(void) { + base_gbm.InitModel(); + if(!silent) printf("BoostLearner:InitModel Done!\n"); + } + /*! + * \brief load model from stream + * \param fi input stream + */ + inline void LoadModel(utils::IStream &fi) { + base_gbm.LoadModel(fi); + utils::Assert(fi.Read(&mparam, sizeof(ModelParam)) != 0); + } + /*! + * \brief DumpModel + * \param fo text file + * \param fmap feature map that may help give interpretations of feature + * \param with_stats whether print statistics as well + */ + inline void DumpModel(FILE *fo, const utils::FeatMap& fmap, bool with_stats) { + base_gbm.DumpModel(fo, fmap, with_stats); + } + /*! + * \brief Dump path of all trees + * \param fo text file + * \param data input data + */ + inline void DumpPath(FILE *fo, const DMatrix &data) { + base_gbm.DumpPath(fo, data.data); + } + + /*! + * \brief save model to stream + * \param fo output stream + */ + inline void SaveModel(utils::IStream &fo) const { + base_gbm.SaveModel(fo); + fo.Write(&mparam, sizeof(ModelParam)); + } + + virtual void EvalOneIter(int iter, FILE *fo = stderr) {} + + /*! + * \brief update the model for one iteration + * \param iteration iteration number + */ + inline void UpdateOneIter(int iter) { + this->PredictBuffer(preds_, *train_, 0); + this->GetGradient(preds_, train_->labels, train_->group_index, grad_, hess_); + std::vector root_index; + base_gbm.DoBoost(grad_, hess_, train_->data, root_index); + +// printf("xgboost_learner.h:UpdateOneIter\n"); +// const unsigned ndata = static_cast(train_->Size()); +// #pragma omp parallel for schedule( static ) +// for (unsigned j = 0; j < ndata; ++j) { +// printf("haha:%d %f\n",j,base_gbm.Predict(train_->data, j, j)); +// } + } + + /*! \brief get intransformed prediction, without buffering */ + inline void Predict(std::vector &preds, const DMatrix &data) { + preds.resize(data.Size()); + const unsigned ndata = static_cast(data.Size()); + #pragma omp parallel for schedule( static ) + for (unsigned j = 0; j < ndata; ++j) { + preds[j] = base_gbm.Predict(data.data, j, -1); + + } + } + + public: + /*! + * \brief update the model for one iteration + * \param iteration iteration number + */ + virtual inline void UpdateInteract(std::string action){ + this->InteractPredict(preds_, *train_, 0); + + int buffer_offset = static_cast(train_->Size()); + for (size_t i = 0; i < evals_.size(); ++i) { + std::vector &preds = this->eval_preds_[i]; + this->InteractPredict(preds, *evals_[i], buffer_offset); + buffer_offset += static_cast(evals_[i]->Size()); + } + + if (action == "remove") { + base_gbm.DelteBooster(); + return; + } + + this->GetGradient(preds_, train_->labels, train_->group_index, grad_, hess_); + std::vector root_index; + base_gbm.DoBoost(grad_, hess_, train_->data, root_index); + + this->InteractRePredict(*train_, 0); + buffer_offset = static_cast(train_->Size()); + for (size_t i = 0; i < evals_.size(); ++i) { + this->InteractRePredict(*evals_[i], buffer_offset); + buffer_offset += static_cast(evals_[i]->Size()); + } + }; + + protected: + /*! \brief get the intransformed predictions, given data */ + inline void InteractPredict(std::vector &preds, const DMatrix &data, unsigned buffer_offset) { + preds.resize(data.Size()); + const unsigned ndata = static_cast(data.Size()); + #pragma omp parallel for schedule( static ) + for (unsigned j = 0; j < ndata; ++j) { + preds[j] = base_gbm.InteractPredict(data.data, j, buffer_offset + j); + } + } + /*! \brief repredict trial */ + inline void InteractRePredict(const xgboost::base::DMatrix &data, unsigned buffer_offset) { + const unsigned ndata = static_cast(data.Size()); + #pragma omp parallel for schedule( static ) + for (unsigned j = 0; j < ndata; ++j) { + base_gbm.InteractRePredict(data.data, j, buffer_offset + j); + } + } + + /*! \brief get intransformed predictions, given data */ + virtual inline void PredictBuffer(std::vector &preds, const DMatrix &data, unsigned buffer_offset) { + preds.resize(data.Size()); + const unsigned ndata = static_cast(data.Size()); + + #pragma omp parallel for schedule( static ) + for (unsigned j = 0; j < ndata; ++j) { + preds[j] = base_gbm.Predict(data.data, j, buffer_offset + j); + } + } + + /*! \brief get the first order and second order gradient, given the transformed predictions and labels */ + virtual inline void GetGradient(const std::vector &preds, + const std::vector &labels, + const std::vector &group_index, + std::vector &grad, + std::vector &hess) {}; + + + protected: + + /*! \brief training parameter for regression */ + struct ModelParam { + /* \brief type of loss function */ + int loss_type; + /* \brief number of features */ + int num_feature; + /*! \brief reserved field */ + int reserved[16]; + /*! \brief constructor */ + ModelParam(void) { + loss_type = 0; + num_feature = 0; + memset(reserved, 0, sizeof(reserved)); + } + /*! + * \brief set parameters from outside + * \param name name of the parameter + * \param val value of the parameter + */ + inline void SetParam(const char *name, const char *val) { + if (!strcmp("loss_type", name)) loss_type = atoi(val); + if (!strcmp("bst:num_feature", name)) num_feature = atoi(val); + } + + }; + + int silent; + booster::GBMBase base_gbm; + ModelParam mparam; + const DMatrix *train_; + std::vector evals_; + std::vector evname_; + std::vector buffer_index_; + std::vector grad_, hess_, preds_; + std::vector< std::vector > eval_preds_; + }; + } +}; + +#endif + + + + + diff --git a/demo/rank/toy.train b/demo/rank/toy.train index 8dad91eb1..cd8b6d628 100644 --- a/demo/rank/toy.train +++ b/demo/rank/toy.train @@ -1,5 +1,11 @@ -1 0:2 1:3 2:2 -0 0:2 1:3 2:2 -0 0:2 1:3 2:2 -0 0:2 1:3 2:2 -1 0:2 1:3 2:2 +1 0:1.2 1:3 2:5.6 +0 0:2.0 1:2.3 2:5.1 +0 0:3.9 1:3 2:3.1 +0 0:2 1:3.2 2:3.4 +1 0:2.1 1:4.5 2:4.2 +0 0:1.9 1:2.8 2:3.1 +1 0:3.0 1:2.0 2:1.1 +0 0:1.9 1:1.8 2:2.1 +0 0:1.1 1:2.2 2:1.4 +1 0:2.1 1:4.1 2:4.0 +0 0:1.9 1:2.2 2:1.1 diff --git a/demo/rank/toy.train.group b/demo/rank/toy.train.group index 4792e70f3..ec385ae9f 100644 --- a/demo/rank/toy.train.group +++ b/demo/rank/toy.train.group @@ -1,2 +1,2 @@ -2 -3 +6 +5 \ No newline at end of file diff --git a/dev/base/xgboost_boost_task.h b/dev/base/xgboost_boost_task.h index 6519cd6ac..b79af31e4 100644 --- a/dev/base/xgboost_boost_task.h +++ b/dev/base/xgboost_boost_task.h @@ -174,7 +174,7 @@ namespace xgboost{ inline void InitLearner(void){ cfg.BeforeFirst(); while (cfg.Next()){ - learner_->SetParam(cfg.name(), cfg.val()); + learner_->SetParam(cfg.name(), cfg.val()); } if (model_in != "NULL"){ utils::FileStream fi(utils::FopenCheck(model_in.c_str(), "rb")); diff --git a/dev/base/xgboost_data_instance.h b/dev/base/xgboost_data_instance.h index e3fdb7b2f..6ac5c5d13 100644 --- a/dev/base/xgboost_data_instance.h +++ b/dev/base/xgboost_data_instance.h @@ -10,8 +10,8 @@ namespace xgboost{ namespace base{ - /*! \brief data matrix for regression,classification,rank content */ - struct DMatrix{ + /*! \brief data matrix for regression, classification, rank content */ + struct DMatrix{ public: /*! \brief maximum feature dimension */ unsigned num_feature; @@ -74,7 +74,7 @@ namespace xgboost{ } inline void LoadGroup(const char* fgroup, bool silent = false){ - //if exists group data load it in + //if exists group data load it in FILE *file_group = fopen64(fgroup, "r"); if (file_group != NULL){ @@ -117,6 +117,7 @@ namespace xgboost{ LoadGroupBinary(fgroup,silent); return true; } + /*! * \brief save to binary file * \param fname name of binary data @@ -139,7 +140,7 @@ namespace xgboost{ } inline void SaveGroupBinary(const char* fgroup, bool silent = false){ - //save group data + //save group data if (group_index.size() > 0){ utils::FileStream file_group(utils::FopenCheck(fgroup, "wb")); int group_index_size = group_index.size(); @@ -151,7 +152,7 @@ namespace xgboost{ } inline void LoadGroupBinary(const char* fgroup, bool silent = false){ - //if group data exists load it in + //if group data exists load it in FILE *file_group = fopen64(fgroup, "r"); if (file_group != NULL){ int group_index_size = 0; @@ -168,8 +169,8 @@ namespace xgboost{ }else{ if(!silent){printf("The binary file of group info not exists");} } - - } + } + /*! * \brief cache load data given a file name, if filename ends with .buffer, direct load binary * otherwise the function will first check if fname + '.buffer' exists, @@ -207,9 +208,6 @@ namespace xgboost{ } } }; - - - } }; diff --git a/dev/base/xgboost_learner.h b/dev/base/xgboost_learner.h index 90696659d..0fc15e7f8 100644 --- a/dev/base/xgboost_learner.h +++ b/dev/base/xgboost_learner.h @@ -144,17 +144,24 @@ namespace xgboost { this->GetGradient(preds_, train_->labels, train_->group_index, grad_, hess_); std::vector root_index; base_gbm.DoBoost(grad_, hess_, train_->data, root_index); + +// printf("xgboost_learner.h:UpdateOneIter\n"); +// const unsigned ndata = static_cast(train_->Size()); +// #pragma omp parallel for schedule( static ) +// for (unsigned j = 0; j < ndata; ++j) { +// printf("haha:%d %f\n",j,base_gbm.Predict(train_->data, j, j)); +// } } /*! \brief get intransformed prediction, without buffering */ inline void Predict(std::vector &preds, const DMatrix &data) { preds.resize(data.Size()); - const unsigned ndata = static_cast(data.Size()); -#pragma omp parallel for schedule( static ) + #pragma omp parallel for schedule( static ) for (unsigned j = 0; j < ndata; ++j) { - preds[j] = base_gbm.Predict(data.data, j, -1); - } + preds[j] = base_gbm.Predict(data.data, j, -1); + + } } public: @@ -194,7 +201,7 @@ namespace xgboost { inline void InteractPredict(std::vector &preds, const DMatrix &data, unsigned buffer_offset) { preds.resize(data.Size()); const unsigned ndata = static_cast(data.Size()); -#pragma omp parallel for schedule( static ) + #pragma omp parallel for schedule( static ) for (unsigned j = 0; j < ndata; ++j) { preds[j] = base_gbm.InteractPredict(data.data, j, buffer_offset + j); } @@ -202,7 +209,7 @@ namespace xgboost { /*! \brief repredict trial */ inline void InteractRePredict(const xgboost::base::DMatrix &data, unsigned buffer_offset) { const unsigned ndata = static_cast(data.Size()); -#pragma omp parallel for schedule( static ) + #pragma omp parallel for schedule( static ) for (unsigned j = 0; j < ndata; ++j) { base_gbm.InteractRePredict(data.data, j, buffer_offset + j); } @@ -212,10 +219,11 @@ namespace xgboost { virtual inline void PredictBuffer(std::vector &preds, const DMatrix &data, unsigned buffer_offset) { preds.resize(data.Size()); const unsigned ndata = static_cast(data.Size()); -#pragma omp parallel for schedule( static ) + + #pragma omp parallel for schedule( static ) for (unsigned j = 0; j < ndata; ++j) { preds[j] = base_gbm.Predict(data.data, j, buffer_offset + j); - } + } } /*! \brief get the first order and second order gradient, given the transformed predictions and labels */ @@ -248,7 +256,7 @@ namespace xgboost { * \param val value of the parameter */ inline void SetParam(const char *name, const char *val) { - if (!strcmp("loss_type", name)) loss_type = atoi(val); + if (!strcmp("loss_type", name)) loss_type = atoi(val); if (!strcmp("bst:num_feature", name)) num_feature = atoi(val); } diff --git a/dev/rank/xgboost_rank.h b/dev/rank/xgboost_rank.h index 7af199b08..0758e9366 100644 --- a/dev/rank/xgboost_rank.h +++ b/dev/rank/xgboost_rank.h @@ -7,7 +7,7 @@ */ #include #include -#include +#include #include "xgboost_sample.h" #include "xgboost_rank_eval.h" #include "../base/xgboost_data_instance.h" @@ -71,11 +71,139 @@ namespace xgboost { fprintf(fo, "\n"); } - inline void SetParam(const char *name, const char *val){ - if (!strcmp(name, "eval_metric")) evaluator_.AddEval(val); + virtual inline void SetParam(const char *name, const char *val){ + BoostLearner::SetParam(name,val); + if (!strcmp(name, "eval_metric")) evaluator_.AddEval(val); if (!strcmp(name, "rank:sampler")) sampler.AssignSampler(atoi(val)); } - /*! \brief get the first order and second order gradient, given the transformed predictions and labels */ + + private: + inline std::vector< Triple > GetSortedTuple(const std::vector &preds, + const std::vector &labels, + const std::vector &group_index, + int group){ + std::vector< Triple > sorted_triple; + for(int j = group_index[group]; j < group_index[group+1]; j++){ + sorted_triple.push_back(Triple(preds[j],labels[j],j)); + } + std::sort(sorted_triple.begin(),sorted_triple.end(),Triplef1Comparer); + return sorted_triple; + } + + inline std::vector GetIndexMap(std::vector< Triple > sorted_triple,int start){ + std::vector index_remap; + index_remap.resize(sorted_triple.size()); + for(int i = 0; i < sorted_triple.size(); i++){ + index_remap[sorted_triple[i].f3_-start] = i; + } + return index_remap; + } + + inline float GetLambdaMAP(const std::vector< Triple > sorted_triple, + int index1,int index2, + std::vector< Quadruple > map_acc){ + if(index1 > index2) std::swap(index1,index2); + float original = map_acc[index2].f1_; + if(index1 != 0) original -= map_acc[index1 - 1].f1_; + float changed = 0; + if(sorted_triple[index1].f2_ < sorted_triple[index2].f2_){ + changed += map_acc[index2 - 1].f3_ - map_acc[index1].f3_; + changed += (map_acc[index1].f4_ + 1.0f)/(index1 + 1); + }else{ + changed += map_acc[index2 - 1].f2_ - map_acc[index1].f2_; + changed += map_acc[index2].f4_/(index2 + 1); + } + float ans = (changed - original)/(map_acc[map_acc.size() - 1].f4_); + if(ans < 0) ans = -ans; + return ans; + } + + inline float GetLambdaNDCG(const std::vector< Triple > sorted_triple, + int index1, + int index2,float IDCG){ + float original = pow(2,sorted_triple[index1].f2_)/log(index1+2) + + pow(2,sorted_triple[index2].f2_)/log(index2+2); + float changed = pow(2,sorted_triple[index2].f2_)/log(index1+2) + + pow(2,sorted_triple[index1].f2_)/log(index2+2); + float ans = (original - changed)/IDCG; + if(ans < 0) ans = -ans; + return ans; + } + + + inline float GetIDCG(const std::vector< Triple > sorted_triple){ + std::vector labels; + for(int i = 0; i < sorted_triple.size(); i++){ + labels.push_back(sorted_triple[i].f2_); + } + + std::sort(labels.begin(),labels.end(),std::greater()); + return EvalNDCG::DCG(labels); + } + + inline std::vector< Quadruple > GetMAPAcc(const std::vector< Triple > sorted_triple){ + std::vector< Quadruple > map_acc; + float hit = 0,acc1 = 0,acc2 = 0,acc3 = 0; + for(int i = 0; i < sorted_triple.size(); i++){ + if(sorted_triple[i].f2_ == 1) { + hit++; + acc1 += hit /( i + 1 ); + acc2 += (hit - 1)/(i+1); + acc3 += (hit + 1)/(i+1); + } + map_acc.push_back(Quadruple(acc1,acc2,acc3,hit)); + } + return map_acc; + + } + + inline void GetGroupGradient(const std::vector &preds, + const std::vector &labels, + const std::vector &group_index, + std::vector &grad, + std::vector &hess, + const std::vector< Triple > sorted_triple, + const std::vector index_remap, + const sample::Pairs& pairs, + int group){ + bool j_better; + float IDCG, pred_diff, pred_diff_exp, delta; + float first_order_gradient, second_order_gradient; + std::vector< Quadruple > map_acc; + + if(mparam.loss_type == NDCG){ + IDCG = GetIDCG(sorted_triple); + }else if(mparam.loss_type == MAP){ + map_acc = GetMAPAcc(sorted_triple); + } + + for (int j = group_index[group]; j < group_index[group + 1]; j++){ + std::vector pair_instance = pairs.GetPairs(j); + for (int k = 0; k < pair_instance.size(); k++){ + j_better = labels[j] > labels[pair_instance[k]]; + if (j_better){ + switch(mparam.loss_type){ + case PAIRWISE: delta = 1.0;break; + case MAP: delta = GetLambdaMAP(sorted_triple,index_remap[j - group_index[group]],index_remap[pair_instance[k]-group_index[group]],map_acc);break; + case NDCG: delta = GetLambdaNDCG(sorted_triple,index_remap[j - group_index[group]],index_remap[pair_instance[k]-group_index[group]],IDCG);break; + default: utils::Error("Cannot find the specified loss type"); + } + + pred_diff = preds[preds[j] - pair_instance[k]]; + pred_diff_exp = j_better ? expf(-pred_diff) : expf(pred_diff); + first_order_gradient = delta * FirstOrderGradient(pred_diff_exp); + second_order_gradient = 2 * delta * SecondOrderGradient(pred_diff_exp); + hess[j] += second_order_gradient; + grad[j] += first_order_gradient; + hess[pair_instance[k]] += second_order_gradient; + grad[pair_instance[k]] += -first_order_gradient; + } + } + } + } + public: + /*! \brief get the first order and second order gradient, given the + * intransformed predictions and labels */ inline void GetGradient(const std::vector &preds, const std::vector &labels, const std::vector &group_index, @@ -83,32 +211,44 @@ namespace xgboost { std::vector &hess) { grad.resize(preds.size()); hess.resize(preds.size()); - bool j_better; - float pred_diff, pred_diff_exp, first_order_gradient, second_order_gradient; for (int i = 0; i < group_index.size() - 1; i++){ sample::Pairs pairs = sampler.GenPairs(preds, labels, group_index[i], group_index[i + 1]); - for (int j = group_index[i]; j < group_index[i + 1]; j++){ - std::vector pair_instance = pairs.GetPairs(j); - for (int k = 0; k < pair_instance.size(); k++){ - j_better = labels[j] > labels[pair_instance[k]]; - if (j_better){ - pred_diff = preds[preds[j] - pair_instance[k]]; - pred_diff_exp = j_better ? expf(-pred_diff) : expf(pred_diff); - first_order_gradient = FirstOrderGradient(pred_diff_exp); - second_order_gradient = 2 * SecondOrderGradient(pred_diff_exp); - hess[j] += second_order_gradient; - grad[j] += first_order_gradient; - hess[pair_instance[k]] += second_order_gradient; - grad[pair_instance[k]] += -first_order_gradient; - } - } - } + //pairs.GetPairs() + std::vector< Triple > sorted_triple = GetSortedTuple(preds,labels,group_index,i); + std::vector index_remap = GetIndexMap(sorted_triple,group_index[i]); + GetGroupGradient(preds,labels,group_index, + grad,hess,sorted_triple,index_remap,pairs,i); } } inline void UpdateInteract(std::string action) { + this->InteractPredict(preds_, *train_, 0); + int buffer_offset = static_cast(train_->Size()); + for (size_t i = 0; i < evals_.size(); ++i){ + std::vector &preds = this->eval_preds_[i]; + this->InteractPredict(preds, *evals_[i], buffer_offset); + buffer_offset += static_cast(evals_[i]->Size()); + } + + if (action == "remove"){ + base_gbm.DelteBooster(); return; + } + + this->GetGradient(preds_, train_->labels,train_->group_index, grad_, hess_); + std::vector root_index; + base_gbm.DoBoost(grad_, hess_, train_->data, root_index); + + this->InteractRePredict(*train_, 0); + buffer_offset = static_cast(train_->Size()); + for (size_t i = 0; i < evals_.size(); ++i){ + this->InteractRePredict(*evals_[i], buffer_offset); + buffer_offset += static_cast(evals_[i]->Size()); + } } + + + private: enum LossType { PAIRWISE = 0, diff --git a/dev/rank/xgboost_rank_eval.h b/dev/rank/xgboost_rank_eval.h index ede770ed9..f03d3bf8f 100644 --- a/dev/rank/xgboost_rank_eval.h +++ b/dev/rank/xgboost_rank_eval.h @@ -34,27 +34,52 @@ namespace xgboost { float key_; float value_; - Pair(float key, float value){ - key_ = key; - value_ = value_; + Pair(float key, float value):key_(key),value_(value){ } }; - bool PairKeyComparer(const Pair &a, const Pair &b){ - return a.key_ < b.key_; + bool PairKeyComparer(const Pair &a, const Pair &b){ + return a.key_ < b.key_; } bool PairValueComparer(const Pair &a, const Pair &b){ return a.value_ < b.value_; } - + template + class Triple{ + public: + T1 f1_; + T2 f2_; + T3 f3_; + Triple(T1 f1,T2 f2,T3 f3):f1_(f1),f2_(f2),f3_(f3){ + + } + }; + + template + class Quadruple{ + public: + T1 f1_; + T2 f2_; + T3 f3_; + T4 f4_; + Quadruple(T1 f1,T2 f2,T3 f3,T4 f4):f1_(f1),f2_(f2),f3_(f3),f4_(f4){ + + } + }; + + bool Triplef1Comparer(const Triple &a, const Triple &b){ + return a.f1_< b.f1_; + } + /*! \brief Mean Average Precision */ class EvalMAP : public IRankEvaluator { public: float Eval(const std::vector &preds, const std::vector &labels, const std::vector &group_index) const { + if (group_index.size() <= 1) return 0; float acc = 0; std::vector pairs_sort; for (int i = 0; i < group_index.size() - 1; i++){ @@ -66,12 +91,13 @@ namespace xgboost { } return acc / (group_index.size() - 1); } - + + virtual const char *Name(void) const { return "MAP"; } - + private: float average_precision(std::vector pairs_sort) const{ std::sort(pairs_sort.begin(), pairs_sort.end(), PairKeyComparer); @@ -94,12 +120,31 @@ namespace xgboost { float Eval(const std::vector &preds, const std::vector &labels, const std::vector &group_index) const { - return 0; - } + if (group_index.size() <= 1) return 0; + float acc = 0; + for (int i = 0; i < group_index.size() - 1; i++){ + acc += Count_Inversion(preds,labels, + group_index[i],group_index[i+1]); + } + return acc / (group_index.size() - 1); + } const char *Name(void) const { return "PAIR"; } + private: + float Count_Inversion(const std::vector &preds, + const std::vector &labels,int begin,int end + ) const{ + float ans = 0; + for(int i = begin; i < end; i++){ + for(int j = i + 1; j < end; j++){ + if(preds[i] > preds[j] && labels[i] < labels[j]) + ans++; + } + } + return ans; + } }; /*! \brief Normalized DCG */ @@ -120,7 +165,20 @@ namespace xgboost { } return acc / (group_index.size() - 1); } - + + static float DCG(const std::vector &labels){ + float ans = 0.0; + for (int i = 0; i < labels.size(); i++){ + ans += (pow(2,labels[i]) - 1 ) / log(i + 2); + } + return ans; + } + + virtual const char *Name(void) const { + return "NDCG"; + } + + private: float NDCG(std::vector pairs_sort) const{ std::sort(pairs_sort.begin(), pairs_sort.end(), PairKeyComparer); float dcg = DCG(pairs_sort); @@ -131,17 +189,14 @@ namespace xgboost { } float DCG(std::vector pairs_sort) const{ - float ans = 0.0; - ans += pairs_sort[0].value_; - for (int i = 1; i < pairs_sort.size(); i++){ - ans += pairs_sort[i].value_ / log(i + 1); - } - return ans; + std::vector labels; + for (int i = 1; i < pairs_sort.size(); i++){ + labels.push_back(pairs_sort[i].value_); + } + return DCG(labels); } - virtual const char *Name(void) const { - return "NDCG"; - } + }; }; diff --git a/dev/rank/xgboost_rank_main.cpp b/dev/rank/xgboost_rank_main.cpp index c68060f20..2ad6d98a4 100644 --- a/dev/rank/xgboost_rank_main.cpp +++ b/dev/rank/xgboost_rank_main.cpp @@ -13,7 +13,8 @@ #include "../regression/xgboost_reg.h" #include "../regression/xgboost_reg_main.cpp" #include "../base/xgboost_data_instance.h" -int main(int argc, char *argv[]) { + +int main(int argc, char *argv[]) { xgboost::random::Seed(0); xgboost::base::BoostTask rank_tsk; rank_tsk.SetLearner(new xgboost::rank::RankBoostLearner); diff --git a/dev/rank/xgboost_sample.h b/dev/rank/xgboost_sample.h index 2b2b255e4..6719390a8 100644 --- a/dev/rank/xgboost_sample.h +++ b/dev/rank/xgboost_sample.h @@ -19,7 +19,7 @@ namespace xgboost { * \param start the begin index of the group * \param end the end index of the group */ - Pairs(int start, int end) :start_(start), end_(end_){ + Pairs(int start, int end) :start_(start), end_(end){ for (int i = start; i < end; i++){ std::vector v; pairs_.push_back(v); @@ -30,7 +30,7 @@ namespace xgboost { * \param index, the index of retrieved instance * \return the index of instances paired */ - std::vector GetPairs(int index) { + std::vector GetPairs(int index) const{ utils::Assert(index >= start_ && index < end_, "The query index out of sampling bound"); return pairs_[index - start_]; } diff --git a/rank/xgboost_rank.h b/rank/xgboost_rank.h new file mode 100644 index 000000000..0758e9366 --- /dev/null +++ b/rank/xgboost_rank.h @@ -0,0 +1,295 @@ +#ifndef XGBOOST_RANK_H +#define XGBOOST_RANK_H +/*! +* \file xgboost_rank.h +* \brief class for gradient boosting ranking +* \author Kailong Chen: chenkl198812@gmail.com, Tianqi Chen: tianqi.tchen@gmail.com +*/ +#include +#include +#include +#include "xgboost_sample.h" +#include "xgboost_rank_eval.h" +#include "../base/xgboost_data_instance.h" +#include "../utils/xgboost_omp.h" +#include "../booster/xgboost_gbmbase.h" +#include "../utils/xgboost_utils.h" +#include "../utils/xgboost_stream.h" +#include "../base/xgboost_learner.h" + +namespace xgboost { + namespace rank { + /*! \brief class for gradient boosted regression */ + class RankBoostLearner :public base::BoostLearner{ + public: + /*! \brief constructor */ + RankBoostLearner(void) { + BoostLearner(); + } + /*! + * \brief a rank booster associated with training and evaluating data + * \param train pointer to the training data + * \param evals array of evaluating data + * \param evname name of evaluation data, used print statistics + */ + RankBoostLearner(const base::DMatrix *train, + const std::vector &evals, + const std::vector &evname) { + + BoostLearner(train, evals, evname); + } + + /*! + * \brief initialize solver before training, called before training + * this function is reserved for solver to allocate necessary space + * and do other preparation + */ + inline void InitTrainer(void) { + BoostLearner::InitTrainer(); + if (mparam.loss_type == PAIRWISE) { + evaluator_.AddEval("PAIR"); + } + else if (mparam.loss_type == MAP) { + evaluator_.AddEval("MAP"); + } + else { + evaluator_.AddEval("NDCG"); + } + evaluator_.Init(); + } + + void EvalOneIter(int iter, FILE *fo = stderr) { + fprintf(fo, "[%d]", iter); + int buffer_offset = static_cast(train_->Size()); + + for (size_t i = 0; i < evals_.size(); ++i) { + std::vector &preds = this->eval_preds_[i]; + this->PredictBuffer(preds, *evals_[i], buffer_offset); + evaluator_.Eval(fo, evname_[i].c_str(), preds, (*evals_[i]).labels, (*evals_[i]).group_index); + buffer_offset += static_cast(evals_[i]->Size()); + } + fprintf(fo, "\n"); + } + + virtual inline void SetParam(const char *name, const char *val){ + BoostLearner::SetParam(name,val); + if (!strcmp(name, "eval_metric")) evaluator_.AddEval(val); + if (!strcmp(name, "rank:sampler")) sampler.AssignSampler(atoi(val)); + } + + private: + inline std::vector< Triple > GetSortedTuple(const std::vector &preds, + const std::vector &labels, + const std::vector &group_index, + int group){ + std::vector< Triple > sorted_triple; + for(int j = group_index[group]; j < group_index[group+1]; j++){ + sorted_triple.push_back(Triple(preds[j],labels[j],j)); + } + std::sort(sorted_triple.begin(),sorted_triple.end(),Triplef1Comparer); + return sorted_triple; + } + + inline std::vector GetIndexMap(std::vector< Triple > sorted_triple,int start){ + std::vector index_remap; + index_remap.resize(sorted_triple.size()); + for(int i = 0; i < sorted_triple.size(); i++){ + index_remap[sorted_triple[i].f3_-start] = i; + } + return index_remap; + } + + inline float GetLambdaMAP(const std::vector< Triple > sorted_triple, + int index1,int index2, + std::vector< Quadruple > map_acc){ + if(index1 > index2) std::swap(index1,index2); + float original = map_acc[index2].f1_; + if(index1 != 0) original -= map_acc[index1 - 1].f1_; + float changed = 0; + if(sorted_triple[index1].f2_ < sorted_triple[index2].f2_){ + changed += map_acc[index2 - 1].f3_ - map_acc[index1].f3_; + changed += (map_acc[index1].f4_ + 1.0f)/(index1 + 1); + }else{ + changed += map_acc[index2 - 1].f2_ - map_acc[index1].f2_; + changed += map_acc[index2].f4_/(index2 + 1); + } + float ans = (changed - original)/(map_acc[map_acc.size() - 1].f4_); + if(ans < 0) ans = -ans; + return ans; + } + + inline float GetLambdaNDCG(const std::vector< Triple > sorted_triple, + int index1, + int index2,float IDCG){ + float original = pow(2,sorted_triple[index1].f2_)/log(index1+2) + + pow(2,sorted_triple[index2].f2_)/log(index2+2); + float changed = pow(2,sorted_triple[index2].f2_)/log(index1+2) + + pow(2,sorted_triple[index1].f2_)/log(index2+2); + float ans = (original - changed)/IDCG; + if(ans < 0) ans = -ans; + return ans; + } + + + inline float GetIDCG(const std::vector< Triple > sorted_triple){ + std::vector labels; + for(int i = 0; i < sorted_triple.size(); i++){ + labels.push_back(sorted_triple[i].f2_); + } + + std::sort(labels.begin(),labels.end(),std::greater()); + return EvalNDCG::DCG(labels); + } + + inline std::vector< Quadruple > GetMAPAcc(const std::vector< Triple > sorted_triple){ + std::vector< Quadruple > map_acc; + float hit = 0,acc1 = 0,acc2 = 0,acc3 = 0; + for(int i = 0; i < sorted_triple.size(); i++){ + if(sorted_triple[i].f2_ == 1) { + hit++; + acc1 += hit /( i + 1 ); + acc2 += (hit - 1)/(i+1); + acc3 += (hit + 1)/(i+1); + } + map_acc.push_back(Quadruple(acc1,acc2,acc3,hit)); + } + return map_acc; + + } + + inline void GetGroupGradient(const std::vector &preds, + const std::vector &labels, + const std::vector &group_index, + std::vector &grad, + std::vector &hess, + const std::vector< Triple > sorted_triple, + const std::vector index_remap, + const sample::Pairs& pairs, + int group){ + bool j_better; + float IDCG, pred_diff, pred_diff_exp, delta; + float first_order_gradient, second_order_gradient; + std::vector< Quadruple > map_acc; + + if(mparam.loss_type == NDCG){ + IDCG = GetIDCG(sorted_triple); + }else if(mparam.loss_type == MAP){ + map_acc = GetMAPAcc(sorted_triple); + } + + for (int j = group_index[group]; j < group_index[group + 1]; j++){ + std::vector pair_instance = pairs.GetPairs(j); + for (int k = 0; k < pair_instance.size(); k++){ + j_better = labels[j] > labels[pair_instance[k]]; + if (j_better){ + switch(mparam.loss_type){ + case PAIRWISE: delta = 1.0;break; + case MAP: delta = GetLambdaMAP(sorted_triple,index_remap[j - group_index[group]],index_remap[pair_instance[k]-group_index[group]],map_acc);break; + case NDCG: delta = GetLambdaNDCG(sorted_triple,index_remap[j - group_index[group]],index_remap[pair_instance[k]-group_index[group]],IDCG);break; + default: utils::Error("Cannot find the specified loss type"); + } + + pred_diff = preds[preds[j] - pair_instance[k]]; + pred_diff_exp = j_better ? expf(-pred_diff) : expf(pred_diff); + first_order_gradient = delta * FirstOrderGradient(pred_diff_exp); + second_order_gradient = 2 * delta * SecondOrderGradient(pred_diff_exp); + hess[j] += second_order_gradient; + grad[j] += first_order_gradient; + hess[pair_instance[k]] += second_order_gradient; + grad[pair_instance[k]] += -first_order_gradient; + } + } + } + } + public: + /*! \brief get the first order and second order gradient, given the + * intransformed predictions and labels */ + inline void GetGradient(const std::vector &preds, + const std::vector &labels, + const std::vector &group_index, + std::vector &grad, + std::vector &hess) { + grad.resize(preds.size()); + hess.resize(preds.size()); + for (int i = 0; i < group_index.size() - 1; i++){ + sample::Pairs pairs = sampler.GenPairs(preds, labels, group_index[i], group_index[i + 1]); + //pairs.GetPairs() + std::vector< Triple > sorted_triple = GetSortedTuple(preds,labels,group_index,i); + std::vector index_remap = GetIndexMap(sorted_triple,group_index[i]); + GetGroupGradient(preds,labels,group_index, + grad,hess,sorted_triple,index_remap,pairs,i); + } + } + + inline void UpdateInteract(std::string action) { + this->InteractPredict(preds_, *train_, 0); + + int buffer_offset = static_cast(train_->Size()); + for (size_t i = 0; i < evals_.size(); ++i){ + std::vector &preds = this->eval_preds_[i]; + this->InteractPredict(preds, *evals_[i], buffer_offset); + buffer_offset += static_cast(evals_[i]->Size()); + } + + if (action == "remove"){ + base_gbm.DelteBooster(); return; + } + + this->GetGradient(preds_, train_->labels,train_->group_index, grad_, hess_); + std::vector root_index; + base_gbm.DoBoost(grad_, hess_, train_->data, root_index); + + this->InteractRePredict(*train_, 0); + buffer_offset = static_cast(train_->Size()); + for (size_t i = 0; i < evals_.size(); ++i){ + this->InteractRePredict(*evals_[i], buffer_offset); + buffer_offset += static_cast(evals_[i]->Size()); + } + } + + + + private: + enum LossType { + PAIRWISE = 0, + MAP = 1, + NDCG = 2 + }; + + + + /*! + * \brief calculate first order gradient of pairwise loss function(f(x) = ln(1+exp(-x)), + * given the exponential of the difference of intransformed pair predictions + * \param the intransformed prediction of positive instance + * \param the intransformed prediction of negative instance + * \return first order gradient + */ + inline float FirstOrderGradient(float pred_diff_exp) const { + return -pred_diff_exp / (1 + pred_diff_exp); + } + + /*! + * \brief calculate second order gradient of pairwise loss function(f(x) = ln(1+exp(-x)), + * given the exponential of the difference of intransformed pair predictions + * \param the intransformed prediction of positive instance + * \param the intransformed prediction of negative instance + * \return second order gradient + */ + inline float SecondOrderGradient(float pred_diff_exp) const { + return pred_diff_exp / pow(1 + pred_diff_exp, 2); + } + + private: + RankEvalSet evaluator_; + sample::PairSamplerWrapper sampler; + }; + }; +}; + +#endif + + + + + diff --git a/rank/xgboost_rank_eval.h b/rank/xgboost_rank_eval.h new file mode 100644 index 000000000..f03d3bf8f --- /dev/null +++ b/rank/xgboost_rank_eval.h @@ -0,0 +1,237 @@ +#ifndef XGBOOST_RANK_EVAL_H +#define XGBOOST_RANK_EVAL_H +/*! +* \file xgboost_rank_eval.h +* \brief evaluation metrics for ranking +* \author Kailong Chen: chenkl198812@gmail.com, Tianqi Chen: tianqi.tchen@gmail.com +*/ + +#include +#include +#include +#include "../utils/xgboost_utils.h" +#include "../utils/xgboost_omp.h" + +namespace xgboost { + namespace rank { + /*! \brief evaluator that evaluates the loss metrics */ + class IRankEvaluator { + public: + /*! + * \brief evaluate a specific metric + * \param preds prediction + * \param labels label + */ + virtual float Eval(const std::vector &preds, + const std::vector &labels, + const std::vector &group_index) const = 0; + /*! \return name of metric */ + virtual const char *Name(void) const = 0; + }; + + class Pair{ + public: + float key_; + float value_; + + Pair(float key, float value):key_(key),value_(value){ + } + }; + + bool PairKeyComparer(const Pair &a, const Pair &b){ + return a.key_ < b.key_; + } + + bool PairValueComparer(const Pair &a, const Pair &b){ + return a.value_ < b.value_; + } + + template + class Triple{ + public: + T1 f1_; + T2 f2_; + T3 f3_; + Triple(T1 f1,T2 f2,T3 f3):f1_(f1),f2_(f2),f3_(f3){ + + } + }; + + template + class Quadruple{ + public: + T1 f1_; + T2 f2_; + T3 f3_; + T4 f4_; + Quadruple(T1 f1,T2 f2,T3 f3,T4 f4):f1_(f1),f2_(f2),f3_(f3),f4_(f4){ + + } + }; + + bool Triplef1Comparer(const Triple &a, const Triple &b){ + return a.f1_< b.f1_; + } + + /*! \brief Mean Average Precision */ + class EvalMAP : public IRankEvaluator { + public: + float Eval(const std::vector &preds, + const std::vector &labels, + const std::vector &group_index) const { + if (group_index.size() <= 1) return 0; + float acc = 0; + std::vector pairs_sort; + for (int i = 0; i < group_index.size() - 1; i++){ + for (int j = group_index[i]; j < group_index[i + 1]; j++){ + Pair pair(preds[j], labels[j]); + pairs_sort.push_back(pair); + } + acc += average_precision(pairs_sort); + } + return acc / (group_index.size() - 1); + } + + + + virtual const char *Name(void) const { + return "MAP"; + } + private: + float average_precision(std::vector pairs_sort) const{ + + std::sort(pairs_sort.begin(), pairs_sort.end(), PairKeyComparer); + float hits = 0; + float average_precision = 0; + for (int j = 0; j < pairs_sort.size(); j++){ + if (pairs_sort[j].value_ == 1){ + hits++; + average_precision += hits / (j + 1); + } + } + if (hits != 0) average_precision /= hits; + return average_precision; + } + }; + + + class EvalPair : public IRankEvaluator{ + public: + float Eval(const std::vector &preds, + const std::vector &labels, + const std::vector &group_index) const { + if (group_index.size() <= 1) return 0; + float acc = 0; + for (int i = 0; i < group_index.size() - 1; i++){ + acc += Count_Inversion(preds,labels, + group_index[i],group_index[i+1]); + } + return acc / (group_index.size() - 1); + } + + const char *Name(void) const { + return "PAIR"; + } + private: + float Count_Inversion(const std::vector &preds, + const std::vector &labels,int begin,int end + ) const{ + float ans = 0; + for(int i = begin; i < end; i++){ + for(int j = i + 1; j < end; j++){ + if(preds[i] > preds[j] && labels[i] < labels[j]) + ans++; + } + } + return ans; + } + }; + + /*! \brief Normalized DCG */ + class EvalNDCG : public IRankEvaluator { + public: + float Eval(const std::vector &preds, + const std::vector &labels, + const std::vector &group_index) const { + if (group_index.size() <= 1) return 0; + float acc = 0; + std::vector pairs_sort; + for (int i = 0; i < group_index.size() - 1; i++){ + for (int j = group_index[i]; j < group_index[i + 1]; j++){ + Pair pair(preds[j], labels[j]); + pairs_sort.push_back(pair); + } + acc += NDCG(pairs_sort); + } + return acc / (group_index.size() - 1); + } + + static float DCG(const std::vector &labels){ + float ans = 0.0; + for (int i = 0; i < labels.size(); i++){ + ans += (pow(2,labels[i]) - 1 ) / log(i + 2); + } + return ans; + } + + virtual const char *Name(void) const { + return "NDCG"; + } + + private: + float NDCG(std::vector pairs_sort) const{ + std::sort(pairs_sort.begin(), pairs_sort.end(), PairKeyComparer); + float dcg = DCG(pairs_sort); + std::sort(pairs_sort.begin(), pairs_sort.end(), PairValueComparer); + float IDCG = DCG(pairs_sort); + if (IDCG == 0) return 0; + return dcg / IDCG; + } + + float DCG(std::vector pairs_sort) const{ + std::vector labels; + for (int i = 1; i < pairs_sort.size(); i++){ + labels.push_back(pairs_sort[i].value_); + } + return DCG(labels); + } + + + }; + + }; + + namespace rank { + /*! \brief a set of evaluators */ + class RankEvalSet { + public: + inline void AddEval(const char *name) { + if (!strcmp(name, "PAIR")) evals_.push_back(&pair_); + if (!strcmp(name, "MAP")) evals_.push_back(&map_); + if (!strcmp(name, "NDCG")) evals_.push_back(&ndcg_); + } + + inline void Init(void) { + std::sort(evals_.begin(), evals_.end()); + evals_.resize(std::unique(evals_.begin(), evals_.end()) - evals_.begin()); + } + + inline void Eval(FILE *fo, const char *evname, + const std::vector &preds, + const std::vector &labels, + const std::vector &group_index) const { + for (size_t i = 0; i < evals_.size(); ++i) { + float res = evals_[i]->Eval(preds, labels, group_index); + fprintf(fo, "\t%s-%s:%f", evname, evals_[i]->Name(), res); + } + } + + private: + EvalPair pair_; + EvalMAP map_; + EvalNDCG ndcg_; + std::vector evals_; + }; + }; +}; +#endif diff --git a/rank/xgboost_rank_main.cpp b/rank/xgboost_rank_main.cpp new file mode 100644 index 000000000..2ad6d98a4 --- /dev/null +++ b/rank/xgboost_rank_main.cpp @@ -0,0 +1,22 @@ +#define _CRT_SECURE_NO_WARNINGS +#define _CRT_SECURE_NO_DEPRECATE +#include +#include +#include +#include "../base/xgboost_learner.h" +#include "../utils/xgboost_fmap.h" +#include "../utils/xgboost_random.h" +#include "../utils/xgboost_config.h" +#include "../base/xgboost_learner.h" +#include "../base/xgboost_boost_task.h" +#include "xgboost_rank.h" +#include "../regression/xgboost_reg.h" +#include "../regression/xgboost_reg_main.cpp" +#include "../base/xgboost_data_instance.h" + +int main(int argc, char *argv[]) { + xgboost::random::Seed(0); + xgboost::base::BoostTask rank_tsk; + rank_tsk.SetLearner(new xgboost::rank::RankBoostLearner); + return rank_tsk.Run(argc, argv); +} diff --git a/rank/xgboost_sample.h b/rank/xgboost_sample.h new file mode 100644 index 000000000..6719390a8 --- /dev/null +++ b/rank/xgboost_sample.h @@ -0,0 +1,128 @@ +#ifndef _XGBOOST_SAMPLE_H_ +#define _XGBOOST_SAMPLE_H_ + +#include +#include"../utils/xgboost_utils.h" + +namespace xgboost { + namespace rank { + namespace sample { + + /* + * \brief the data structure to maintain the sample pairs + */ + struct Pairs { + + /* + * \brief constructor given the start and end offset of the sampling group + * in overall instances + * \param start the begin index of the group + * \param end the end index of the group + */ + Pairs(int start, int end) :start_(start), end_(end){ + for (int i = start; i < end; i++){ + std::vector v; + pairs_.push_back(v); + } + } + /* + * \brief retrieve the related pair information of an data instances + * \param index, the index of retrieved instance + * \return the index of instances paired + */ + std::vector GetPairs(int index) const{ + utils::Assert(index >= start_ && index < end_, "The query index out of sampling bound"); + return pairs_[index - start_]; + } + + /* + * \brief add in a sampled pair + * \param index the index of the instance to sample a friend + * \param paired_index the index of the instance sampled as a friend + */ + void push(int index, int paired_index){ + pairs_[index - start_].push_back(paired_index); + } + + std::vector< std::vector > pairs_; + int start_; + int end_; + }; + + /* + * \brief the interface of pair sampler + */ + struct IPairSampler { + /* + * \brief Generate sample pairs given the predcions, labels, the start and the end index + * of a specified group + * \param preds, the predictions of all data instances + * \param labels, the labels of all data instances + * \param start, the start index of a specified group + * \param end, the end index of a specified group + * \return the generated pairs + */ + virtual Pairs GenPairs(const std::vector &preds, + const std::vector &labels, + int start, int end) = 0; + + }; + + enum{ + BINARY_LINEAR_SAMPLER + }; + + /*! \brief A simple pair sampler when the rank relevence scale is binary + * for each positive instance, we will pick a negative + * instance and add in a pair. When using binary linear sampler, + * we should guarantee the labels are 0 or 1 + */ + struct BinaryLinearSampler :public IPairSampler{ + virtual Pairs GenPairs(const std::vector &preds, + const std::vector &labels, + int start, int end) { + Pairs pairs(start, end); + int pointer = 0, last_pointer = 0, index = start, interval = end - start; + for (int i = start; i < end; i++){ + if (labels[i] == 1){ + while (true){ + index = (++pointer) % interval + start; + if (labels[index] == 0) break; + if (pointer - last_pointer > interval) return pairs; + } + pairs.push(i, index); + pairs.push(index, i); + last_pointer = pointer; + } + } + return pairs; + } + }; + + + /*! \brief Pair Sampler Wrapper*/ + struct PairSamplerWrapper{ + public: + inline void AssignSampler(int sampler_index){ + + switch (sampler_index){ + case BINARY_LINEAR_SAMPLER:sampler_ = &binary_linear_sampler; break; + + default:utils::Error("Cannot find the specified sampler"); + } + } + + Pairs GenPairs(const std::vector &preds, + const std::vector &labels, + int start, int end){ + utils::Assert(sampler_ != NULL,"Not config the sampler yet. Add rank:sampler in the config file\n"); + return sampler_->GenPairs(preds, labels, start, end); + } + private: + BinaryLinearSampler binary_linear_sampler; + IPairSampler *sampler_; + }; + } + } +} +#endif \ No newline at end of file diff --git a/regression/xgboost_reg.h b/regression/xgboost_reg.h index 01416f7f6..01cf0d2f3 100644 --- a/regression/xgboost_reg.h +++ b/regression/xgboost_reg.h @@ -213,7 +213,7 @@ namespace xgboost{ inline void InteractPredict(std::vector &preds, const DMatrix &data, unsigned buffer_offset){ preds.resize(data.Size()); const unsigned ndata = static_cast(data.Size()); -#pragma omp parallel for schedule( static ) + #pragma omp parallel for schedule( static ) for (unsigned j = 0; j < ndata; ++j){ preds[j] = mparam.PredTransform (mparam.base_score + base_gbm.InteractPredict(data.data, j, buffer_offset + j)); @@ -222,7 +222,7 @@ namespace xgboost{ /*! \brief repredict trial */ inline void InteractRePredict(const DMatrix &data, unsigned buffer_offset){ const unsigned ndata = static_cast(data.Size()); -#pragma omp parallel for schedule( static ) + #pragma omp parallel for schedule( static ) for (unsigned j = 0; j < ndata; ++j){ base_gbm.InteractRePredict(data.data, j, buffer_offset + j); }