diff --git a/src/objective/rank_obj.cc b/src/objective/rank_obj.cc index 65a01d759..ed18f13c0 100644 --- a/src/objective/rank_obj.cc +++ b/src/objective/rank_obj.cc @@ -37,6 +37,7 @@ class LambdaRankObj : public ObjFunction { void Configure(const std::vector >& args) override { param_.InitAllowUnknown(args); } + void GetGradient(HostDeviceVector* preds, const MetaInfo& info, int iter, @@ -50,6 +51,7 @@ class LambdaRankObj : public ObjFunction { const std::vector &gptr = info.group_ptr_.size() == 0 ? tgptr : info.group_ptr_; CHECK(gptr.size() != 0 && gptr.back() == info.labels_.size()) << "group structure not consistent with #rows"; + const auto ngroup = static_cast(gptr.size() - 1); #pragma omp parallel { @@ -60,6 +62,11 @@ class LambdaRankObj : public ObjFunction { std::vector pairs; std::vector lst; std::vector< std::pair > rec; + bst_float sum_weights = 0; + for (bst_omp_uint k = 0; k < ngroup; ++k) { + sum_weights += info.GetWeight(k); + } + bst_float weight_normalization_factor = ngroup/sum_weights; #pragma omp for schedule(static) for (bst_omp_uint k = 0; k < ngroup; ++k) { lst.clear(); pairs.clear(); @@ -85,9 +92,11 @@ class LambdaRankObj : public ObjFunction { for (unsigned pid = i; pid < j; ++pid) { unsigned ridx = std::uniform_int_distribution(0, nleft + nright - 1)(rnd); if (ridx < nleft) { - pairs.emplace_back(rec[ridx].second, rec[pid].second); + pairs.emplace_back(rec[ridx].second, rec[pid].second, + info.GetWeight(k) * weight_normalization_factor); } else { - pairs.emplace_back(rec[pid].second, rec[ridx+j-i].second); + pairs.emplace_back(rec[pid].second, rec[ridx+j-i].second, + info.GetWeight(k) * weight_normalization_factor); } } } @@ -152,6 +161,9 @@ class LambdaRankObj : public ObjFunction { // constructor LambdaPair(unsigned pos_index, unsigned neg_index) : pos_index(pos_index), neg_index(neg_index), weight(1.0f) {} + // constructor + LambdaPair(unsigned pos_index, unsigned neg_index, bst_float weight) + : pos_index(pos_index), neg_index(neg_index), weight(weight) {} }; /*! * \brief get lambda weight for existing pairs @@ -205,7 +217,7 @@ class LambdaRankObjNDCG : public LambdaRankObj { ((1 << neg_label) - 1) * pos_loginv + ((1 << pos_label) - 1) * neg_loginv; bst_float delta = (original - changed) * IDCG; if (delta < 0.0f) delta = - delta; - pair.weight = delta; + pair.weight *= delta; } } } @@ -301,7 +313,7 @@ class LambdaRankObjMAP : public LambdaRankObj { std::vector map_stats; GetMAPStats(sorted_list, &map_stats); for (auto & pair : pairs) { - pair.weight = + pair.weight *= GetLambdaMAP(sorted_list, pair.pos_index, pair.neg_index, &map_stats); } diff --git a/tests/cpp/helpers.cc b/tests/cpp/helpers.cc index d2d8dacd5..2c8526c1c 100644 --- a/tests/cpp/helpers.cc +++ b/tests/cpp/helpers.cc @@ -41,17 +41,13 @@ std::string CreateBigTestData(size_t n_entries) { return tmp_file; } -void CheckObjFunction(xgboost::ObjFunction * obj, +void _CheckObjFunction(xgboost::ObjFunction * obj, std::vector preds, std::vector labels, std::vector weights, + xgboost::MetaInfo info, std::vector out_grad, std::vector out_hess) { - xgboost::MetaInfo info; - info.num_row_ = labels.size(); - info.labels_ = labels; - info.weights_ = weights; - xgboost::HostDeviceVector in_preds(preds); xgboost::HostDeviceVector out_gpair; @@ -69,6 +65,37 @@ void CheckObjFunction(xgboost::ObjFunction * obj, } } +void CheckObjFunction(xgboost::ObjFunction * obj, + std::vector preds, + std::vector labels, + std::vector weights, + std::vector out_grad, + std::vector out_hess) { + xgboost::MetaInfo info; + info.num_row_ = labels.size(); + info.labels_ = labels; + info.weights_ = weights; + + _CheckObjFunction(obj, preds, labels, weights, info, out_grad, out_hess); +} + +void CheckRankingObjFunction(xgboost::ObjFunction * obj, + std::vector preds, + std::vector labels, + std::vector weights, + std::vector groups, + std::vector out_grad, + std::vector out_hess) { + xgboost::MetaInfo info; + info.num_row_ = labels.size(); + info.labels_ = labels; + info.weights_ = weights; + info.group_ptr_ = groups; + + _CheckObjFunction(obj, preds, labels, weights, info, out_grad, out_hess); +} + + xgboost::bst_float GetMetricEval(xgboost::Metric * metric, std::vector preds, std::vector labels, diff --git a/tests/cpp/helpers.h b/tests/cpp/helpers.h index b3fcebfb3..411b916de 100644 --- a/tests/cpp/helpers.h +++ b/tests/cpp/helpers.h @@ -32,6 +32,14 @@ void CheckObjFunction(xgboost::ObjFunction * obj, std::vector out_grad, std::vector out_hess); +void CheckRankingObjFunction(xgboost::ObjFunction * obj, + std::vector preds, + std::vector labels, + std::vector weights, + std::vector groups, + std::vector out_grad, + std::vector out_hess); + xgboost::bst_float GetMetricEval( xgboost::Metric * metric, std::vector preds, diff --git a/tests/cpp/objective/test_multiclass_metric.cc b/tests/cpp/metric/test_multiclass_metric.cc similarity index 100% rename from tests/cpp/objective/test_multiclass_metric.cc rename to tests/cpp/metric/test_multiclass_metric.cc diff --git a/tests/cpp/objective/test_ranking_obj.cc b/tests/cpp/objective/test_ranking_obj.cc new file mode 100644 index 000000000..8dc136648 --- /dev/null +++ b/tests/cpp/objective/test_ranking_obj.cc @@ -0,0 +1,28 @@ +// Copyright by Contributors +#include + +#include "../helpers.h" + +TEST(Objective, PairwiseRankingGPair) { + xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("rank:pairwise"); + std::vector > args; + obj->Configure(args); + // Test with setting sample weight to second query group + CheckRankingObjFunction(obj, + {0, 0.1f, 0, 0.1f}, + {0, 1, 0, 1}, + {2.0f, 0.0f}, + {0, 2, 4}, + {1.9f, -1.9f, 0.0f, 0.0f}, + {1.995f, 1.995f, 0.0f, 0.0f}); + + CheckRankingObjFunction(obj, + {0, 0.1f, 0, 0.1f}, + {0, 1, 0, 1}, + {1.0f, 1.0f}, + {0, 2, 4}, + {0.95f, -0.95f, 0.95f, -0.95f}, + {0.9975f, 0.9975f, 0.9975f, 0.9975f}); + + ASSERT_NO_THROW(obj->DefaultEvalMetric()); +} \ No newline at end of file