Make AUCPR work with multiple query groups (#4436)

* Make AUCPR work with multiple query groups

* Check AUCPR <= 1.0 in distributed setting
This commit is contained in:
Philip Hyunsu Cho
2019-05-03 10:34:44 -07:00
committed by GitHub
parent 2be85fc62a
commit 9252b686ae
2 changed files with 45 additions and 20 deletions

View File

@@ -101,11 +101,11 @@ struct EvalAuc : public Metric {
CHECK_EQ(gptr.back(), info.labels_.Size())
<< "EvalAuc: group structure must match number of prediction";
const auto ngroup = static_cast<bst_omp_uint>(gptr.size() - 1);
// sum statistics
bst_float sum_auc = 0.0f;
// sum of all AUC's across all query groups
double sum_auc = 0.0;
int auc_error = 0;
// each thread takes a local rec
std::vector< std::pair<bst_float, unsigned> > rec;
std::vector<std::pair<bst_float, unsigned>> rec;
const auto& labels = info.labels_.HostVector();
const std::vector<bst_float>& h_preds = preds.HostVector();
for (bst_omp_uint k = 0; k < ngroup; ++k) {
@@ -130,7 +130,7 @@ struct EvalAuc : public Metric {
buf_pos += ctr * wt;
buf_neg += (1.0f - ctr) * wt;
}
sum_pospair += buf_neg * (sum_npos + buf_pos *0.5);
sum_pospair += buf_neg * (sum_npos + buf_pos * 0.5);
sum_npos += buf_pos;
sum_nneg += buf_neg;
// check weird conditions
@@ -139,15 +139,15 @@ struct EvalAuc : public Metric {
continue;
}
// this is the AUC
sum_auc += sum_pospair / (sum_npos*sum_nneg);
sum_auc += sum_pospair / (sum_npos * sum_nneg);
}
CHECK(!auc_error)
<< "AUC: the dataset only contains pos or neg samples";
/* Report average AUC across all groups */
if (distributed) {
bst_float dat[2];
dat[0] = static_cast<bst_float>(sum_auc);
dat[1] = static_cast<bst_float>(ngroup);
// approximately estimate auc using mean
rabit::Allreduce<rabit::op::Sum>(dat, 2);
return dat[0] / dat[1];
} else {
@@ -383,9 +383,9 @@ struct EvalAucPR : public Metric {
CHECK_EQ(gptr.back(), info.labels_.Size())
<< "EvalAucPR: group structure must match number of prediction";
const auto ngroup = static_cast<bst_omp_uint>(gptr.size() - 1);
// sum statistics
double auc = 0.0;
int auc_error = 0, auc_gt_one = 0;
// sum of all AUC's across all query groups
double sum_auc = 0.0;
int auc_error = 0;
// each thread takes a local rec
std::vector<std::pair<bst_float, unsigned>> rec;
const auto& h_labels = info.labels_.HostVector();
@@ -420,14 +420,11 @@ struct EvalAucPR : public Metric {
b = (prevfp - h * prevtp) / total_pos;
}
if (0.0 != b) {
auc += (tp / total_pos - prevtp / total_pos -
b / a * (std::log(a * tp / total_pos + b) -
std::log(a * prevtp / total_pos + b))) / a;
sum_auc += (tp / total_pos - prevtp / total_pos -
b / a * (std::log(a * tp / total_pos + b) -
std::log(a * prevtp / total_pos + b))) / a;
} else {
auc += (tp / total_pos - prevtp / total_pos) / a;
}
if (auc > 1.0) {
auc_gt_one = 1;
sum_auc += (tp / total_pos - prevtp / total_pos) / a;
}
prevtp = tp;
prevfp = fp;
@@ -439,16 +436,17 @@ struct EvalAucPR : public Metric {
}
}
CHECK(!auc_error) << "AUC-PR: the dataset only contains pos or neg samples";
CHECK(!auc_gt_one) << "AUC-PR: AUC > 1.0";
/* Report average AUC across all groups */
if (distributed) {
bst_float dat[2];
dat[0] = static_cast<bst_float>(auc);
dat[0] = static_cast<bst_float>(sum_auc);
dat[1] = static_cast<bst_float>(ngroup);
// approximately estimate auc using mean
rabit::Allreduce<rabit::op::Sum>(dat, 2);
CHECK_LE(dat[0], dat[1]) << "AUC-PR: AUC > 1.0";
return dat[0] / dat[1];
} else {
return static_cast<bst_float>(auc) / ngroup;
CHECK_LE(sum_auc, static_cast<double>(ngroup)) << "AUC-PR: AUC > 1.0";
return static_cast<bst_float>(sum_auc) / ngroup;
}
}
const char *Name() const override { return "aucpr"; }