diff --git a/regrank/xgboost_regrank_eval.h b/regrank/xgboost_regrank_eval.h index 9a37e50da..2e90f583a 100644 --- a/regrank/xgboost_regrank_eval.h +++ b/regrank/xgboost_regrank_eval.h @@ -24,24 +24,24 @@ namespace xgboost{ * \param info information, including label etc. */ virtual float Eval(const std::vector &preds, - const DMatrix::Info &info ) const = 0; + const DMatrix::Info &info) const = 0; /*! \return name of metric */ virtual const char *Name(void) const = 0; /*! \brief virtual destructor */ virtual ~IEvaluator(void){} }; - inline static bool CmpFirst( const std::pair &a, const std::pair &b ){ - return a.first > b.first; - } + inline static bool CmpFirst(const std::pair &a, const std::pair &b){ + return a.first > b.first; + } /*! \brief RMSE */ struct EvalRMSE : public IEvaluator{ virtual float Eval(const std::vector &preds, - const DMatrix::Info &info ) const { + const DMatrix::Info &info) const { const unsigned ndata = static_cast(preds.size()); float sum = 0.0, wsum = 0.0; - #pragma omp parallel for reduction(+:sum,wsum) schedule( static ) +#pragma omp parallel for reduction(+:sum,wsum) schedule( static ) for (unsigned i = 0; i < ndata; ++i){ const float wt = info.GetWeight(i); const float diff = info.labels[i] - preds[i]; @@ -58,16 +58,16 @@ namespace xgboost{ /*! \brief Error */ struct EvalLogLoss : public IEvaluator{ virtual float Eval(const std::vector &preds, - const DMatrix::Info &info ) const { + const DMatrix::Info &info) const { const unsigned ndata = static_cast(preds.size()); - float sum = 0.0f, wsum = 0.0f; - #pragma omp parallel for reduction(+:sum,wsum) schedule( static ) + float sum = 0.0f, wsum = 0.0f; +#pragma omp parallel for reduction(+:sum,wsum) schedule( static ) for (unsigned i = 0; i < ndata; ++i){ const float y = info.labels[i]; const float py = preds[i]; const float wt = info.GetWeight(i); - sum -= wt * ( y * std::log(py) + (1.0f - y)*std::log(1 - py) ); - wsum+= wt; + sum -= wt * (y * std::log(py) + (1.0f - y)*std::log(1 - py)); + wsum += wt; } return sum / wsum; } @@ -79,15 +79,15 @@ namespace xgboost{ /*! \brief Error */ struct EvalError : public IEvaluator{ virtual float Eval(const std::vector &preds, - const DMatrix::Info &info ) const { + const DMatrix::Info &info) const { const unsigned ndata = static_cast(preds.size()); - float sum = 0.0f, wsum = 0.0f; - #pragma omp parallel for reduction(+:sum,wsum) schedule( static ) + float sum = 0.0f, wsum = 0.0f; +#pragma omp parallel for reduction(+:sum,wsum) schedule( static ) for (unsigned i = 0; i < ndata; ++i){ const float wt = info.GetWeight(i); if (preds[i] > 0.5f){ - if (info.labels[i] < 0.5f) sum += wt; - } + if (info.labels[i] < 0.5f) sum += wt; + } else{ if (info.labels[i] >= 0.5f) sum += wt; } @@ -101,44 +101,44 @@ namespace xgboost{ }; /*! \brief Area under curve, for both classification and rank */ - struct EvalAuc : public IEvaluator{ - virtual float Eval( const std::vector &preds, - const DMatrix::Info &info ) const { - std::vector tgptr(2,0); tgptr[1] = preds.size(); + struct EvalAuc : public IEvaluator{ + virtual float Eval(const std::vector &preds, + const DMatrix::Info &info) const { + std::vector tgptr(2, 0); tgptr[1] = preds.size(); const std::vector &gptr = info.group_ptr.size() == 0 ? tgptr : info.group_ptr; - utils::Assert( gptr.back() == preds.size(), "EvalAuc: group structure must match number of prediction" ); - const unsigned ngroup = static_cast( gptr.size() - 1 ); + utils::Assert(gptr.back() == preds.size(), "EvalAuc: group structure must match number of prediction"); + const unsigned ngroup = static_cast(gptr.size() - 1); double sum_auc = 0.0f; - #pragma omp parallel reduction(+:sum_auc) - { +#pragma omp parallel reduction(+:sum_auc) + { // each thread takes a local rec - std::vector< std::pair > rec; - #pragma omp for schedule(static) - for( unsigned k = 0; k < ngroup; ++ k ){ + std::vector< std::pair > rec; +#pragma omp for schedule(static) + for (unsigned k = 0; k < ngroup; ++k){ rec.clear(); - for( unsigned j = gptr[k]; j < gptr[k+1]; ++ j ){ - rec.push_back( std::make_pair( preds[j], j ) ); + for (unsigned j = gptr[k]; j < gptr[k + 1]; ++j){ + rec.push_back(std::make_pair(preds[j], j)); } - std::sort( rec.begin(), rec.end(), CmpFirst ); + std::sort(rec.begin(), rec.end(), CmpFirst); // calculate AUC double sum_pospair = 0.0; double sum_npos = 0.0, sum_nneg = 0.0, buf_pos = 0.0, buf_neg = 0.0; - for( size_t j = 0; j < rec.size(); ++ j ){ - const float wt = info.GetWeight( rec[j].second ); - const float ctr = info.labels[ rec[j].second ]; + for (size_t j = 0; j < rec.size(); ++j){ + const float wt = info.GetWeight(rec[j].second); + const float ctr = info.labels[rec[j].second]; // keep bucketing predictions in same bucket - if( j != 0 && rec[j].first != rec[j-1].first ){ + if (j != 0 && rec[j].first != rec[j - 1].first){ sum_pospair += buf_neg * (sum_npos + buf_pos *0.5); sum_npos += buf_pos; sum_nneg += buf_neg; buf_neg = buf_pos = 0.0f; } - buf_pos += ctr * wt; buf_neg += (1.0f-ctr) * wt; + buf_pos += ctr * wt; buf_neg += (1.0f - ctr) * wt; } - sum_pospair += buf_neg * (sum_npos + buf_pos *0.5); + sum_pospair += buf_neg * (sum_npos + buf_pos *0.5); sum_npos += buf_pos; sum_nneg += buf_neg; // - utils::Assert( sum_npos > 0.0 && sum_nneg > 0.0, "the dataset only contains pos or neg samples" ); + utils::Assert(sum_npos > 0.0 && sum_nneg > 0.0, "the dataset only contains pos or neg samples"); // this is the AUC sum_auc += sum_pospair / (sum_npos*sum_nneg); } @@ -146,40 +146,40 @@ namespace xgboost{ // return average AUC over list return static_cast(sum_auc) / ngroup; } - virtual const char *Name( void ) const{ + virtual const char *Name(void) const{ return "auc"; } }; /*! \brief Precison at N, for both classification and rank */ - struct EvalPrecision : public IEvaluator{ + struct EvalPrecision : public IEvaluator{ unsigned topn_; std::string name_; - EvalPrecision( const char *name ){ + EvalPrecision(const char *name){ name_ = name; - utils::Assert( sscanf( name, "pre@%u", &topn_ ) ); + utils::Assert(sscanf(name, "pre@%u", &topn_)); } - virtual float Eval( const std::vector &preds, - const DMatrix::Info &info ) const { + virtual float Eval(const std::vector &preds, + const DMatrix::Info &info) const { const std::vector &gptr = info.group_ptr; - utils::Assert( gptr.size()!=0 && gptr.back() == preds.size(), "EvalAuc: group structure must match number of prediction" ); - const unsigned ngroup = static_cast( gptr.size() - 1 ); + utils::Assert(gptr.size() != 0 && gptr.back() == preds.size(), "EvalAuc: group structure must match number of prediction"); + const unsigned ngroup = static_cast(gptr.size() - 1); double sum_pre = 0.0f; - #pragma omp parallel reduction(+:sum_pre) - { +#pragma omp parallel reduction(+:sum_pre) + { // each thread takes a local rec - std::vector< std::pair > rec; - #pragma omp for schedule(static) - for( unsigned k = 0; k < ngroup; ++ k ){ + std::vector< std::pair > rec; +#pragma omp for schedule(static) + for (unsigned k = 0; k < ngroup; ++k){ rec.clear(); - for( unsigned j = gptr[k]; j < gptr[k+1]; ++ j ){ - rec.push_back( std::make_pair( preds[j], (int)info.labels[j] ) ); + for (unsigned j = gptr[k]; j < gptr[k + 1]; ++j){ + rec.push_back(std::make_pair(preds[j], (int)info.labels[j])); } - std::sort( rec.begin(), rec.end(), CmpFirst ); + std::sort(rec.begin(), rec.end(), CmpFirst); // calculate Preicsion unsigned nhit = 0; - for( size_t j = 0; j < rec.size() && j < topn_; ++ j ){ + for (size_t j = 0; j < rec.size() && j < topn_; ++j){ nhit += rec[j].second; } sum_pre += ((float)nhit) / topn_; @@ -187,9 +187,95 @@ namespace xgboost{ } return static_cast(sum_pre) / ngroup; } - virtual const char *Name( void ) const{ + virtual const char *Name(void) const{ return name_.c_str(); - } + } + }; + + /*! \brief Normalized DCG */ + class EvalNDCG : public IEvaluator { + public: + virtual float Eval(const std::vector &preds, + const DMatrix::Info &info) const{ + if (info.group_ptr.size() <= 1) return 0; + float acc = 0; + std::vector> pairs_sort; + for (int i = 0; i < info.group_ptr.size() - 1; i++){ + for (int j = info.group_ptr[i]; j < info.group_ptr[i + 1]; j++){ + pairs_sort.push_back(std::make_pair(preds[j], info.labels[j])); + } + acc += NDCG(pairs_sort); + } + return acc / (info.group_ptr.size() - 1); + } + + static float DCG(const std::vector &labels){ + float ans = 0.0; + for (int i = 0; i < labels.size(); i++){ + ans += (pow(2, labels[i]) - 1) / log(i + 2); + } + return ans; + } + + virtual const char *Name(void) const { + return "NDCG"; + } + + private: + float NDCG(std::vector> pairs_sort) const{ + std::sort(pairs_sort.begin(), pairs_sort.end(), std::greater()); + float dcg = DCG(pairs_sort); + std::sort(pairs_sort.begin(), pairs_sort.end(), std::greater()); + float IDCG = DCG(pairs_sort); + if (IDCG == 0) return 0; + return dcg / IDCG; + } + + float DCG(std::vector> pairs_sort) const{ + std::vector labels; + for (int i = 1; i < pairs_sort.size(); i++){ + labels.push_back(std::get<1>(pairs_sort[i])); + } + return DCG(labels); + } + }; + + + /*! \brief Mean Average Precision */ + class EvalMAP : public IEvaluator { + public: + virtual float Eval(const std::vector &preds, + const DMatrix::Info &info) const{ + if (info.group_ptr.size() <= 1) return 0; + float acc = 0; + std::vector> pairs_sort; + for (int i = 0; i < info.group_ptr.size() - 1; i++){ + for (int j = info.group_ptr[i]; j < info.group_ptr[i + 1]; j++){ + pairs_sort.push_back(std::make_pair(preds[j], info.labels[j])); + } + acc += average_precision(pairs_sort); + } + return acc / (info.group_ptr.size() - 1); + } + + virtual const char *Name(void) const { + return "MAP"; + } + + private: + float average_precision(std::vector> pairs_sort) const{ + std::sort(pairs_sort.begin(), pairs_sort.end(), std::greater()); + float hits = 0; + float average_precision = 0; + for (int j = 0; j < pairs_sort.size(); j++){ + if (std::get<1>(pairs_sort[j]) == 1){ + hits++; + average_precision += hits / (j + 1); + } + } + if (hits != 0) average_precision /= hits; + return average_precision; + } }; }; @@ -198,23 +284,23 @@ namespace xgboost{ struct EvalSet{ public: inline void AddEval(const char *name){ - for( size_t i = 0; i < evals_.size(); ++ i ){ - if(!strcmp(name, evals_[i]->Name())) return; + for (size_t i = 0; i < evals_.size(); ++i){ + if (!strcmp(name, evals_[i]->Name())) return; } - if (!strcmp(name, "rmse")) evals_.push_back( new EvalRMSE() ); - if (!strcmp(name, "error")) evals_.push_back( new EvalError() ); - if (!strcmp(name, "logloss")) evals_.push_back( new EvalLogLoss() ); - if (!strcmp( name, "auc")) evals_.push_back( new EvalAuc() ); - if (!strncmp( name, "pre@",4)) evals_.push_back( new EvalPrecision(name) ); + if (!strcmp(name, "rmse")) evals_.push_back(new EvalRMSE()); + if (!strcmp(name, "error")) evals_.push_back(new EvalError()); + if (!strcmp(name, "logloss")) evals_.push_back(new EvalLogLoss()); + if (!strcmp(name, "auc")) evals_.push_back(new EvalAuc()); + if (!strncmp(name, "pre@", 4)) evals_.push_back(new EvalPrecision(name)); } ~EvalSet(){ - for( size_t i = 0; i < evals_.size(); ++ i ){ + for (size_t i = 0; i < evals_.size(); ++i){ delete evals_[i]; } } inline void Eval(FILE *fo, const char *evname, - const std::vector &preds, - const DMatrix::Info &info ) const{ + const std::vector &preds, + const DMatrix::Info &info) const{ for (size_t i = 0; i < evals_.size(); ++i){ float res = evals_[i]->Eval(preds, info); fprintf(fo, "\t%s-%s:%f", evname, evals_[i]->Name(), res); diff --git a/regrank/xgboost_regrank_obj.hpp b/regrank/xgboost_regrank_obj.hpp index fb20f68a0..ab472eb3b 100644 --- a/regrank/xgboost_regrank_obj.hpp +++ b/regrank/xgboost_regrank_obj.hpp @@ -5,6 +5,10 @@ * \brief implementation of objective functions * \author Tianqi Chen, Kailong Chen */ +#include "xgboost_regrank_sample.h" +#include +#include +#include namespace xgboost{ namespace regrank{ class RegressionObj : public IObjFunction{ @@ -202,5 +206,208 @@ namespace xgboost{ LossType loss; }; }; + + namespace regrank{ + // simple pairwise rank + class LambdaRankObj : public IObjFunction{ + public: + LambdaRankObj(void){} + + virtual ~LambdaRankObj(){} + + virtual void SetParam(const char *name, const char *val){ + if (!strcmp("loss_type", name)) loss_.loss_type = atoi(val); + if (!strcmp("sampler", name)) sampler_.AssignSampler(atoi(val)); + if (!strcmp("lambda", name)) lambda_ = atoi(val); + } + + virtual void GetGradient(const std::vector& preds, + const DMatrix::Info &info, + int iter, + std::vector &grad, + std::vector &hess) { + grad.resize(preds.size()); hess.resize(preds.size()); + const std::vector &group_index = info.group_ptr; + utils::Assert(group_index.size() != 0 && group_index.back() == preds.size(), "rank loss must have group file"); + + for (int i = 0; i < group_index.size() - 1; i++){ + sample::Pairs pairs = sampler_.GenPairs(preds, info.labels, group_index[i], group_index[i + 1]); + //pairs.GetPairs() + std::vector< std::tuple > sorted_triple = GetSortedTuple(preds, info.labels, group_index, i); + std::vector index_remap = GetIndexMap(sorted_triple, group_index[i]); + GetGroupGradient(preds, info.labels, group_index, + grad, hess, sorted_triple, index_remap, pairs, i); + } + } + + virtual const char* DefaultEvalMetric(void) { + return "auc"; + } + + private: + /* \brief Sorted tuples of a group by the predictions, and + * the fields in the return tuples successively are predicions, + * labels, and the index of the instance + */ + inline std::vector< std::tuple > GetSortedTuple(const std::vector &preds, + const std::vector &labels, + const std::vector &group_index, + int group){ + std::vector< std::tuple > sorted_triple; + for (int j = group_index[group]; j < group_index[group + 1]; j++){ + sorted_triple.push_back(std::tuple(preds[j], labels[j], j)); + } + std::sort(sorted_triple.begin(), sorted_triple.end(), + [](std::tuple a, std::tuple b){ + return std::get<0>(a) > std::get<0>(b); + }); + return sorted_triple; + } + + inline std::vector GetIndexMap(std::vector< std::tuple > sorted_triple, int start){ + std::vector index_remap; + index_remap.resize(sorted_triple.size()); + for (int i = 0; i < sorted_triple.size(); i++){ + index_remap[std::get<2>(sorted_triple[i]) - start] = i; + } + return index_remap; + } + + inline float GetLambdaMAP(const std::vector< std::tuple > sorted_triple, + int index1, int index2, + std::vector< std::tuple > map_acc){ + if (index1 > index2) std::swap(index1, index2); + float original = std::get<0>(map_acc[index2]); + if (index1 != 0) original -= std::get<0>(map_acc[index1 - 1]); + float changed = 0; + if (std::get<1>(sorted_triple[index1]) < std::get<1>(sorted_triple[index2])){ + changed += std::get<2>(map_acc[index2 - 1]) - std::get<2>(map_acc[index1]); + changed += (std::get<3>(map_acc[index1])+ 1.0f) / (index1 + 1); + } + else{ + changed += std::get<1>(map_acc[index2 - 1]) - std::get<1>(map_acc[index1]); + changed += std::get<3>(map_acc[index2]) / (index2 + 1); + } + float ans = (changed - original) / (std::get<3>(map_acc[map_acc.size() - 1])); + if (ans < 0) ans = -ans; + return ans; + } + + inline float GetLambdaNDCG(const std::vector< std::tuple > sorted_triple, + int index1, + int index2, float IDCG){ + float original = pow(2, std::get<1>(sorted_triple[index1])) / log(index1 + 2) + + pow(2, std::get<1>(sorted_triple[index2])) / log(index2 + 2); + float changed = pow(2, std::get<1>(sorted_triple[index2])) / log(index1 + 2) + + pow(2, std::get<1>(sorted_triple[index1])) / log(index2 + 2); + float ans = (original - changed) / IDCG; + if (ans < 0) ans = -ans; + return ans; + } + + + inline float GetIDCG(const std::vector< std::tuple > sorted_triple){ + std::vector labels; + for (int i = 0; i < sorted_triple.size(); i++){ + labels.push_back(std::get<1>(sorted_triple[i])); + } + + std::sort(labels.begin(), labels.end(), std::greater()); + return EvalNDCG::DCG(labels); + } + + inline std::vector< std::tuple > GetMAPAcc(const std::vector< std::tuple > sorted_triple){ + std::vector< std::tuple > map_acc; + float hit = 0, acc1 = 0, acc2 = 0, acc3 = 0; + for (int i = 0; i < sorted_triple.size(); i++){ + if (std::get<1>(sorted_triple[i]) == 1) { + hit++; + acc1 += hit / (i + 1); + acc2 += (hit - 1) / (i + 1); + acc3 += (hit + 1) / (i + 1); + } + map_acc.push_back(std::make_tuple(acc1, acc2, acc3, hit)); + } + return map_acc; + + } + + inline void GetGroupGradient(const std::vector &preds, + const std::vector &labels, + const std::vector &group_index, + std::vector &grad, + std::vector &hess, + const std::vector< std::tuple > sorted_triple, + const std::vector index_remap, + const sample::Pairs& pairs, + int group){ + bool j_better; + float IDCG, pred_diff, pred_diff_exp, delta; + float first_order_gradient, second_order_gradient; + std::vector< std::tuple > map_acc; + + if (lambda_ == NDCG){ + IDCG = GetIDCG(sorted_triple); + } + else if (lambda_ == MAP){ + map_acc = GetMAPAcc(sorted_triple); + } + + for (int j = group_index[group]; j < group_index[group + 1]; j++){ + std::vector pair_instance = pairs.GetPairs(j); + for (int k = 0; k < pair_instance.size(); k++){ + j_better = labels[j] > labels[pair_instance[k]]; + if (j_better){ + switch (lambda_){ + case PAIRWISE: delta = 1.0; break; + case MAP: delta = GetLambdaMAP(sorted_triple, index_remap[j - group_index[group]], index_remap[pair_instance[k] - group_index[group]], map_acc); break; + case NDCG: delta = GetLambdaNDCG(sorted_triple, index_remap[j - group_index[group]], index_remap[pair_instance[k] - group_index[group]], IDCG); break; + default: utils::Error("Cannot find the specified loss type"); + } + + pred_diff = preds[preds[j] - pair_instance[k]]; + pred_diff_exp = j_better ? expf(-pred_diff) : expf(pred_diff); + first_order_gradient = delta * FirstOrderGradient(pred_diff_exp); + second_order_gradient = 2 * delta * SecondOrderGradient(pred_diff_exp); + hess[j] += second_order_gradient; + grad[j] += first_order_gradient; + hess[pair_instance[k]] += second_order_gradient; + grad[pair_instance[k]] += -first_order_gradient; + } + } + } + } + + /*! + * \brief calculate first order gradient of pairwise loss function(f(x) = ln(1+exp(-x)), + * given the exponential of the difference of intransformed pair predictions + * \param the intransformed prediction of positive instance + * \param the intransformed prediction of negative instance + * \return first order gradient + */ + inline float FirstOrderGradient(float pred_diff_exp) const { + return -pred_diff_exp / (1 + pred_diff_exp); + } + + /*! + * \brief calculate second order gradient of pairwise loss function(f(x) = ln(1+exp(-x)), + * given the exponential of the difference of intransformed pair predictions + * \param the intransformed prediction of positive instance + * \param the intransformed prediction of negative instance + * \return second order gradient + */ + inline float SecondOrderGradient(float pred_diff_exp) const { + return pred_diff_exp / pow(1 + pred_diff_exp, 2); + } + + private: + int lambda_; + const static int PAIRWISE = 0; + const static int MAP = 1; + const static int NDCG = 2; + sample::PairSamplerWrapper sampler_; + LossType loss_; + }; + }; }; #endif diff --git a/regrank/xgboost_regrank_sample.h b/regrank/xgboost_regrank_sample.h new file mode 100644 index 000000000..0413e51a3 --- /dev/null +++ b/regrank/xgboost_regrank_sample.h @@ -0,0 +1,129 @@ +#ifndef _XGBOOST_REGRANK_SAMPLE_H_ +#define _XGBOOST_REGRANK_SAMPLE_H_ +#include +#include"../utils/xgboost_utils.h" + +namespace xgboost { + namespace regrank { + namespace sample { + + /* + * \brief the data structure to maintain the sample pairs + * similar to the adjacency list of a graph + */ + struct Pairs { + + /* + * \brief constructor given the start and end offset of the sampling group + * in overall instances + * \param start the begin index of the group + * \param end the end index of the group + */ + Pairs(int start, int end) :start_(start), end_(end){ + for (int i = start; i < end; i++){ + std::vector v; + pairs_.push_back(v); + } + } + /* + * \brief retrieve the related pair information of an data instances + * \param index, the index of retrieved instance + * \return the index of instances paired + */ + std::vector GetPairs(int index) const{ + utils::Assert(index >= start_ && index < end_, "The query index out of sampling bound"); + return pairs_[index - start_]; + } + + /* + * \brief add in a sampled pair + * \param index the index of the instance to sample a friend + * \param paired_index the index of the instance sampled as a friend + */ + void push(int index, int paired_index){ + pairs_[index - start_].push_back(paired_index); + } + + std::vector< std::vector > pairs_; + int start_; + int end_; + }; + + /* + * \brief the interface of pair sampler + */ + struct IPairSampler { + /* + * \brief Generate sample pairs given the predcions, labels, the start and the end index + * of a specified group + * \param preds, the predictions of all data instances + * \param labels, the labels of all data instances + * \param start, the start index of a specified group + * \param end, the end index of a specified group + * \return the generated pairs + */ + virtual Pairs GenPairs(const std::vector &preds, + const std::vector &labels, + int start, int end) = 0; + + }; + + enum{ + BINARY_LINEAR_SAMPLER + }; + + /*! \brief A simple pair sampler when the rank relevence scale is binary + * for each positive instance, we will pick a negative + * instance and add in a pair. When using binary linear sampler, + * we should guarantee the labels are 0 or 1 + */ + struct BinaryLinearSampler :public IPairSampler{ + virtual Pairs GenPairs(const std::vector &preds, + const std::vector &labels, + int start, int end) { + Pairs pairs(start, end); + int pointer = 0, last_pointer = 0, index = start, interval = end - start; + for (int i = start; i < end; i++){ + if (labels[i] == 1){ + while (true){ + index = (++pointer) % interval + start; + if (labels[index] == 0) break; + if (pointer - last_pointer > interval) return pairs; + } + pairs.push(i, index); + pairs.push(index, i); + last_pointer = pointer; + } + } + return pairs; + } + }; + + + /*! \brief Pair Sampler Wrapper*/ + struct PairSamplerWrapper{ + public: + inline void AssignSampler(int sampler_index){ + + switch (sampler_index){ + case BINARY_LINEAR_SAMPLER:sampler_ = &binary_linear_sampler; break; + default:utils::Error("Cannot find the specified sampler"); + } + } + + ~PairSamplerWrapper(){ delete sampler_; } + + Pairs GenPairs(const std::vector &preds, + const std::vector &labels, + int start, int end){ + utils::Assert(sampler_ != NULL, "Not config the sampler yet. Add rank:sampler in the config file\n"); + return sampler_->GenPairs(preds, labels, start, end); + } + private: + BinaryLinearSampler binary_linear_sampler; + IPairSampler *sampler_; + }; + } + } +} +#endif \ No newline at end of file