Merge branch 'unity' of ssh://github.com/tqchen/xgboost into unity
This commit is contained in:
commit
4ed67b9c27
@ -159,7 +159,12 @@ struct EvalAMS : public IEvaluator {
|
||||
struct EvalPrecisionRatio : public IEvaluator{
|
||||
public:
|
||||
explicit EvalPrecisionRatio(const char *name) : name_(name) {
|
||||
utils::Assert(sscanf(name, "pratio@%f", &ratio_) == 1, "BUG");
|
||||
if (sscanf(name, "apratio@%f", &ratio_) == 1) {
|
||||
use_ap = 1;
|
||||
} else {
|
||||
utils::Assert(sscanf(name, "pratio@%f", &ratio_) == 1, "BUG");
|
||||
use_ap = 0;
|
||||
}
|
||||
}
|
||||
virtual float Eval(const std::vector<float> &preds,
|
||||
const MetaInfo &info) const {
|
||||
@ -179,13 +184,20 @@ struct EvalPrecisionRatio : public IEvaluator{
|
||||
protected:
|
||||
inline double CalcPRatio(const std::vector< std::pair<float, unsigned> >& rec, const MetaInfo &info) const {
|
||||
size_t cutoff = static_cast<size_t>(ratio_ * rec.size());
|
||||
double wt_hit = 0.0, wsum = 0.0;
|
||||
double wt_hit = 0.0, wsum = 0.0, wt_sum = 0.0;
|
||||
for (size_t j = 0; j < cutoff; ++j) {
|
||||
wt_hit += info.labels[rec[j].second];
|
||||
wsum += wt_hit / (j + 1);
|
||||
}
|
||||
return wsum / cutoff;
|
||||
const float wt = info.GetWeight(j);
|
||||
wt_hit += info.labels[rec[j].second] * wt;
|
||||
wt_sum += wt;
|
||||
wsum += wt_hit / wt_sum;
|
||||
}
|
||||
if (use_ap != 0) {
|
||||
return wsum / cutoff;
|
||||
} else {
|
||||
return wt_hit / wt_sum;
|
||||
}
|
||||
}
|
||||
int use_ap;
|
||||
float ratio_;
|
||||
std::string name_;
|
||||
};
|
||||
|
||||
@ -8,6 +8,7 @@
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "../utils/utils.h"
|
||||
#include "./dmatrix.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace learner {
|
||||
@ -41,7 +42,8 @@ inline IEvaluator* CreateEvaluator(const char *name) {
|
||||
if (!strcmp(name, "auc")) return new EvalAuc();
|
||||
if (!strncmp(name, "ams@", 4)) return new EvalAMS(name);
|
||||
if (!strncmp(name, "pre@", 4)) return new EvalPrecision(name);
|
||||
if (!strncmp(name, "pratio@", 4)) return new EvalPrecisionRatio(name);
|
||||
if (!strncmp(name, "pratio@", 7)) return new EvalPrecisionRatio(name);
|
||||
if (!strncmp(name, "apratio@", 8)) return new EvalPrecisionRatio(name);
|
||||
if (!strncmp(name, "map", 3)) return new EvalMAP(name);
|
||||
if (!strncmp(name, "ndcg", 3)) return new EvalNDCG(name);
|
||||
utils::Error("unknown evaluation metric type: %s", name);
|
||||
|
||||
@ -84,6 +84,9 @@ class BoostLearner {
|
||||
if (!strcmp(name, "eval_metric")) evaluator_.AddEval(val);
|
||||
if (!strcmp("seed", name)) random::Seed(atoi(val));
|
||||
if (!strcmp(name, "num_class")) this->SetParam("num_output_group", val);
|
||||
if (!strcmp(name, "nthread")) {
|
||||
omp_set_num_threads(atoi(val));
|
||||
}
|
||||
if (gbm_ == NULL) {
|
||||
if (!strcmp(name, "objective")) name_obj_ = val;
|
||||
if (!strcmp(name, "booster")) name_gbm_ = val;
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user