From fdba6e9c46dd348ce1a372e35c123df22b57959c Mon Sep 17 00:00:00 2001 From: tqchen Date: Tue, 19 Aug 2014 08:02:29 -0700 Subject: [PATCH] add pratio --- src/learner/evaluation-inl.hpp | 35 ++++++++++++++++++++++++++++++++++ src/learner/evaluation.h | 1 + 2 files changed, 36 insertions(+) diff --git a/src/learner/evaluation-inl.hpp b/src/learner/evaluation-inl.hpp index 43fe48726..17a0d5589 100644 --- a/src/learner/evaluation-inl.hpp +++ b/src/learner/evaluation-inl.hpp @@ -155,6 +155,41 @@ struct EvalAMS : public IEvaluator { float ratio_; }; +/*! \brief precision with cut off at top percentile */ +struct EvalPrecisionRatio : public IEvaluator{ + public: + EvalPrecisionRatio( const char *name ) : name_(name) { + utils::Assert(sscanf( name, "apratio@%f", &ratio_) == 1, "BUG"); + } + virtual float Eval(const std::vector &preds, + const MetaInfo &info) const { + utils::Assert(preds.size() == info.labels.size(), "label size predict size not match"); + std::vector< std::pair > rec; + for (size_t j = 0; j < preds.size(); ++j) { + rec.push_back(std::make_pair(preds[j], j)); + } + std::sort(rec.begin(), rec.end(), CmpFirst); + double pratio = CalcPRatio( rec, info ); + return static_cast(pratio); + } + virtual const char *Name(void) const{ + return name_.c_str(); + } + protected: + inline double CalcPRatio(const std::vector< std::pair >& rec, const MetaInfo &info) const{ + size_t cutoff = static_cast(ratio_ * rec.size()); + double wt_hit = 0.0, wsum = 0.0; + for (size_t j = 0; j < cutoff; ++j) { + wt_hit += info.labels[rec[j].second]; + wsum += wt_hit / j; + } + return wsum / cutoff; + } + protected: + float ratio_; + std::string name_; +}; + /*! \brief Area under curve, for both classification and rank */ struct EvalAuc : public IEvaluator { virtual float Eval(const std::vector &preds, diff --git a/src/learner/evaluation.h b/src/learner/evaluation.h index fa25aa7d7..79ad4902e 100644 --- a/src/learner/evaluation.h +++ b/src/learner/evaluation.h @@ -41,6 +41,7 @@ inline IEvaluator* CreateEvaluator(const char *name) { if (!strcmp(name, "auc")) return new EvalAuc(); if (!strncmp(name, "ams@", 4)) return new EvalAMS(name); if (!strncmp(name, "pre@", 4)) return new EvalPrecision(name); + if (!strncmp(name, "pratio@", 4)) return new EvalPrecisionRatio(name); if (!strncmp(name, "map", 3)) return new EvalMAP(name); if (!strncmp(name, "ndcg", 3)) return new EvalNDCG(name); utils::Error("unknown evaluation metric type: %s", name);