In AUC and AUCPR metrics, detect whether weights are per-instance or per-group (#4216)

* In AUC and AUCPR metrics, detect whether weights are per-instance or per-group

* Fix C++ style check

* Add a test for weighted AUC
This commit is contained in:
Xin Yin
2019-05-04 03:53:05 -04:00
committed by Philip Hyunsu Cho
parent 9252b686ae
commit 8d1098a983
2 changed files with 144 additions and 14 deletions

View File

@@ -14,6 +14,59 @@
#include "../common/host_device_vector.h"
#include "../common/math.h"
namespace {
/*
* Adapter to access instance weights.
*
* - For ranking task, weights are per-group
* - For binary classification task, weights are per-instance
*
* WeightPolicy::GetWeightOfInstance() :
* get weight associated with an individual instance, using index into
* `info.weights`
* WeightPolicy::GetWeightOfSortedRecord() :
* get weight associated with an individual instance, using index into
* sorted records `rec` (in ascending order of predicted labels). `rec` is
* of type PredIndPairContainer
*/
using PredIndPairContainer
= std::vector<std::pair<xgboost::bst_float, unsigned>>;
class PerInstanceWeightPolicy {
public:
inline static xgboost::bst_float
GetWeightOfInstance(const xgboost::MetaInfo& info,
unsigned instance_id, unsigned group_id) {
return info.GetWeight(instance_id);
}
inline static xgboost::bst_float
GetWeightOfSortedRecord(const xgboost::MetaInfo& info,
const PredIndPairContainer& rec,
unsigned record_id, unsigned group_id) {
return info.GetWeight(rec[record_id].second);
}
};
class PerGroupWeightPolicy {
public:
inline static xgboost::bst_float
GetWeightOfInstance(const xgboost::MetaInfo& info,
unsigned instance_id, unsigned group_id) {
return info.GetWeight(group_id);
}
inline static xgboost::bst_float
GetWeightOfSortedRecord(const xgboost::MetaInfo& info,
const PredIndPairContainer& rec,
unsigned record_id, unsigned group_id) {
return info.GetWeight(group_id);
}
};
} // anonymous namespace
namespace xgboost {
namespace metric {
// tag the this file, used by force static link later.
@@ -88,16 +141,18 @@ struct EvalAMS : public Metric {
/*! \brief Area Under Curve, for both classification and rank */
struct EvalAuc : public Metric {
private:
template <typename WeightPolicy>
bst_float Eval(const HostDeviceVector<bst_float> &preds,
const MetaInfo &info,
bool distributed) override {
bool distributed) {
CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty";
CHECK_EQ(preds.Size(), info.labels_.Size())
<< "label size predict size not match";
std::vector<unsigned> tgptr(2, 0);
tgptr[1] = static_cast<unsigned>(info.labels_.Size());
const std::vector<unsigned> &gptr = info.group_ptr_.size() == 0 ? tgptr : info.group_ptr_;
const std::vector<unsigned> &gptr = info.group_ptr_.empty() ? tgptr : info.group_ptr_;
CHECK_EQ(gptr.back(), info.labels_.Size())
<< "EvalAuc: group structure must match number of prediction";
const auto ngroup = static_cast<bst_omp_uint>(gptr.size() - 1);
@@ -108,9 +163,9 @@ struct EvalAuc : public Metric {
std::vector<std::pair<bst_float, unsigned>> rec;
const auto& labels = info.labels_.HostVector();
const std::vector<bst_float>& h_preds = preds.HostVector();
for (bst_omp_uint k = 0; k < ngroup; ++k) {
for (bst_omp_uint group_id = 0; group_id < ngroup; ++group_id) {
rec.clear();
for (unsigned j = gptr[k]; j < gptr[k + 1]; ++j) {
for (unsigned j = gptr[group_id]; j < gptr[group_id + 1]; ++j) {
rec.emplace_back(h_preds[j], j);
}
XGBOOST_PARALLEL_SORT(rec.begin(), rec.end(), common::CmpFirst);
@@ -118,7 +173,8 @@ struct EvalAuc : public Metric {
double sum_pospair = 0.0;
double sum_npos = 0.0, sum_nneg = 0.0, buf_pos = 0.0, buf_neg = 0.0;
for (size_t j = 0; j < rec.size(); ++j) {
const bst_float wt = info.GetWeight(rec[j].second);
const bst_float wt
= WeightPolicy::GetWeightOfSortedRecord(info, rec, j, group_id);
const bst_float ctr = labels[rec[j].second];
// keep bucketing predictions in same bucket
if (j != 0 && rec[j].first != rec[j - 1].first) {
@@ -154,6 +210,21 @@ struct EvalAuc : public Metric {
return static_cast<bst_float>(sum_auc) / ngroup;
}
}
public:
bst_float Eval(const HostDeviceVector<bst_float> &preds,
const MetaInfo &info,
bool distributed) override {
// For ranking task, weights are per-group
// For binary classification task, weights are per-instance
const bool is_ranking_task =
!info.group_ptr_.empty() && info.weights_.Size() != info.num_row_;
if (is_ranking_task) {
return Eval<PerGroupWeightPolicy>(preds, info, distributed);
} else {
return Eval<PerInstanceWeightPolicy>(preds, info, distributed);
}
}
const char* Name() const override {
return "auc";
}
@@ -370,9 +441,11 @@ struct EvalAucPR : public Metric {
// implementation of AUC-PR for weighted data
// translated from PRROC R Package
// see https://doi.org/10.1371/journal.pone.0092209
bst_float Eval(const HostDeviceVector<bst_float> &preds, const MetaInfo &info,
bool distributed) override {
private:
template <typename WeightPolicy>
bst_float Eval(const HostDeviceVector<bst_float> &preds,
const MetaInfo &info,
bool distributed) {
CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty";
CHECK_EQ(preds.Size(), info.labels_.Size())
<< "label size predict size not match";
@@ -391,13 +464,15 @@ struct EvalAucPR : public Metric {
const auto& h_labels = info.labels_.HostVector();
const std::vector<bst_float>& h_preds = preds.HostVector();
for (bst_omp_uint k = 0; k < ngroup; ++k) {
for (bst_omp_uint group_id = 0; group_id < ngroup; ++group_id) {
double total_pos = 0.0;
double total_neg = 0.0;
rec.clear();
for (unsigned j = gptr[k]; j < gptr[k + 1]; ++j) {
total_pos += info.GetWeight(j) * h_labels[j];
total_neg += info.GetWeight(j) * (1.0f - h_labels[j]);
for (unsigned j = gptr[group_id]; j < gptr[group_id + 1]; ++j) {
const bst_float wt
= WeightPolicy::GetWeightOfInstance(info, j, group_id);
total_pos += wt * h_labels[j];
total_neg += wt * (1.0f - h_labels[j]);
rec.emplace_back(h_preds[j], j);
}
XGBOOST_PARALLEL_SORT(rec.begin(), rec.end(), common::CmpFirst);
@@ -408,8 +483,10 @@ struct EvalAucPR : public Metric {
// calculate AUC
double tp = 0.0, prevtp = 0.0, fp = 0.0, prevfp = 0.0, h = 0.0, a = 0.0, b = 0.0;
for (size_t j = 0; j < rec.size(); ++j) {
tp += info.GetWeight(rec[j].second) * h_labels[rec[j].second];
fp += info.GetWeight(rec[j].second) * (1.0f - h_labels[rec[j].second]);
const bst_float wt
= WeightPolicy::GetWeightOfSortedRecord(info, rec, j, group_id);
tp += wt * h_labels[rec[j].second];
fp += wt * (1.0f - h_labels[rec[j].second]);
if ((j < rec.size() - 1 && rec[j].first != rec[j + 1].first) || j == rec.size() - 1) {
if (tp == prevtp) {
a = 1.0;
@@ -449,6 +526,21 @@ struct EvalAucPR : public Metric {
return static_cast<bst_float>(sum_auc) / ngroup;
}
}
public:
bst_float Eval(const HostDeviceVector<bst_float> &preds,
const MetaInfo &info,
bool distributed) override {
// For ranking task, weights are per-group
// For binary classification task, weights are per-instance
const bool is_ranking_task =
!info.group_ptr_.empty() && info.weights_.Size() != info.num_row_;
if (is_ranking_task) {
return Eval<PerGroupWeightPolicy>(preds, info, distributed);
} else {
return Eval<PerInstanceWeightPolicy>(preds, info, distributed);
}
}
const char *Name() const override { return "aucpr"; }
};