- ndcg ltr implementation on gpu (#5004)
* - ndcg ltr implementation on gpu - this is a follow-up to the pairwise ltr implementation
This commit is contained in:
@@ -98,7 +98,7 @@ struct EvalAMS : public Metric {
|
||||
for (bst_omp_uint i = 0; i < ndata; ++i) {
|
||||
rec[i] = std::make_pair(h_preds[i], i);
|
||||
}
|
||||
std::sort(rec.begin(), rec.end(), common::CmpFirst);
|
||||
std::stable_sort(rec.begin(), rec.end(), common::CmpFirst);
|
||||
auto ntop = static_cast<unsigned>(ratio_ * ndata);
|
||||
if (ntop == 0) ntop = ndata;
|
||||
const double br = 10.0;
|
||||
@@ -168,7 +168,7 @@ struct EvalAuc : public Metric {
|
||||
for (unsigned j = gptr[group_id]; j < gptr[group_id + 1]; ++j) {
|
||||
rec.emplace_back(h_preds[j], j);
|
||||
}
|
||||
XGBOOST_PARALLEL_SORT(rec.begin(), rec.end(), common::CmpFirst);
|
||||
XGBOOST_PARALLEL_STABLE_SORT(rec.begin(), rec.end(), common::CmpFirst);
|
||||
// calculate AUC
|
||||
double sum_pospair = 0.0;
|
||||
double sum_npos = 0.0, sum_nneg = 0.0, buf_pos = 0.0, buf_neg = 0.0;
|
||||
@@ -321,7 +321,7 @@ struct EvalPrecision : public EvalRankList{
|
||||
protected:
|
||||
bst_float EvalMetric(std::vector< std::pair<bst_float, unsigned> > &rec) const override {
|
||||
// calculate Precision
|
||||
std::sort(rec.begin(), rec.end(), common::CmpFirst);
|
||||
std::stable_sort(rec.begin(), rec.end(), common::CmpFirst);
|
||||
unsigned nhit = 0;
|
||||
for (size_t j = 0; j < rec.size() && j < this->topn_; ++j) {
|
||||
nhit += (rec[j].second != 0);
|
||||
@@ -369,7 +369,7 @@ struct EvalMAP : public EvalRankList {
|
||||
|
||||
protected:
|
||||
bst_float EvalMetric(std::vector< std::pair<bst_float, unsigned> > &rec) const override {
|
||||
std::sort(rec.begin(), rec.end(), common::CmpFirst);
|
||||
std::stable_sort(rec.begin(), rec.end(), common::CmpFirst);
|
||||
unsigned nhits = 0;
|
||||
double sumap = 0.0;
|
||||
for (size_t i = 0; i < rec.size(); ++i) {
|
||||
@@ -481,7 +481,7 @@ struct EvalAucPR : public Metric {
|
||||
total_neg += wt * (1.0f - h_labels[j]);
|
||||
rec.emplace_back(h_preds[j], j);
|
||||
}
|
||||
XGBOOST_PARALLEL_SORT(rec.begin(), rec.end(), common::CmpFirst);
|
||||
XGBOOST_PARALLEL_STABLE_SORT(rec.begin(), rec.end(), common::CmpFirst);
|
||||
// we need pos > 0 && neg > 0
|
||||
if (0.0 == total_pos || 0.0 == total_neg) {
|
||||
auc_error += 1;
|
||||
|
||||
Reference in New Issue
Block a user