- ndcg ltr implementation on gpu (#5004)

* - ndcg ltr implementation on gpu
  - this is a follow-up to the pairwise ltr implementation
This commit is contained in:
sriramch
2019-11-12 14:21:04 -08:00
committed by Rory Mitchell
parent f4e7b707c9
commit 2abe69d774
5 changed files with 780 additions and 202 deletions

View File

@@ -98,7 +98,7 @@ struct EvalAMS : public Metric {
for (bst_omp_uint i = 0; i < ndata; ++i) {
rec[i] = std::make_pair(h_preds[i], i);
}
std::sort(rec.begin(), rec.end(), common::CmpFirst);
std::stable_sort(rec.begin(), rec.end(), common::CmpFirst);
auto ntop = static_cast<unsigned>(ratio_ * ndata);
if (ntop == 0) ntop = ndata;
const double br = 10.0;
@@ -168,7 +168,7 @@ struct EvalAuc : public Metric {
for (unsigned j = gptr[group_id]; j < gptr[group_id + 1]; ++j) {
rec.emplace_back(h_preds[j], j);
}
XGBOOST_PARALLEL_SORT(rec.begin(), rec.end(), common::CmpFirst);
XGBOOST_PARALLEL_STABLE_SORT(rec.begin(), rec.end(), common::CmpFirst);
// calculate AUC
double sum_pospair = 0.0;
double sum_npos = 0.0, sum_nneg = 0.0, buf_pos = 0.0, buf_neg = 0.0;
@@ -321,7 +321,7 @@ struct EvalPrecision : public EvalRankList{
protected:
bst_float EvalMetric(std::vector< std::pair<bst_float, unsigned> > &rec) const override {
// calculate Precision
std::sort(rec.begin(), rec.end(), common::CmpFirst);
std::stable_sort(rec.begin(), rec.end(), common::CmpFirst);
unsigned nhit = 0;
for (size_t j = 0; j < rec.size() && j < this->topn_; ++j) {
nhit += (rec[j].second != 0);
@@ -369,7 +369,7 @@ struct EvalMAP : public EvalRankList {
protected:
bst_float EvalMetric(std::vector< std::pair<bst_float, unsigned> > &rec) const override {
std::sort(rec.begin(), rec.end(), common::CmpFirst);
std::stable_sort(rec.begin(), rec.end(), common::CmpFirst);
unsigned nhits = 0;
double sumap = 0.0;
for (size_t i = 0; i < rec.size(); ++i) {
@@ -481,7 +481,7 @@ struct EvalAucPR : public Metric {
total_neg += wt * (1.0f - h_labels[j]);
rec.emplace_back(h_preds[j], j);
}
XGBOOST_PARALLEL_SORT(rec.begin(), rec.end(), common::CmpFirst);
XGBOOST_PARALLEL_STABLE_SORT(rec.begin(), rec.end(), common::CmpFirst);
// we need pos > 0 && neg > 0
if (0.0 == total_pos || 0.0 == total_neg) {
auc_error += 1;