add apratio
This commit is contained in:
parent
762b360739
commit
3f5b5e1fdc
@ -159,7 +159,12 @@ struct EvalAMS : public IEvaluator {
|
|||||||
struct EvalPrecisionRatio : public IEvaluator{
|
struct EvalPrecisionRatio : public IEvaluator{
|
||||||
public:
|
public:
|
||||||
explicit EvalPrecisionRatio(const char *name) : name_(name) {
|
explicit EvalPrecisionRatio(const char *name) : name_(name) {
|
||||||
utils::Assert(sscanf(name, "pratio@%f", &ratio_) == 1, "BUG");
|
if (sscanf(name, "apratio@%f", &ratio_) == 1) {
|
||||||
|
use_ap = 1;
|
||||||
|
} else {
|
||||||
|
utils::Assert(sscanf(name, "pratio@%f", &ratio_) == 1, "BUG");
|
||||||
|
use_ap = 0;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
virtual float Eval(const std::vector<float> &preds,
|
virtual float Eval(const std::vector<float> &preds,
|
||||||
const MetaInfo &info) const {
|
const MetaInfo &info) const {
|
||||||
@ -179,13 +184,20 @@ struct EvalPrecisionRatio : public IEvaluator{
|
|||||||
protected:
|
protected:
|
||||||
inline double CalcPRatio(const std::vector< std::pair<float, unsigned> >& rec, const MetaInfo &info) const {
|
inline double CalcPRatio(const std::vector< std::pair<float, unsigned> >& rec, const MetaInfo &info) const {
|
||||||
size_t cutoff = static_cast<size_t>(ratio_ * rec.size());
|
size_t cutoff = static_cast<size_t>(ratio_ * rec.size());
|
||||||
double wt_hit = 0.0, wsum = 0.0;
|
double wt_hit = 0.0, wsum = 0.0, wt_sum = 0.0;
|
||||||
for (size_t j = 0; j < cutoff; ++j) {
|
for (size_t j = 0; j < cutoff; ++j) {
|
||||||
wt_hit += info.labels[rec[j].second];
|
const float wt = info.GetWeight(j);
|
||||||
wsum += wt_hit / (j + 1);
|
wt_hit += info.labels[rec[j].second] * wt;
|
||||||
}
|
wt_sum += wt;
|
||||||
return wsum / cutoff;
|
wsum += wt_hit / wt_sum;
|
||||||
|
}
|
||||||
|
if (use_ap != 0) {
|
||||||
|
return wsum / cutoff;
|
||||||
|
} else {
|
||||||
|
return wt_hit / wt_sum;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
int use_ap;
|
||||||
float ratio_;
|
float ratio_;
|
||||||
std::string name_;
|
std::string name_;
|
||||||
};
|
};
|
||||||
|
|||||||
@ -8,6 +8,7 @@
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "../utils/utils.h"
|
#include "../utils/utils.h"
|
||||||
|
#include "./dmatrix.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace learner {
|
namespace learner {
|
||||||
@ -41,7 +42,8 @@ inline IEvaluator* CreateEvaluator(const char *name) {
|
|||||||
if (!strcmp(name, "auc")) return new EvalAuc();
|
if (!strcmp(name, "auc")) return new EvalAuc();
|
||||||
if (!strncmp(name, "ams@", 4)) return new EvalAMS(name);
|
if (!strncmp(name, "ams@", 4)) return new EvalAMS(name);
|
||||||
if (!strncmp(name, "pre@", 4)) return new EvalPrecision(name);
|
if (!strncmp(name, "pre@", 4)) return new EvalPrecision(name);
|
||||||
if (!strncmp(name, "pratio@", 4)) return new EvalPrecisionRatio(name);
|
if (!strncmp(name, "pratio@", 7)) return new EvalPrecisionRatio(name);
|
||||||
|
if (!strncmp(name, "apratio@", 8)) return new EvalPrecisionRatio(name);
|
||||||
if (!strncmp(name, "map", 3)) return new EvalMAP(name);
|
if (!strncmp(name, "map", 3)) return new EvalMAP(name);
|
||||||
if (!strncmp(name, "ndcg", 3)) return new EvalNDCG(name);
|
if (!strncmp(name, "ndcg", 3)) return new EvalNDCG(name);
|
||||||
utils::Error("unknown evaluation metric type: %s", name);
|
utils::Error("unknown evaluation metric type: %s", name);
|
||||||
|
|||||||
@ -80,6 +80,9 @@ class BoostLearner {
|
|||||||
if (!strcmp(name, "eval_metric")) evaluator_.AddEval(val);
|
if (!strcmp(name, "eval_metric")) evaluator_.AddEval(val);
|
||||||
if (!strcmp("seed", name)) random::Seed(atoi(val));
|
if (!strcmp("seed", name)) random::Seed(atoi(val));
|
||||||
if (!strcmp(name, "num_class")) this->SetParam("num_output_group", val);
|
if (!strcmp(name, "num_class")) this->SetParam("num_output_group", val);
|
||||||
|
if (!strcmp(name, "nthread")) {
|
||||||
|
omp_set_num_threads(atoi(val));
|
||||||
|
}
|
||||||
if (gbm_ == NULL) {
|
if (gbm_ == NULL) {
|
||||||
if (!strcmp(name, "objective")) name_obj_ = val;
|
if (!strcmp(name, "objective")) name_obj_ = val;
|
||||||
if (!strcmp(name, "booster")) name_gbm_ = val;
|
if (!strcmp(name, "booster")) name_gbm_ = val;
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user