Merge branch 'unity' of ssh://github.com/tqchen/xgboost into unity

This commit is contained in:
tqchen 2014-08-22 16:26:45 -07:00
commit 4ed67b9c27
3 changed files with 24 additions and 7 deletions

View File

@ -159,7 +159,12 @@ struct EvalAMS : public IEvaluator {
struct EvalPrecisionRatio : public IEvaluator{
public:
explicit EvalPrecisionRatio(const char *name) : name_(name) {
utils::Assert(sscanf(name, "pratio@%f", &ratio_) == 1, "BUG");
if (sscanf(name, "apratio@%f", &ratio_) == 1) {
use_ap = 1;
} else {
utils::Assert(sscanf(name, "pratio@%f", &ratio_) == 1, "BUG");
use_ap = 0;
}
}
virtual float Eval(const std::vector<float> &preds,
const MetaInfo &info) const {
@ -179,13 +184,20 @@ struct EvalPrecisionRatio : public IEvaluator{
protected:
inline double CalcPRatio(const std::vector< std::pair<float, unsigned> >& rec, const MetaInfo &info) const {
size_t cutoff = static_cast<size_t>(ratio_ * rec.size());
double wt_hit = 0.0, wsum = 0.0;
double wt_hit = 0.0, wsum = 0.0, wt_sum = 0.0;
for (size_t j = 0; j < cutoff; ++j) {
wt_hit += info.labels[rec[j].second];
wsum += wt_hit / (j + 1);
}
return wsum / cutoff;
const float wt = info.GetWeight(j);
wt_hit += info.labels[rec[j].second] * wt;
wt_sum += wt;
wsum += wt_hit / wt_sum;
}
if (use_ap != 0) {
return wsum / cutoff;
} else {
return wt_hit / wt_sum;
}
}
int use_ap;
float ratio_;
std::string name_;
};

View File

@ -8,6 +8,7 @@
#include <string>
#include <vector>
#include "../utils/utils.h"
#include "./dmatrix.h"
namespace xgboost {
namespace learner {
@ -41,7 +42,8 @@ inline IEvaluator* CreateEvaluator(const char *name) {
if (!strcmp(name, "auc")) return new EvalAuc();
if (!strncmp(name, "ams@", 4)) return new EvalAMS(name);
if (!strncmp(name, "pre@", 4)) return new EvalPrecision(name);
if (!strncmp(name, "pratio@", 4)) return new EvalPrecisionRatio(name);
if (!strncmp(name, "pratio@", 7)) return new EvalPrecisionRatio(name);
if (!strncmp(name, "apratio@", 8)) return new EvalPrecisionRatio(name);
if (!strncmp(name, "map", 3)) return new EvalMAP(name);
if (!strncmp(name, "ndcg", 3)) return new EvalNDCG(name);
utils::Error("unknown evaluation metric type: %s", name);

View File

@ -84,6 +84,9 @@ class BoostLearner {
if (!strcmp(name, "eval_metric")) evaluator_.AddEval(val);
if (!strcmp("seed", name)) random::Seed(atoi(val));
if (!strcmp(name, "num_class")) this->SetParam("num_output_group", val);
if (!strcmp(name, "nthread")) {
omp_set_num_threads(atoi(val));
}
if (gbm_ == NULL) {
if (!strcmp(name, "objective")) name_obj_ = val;
if (!strcmp(name, "booster")) name_gbm_ = val;