beta version, do a review
This commit is contained in:
parent
ce97f2fdf8
commit
f62b4a02f9
@ -100,6 +100,45 @@ struct EvalMatchError : public EvalEWiseBase<EvalMatchError> {
|
||||
}
|
||||
};
|
||||
|
||||
/*! \brief ctest */
|
||||
struct EvalCTest: public IEvaluator {
|
||||
EvalCTest(IEvaluator *base, const char *name)
|
||||
: base_(base), name_(name) {}
|
||||
virtual ~EvalCTest(void) {
|
||||
delete base_;
|
||||
}
|
||||
virtual const char *Name(void) const {
|
||||
return name_.c_str();
|
||||
}
|
||||
virtual float Eval(const std::vector<float> &preds,
|
||||
const MetaInfo &info) const {
|
||||
utils::Check(preds.size() % info.labels.size() == 0,
|
||||
"label and prediction size not match");
|
||||
size_t ngroup = preds.size() / info.labels.size() - 1;
|
||||
const unsigned ndata = static_cast<unsigned>(info.labels.size());
|
||||
utils::Check(ngroup > 1, "pred size does not meet requirement");
|
||||
utils::Check(ndata == info.info.fold_index.size(), "need fold index");
|
||||
double wsum = 0.0;
|
||||
for (size_t k = 0; k < ngroup; ++k) {
|
||||
std::vector<float> tpred;
|
||||
MetaInfo tinfo;
|
||||
for (unsigned i = 0; i < ndata; ++i) {
|
||||
if (info.info.fold_index[i] == k) {
|
||||
tpred.push_back(preds[i + (k + 1) * ndata]);
|
||||
tinfo.labels.push_back(info.labels[i]);
|
||||
tinfo.weights.push_back(info.GetWeight(i));
|
||||
}
|
||||
}
|
||||
wsum += base_->Eval(tpred, tinfo);
|
||||
}
|
||||
return wsum / ngroup;
|
||||
}
|
||||
|
||||
private:
|
||||
IEvaluator *base_;
|
||||
std::string name_;
|
||||
};
|
||||
|
||||
/*! \brief AMS: also records best threshold */
|
||||
struct EvalAMS : public IEvaluator {
|
||||
public:
|
||||
|
||||
@ -44,7 +44,9 @@ inline IEvaluator* CreateEvaluator(const char *name) {
|
||||
if (!strncmp(name, "pre@", 4)) return new EvalPrecision(name);
|
||||
if (!strncmp(name, "pratio@", 7)) return new EvalPrecisionRatio(name);
|
||||
if (!strncmp(name, "map", 3)) return new EvalMAP(name);
|
||||
if (!strncmp(name, "ndcg", 3)) return new EvalNDCG(name);
|
||||
if (!strncmp(name, "ndcg", 4)) return new EvalNDCG(name);
|
||||
if (!strncmp(name, "ct-", 3)) return new EvalCTest(CreateEvaluator(name+3), name);
|
||||
|
||||
utils::Error("unknown evaluation metric type: %s", name);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user