Rework the NDCG objective. (#9015)
This commit is contained in:
parent
ba9d24ff7b
commit
ef13dd31b1
@ -33,6 +33,7 @@ OBJECTS= \
|
||||
$(PKGROOT)/src/objective/regression_obj.o \
|
||||
$(PKGROOT)/src/objective/multiclass_obj.o \
|
||||
$(PKGROOT)/src/objective/rank_obj.o \
|
||||
$(PKGROOT)/src/objective/lambdarank_obj.o \
|
||||
$(PKGROOT)/src/objective/hinge.o \
|
||||
$(PKGROOT)/src/objective/aft_obj.o \
|
||||
$(PKGROOT)/src/objective/adaptive.o \
|
||||
|
||||
@ -33,6 +33,7 @@ OBJECTS= \
|
||||
$(PKGROOT)/src/objective/regression_obj.o \
|
||||
$(PKGROOT)/src/objective/multiclass_obj.o \
|
||||
$(PKGROOT)/src/objective/rank_obj.o \
|
||||
$(PKGROOT)/src/objective/lambdarank_obj.o \
|
||||
$(PKGROOT)/src/objective/hinge.o \
|
||||
$(PKGROOT)/src/objective/aft_obj.o \
|
||||
$(PKGROOT)/src/objective/adaptive.o \
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2015 by Contributors
|
||||
/**
|
||||
* Copyright 2015-2023 by XGBoost Contributors
|
||||
* \file math.h
|
||||
* \brief additional math utils
|
||||
* \author Tianqi Chen
|
||||
@ -7,16 +7,19 @@
|
||||
#ifndef XGBOOST_COMMON_MATH_H_
|
||||
#define XGBOOST_COMMON_MATH_H_
|
||||
|
||||
#include <xgboost/base.h>
|
||||
#include <xgboost/base.h> // for XGBOOST_DEVICE
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <limits>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <algorithm> // for max
|
||||
#include <cmath> // for exp, abs, log, lgamma
|
||||
#include <limits> // for numeric_limits
|
||||
#include <type_traits> // for is_floating_point, conditional, is_signed, is_same, declval, enable_if
|
||||
#include <utility> // for pair
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
|
||||
template <typename T> XGBOOST_DEVICE T Sqr(T const &w) { return w * w; }
|
||||
|
||||
/*!
|
||||
* \brief calculate the sigmoid of the input.
|
||||
* \param x input parameter
|
||||
@ -30,9 +33,11 @@ XGBOOST_DEVICE inline float Sigmoid(float x) {
|
||||
return y;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
XGBOOST_DEVICE inline static T Sqr(T a) { return a * a; }
|
||||
|
||||
XGBOOST_DEVICE inline double Sigmoid(double x) {
|
||||
auto denom = std::exp(-x) + 1.0;
|
||||
auto y = 1.0 / denom;
|
||||
return y;
|
||||
}
|
||||
/*!
|
||||
* \brief Equality test for both integer and floating point.
|
||||
*/
|
||||
@ -134,10 +139,6 @@ inline static bool CmpFirst(const std::pair<float, unsigned> &a,
|
||||
const std::pair<float, unsigned> &b) {
|
||||
return a.first > b.first;
|
||||
}
|
||||
inline static bool CmpSecond(const std::pair<float, unsigned> &a,
|
||||
const std::pair<float, unsigned> &b) {
|
||||
return a.second > b.second;
|
||||
}
|
||||
|
||||
// Redefined here to workaround a VC bug that doesn't support overloading for integer
|
||||
// types.
|
||||
|
||||
@ -70,7 +70,7 @@ struct LambdaRankParam : public XGBoostParameter<LambdaRankParam> {
|
||||
// pairs
|
||||
// should be accessed by getter for auto configuration.
|
||||
// nolint so that we can keep the string name.
|
||||
PairMethod lambdarank_pair_method{PairMethod::kMean}; // NOLINT
|
||||
PairMethod lambdarank_pair_method{PairMethod::kTopK}; // NOLINT
|
||||
std::size_t lambdarank_num_pair_per_sample{NotSet()}; // NOLINT
|
||||
|
||||
public:
|
||||
@ -78,7 +78,7 @@ struct LambdaRankParam : public XGBoostParameter<LambdaRankParam> {
|
||||
|
||||
// unbiased
|
||||
bool lambdarank_unbiased{false};
|
||||
double lambdarank_bias_norm{2.0};
|
||||
double lambdarank_bias_norm{1.0};
|
||||
// ndcg
|
||||
bool ndcg_exp_gain{true};
|
||||
|
||||
@ -135,7 +135,7 @@ struct LambdaRankParam : public XGBoostParameter<LambdaRankParam> {
|
||||
.set_default(false)
|
||||
.describe("Unbiased lambda mart. Use extended IPW to debias click position");
|
||||
DMLC_DECLARE_FIELD(lambdarank_bias_norm)
|
||||
.set_default(2.0)
|
||||
.set_default(1.0)
|
||||
.set_lower_bound(0.0)
|
||||
.describe("Lp regularization for unbiased lambdarank.");
|
||||
DMLC_DECLARE_FIELD(ndcg_exp_gain)
|
||||
|
||||
440
src/objective/lambdarank_obj.cc
Normal file
440
src/objective/lambdarank_obj.cc
Normal file
@ -0,0 +1,440 @@
|
||||
/**
|
||||
* Copyright (c) 2023, XGBoost contributors
|
||||
*/
|
||||
#include "lambdarank_obj.h"
|
||||
|
||||
#include <dmlc/registry.h> // for DMLC_REGISTRY_FILE_TAG
|
||||
|
||||
#include <algorithm> // for transform, copy, fill_n, min, max
|
||||
#include <cmath> // for pow, log2
|
||||
#include <cstddef> // for size_t
|
||||
#include <cstdint> // for int32_t
|
||||
#include <map> // for operator!=
|
||||
#include <memory> // for shared_ptr, __shared_ptr_access, allocator
|
||||
#include <ostream> // for operator<<, basic_ostream
|
||||
#include <string> // for char_traits, operator<, basic_string, string
|
||||
#include <tuple> // for apply, make_tuple
|
||||
#include <type_traits> // for is_floating_point
|
||||
#include <utility> // for pair, swap
|
||||
#include <vector> // for vector
|
||||
|
||||
#include "../common/error_msg.h" // for GroupWeight, LabelScoreSize
|
||||
#include "../common/linalg_op.h" // for begin, cbegin, cend
|
||||
#include "../common/optional_weight.h" // for MakeOptionalWeights, OptionalWeights
|
||||
#include "../common/ranking_utils.h" // for RankingCache, LambdaRankParam, MAPCache, NDCGC...
|
||||
#include "../common/threading_utils.h" // for ParallelFor, Sched
|
||||
#include "../common/transform_iterator.h" // for IndexTransformIter
|
||||
#include "init_estimation.h" // for FitIntercept
|
||||
#include "xgboost/base.h" // for bst_group_t, GradientPair, kRtEps, GradientPai...
|
||||
#include "xgboost/context.h" // for Context
|
||||
#include "xgboost/data.h" // for MetaInfo
|
||||
#include "xgboost/host_device_vector.h" // for HostDeviceVector
|
||||
#include "xgboost/json.h" // for Json, get, Value, ToJson, F32Array, FromJson, IsA
|
||||
#include "xgboost/linalg.h" // for Vector, Range, TensorView, VectorView, All
|
||||
#include "xgboost/logging.h" // for LogCheck_EQ, CHECK_EQ, CHECK, LogCheck_LE, CHE...
|
||||
#include "xgboost/objective.h" // for ObjFunctionReg, XGBOOST_REGISTER_OBJECTIVE
|
||||
#include "xgboost/span.h" // for Span, operator!=
|
||||
#include "xgboost/string_view.h" // for operator<<, StringView
|
||||
#include "xgboost/task.h" // for ObjInfo
|
||||
|
||||
namespace xgboost::obj {
|
||||
namespace cpu_impl {
|
||||
void LambdaRankUpdatePositionBias(Context const* ctx, linalg::VectorView<double const> li_full,
|
||||
linalg::VectorView<double const> lj_full,
|
||||
linalg::Vector<double>* p_ti_plus,
|
||||
linalg::Vector<double>* p_tj_minus, linalg::Vector<double>* p_li,
|
||||
linalg::Vector<double>* p_lj,
|
||||
std::shared_ptr<ltr::RankingCache> p_cache) {
|
||||
auto ti_plus = p_ti_plus->HostView();
|
||||
auto tj_minus = p_tj_minus->HostView();
|
||||
auto li = p_li->HostView();
|
||||
auto lj = p_lj->HostView();
|
||||
|
||||
auto gptr = p_cache->DataGroupPtr(ctx);
|
||||
auto n_groups = p_cache->Groups();
|
||||
auto regularizer = p_cache->Param().Regularizer();
|
||||
|
||||
// Aggregate over query groups
|
||||
for (bst_group_t g{0}; g < n_groups; ++g) {
|
||||
auto begin = gptr[g];
|
||||
auto end = gptr[g + 1];
|
||||
std::size_t group_size = end - begin;
|
||||
auto n = std::min(group_size, p_cache->MaxPositionSize());
|
||||
|
||||
auto g_li = li_full.Slice(linalg::Range(begin, end));
|
||||
auto g_lj = lj_full.Slice(linalg::Range(begin, end));
|
||||
|
||||
for (std::size_t i{0}; i < n; ++i) {
|
||||
li(i) += g_li(i);
|
||||
lj(i) += g_lj(i);
|
||||
}
|
||||
}
|
||||
// The ti+ is not guaranteed to decrease since it depends on the |\delta Z|
|
||||
//
|
||||
// The update normalizes the ti+ to make ti+(0) equal to 1, which breaks the probability
|
||||
// meaning. The reasoning behind the normalization is not clear, here we are just
|
||||
// following the authors.
|
||||
for (std::size_t i = 0; i < ti_plus.Size(); ++i) {
|
||||
if (li(0) >= Eps64()) {
|
||||
ti_plus(i) = std::pow(li(i) / li(0), regularizer); // eq.30
|
||||
}
|
||||
if (lj(0) >= Eps64()) {
|
||||
tj_minus(i) = std::pow(lj(i) / lj(0), regularizer); // eq.31
|
||||
}
|
||||
assert(!std::isinf(ti_plus(i)));
|
||||
assert(!std::isinf(tj_minus(i)));
|
||||
}
|
||||
}
|
||||
} // namespace cpu_impl
|
||||
|
||||
/**
|
||||
* \brief Base class for pair-wise learning to rank.
|
||||
*
|
||||
* See `From RankNet to LambdaRank to LambdaMART: An Overview` for a description of the
|
||||
* algorithm.
|
||||
*
|
||||
* In addition to ranking, this also implements `Unbiased LambdaMART: An Unbiased
|
||||
* Pairwise Learning-to-Rank Algorithm`.
|
||||
*/
|
||||
template <typename Loss, typename Cache>
|
||||
class LambdaRankObj : public FitIntercept {
|
||||
MetaInfo const* p_info_{nullptr};
|
||||
|
||||
// Update position biased for unbiased click data
|
||||
void UpdatePositionBias() {
|
||||
li_full_.SetDevice(ctx_->gpu_id);
|
||||
lj_full_.SetDevice(ctx_->gpu_id);
|
||||
li_.SetDevice(ctx_->gpu_id);
|
||||
lj_.SetDevice(ctx_->gpu_id);
|
||||
|
||||
if (ctx_->IsCPU()) {
|
||||
cpu_impl::LambdaRankUpdatePositionBias(ctx_, li_full_.View(ctx_->gpu_id),
|
||||
lj_full_.View(ctx_->gpu_id), &ti_plus_, &tj_minus_,
|
||||
&li_, &lj_, p_cache_);
|
||||
} else {
|
||||
cuda_impl::LambdaRankUpdatePositionBias(ctx_, li_full_.View(ctx_->gpu_id),
|
||||
lj_full_.View(ctx_->gpu_id), &ti_plus_, &tj_minus_,
|
||||
&li_, &lj_, p_cache_);
|
||||
}
|
||||
|
||||
li_full_.Data()->Fill(0.0);
|
||||
lj_full_.Data()->Fill(0.0);
|
||||
|
||||
li_.Data()->Fill(0.0);
|
||||
lj_.Data()->Fill(0.0);
|
||||
}
|
||||
|
||||
protected:
|
||||
// L / tj-* (eq. 30)
|
||||
linalg::Vector<double> li_;
|
||||
// L / ti+* (eq. 31)
|
||||
linalg::Vector<double> lj_;
|
||||
// position bias ratio for relevant doc, ti+ (eq. 30)
|
||||
linalg::Vector<double> ti_plus_;
|
||||
// position bias ratio for irrelevant doc, tj- (eq. 31)
|
||||
linalg::Vector<double> tj_minus_;
|
||||
// li buffer for all samples
|
||||
linalg::Vector<double> li_full_;
|
||||
// lj buffer for all samples
|
||||
linalg::Vector<double> lj_full_;
|
||||
|
||||
ltr::LambdaRankParam param_;
|
||||
// cache
|
||||
std::shared_ptr<ltr::RankingCache> p_cache_;
|
||||
|
||||
[[nodiscard]] std::shared_ptr<Cache> GetCache() const {
|
||||
auto ptr = std::static_pointer_cast<Cache>(p_cache_);
|
||||
CHECK(ptr);
|
||||
return ptr;
|
||||
}
|
||||
|
||||
// get group view for li/lj
|
||||
linalg::VectorView<double> GroupLoss(bst_group_t g, linalg::Vector<double>* v) const {
|
||||
auto gptr = p_cache_->DataGroupPtr(ctx_);
|
||||
auto begin = gptr[g];
|
||||
auto end = gptr[g + 1];
|
||||
if (param_.lambdarank_unbiased) {
|
||||
return v->HostView().Slice(linalg::Range(begin, end));
|
||||
}
|
||||
return v->HostView();
|
||||
}
|
||||
|
||||
// Calculate lambda gradient for each group on CPU.
|
||||
template <bool unbiased, typename Delta>
|
||||
void CalcLambdaForGroup(std::int32_t iter, common::Span<float const> g_predt,
|
||||
linalg::VectorView<float const> g_label, float w,
|
||||
common::Span<std::size_t const> g_rank, bst_group_t g, Delta delta,
|
||||
common::Span<GradientPair> g_gpair) {
|
||||
std::fill_n(g_gpair.data(), g_gpair.size(), GradientPair{});
|
||||
auto p_gpair = g_gpair.data();
|
||||
|
||||
auto ti_plus = ti_plus_.HostView();
|
||||
auto tj_minus = tj_minus_.HostView();
|
||||
|
||||
auto li = GroupLoss(g, &li_full_);
|
||||
auto lj = GroupLoss(g, &lj_full_);
|
||||
|
||||
// Normalization, first used by LightGBM.
|
||||
// https://github.com/microsoft/LightGBM/pull/2331#issuecomment-523259298
|
||||
double sum_lambda{0.0};
|
||||
|
||||
auto delta_op = [&](auto const&... args) { return delta(args..., g); };
|
||||
|
||||
auto loop = [&](std::size_t i, std::size_t j) {
|
||||
// higher/lower on the target ranked list
|
||||
std::size_t rank_high = i, rank_low = j;
|
||||
if (g_label(g_rank[rank_high]) == g_label(g_rank[rank_low])) {
|
||||
return;
|
||||
}
|
||||
if (g_label(g_rank[rank_high]) < g_label(g_rank[rank_low])) {
|
||||
std::swap(rank_high, rank_low);
|
||||
}
|
||||
|
||||
double cost;
|
||||
auto pg = LambdaGrad<unbiased>(g_label, g_predt, g_rank, rank_high, rank_low, delta_op,
|
||||
ti_plus, tj_minus, &cost);
|
||||
auto ng = Repulse(pg);
|
||||
|
||||
std::size_t idx_high = g_rank[rank_high];
|
||||
std::size_t idx_low = g_rank[rank_low];
|
||||
p_gpair[idx_high] += pg;
|
||||
p_gpair[idx_low] += ng;
|
||||
|
||||
if (unbiased) {
|
||||
auto k = ti_plus.Size();
|
||||
// We can probably use all the positions. If we skip the update due to having
|
||||
// high/low > k, we might be losing out too many pairs. On the other hand, if we
|
||||
// cap the position, then we might be accumulating too many tail bias into the
|
||||
// last tracked position.
|
||||
// We use `idx_high` since it represents the original position from the label
|
||||
// list, and label list is assumed to be sorted.
|
||||
if (idx_high < k && idx_low < k) {
|
||||
if (tj_minus(idx_low) >= Eps64()) {
|
||||
li(idx_high) += cost / tj_minus(idx_low); // eq.30
|
||||
}
|
||||
if (ti_plus(idx_high) >= Eps64()) {
|
||||
lj(idx_low) += cost / ti_plus(idx_high); // eq.31
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sum_lambda += -2.0 * static_cast<double>(pg.GetGrad());
|
||||
};
|
||||
|
||||
MakePairs(ctx_, iter, p_cache_, g, g_label, g_rank, loop);
|
||||
if (sum_lambda > 0.0) {
|
||||
double norm = std::log2(1.0 + sum_lambda) / sum_lambda;
|
||||
std::transform(g_gpair.data(), g_gpair.data() + g_gpair.size(), g_gpair.data(),
|
||||
[norm](GradientPair const& g) { return g * norm; });
|
||||
}
|
||||
|
||||
auto w_norm = p_cache_->WeightNorm();
|
||||
std::transform(g_gpair.begin(), g_gpair.end(), g_gpair.begin(),
|
||||
[&](GradientPair const& gpair) { return gpair * w * w_norm; });
|
||||
}
|
||||
|
||||
public:
|
||||
void Configure(Args const& args) override { param_.UpdateAllowUnknown(args); }
|
||||
void SaveConfig(Json* p_out) const override {
|
||||
auto& out = *p_out;
|
||||
out["name"] = String(Loss::Name());
|
||||
out["lambdarank_param"] = ToJson(param_);
|
||||
|
||||
auto save_bias = [](linalg::Vector<double> const& in, Json out) {
|
||||
auto& out_array = get<F32Array>(out);
|
||||
out_array.resize(in.Size());
|
||||
auto h_in = in.HostView();
|
||||
std::copy(linalg::cbegin(h_in), linalg::cend(h_in), out_array.begin());
|
||||
};
|
||||
|
||||
if (param_.lambdarank_unbiased) {
|
||||
out["ti+"] = F32Array();
|
||||
save_bias(ti_plus_, out["ti+"]);
|
||||
out["tj-"] = F32Array();
|
||||
save_bias(tj_minus_, out["tj-"]);
|
||||
}
|
||||
}
|
||||
void LoadConfig(Json const& in) override {
|
||||
auto const& obj = get<Object const>(in);
|
||||
if (obj.find("lambdarank_param") != obj.cend()) {
|
||||
FromJson(in["lambdarank_param"], ¶m_);
|
||||
}
|
||||
|
||||
if (param_.lambdarank_unbiased) {
|
||||
auto load_bias = [](Json in, linalg::Vector<double>* out) {
|
||||
if (IsA<F32Array>(in)) {
|
||||
// JSON
|
||||
auto const& array = get<F32Array>(in);
|
||||
out->Reshape(array.size());
|
||||
auto h_out = out->HostView();
|
||||
std::copy(array.cbegin(), array.cend(), linalg::begin(h_out));
|
||||
} else {
|
||||
// UBJSON
|
||||
auto const& array = get<Array>(in);
|
||||
out->Reshape(array.size());
|
||||
auto h_out = out->HostView();
|
||||
std::transform(array.cbegin(), array.cend(), linalg::begin(h_out),
|
||||
[](Json const& v) { return get<Number const>(v); });
|
||||
}
|
||||
};
|
||||
load_bias(in["ti+"], &ti_plus_);
|
||||
load_bias(in["tj-"], &tj_minus_);
|
||||
}
|
||||
}
|
||||
|
||||
[[nodiscard]] ObjInfo Task() const override { return ObjInfo{ObjInfo::kRanking}; }
|
||||
|
||||
[[nodiscard]] bst_target_t Targets(MetaInfo const& info) const override {
|
||||
CHECK_LE(info.labels.Shape(1), 1) << "multi-output for LTR is not yet supported.";
|
||||
return 1;
|
||||
}
|
||||
|
||||
[[nodiscard]] const char* RankEvalMetric(StringView metric) const {
|
||||
static thread_local std::string name;
|
||||
if (param_.HasTruncation()) {
|
||||
name = ltr::MakeMetricName(metric, param_.NumPair(), false);
|
||||
} else {
|
||||
name = ltr::MakeMetricName(metric, param_.NotSet(), false);
|
||||
}
|
||||
return name.c_str();
|
||||
}
|
||||
|
||||
void GetGradient(HostDeviceVector<float> const& predt, MetaInfo const& info, std::int32_t iter,
|
||||
HostDeviceVector<GradientPair>* out_gpair) override {
|
||||
CHECK_EQ(info.labels.Size(), predt.Size()) << error::LabelScoreSize();
|
||||
|
||||
// init/renew cache
|
||||
if (!p_cache_ || p_info_ != &info || p_cache_->Param() != param_) {
|
||||
p_cache_ = std::make_shared<Cache>(ctx_, info, param_);
|
||||
p_info_ = &info;
|
||||
}
|
||||
auto n_groups = p_cache_->Groups();
|
||||
if (!info.weights_.Empty()) {
|
||||
CHECK_EQ(info.weights_.Size(), n_groups) << error::GroupWeight();
|
||||
}
|
||||
|
||||
if (ti_plus_.Size() == 0 && param_.lambdarank_unbiased) {
|
||||
CHECK_EQ(iter, 0);
|
||||
ti_plus_ = linalg::Constant<double>(ctx_, 1.0, p_cache_->MaxPositionSize());
|
||||
tj_minus_ = linalg::Constant<double>(ctx_, 1.0, p_cache_->MaxPositionSize());
|
||||
|
||||
li_ = linalg::Zeros<double>(ctx_, p_cache_->MaxPositionSize());
|
||||
lj_ = linalg::Zeros<double>(ctx_, p_cache_->MaxPositionSize());
|
||||
|
||||
li_full_ = linalg::Zeros<double>(ctx_, info.num_row_);
|
||||
lj_full_ = linalg::Zeros<double>(ctx_, info.num_row_);
|
||||
}
|
||||
static_cast<Loss*>(this)->GetGradientImpl(iter, predt, info, out_gpair);
|
||||
|
||||
if (param_.lambdarank_unbiased) {
|
||||
this->UpdatePositionBias();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class LambdaRankNDCG : public LambdaRankObj<LambdaRankNDCG, ltr::NDCGCache> {
|
||||
public:
|
||||
template <bool unbiased, bool exp_gain>
|
||||
void CalcLambdaForGroupNDCG(std::int32_t iter, common::Span<float const> g_predt,
|
||||
linalg::VectorView<float const> g_label, float w,
|
||||
common::Span<std::size_t const> g_rank,
|
||||
common::Span<GradientPair> g_gpair,
|
||||
linalg::VectorView<double const> inv_IDCG,
|
||||
common::Span<double const> discount, bst_group_t g) {
|
||||
auto delta = [&](auto y_high, auto y_low, std::size_t rank_high, std::size_t rank_low,
|
||||
bst_group_t g) {
|
||||
static_assert(std::is_floating_point<decltype(y_high)>::value);
|
||||
return DeltaNDCG<exp_gain>(y_high, y_low, rank_high, rank_low, inv_IDCG(g), discount);
|
||||
};
|
||||
this->CalcLambdaForGroup<unbiased>(iter, g_predt, g_label, w, g_rank, g, delta, g_gpair);
|
||||
}
|
||||
|
||||
void GetGradientImpl(std::int32_t iter, const HostDeviceVector<float>& predt,
|
||||
const MetaInfo& info, HostDeviceVector<GradientPair>* out_gpair) {
|
||||
if (ctx_->IsCUDA()) {
|
||||
cuda_impl::LambdaRankGetGradientNDCG(
|
||||
ctx_, iter, predt, info, GetCache(), ti_plus_.View(ctx_->gpu_id),
|
||||
tj_minus_.View(ctx_->gpu_id), li_full_.View(ctx_->gpu_id), lj_full_.View(ctx_->gpu_id),
|
||||
out_gpair);
|
||||
return;
|
||||
}
|
||||
|
||||
bst_group_t n_groups = p_cache_->Groups();
|
||||
auto gptr = p_cache_->DataGroupPtr(ctx_);
|
||||
|
||||
out_gpair->Resize(info.num_row_);
|
||||
auto h_gpair = out_gpair->HostSpan();
|
||||
auto h_predt = predt.ConstHostSpan();
|
||||
auto h_label = info.labels.HostView();
|
||||
auto h_weight = common::MakeOptionalWeights(ctx_, info.weights_);
|
||||
auto make_range = [&](bst_group_t g) { return linalg::Range(gptr[g], gptr[g + 1]); };
|
||||
|
||||
auto dct = GetCache()->Discount(ctx_);
|
||||
auto rank_idx = p_cache_->SortedIdx(ctx_, h_predt);
|
||||
auto inv_IDCG = GetCache()->InvIDCG(ctx_);
|
||||
|
||||
common::ParallelFor(n_groups, ctx_->Threads(), common::Sched::Guided(), [&](auto g) {
|
||||
std::size_t cnt = gptr[g + 1] - gptr[g];
|
||||
auto w = h_weight[g];
|
||||
auto g_predt = h_predt.subspan(gptr[g], cnt);
|
||||
auto g_gpair = h_gpair.subspan(gptr[g], cnt);
|
||||
auto g_label = h_label.Slice(make_range(g), 0);
|
||||
auto g_rank = rank_idx.subspan(gptr[g], cnt);
|
||||
|
||||
auto args =
|
||||
std::make_tuple(this, iter, g_predt, g_label, w, g_rank, g_gpair, inv_IDCG, dct, g);
|
||||
|
||||
if (param_.lambdarank_unbiased) {
|
||||
if (param_.ndcg_exp_gain) {
|
||||
std::apply(&LambdaRankNDCG::CalcLambdaForGroupNDCG<true, true>, args);
|
||||
} else {
|
||||
std::apply(&LambdaRankNDCG::CalcLambdaForGroupNDCG<true, false>, args);
|
||||
}
|
||||
} else {
|
||||
if (param_.ndcg_exp_gain) {
|
||||
std::apply(&LambdaRankNDCG::CalcLambdaForGroupNDCG<false, true>, args);
|
||||
} else {
|
||||
std::apply(&LambdaRankNDCG::CalcLambdaForGroupNDCG<false, false>, args);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
static char const* Name() { return "rank:ndcg"; }
|
||||
[[nodiscard]] const char* DefaultEvalMetric() const override {
|
||||
return this->RankEvalMetric("ndcg");
|
||||
}
|
||||
[[nodiscard]] Json DefaultMetricConfig() const override {
|
||||
Json config{Object{}};
|
||||
config["name"] = String{DefaultEvalMetric()};
|
||||
config["lambdarank_param"] = ToJson(param_);
|
||||
return config;
|
||||
}
|
||||
};
|
||||
|
||||
namespace cuda_impl {
|
||||
#if !defined(XGBOOST_USE_CUDA)
|
||||
void LambdaRankGetGradientNDCG(Context const*, std::int32_t, HostDeviceVector<float> const&,
|
||||
const MetaInfo&, std::shared_ptr<ltr::NDCGCache>,
|
||||
linalg::VectorView<double const>, // input bias ratio
|
||||
linalg::VectorView<double const>, // input bias ratio
|
||||
linalg::VectorView<double>, linalg::VectorView<double>,
|
||||
HostDeviceVector<GradientPair>*) {
|
||||
common::AssertGPUSupport();
|
||||
}
|
||||
|
||||
void LambdaRankUpdatePositionBias(Context const*, linalg::VectorView<double const>,
|
||||
linalg::VectorView<double const>, linalg::Vector<double>*,
|
||||
linalg::Vector<double>*, linalg::Vector<double>*,
|
||||
linalg::Vector<double>*, std::shared_ptr<ltr::RankingCache>) {
|
||||
common::AssertGPUSupport();
|
||||
}
|
||||
#endif // !defined(XGBOOST_USE_CUDA)
|
||||
} // namespace cuda_impl
|
||||
|
||||
XGBOOST_REGISTER_OBJECTIVE(LambdaRankNDCG, LambdaRankNDCG::Name())
|
||||
.describe("LambdaRank with NDCG loss as objective")
|
||||
.set_body([]() { return new LambdaRankNDCG{}; });
|
||||
|
||||
DMLC_REGISTRY_FILE_TAG(lambdarank_obj);
|
||||
} // namespace xgboost::obj
|
||||
@ -37,6 +37,312 @@ namespace xgboost::obj {
|
||||
DMLC_REGISTRY_FILE_TAG(lambdarank_obj_cu);
|
||||
|
||||
namespace cuda_impl {
|
||||
namespace {
|
||||
/**
|
||||
* \brief Calculate minimum value of bias for floating point truncation.
|
||||
*/
|
||||
void MinBias(Context const* ctx, std::shared_ptr<ltr::RankingCache> p_cache,
|
||||
linalg::VectorView<double const> t_plus, linalg::VectorView<double const> tj_minus,
|
||||
common::Span<double> d_min) {
|
||||
CHECK_EQ(d_min.size(), 2);
|
||||
auto cuctx = ctx->CUDACtx();
|
||||
|
||||
auto k = t_plus.Size();
|
||||
auto const& p = p_cache->Param();
|
||||
CHECK_GT(k, 0);
|
||||
CHECK_EQ(k, p_cache->MaxPositionSize());
|
||||
|
||||
auto key_it = dh::MakeTransformIterator<std::size_t>(
|
||||
thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(std::size_t i) { return i * k; });
|
||||
auto val_it = dh::MakeTransformIterator<double>(thrust::make_counting_iterator(0ul),
|
||||
[=] XGBOOST_DEVICE(std::size_t i) {
|
||||
if (i >= k) {
|
||||
return std::abs(tj_minus(i - k));
|
||||
}
|
||||
return std::abs(t_plus(i));
|
||||
});
|
||||
std::size_t bytes;
|
||||
cub::DeviceSegmentedReduce::Min(nullptr, bytes, val_it, d_min.data(), 2, key_it, key_it + 1,
|
||||
cuctx->Stream());
|
||||
dh::TemporaryArray<char> temp(bytes);
|
||||
cub::DeviceSegmentedReduce::Min(temp.data().get(), bytes, val_it, d_min.data(), 2, key_it,
|
||||
key_it + 1, cuctx->Stream());
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Type for gradient statistic. (Gradient, cost for unbiased LTR, normalization factor)
|
||||
*/
|
||||
using GradCostNorm = thrust::tuple<GradientPair, double, double>;
|
||||
|
||||
/**
|
||||
* \brief Obtain and update the gradient for one pair.
|
||||
*/
|
||||
template <bool unbiased, bool has_truncation, typename Delta>
|
||||
struct GetGradOp {
|
||||
MakePairsOp<has_truncation> make_pair;
|
||||
Delta delta;
|
||||
|
||||
bool need_update;
|
||||
|
||||
auto __device__ operator()(std::size_t idx) -> GradCostNorm {
|
||||
auto const& args = make_pair.args;
|
||||
auto g = dh::SegmentId(args.d_threads_group_ptr, idx);
|
||||
|
||||
auto data_group_begin = static_cast<std::size_t>(args.d_group_ptr[g]);
|
||||
std::size_t n_data = args.d_group_ptr[g + 1] - data_group_begin;
|
||||
// obtain group segment data.
|
||||
auto g_label = args.labels.Slice(linalg::Range(data_group_begin, data_group_begin + n_data), 0);
|
||||
auto g_predt = args.predts.subspan(data_group_begin, n_data);
|
||||
auto g_gpair = args.gpairs.subspan(data_group_begin, n_data).data();
|
||||
auto g_rank = args.d_sorted_idx.subspan(data_group_begin, n_data);
|
||||
|
||||
auto [i, j] = make_pair(idx, g);
|
||||
|
||||
std::size_t rank_high = i, rank_low = j;
|
||||
if (g_label(g_rank[i]) == g_label(g_rank[j])) {
|
||||
return thrust::make_tuple(GradientPair{}, 0.0, 0.0);
|
||||
}
|
||||
if (g_label(g_rank[i]) < g_label(g_rank[j])) {
|
||||
thrust::swap(rank_high, rank_low);
|
||||
}
|
||||
|
||||
double cost{0};
|
||||
|
||||
auto delta_op = [&](auto const&... args) { return delta(args..., g); };
|
||||
GradientPair pg = LambdaGrad<unbiased>(g_label, g_predt, g_rank, rank_high, rank_low, delta_op,
|
||||
args.ti_plus, args.tj_minus, &cost);
|
||||
|
||||
std::size_t idx_high = g_rank[rank_high];
|
||||
std::size_t idx_low = g_rank[rank_low];
|
||||
|
||||
if (need_update) {
|
||||
// second run, update the gradient
|
||||
|
||||
auto ng = Repulse(pg);
|
||||
|
||||
auto gr = args.d_roundings(g);
|
||||
// positive gradient truncated
|
||||
auto pgt = GradientPair{common::TruncateWithRounding(gr.GetGrad(), pg.GetGrad()),
|
||||
common::TruncateWithRounding(gr.GetHess(), pg.GetHess())};
|
||||
// negative gradient truncated
|
||||
auto ngt = GradientPair{common::TruncateWithRounding(gr.GetGrad(), ng.GetGrad()),
|
||||
common::TruncateWithRounding(gr.GetHess(), ng.GetHess())};
|
||||
|
||||
dh::AtomicAddGpair(g_gpair + idx_high, pgt);
|
||||
dh::AtomicAddGpair(g_gpair + idx_low, ngt);
|
||||
}
|
||||
|
||||
if (unbiased && need_update) {
|
||||
// second run, update the cost
|
||||
assert(args.tj_minus.Size() == args.ti_plus.Size() && "Invalid size of position bias");
|
||||
|
||||
auto g_li = args.li.Slice(linalg::Range(data_group_begin, data_group_begin + n_data));
|
||||
auto g_lj = args.lj.Slice(linalg::Range(data_group_begin, data_group_begin + n_data));
|
||||
|
||||
if (idx_high < args.ti_plus.Size() && idx_low < args.ti_plus.Size()) {
|
||||
if (args.tj_minus(idx_low) >= Eps64()) {
|
||||
// eq.30
|
||||
atomicAdd(&g_li(idx_high), common::TruncateWithRounding(args.d_cost_rounding[0],
|
||||
cost / args.tj_minus(idx_low)));
|
||||
}
|
||||
if (args.ti_plus(idx_high) >= Eps64()) {
|
||||
// eq.31
|
||||
atomicAdd(&g_lj(idx_low), common::TruncateWithRounding(args.d_cost_rounding[0],
|
||||
cost / args.ti_plus(idx_high)));
|
||||
}
|
||||
}
|
||||
}
|
||||
return thrust::make_tuple(GradientPair{std::abs(pg.GetGrad()), std::abs(pg.GetHess())},
|
||||
std::abs(cost), -2.0 * static_cast<double>(pg.GetGrad()));
|
||||
}
|
||||
};
|
||||
|
||||
template <bool unbiased, bool has_truncation, typename Delta>
|
||||
struct MakeGetGrad {
|
||||
MakePairsOp<has_truncation> make_pair;
|
||||
Delta delta;
|
||||
|
||||
[[nodiscard]] KernelInputs const& Args() const { return make_pair.args; }
|
||||
|
||||
MakeGetGrad(KernelInputs args, Delta d) : make_pair{args}, delta{std::move(d)} {}
|
||||
|
||||
GetGradOp<unbiased, has_truncation, Delta> operator()(bool need_update) {
|
||||
return GetGradOp<unbiased, has_truncation, Delta>{make_pair, delta, need_update};
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* \brief Calculate gradient for all pairs using update op created by make_get_grad.
|
||||
*
|
||||
* We need to run gradient calculation twice, the first time gathers infomation like
|
||||
* maximum gradient, maximum cost, and the normalization term using reduction. The second
|
||||
* time performs the actual update.
|
||||
*
|
||||
* Without normalization, we only need to run it once since we can manually calculate
|
||||
* the bounds of gradient (NDCG \in [0, 1], delta_NDCG \in [0, 1], ti+/tj- are from the
|
||||
* previous iteration so the bound can be calculated for current iteration). However, if
|
||||
* normalization is used, the delta score is un-bounded and we need to obtain the sum
|
||||
* gradient. As a tradeoff, we simply run the kernel twice, once as reduction, second
|
||||
* one as for_each.
|
||||
*
|
||||
* Alternatively, we can bound the delta score by limiting the output of the model using
|
||||
* sigmoid for binary output and some normalization for multi-level. But effect to the
|
||||
* accuracy is not known yet, and it's only used by GPU.
|
||||
*
|
||||
* For performance, the segmented sort for sorted scores is the bottleneck and takes up
|
||||
* about half of the time, while the reduction and for_each takes up the second half.
|
||||
*/
|
||||
template <bool unbiased, bool has_truncation, typename Delta>
|
||||
void CalcGrad(Context const* ctx, MetaInfo const& info, std::shared_ptr<ltr::RankingCache> p_cache,
|
||||
MakeGetGrad<unbiased, has_truncation, Delta> make_get_grad) {
|
||||
auto n_groups = p_cache->Groups();
|
||||
auto d_threads_group_ptr = p_cache->CUDAThreadsGroupPtr();
|
||||
auto d_gptr = p_cache->DataGroupPtr(ctx);
|
||||
auto d_gpair = make_get_grad.Args().gpairs;
|
||||
|
||||
/**
|
||||
* First pass, gather info for normalization and rounding factor.
|
||||
*/
|
||||
auto val_it = dh::MakeTransformIterator<GradCostNorm>(thrust::make_counting_iterator(0ul),
|
||||
make_get_grad(false));
|
||||
auto reduction_op = [] XGBOOST_DEVICE(GradCostNorm const& l,
|
||||
GradCostNorm const& r) -> GradCostNorm {
|
||||
// get maximum gradient for each group, along with cost and the normalization term
|
||||
auto const& lg = thrust::get<0>(l);
|
||||
auto const& rg = thrust::get<0>(r);
|
||||
auto grad = std::max(lg.GetGrad(), rg.GetGrad());
|
||||
auto hess = std::max(lg.GetHess(), rg.GetHess());
|
||||
auto cost = std::max(thrust::get<1>(l), thrust::get<1>(r));
|
||||
double sum_lambda = thrust::get<2>(l) + thrust::get<2>(r);
|
||||
return thrust::make_tuple(GradientPair{std::abs(grad), std::abs(hess)}, cost, sum_lambda);
|
||||
};
|
||||
auto init = thrust::make_tuple(GradientPair{0.0f, 0.0f}, 0.0, 0.0);
|
||||
common::Span<GradCostNorm> d_max_lambdas = p_cache->MaxLambdas<GradCostNorm>(ctx, n_groups);
|
||||
CHECK_EQ(n_groups * sizeof(GradCostNorm), d_max_lambdas.size_bytes());
|
||||
|
||||
std::size_t bytes;
|
||||
cub::DeviceSegmentedReduce::Reduce(nullptr, bytes, val_it, d_max_lambdas.data(), n_groups,
|
||||
d_threads_group_ptr.data(), d_threads_group_ptr.data() + 1,
|
||||
reduction_op, init, ctx->CUDACtx()->Stream());
|
||||
dh::TemporaryArray<char> temp(bytes);
|
||||
cub::DeviceSegmentedReduce::Reduce(
|
||||
temp.data().get(), bytes, val_it, d_max_lambdas.data(), n_groups, d_threads_group_ptr.data(),
|
||||
d_threads_group_ptr.data() + 1, reduction_op, init, ctx->CUDACtx()->Stream());
|
||||
|
||||
dh::TemporaryArray<double> min_bias(2);
|
||||
auto d_min_bias = dh::ToSpan(min_bias);
|
||||
if (unbiased) {
|
||||
MinBias(ctx, p_cache, make_get_grad.Args().ti_plus, make_get_grad.Args().tj_minus, d_min_bias);
|
||||
}
|
||||
/**
|
||||
* Create rounding factors
|
||||
*/
|
||||
auto d_cost_rounding = p_cache->CUDACostRounding(ctx);
|
||||
auto d_rounding = p_cache->CUDARounding(ctx);
|
||||
dh::LaunchN(n_groups, ctx->CUDACtx()->Stream(), [=] XGBOOST_DEVICE(std::size_t g) mutable {
|
||||
auto group_size = d_gptr[g + 1] - d_gptr[g];
|
||||
auto const& max_grad = thrust::get<0>(d_max_lambdas[g]);
|
||||
// float group size
|
||||
auto fgs = static_cast<float>(group_size);
|
||||
auto grad = common::CreateRoundingFactor(fgs * max_grad.GetGrad(), group_size);
|
||||
auto hess = common::CreateRoundingFactor(fgs * max_grad.GetHess(), group_size);
|
||||
d_rounding(g) = GradientPair{grad, hess};
|
||||
|
||||
auto cost = thrust::get<1>(d_max_lambdas[g]);
|
||||
if (unbiased) {
|
||||
cost /= std::min(d_min_bias[0], d_min_bias[1]);
|
||||
d_cost_rounding[0] = common::CreateRoundingFactor(fgs * cost, group_size);
|
||||
}
|
||||
});
|
||||
|
||||
/**
|
||||
* Second pass, actual update to gradient and bias.
|
||||
*/
|
||||
thrust::for_each_n(ctx->CUDACtx()->CTP(), thrust::make_counting_iterator(0ul),
|
||||
p_cache->CUDAThreads(), make_get_grad(true));
|
||||
|
||||
/**
|
||||
* Lastly, normalization and weight.
|
||||
*/
|
||||
auto d_weights = common::MakeOptionalWeights(ctx, info.weights_);
|
||||
auto w_norm = p_cache->WeightNorm();
|
||||
thrust::for_each_n(ctx->CUDACtx()->CTP(), thrust::make_counting_iterator(0ul), d_gpair.size(),
|
||||
[=] XGBOOST_DEVICE(std::size_t i) {
|
||||
auto g = dh::SegmentId(d_gptr, i);
|
||||
auto sum_lambda = thrust::get<2>(d_max_lambdas[g]);
|
||||
// Normalization
|
||||
if (sum_lambda > 0.0) {
|
||||
double norm = std::log2(1.0 + sum_lambda) / sum_lambda;
|
||||
d_gpair[i] *= norm;
|
||||
}
|
||||
d_gpair[i] *= (d_weights[g] * w_norm);
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Handles boilerplate code like getting device span.
|
||||
*/
|
||||
template <typename Delta>
|
||||
void Launch(Context const* ctx, std::int32_t iter, HostDeviceVector<float> const& preds,
|
||||
const MetaInfo& info, std::shared_ptr<ltr::RankingCache> p_cache, Delta delta,
|
||||
linalg::VectorView<double const> ti_plus, // input bias ratio
|
||||
linalg::VectorView<double const> tj_minus, // input bias ratio
|
||||
linalg::VectorView<double> li, linalg::VectorView<double> lj,
|
||||
HostDeviceVector<GradientPair>* out_gpair) {
|
||||
// boilerplate
|
||||
std::int32_t device_id = ctx->gpu_id;
|
||||
dh::safe_cuda(cudaSetDevice(device_id));
|
||||
auto n_groups = p_cache->Groups();
|
||||
|
||||
info.labels.SetDevice(device_id);
|
||||
preds.SetDevice(device_id);
|
||||
out_gpair->SetDevice(device_id);
|
||||
out_gpair->Resize(preds.Size());
|
||||
|
||||
CHECK(p_cache);
|
||||
|
||||
auto d_rounding = p_cache->CUDARounding(ctx);
|
||||
auto d_cost_rounding = p_cache->CUDACostRounding(ctx);
|
||||
|
||||
CHECK_NE(d_rounding.Size(), 0);
|
||||
|
||||
auto label = info.labels.View(ctx->gpu_id);
|
||||
auto predts = preds.ConstDeviceSpan();
|
||||
auto gpairs = out_gpair->DeviceSpan();
|
||||
thrust::fill_n(ctx->CUDACtx()->CTP(), gpairs.data(), gpairs.size(), GradientPair{0.0f, 0.0f});
|
||||
|
||||
auto const d_threads_group_ptr = p_cache->CUDAThreadsGroupPtr();
|
||||
auto const d_gptr = p_cache->DataGroupPtr(ctx);
|
||||
auto const rank_idx = p_cache->SortedIdx(ctx, predts);
|
||||
|
||||
auto const unbiased = p_cache->Param().lambdarank_unbiased;
|
||||
|
||||
common::Span<std::size_t const> d_y_sorted_idx;
|
||||
if (!p_cache->Param().HasTruncation()) {
|
||||
d_y_sorted_idx = SortY(ctx, info, rank_idx, p_cache);
|
||||
}
|
||||
|
||||
KernelInputs args{ti_plus, tj_minus, li, lj, d_gptr, d_threads_group_ptr,
|
||||
rank_idx, label, predts, gpairs, d_rounding, d_cost_rounding.data(),
|
||||
d_y_sorted_idx, iter};
|
||||
|
||||
// dispatch based on unbiased and truncation
|
||||
if (p_cache->Param().HasTruncation()) {
|
||||
if (unbiased) {
|
||||
CalcGrad(ctx, info, p_cache, MakeGetGrad<true, true, Delta>{args, delta});
|
||||
} else {
|
||||
CalcGrad(ctx, info, p_cache, MakeGetGrad<false, true, Delta>{args, delta});
|
||||
}
|
||||
} else {
|
||||
if (unbiased) {
|
||||
CalcGrad(ctx, info, p_cache, MakeGetGrad<true, false, Delta>{args, delta});
|
||||
} else {
|
||||
CalcGrad(ctx, info, p_cache, MakeGetGrad<false, false, Delta>{args, delta});
|
||||
}
|
||||
}
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
common::Span<std::size_t const> SortY(Context const* ctx, MetaInfo const& info,
|
||||
common::Span<std::size_t const> d_rank,
|
||||
std::shared_ptr<ltr::RankingCache> p_cache) {
|
||||
@ -58,5 +364,116 @@ common::Span<std::size_t const> SortY(Context const* ctx, MetaInfo const& info,
|
||||
common::SegmentedArgSort<false, true>(ctx, d_y_ranked, d_group_ptr, d_y_sorted_idx);
|
||||
return d_y_sorted_idx;
|
||||
}
|
||||
|
||||
void LambdaRankGetGradientNDCG(Context const* ctx, std::int32_t iter,
|
||||
const HostDeviceVector<float>& preds, const MetaInfo& info,
|
||||
std::shared_ptr<ltr::NDCGCache> p_cache,
|
||||
linalg::VectorView<double const> ti_plus, // input bias ratio
|
||||
linalg::VectorView<double const> tj_minus, // input bias ratio
|
||||
linalg::VectorView<double> li, linalg::VectorView<double> lj,
|
||||
HostDeviceVector<GradientPair>* out_gpair) {
|
||||
// boilerplate
|
||||
std::int32_t device_id = ctx->gpu_id;
|
||||
dh::safe_cuda(cudaSetDevice(device_id));
|
||||
auto const d_inv_IDCG = p_cache->InvIDCG(ctx);
|
||||
auto const discount = p_cache->Discount(ctx);
|
||||
|
||||
info.labels.SetDevice(device_id);
|
||||
preds.SetDevice(device_id);
|
||||
|
||||
auto const exp_gain = p_cache->Param().ndcg_exp_gain;
|
||||
auto delta_ndcg = [=] XGBOOST_DEVICE(float y_high, float y_low, std::size_t rank_high,
|
||||
std::size_t rank_low, bst_group_t g) {
|
||||
return exp_gain ? DeltaNDCG<true>(y_high, y_low, rank_high, rank_low, d_inv_IDCG(g), discount)
|
||||
: DeltaNDCG<false>(y_high, y_low, rank_high, rank_low, d_inv_IDCG(g), discount);
|
||||
};
|
||||
Launch(ctx, iter, preds, info, p_cache, delta_ndcg, ti_plus, tj_minus, li, lj, out_gpair);
|
||||
}
|
||||
|
||||
namespace {
|
||||
struct ReduceOp {
|
||||
template <typename Tup>
|
||||
Tup XGBOOST_DEVICE operator()(Tup const& l, Tup const& r) {
|
||||
return thrust::make_tuple(thrust::get<0>(l) + thrust::get<0>(r),
|
||||
thrust::get<1>(l) + thrust::get<1>(r));
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void LambdaRankUpdatePositionBias(Context const* ctx, linalg::VectorView<double const> li_full,
|
||||
linalg::VectorView<double const> lj_full,
|
||||
linalg::Vector<double>* p_ti_plus,
|
||||
linalg::Vector<double>* p_tj_minus,
|
||||
linalg::Vector<double>* p_li, // loss
|
||||
linalg::Vector<double>* p_lj,
|
||||
std::shared_ptr<ltr::RankingCache> p_cache) {
|
||||
auto const d_group_ptr = p_cache->DataGroupPtr(ctx);
|
||||
auto n_groups = d_group_ptr.size() - 1;
|
||||
|
||||
auto ti_plus = p_ti_plus->View(ctx->gpu_id);
|
||||
auto tj_minus = p_tj_minus->View(ctx->gpu_id);
|
||||
|
||||
auto li = p_li->View(ctx->gpu_id);
|
||||
auto lj = p_lj->View(ctx->gpu_id);
|
||||
CHECK_EQ(li.Size(), ti_plus.Size());
|
||||
|
||||
auto const& param = p_cache->Param();
|
||||
auto regularizer = param.Regularizer();
|
||||
std::size_t k = p_cache->MaxPositionSize();
|
||||
|
||||
CHECK_EQ(li.Size(), k);
|
||||
CHECK_EQ(lj.Size(), k);
|
||||
// reduce li_full to li for each group.
|
||||
auto make_iter = [&](linalg::VectorView<double const> l_full) {
|
||||
auto l_it = [=] XGBOOST_DEVICE(std::size_t i) {
|
||||
// group index
|
||||
auto g = i % n_groups;
|
||||
// rank is the position within a group, also the segment index
|
||||
auto r = i / n_groups;
|
||||
|
||||
auto begin = d_group_ptr[g];
|
||||
std::size_t group_size = d_group_ptr[g + 1] - begin;
|
||||
auto n = std::min(group_size, k);
|
||||
// r can be greater than n since we allocate threads based on truncation level
|
||||
// instead of actual group size.
|
||||
if (r >= n) {
|
||||
return 0.0;
|
||||
}
|
||||
return l_full(r + begin);
|
||||
};
|
||||
return l_it;
|
||||
};
|
||||
auto li_it =
|
||||
dh::MakeTransformIterator<double>(thrust::make_counting_iterator(0ul), make_iter(li_full));
|
||||
auto lj_it =
|
||||
dh::MakeTransformIterator<double>(thrust::make_counting_iterator(0ul), make_iter(lj_full));
|
||||
// k segments, each segment has size n_groups.
|
||||
auto key_it = dh::MakeTransformIterator<std::size_t>(
|
||||
thrust::make_counting_iterator(0ul),
|
||||
[=] XGBOOST_DEVICE(std::size_t i) { return i * n_groups; });
|
||||
auto val_it = thrust::make_zip_iterator(thrust::make_tuple(li_it, lj_it));
|
||||
auto out_it =
|
||||
thrust::make_zip_iterator(thrust::make_tuple(li.Values().data(), lj.Values().data()));
|
||||
|
||||
auto init = thrust::make_tuple(0.0, 0.0);
|
||||
std::size_t bytes;
|
||||
cub::DeviceSegmentedReduce::Reduce(nullptr, bytes, val_it, out_it, k, key_it, key_it + 1,
|
||||
ReduceOp{}, init, ctx->CUDACtx()->Stream());
|
||||
dh::TemporaryArray<char> temp(bytes);
|
||||
cub::DeviceSegmentedReduce::Reduce(temp.data().get(), bytes, val_it, out_it, k, key_it,
|
||||
key_it + 1, ReduceOp{}, init, ctx->CUDACtx()->Stream());
|
||||
|
||||
thrust::for_each_n(ctx->CUDACtx()->CTP(), thrust::make_counting_iterator(0ul), li.Size(),
|
||||
[=] XGBOOST_DEVICE(std::size_t i) mutable {
|
||||
if (li(0) >= Eps64()) {
|
||||
ti_plus(i) = std::pow(li(i) / li(0), regularizer);
|
||||
}
|
||||
if (lj(0) >= Eps64()) {
|
||||
tj_minus(i) = std::pow(lj(i) / lj(0), regularizer);
|
||||
}
|
||||
assert(!std::isinf(ti_plus(i)));
|
||||
assert(!std::isinf(tj_minus(i)));
|
||||
});
|
||||
}
|
||||
} // namespace cuda_impl
|
||||
} // namespace xgboost::obj
|
||||
|
||||
@ -1,5 +1,15 @@
|
||||
/**
|
||||
* Copyright 2023 XGBoost contributors
|
||||
* Copyright 2023, XGBoost contributors
|
||||
*
|
||||
* Vocabulary explanation:
|
||||
*
|
||||
* There are two different lists we need to handle in the objective, first is the list of
|
||||
* labels (relevance degree) provided by the user. Its order has no particular meaning
|
||||
* when bias estimation is NOT used. Another one is generated by our model, sorted index
|
||||
* based on prediction scores. `rank_high` refers to the position index of the model rank
|
||||
* list that is higher than `rank_low`, while `idx_high` refers to where does the
|
||||
* `rank_high` sample comes from. Simply put, `rank_high` indexes into the rank list
|
||||
* obtained from the model, while `idx_high` indexes into the user provided sample list.
|
||||
*/
|
||||
#ifndef XGBOOST_OBJECTIVE_LAMBDARANK_OBJ_H_
|
||||
#define XGBOOST_OBJECTIVE_LAMBDARANK_OBJ_H_
|
||||
@ -25,14 +35,19 @@
|
||||
#include "xgboost/span.h" // for Span
|
||||
|
||||
namespace xgboost::obj {
|
||||
double constexpr Eps64() { return 1e-16; }
|
||||
|
||||
template <bool exp>
|
||||
XGBOOST_DEVICE double DeltaNDCG(float y_high, float y_low, std::size_t r_high, std::size_t r_low,
|
||||
double inv_IDCG, common::Span<double const> discount) {
|
||||
XGBOOST_DEVICE double DeltaNDCG(float y_high, float y_low, std::size_t rank_high,
|
||||
std::size_t rank_low, double inv_IDCG,
|
||||
common::Span<double const> discount) {
|
||||
// Use rank_high instead of idx_high as we are calculating discount based on ranks
|
||||
// provided by the model.
|
||||
double gain_high = exp ? ltr::CalcDCGGain(y_high) : y_high;
|
||||
double discount_high = discount[r_high];
|
||||
double discount_high = discount[rank_high];
|
||||
|
||||
double gain_low = exp ? ltr::CalcDCGGain(y_low) : y_low;
|
||||
double discount_low = discount[r_low];
|
||||
double discount_low = discount[rank_low];
|
||||
|
||||
double original = gain_high * discount_high + gain_low * discount_low;
|
||||
double changed = gain_low * discount_high + gain_high * discount_low;
|
||||
@ -70,9 +85,9 @@ template <bool unbiased, typename Delta>
|
||||
XGBOOST_DEVICE GradientPair
|
||||
LambdaGrad(linalg::VectorView<float const> labels, common::Span<float const> predts,
|
||||
common::Span<size_t const> sorted_idx,
|
||||
std::size_t rank_high, // cordiniate
|
||||
std::size_t rank_low, // cordiniate
|
||||
Delta delta, // delta score
|
||||
std::size_t rank_high, // higher index on the model rank list
|
||||
std::size_t rank_low, // lower index on the model rank list
|
||||
Delta delta, // function to calculate delta score
|
||||
linalg::VectorView<double const> t_plus, // input bias ratio
|
||||
linalg::VectorView<double const> t_minus, // input bias ratio
|
||||
double* p_cost) {
|
||||
@ -95,30 +110,34 @@ LambdaGrad(linalg::VectorView<float const> labels, common::Span<float const> pre
|
||||
|
||||
// Use double whenever possible as we are working on the exp space.
|
||||
double delta_score = std::abs(s_high - s_low);
|
||||
double sigmoid = common::Sigmoid(s_high - s_low);
|
||||
double const sigmoid = common::Sigmoid(s_high - s_low);
|
||||
// Change in metric score like \delta NDCG or \delta MAP
|
||||
double delta_metric = std::abs(delta(y_high, y_low, rank_high, rank_low));
|
||||
|
||||
if (best_score != worst_score) {
|
||||
delta_metric /= (delta_score + kRtEps);
|
||||
delta_metric /= (delta_score + 0.01);
|
||||
}
|
||||
|
||||
if (unbiased) {
|
||||
*p_cost = std::log(1.0 / (1.0 - sigmoid)) * delta_metric;
|
||||
}
|
||||
|
||||
constexpr double kEps = 1e-16;
|
||||
auto lambda_ij = (sigmoid - 1.0) * delta_metric;
|
||||
auto hessian_ij = std::max(sigmoid * (1.0 - sigmoid), kEps) * delta_metric * 2.0;
|
||||
auto hessian_ij = std::max(sigmoid * (1.0 - sigmoid), Eps64()) * delta_metric * 2.0;
|
||||
|
||||
auto k = t_plus.Size();
|
||||
assert(t_minus.Size() == k && "Invalid size of position bias");
|
||||
|
||||
if (unbiased && idx_high < k && idx_low < k) {
|
||||
lambda_ij /= (t_minus(idx_low) * t_plus(idx_high) + kRtEps);
|
||||
hessian_ij /= (t_minus(idx_low) * t_plus(idx_high) + kRtEps);
|
||||
// We need to skip samples that exceed the maximum number of tracked positions, and
|
||||
// samples that have low probability and might bring us floating point issues.
|
||||
if (unbiased && idx_high < k && idx_low < k && t_minus(idx_low) >= Eps64() &&
|
||||
t_plus(idx_high) >= Eps64()) {
|
||||
// The index should be ranks[idx_low], since we assume label is sorted, this reduces
|
||||
// to `idx_low`, which represents the position on the input list, as explained in the
|
||||
// file header.
|
||||
lambda_ij /= (t_plus(idx_high) * t_minus(idx_low));
|
||||
hessian_ij /= (t_plus(idx_high) * t_minus(idx_low));
|
||||
}
|
||||
|
||||
auto pg = GradientPair{static_cast<float>(lambda_ij), static_cast<float>(hessian_ij)};
|
||||
return pg;
|
||||
}
|
||||
@ -137,27 +156,6 @@ void LambdaRankGetGradientNDCG(Context const* ctx, std::int32_t iter,
|
||||
linalg::VectorView<double> li, linalg::VectorView<double> lj,
|
||||
HostDeviceVector<GradientPair>* out_gpair);
|
||||
|
||||
/**
|
||||
* \brief Generate statistic for MAP used for calculating \Delta Z in lambda mart.
|
||||
*/
|
||||
void MAPStat(Context const* ctx, MetaInfo const& info, common::Span<std::size_t const> d_rank_idx,
|
||||
std::shared_ptr<ltr::MAPCache> p_cache);
|
||||
|
||||
void LambdaRankGetGradientMAP(Context const* ctx, std::int32_t iter,
|
||||
HostDeviceVector<float> const& predt, MetaInfo const& info,
|
||||
std::shared_ptr<ltr::MAPCache> p_cache,
|
||||
linalg::VectorView<double const> t_plus, // input bias ratio
|
||||
linalg::VectorView<double const> t_minus, // input bias ratio
|
||||
linalg::VectorView<double> li, linalg::VectorView<double> lj,
|
||||
HostDeviceVector<GradientPair>* out_gpair);
|
||||
|
||||
void LambdaRankGetGradientPairwise(Context const* ctx, std::int32_t iter,
|
||||
HostDeviceVector<float> const& predt, const MetaInfo& info,
|
||||
std::shared_ptr<ltr::RankingCache> p_cache,
|
||||
linalg::VectorView<double const> ti_plus, // input bias ratio
|
||||
linalg::VectorView<double const> tj_minus, // input bias ratio
|
||||
linalg::VectorView<double> li, linalg::VectorView<double> lj,
|
||||
HostDeviceVector<GradientPair>* out_gpair);
|
||||
|
||||
void LambdaRankUpdatePositionBias(Context const* ctx, linalg::VectorView<double const> li_full,
|
||||
linalg::VectorView<double const> lj_full,
|
||||
@ -167,18 +165,6 @@ void LambdaRankUpdatePositionBias(Context const* ctx, linalg::VectorView<double
|
||||
std::shared_ptr<ltr::RankingCache> p_cache);
|
||||
} // namespace cuda_impl
|
||||
|
||||
namespace cpu_impl {
|
||||
/**
|
||||
* \brief Generate statistic for MAP used for calculating \Delta Z in lambda mart.
|
||||
*
|
||||
* \param label Ground truth relevance label.
|
||||
* \param rank_idx Sorted index of prediction.
|
||||
* \param p_cache An initialized MAPCache.
|
||||
*/
|
||||
void MAPStat(Context const* ctx, linalg::VectorView<float const> label,
|
||||
common::Span<std::size_t const> rank_idx, std::shared_ptr<ltr::MAPCache> p_cache);
|
||||
} // namespace cpu_impl
|
||||
|
||||
/**
|
||||
* \param Construct pairs on CPU
|
||||
*
|
||||
|
||||
@ -48,12 +48,15 @@ DMLC_REGISTRY_LINK_TAG(quantile_obj_gpu);
|
||||
DMLC_REGISTRY_LINK_TAG(hinge_obj_gpu);
|
||||
DMLC_REGISTRY_LINK_TAG(multiclass_obj_gpu);
|
||||
DMLC_REGISTRY_LINK_TAG(rank_obj_gpu);
|
||||
DMLC_REGISTRY_LINK_TAG(lambdarank_obj);
|
||||
DMLC_REGISTRY_LINK_TAG(lambdarank_obj_cu);
|
||||
#else
|
||||
DMLC_REGISTRY_LINK_TAG(regression_obj);
|
||||
DMLC_REGISTRY_LINK_TAG(quantile_obj);
|
||||
DMLC_REGISTRY_LINK_TAG(hinge_obj);
|
||||
DMLC_REGISTRY_LINK_TAG(multiclass_obj);
|
||||
DMLC_REGISTRY_LINK_TAG(rank_obj);
|
||||
DMLC_REGISTRY_LINK_TAG(lambdarank_obj);
|
||||
#endif // XGBOOST_USE_CUDA
|
||||
} // namespace obj
|
||||
} // namespace xgboost
|
||||
|
||||
@ -207,174 +207,6 @@ class IndexablePredictionSorter {
|
||||
};
|
||||
#endif
|
||||
|
||||
// beta version: NDCG lambda rank
|
||||
class NDCGLambdaWeightComputer
|
||||
#if defined(__CUDACC__)
|
||||
: public IndexablePredictionSorter
|
||||
#endif
|
||||
{
|
||||
public:
|
||||
#if defined(__CUDACC__)
|
||||
// This function object computes the item's DCG value
|
||||
class ComputeItemDCG : public thrust::unary_function<uint32_t, float> {
|
||||
public:
|
||||
XGBOOST_DEVICE ComputeItemDCG(const common::Span<const float> &dsorted_labels,
|
||||
const common::Span<const uint32_t> &dgroups,
|
||||
const common::Span<const uint32_t> &gidxs)
|
||||
: dsorted_labels_(dsorted_labels),
|
||||
dgroups_(dgroups),
|
||||
dgidxs_(gidxs) {}
|
||||
|
||||
// Compute DCG for the item at 'idx'
|
||||
__device__ __forceinline__ float operator()(uint32_t idx) const {
|
||||
return ComputeItemDCGWeight(dsorted_labels_[idx], idx - dgroups_[dgidxs_[idx]]);
|
||||
}
|
||||
|
||||
private:
|
||||
const common::Span<const float> dsorted_labels_; // Labels sorted within a group
|
||||
const common::Span<const uint32_t> dgroups_; // The group indices - where each group
|
||||
// begins and ends
|
||||
const common::Span<const uint32_t> dgidxs_; // The group each items belongs to
|
||||
};
|
||||
|
||||
// Type containing device pointers that can be cheaply copied on the kernel
|
||||
class NDCGLambdaWeightMultiplier : public BaseLambdaWeightMultiplier {
|
||||
public:
|
||||
NDCGLambdaWeightMultiplier(const dh::SegmentSorter<float> &segment_label_sorter,
|
||||
const NDCGLambdaWeightComputer &lwc)
|
||||
: BaseLambdaWeightMultiplier(segment_label_sorter, lwc.GetPredictionSorter()),
|
||||
dgroup_dcgs_(lwc.GetGroupDcgsSpan()) {}
|
||||
|
||||
// Adjust the items weight by this value
|
||||
__device__ __forceinline__ bst_float GetWeight(uint32_t gidx, int pidx, int nidx) const {
|
||||
if (dgroup_dcgs_[gidx] == 0.0) return 0.0f;
|
||||
|
||||
uint32_t group_begin = dgroups_[gidx];
|
||||
|
||||
auto pos_lab_orig_posn = dorig_pos_[pidx];
|
||||
auto neg_lab_orig_posn = dorig_pos_[nidx];
|
||||
KERNEL_CHECK(pos_lab_orig_posn != neg_lab_orig_posn);
|
||||
|
||||
// Note: the label positive and negative indices are relative to the entire dataset.
|
||||
// Hence, scale them back to an index within the group
|
||||
auto pos_pred_pos = dindexable_sorted_preds_pos_[pos_lab_orig_posn] - group_begin;
|
||||
auto neg_pred_pos = dindexable_sorted_preds_pos_[neg_lab_orig_posn] - group_begin;
|
||||
return NDCGLambdaWeightComputer::ComputeDeltaWeight(
|
||||
pos_pred_pos, neg_pred_pos,
|
||||
static_cast<int>(dsorted_labels_[pidx]), static_cast<int>(dsorted_labels_[nidx]),
|
||||
dgroup_dcgs_[gidx]);
|
||||
}
|
||||
|
||||
private:
|
||||
const common::Span<const float> dgroup_dcgs_; // Group DCG values
|
||||
};
|
||||
|
||||
NDCGLambdaWeightComputer(const bst_float *dpreds,
|
||||
const bst_float*,
|
||||
const dh::SegmentSorter<float> &segment_label_sorter)
|
||||
: IndexablePredictionSorter(dpreds, segment_label_sorter),
|
||||
dgroup_dcg_(segment_label_sorter.GetNumGroups(), 0.0f),
|
||||
weight_multiplier_(segment_label_sorter, *this) {
|
||||
const auto &group_segments = segment_label_sorter.GetGroupSegmentsSpan();
|
||||
|
||||
// Allocator to be used for managing space overhead while performing transformed reductions
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
|
||||
// Compute each elements DCG values and reduce them across groups concurrently.
|
||||
auto end_range =
|
||||
thrust::reduce_by_key(thrust::cuda::par(alloc),
|
||||
dh::tcbegin(group_segments), dh::tcend(group_segments),
|
||||
thrust::make_transform_iterator(
|
||||
// The indices need not be sequential within a group, as we care only
|
||||
// about the sum of items DCG values within a group
|
||||
dh::tcbegin(segment_label_sorter.GetOriginalPositionsSpan()),
|
||||
ComputeItemDCG(segment_label_sorter.GetItemsSpan(),
|
||||
segment_label_sorter.GetGroupsSpan(),
|
||||
group_segments)),
|
||||
thrust::make_discard_iterator(), // We don't care for the group indices
|
||||
dgroup_dcg_.begin()); // Sum of the item's DCG values in the group
|
||||
CHECK_EQ(static_cast<unsigned>(end_range.second - dgroup_dcg_.begin()), dgroup_dcg_.size());
|
||||
}
|
||||
|
||||
inline const common::Span<const float> GetGroupDcgsSpan() const {
|
||||
return { dgroup_dcg_.data().get(), dgroup_dcg_.size() };
|
||||
}
|
||||
|
||||
inline const NDCGLambdaWeightMultiplier GetWeightMultiplier() const {
|
||||
return weight_multiplier_;
|
||||
}
|
||||
#endif
|
||||
|
||||
static void GetLambdaWeight(const std::vector<ListEntry> &sorted_list,
|
||||
std::vector<LambdaPair> *io_pairs) {
|
||||
std::vector<LambdaPair> &pairs = *io_pairs;
|
||||
float IDCG; // NOLINT
|
||||
{
|
||||
std::vector<bst_float> labels(sorted_list.size());
|
||||
for (size_t i = 0; i < sorted_list.size(); ++i) {
|
||||
labels[i] = sorted_list[i].label;
|
||||
}
|
||||
std::stable_sort(labels.begin(), labels.end(), std::greater<>());
|
||||
IDCG = ComputeGroupDCGWeight(&labels[0], labels.size());
|
||||
}
|
||||
if (IDCG == 0.0) {
|
||||
for (auto & pair : pairs) {
|
||||
pair.weight = 0.0f;
|
||||
}
|
||||
} else {
|
||||
for (auto & pair : pairs) {
|
||||
unsigned pos_idx = pair.pos_index;
|
||||
unsigned neg_idx = pair.neg_index;
|
||||
pair.weight *= ComputeDeltaWeight(pos_idx, neg_idx,
|
||||
sorted_list[pos_idx].label, sorted_list[neg_idx].label,
|
||||
IDCG);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static char const* Name() {
|
||||
return "rank:ndcg";
|
||||
}
|
||||
|
||||
inline static bst_float ComputeGroupDCGWeight(const float *sorted_labels, uint32_t size) {
|
||||
double sumdcg = 0.0;
|
||||
for (uint32_t i = 0; i < size; ++i) {
|
||||
sumdcg += ComputeItemDCGWeight(sorted_labels[i], i);
|
||||
}
|
||||
|
||||
return static_cast<bst_float>(sumdcg);
|
||||
}
|
||||
|
||||
private:
|
||||
XGBOOST_DEVICE inline static bst_float ComputeItemDCGWeight(unsigned label, uint32_t idx) {
|
||||
return (label != 0) ? (((1 << label) - 1) / std::log2(static_cast<bst_float>(idx + 2))) : 0;
|
||||
}
|
||||
|
||||
// Compute the weight adjustment for an item within a group:
|
||||
// pos_pred_pos => Where does the positive label live, had the list been sorted by prediction
|
||||
// neg_pred_pos => Where does the negative label live, had the list been sorted by prediction
|
||||
// pos_label => positive label value from sorted label list
|
||||
// neg_label => negative label value from sorted label list
|
||||
XGBOOST_DEVICE inline static bst_float ComputeDeltaWeight(uint32_t pos_pred_pos,
|
||||
uint32_t neg_pred_pos,
|
||||
int pos_label, int neg_label,
|
||||
float idcg) {
|
||||
float pos_loginv = 1.0f / std::log2(pos_pred_pos + 2.0f);
|
||||
float neg_loginv = 1.0f / std::log2(neg_pred_pos + 2.0f);
|
||||
bst_float original = ((1 << pos_label) - 1) * pos_loginv + ((1 << neg_label) - 1) * neg_loginv;
|
||||
float changed = ((1 << neg_label) - 1) * pos_loginv + ((1 << pos_label) - 1) * neg_loginv;
|
||||
bst_float delta = (original - changed) * (1.0f / idcg);
|
||||
if (delta < 0.0f) delta = - delta;
|
||||
return delta;
|
||||
}
|
||||
|
||||
#if defined(__CUDACC__)
|
||||
dh::caching_device_vector<float> dgroup_dcg_;
|
||||
// This computes the adjustment to the weight
|
||||
const NDCGLambdaWeightMultiplier weight_multiplier_;
|
||||
#endif
|
||||
};
|
||||
|
||||
class MAPLambdaWeightComputer
|
||||
#if defined(__CUDACC__)
|
||||
: public IndexablePredictionSorter
|
||||
@ -948,10 +780,6 @@ XGBOOST_REGISTER_OBJECTIVE(PairwiseRankObj, PairwiseLambdaWeightComputer::Name()
|
||||
.describe("Pairwise rank objective.")
|
||||
.set_body([]() { return new LambdaRankObj<PairwiseLambdaWeightComputer>(); });
|
||||
|
||||
XGBOOST_REGISTER_OBJECTIVE(LambdaRankNDCG, NDCGLambdaWeightComputer::Name())
|
||||
.describe("LambdaRank with NDCG as objective.")
|
||||
.set_body([]() { return new LambdaRankObj<NDCGLambdaWeightComputer>(); });
|
||||
|
||||
XGBOOST_REGISTER_OBJECTIVE(LambdaRankObjMAP, MAPLambdaWeightComputer::Name())
|
||||
.describe("LambdaRank with MAP as objective.")
|
||||
.set_body([]() { return new LambdaRankObj<MAPLambdaWeightComputer>(); });
|
||||
|
||||
@ -5,6 +5,7 @@
|
||||
|
||||
#include <gtest/gtest.h> // for Test, Message, TestPartResult, CmpHel...
|
||||
|
||||
#include <algorithm> // for sort
|
||||
#include <cstddef> // for size_t
|
||||
#include <initializer_list> // for initializer_list
|
||||
#include <map> // for map
|
||||
@ -13,7 +14,6 @@
|
||||
#include <string> // for char_traits, basic_string, string
|
||||
#include <vector> // for vector
|
||||
|
||||
#include "../../../src/common/ranking_utils.h" // for LambdaRankParam
|
||||
#include "../../../src/common/ranking_utils.h" // for NDCGCache, LambdaRankParam
|
||||
#include "../helpers.h" // for CheckRankingObjFunction, CheckConfigReload
|
||||
#include "xgboost/base.h" // for GradientPair, bst_group_t, Args
|
||||
@ -25,6 +25,126 @@
|
||||
#include "xgboost/span.h" // for Span
|
||||
|
||||
namespace xgboost::obj {
|
||||
TEST(LambdaRank, NDCGJsonIO) {
|
||||
Context ctx;
|
||||
TestNDCGJsonIO(&ctx);
|
||||
}
|
||||
|
||||
void TestNDCGGPair(Context const* ctx) {
|
||||
{
|
||||
std::unique_ptr<xgboost::ObjFunction> obj{xgboost::ObjFunction::Create("rank:ndcg", ctx)};
|
||||
obj->Configure(Args{{"lambdarank_pair_method", "topk"}});
|
||||
CheckConfigReload(obj, "rank:ndcg");
|
||||
|
||||
// No gain in swapping 2 documents.
|
||||
CheckRankingObjFunction(obj,
|
||||
{1, 1, 1, 1},
|
||||
{1, 1, 1, 1},
|
||||
{1.0f, 1.0f},
|
||||
{0, 2, 4},
|
||||
{0.0f, -0.0f, 0.0f, 0.0f},
|
||||
{0.0f, 0.0f, 0.0f, 0.0f});
|
||||
}
|
||||
{
|
||||
std::unique_ptr<xgboost::ObjFunction> obj{xgboost::ObjFunction::Create("rank:ndcg", ctx)};
|
||||
obj->Configure(Args{{"lambdarank_pair_method", "topk"}});
|
||||
// Test with setting sample weight to second query group
|
||||
CheckRankingObjFunction(obj,
|
||||
{0, 0.1f, 0, 0.1f},
|
||||
{0, 1, 0, 1},
|
||||
{2.0f, 0.0f},
|
||||
{0, 2, 4},
|
||||
{2.06611f, -2.06611f, 0.0f, 0.0f},
|
||||
{2.169331f, 2.169331f, 0.0f, 0.0f});
|
||||
|
||||
CheckRankingObjFunction(obj,
|
||||
{0, 0.1f, 0, 0.1f},
|
||||
{0, 1, 0, 1},
|
||||
{2.0f, 2.0f},
|
||||
{0, 2, 4},
|
||||
{2.06611f, -2.06611f, 2.06611f, -2.06611f},
|
||||
{2.169331f, 2.169331f, 2.169331f, 2.169331f});
|
||||
}
|
||||
|
||||
std::unique_ptr<xgboost::ObjFunction> obj{xgboost::ObjFunction::Create("rank:ndcg", ctx)};
|
||||
obj->Configure(Args{{"lambdarank_pair_method", "topk"}});
|
||||
|
||||
HostDeviceVector<float> predts{0, 1, 0, 1};
|
||||
MetaInfo info;
|
||||
info.labels = linalg::Tensor<float, 2>{{0, 1, 0, 1}, {4, 1}, GPUIDX};
|
||||
info.group_ptr_ = {0, 2, 4};
|
||||
info.num_row_ = 4;
|
||||
HostDeviceVector<GradientPair> gpairs;
|
||||
obj->GetGradient(predts, info, 0, &gpairs);
|
||||
ASSERT_EQ(gpairs.Size(), predts.Size());
|
||||
|
||||
{
|
||||
predts = {1, 0, 1, 0};
|
||||
HostDeviceVector<GradientPair> gpairs;
|
||||
obj->GetGradient(predts, info, 0, &gpairs);
|
||||
for (size_t i = 0; i < gpairs.Size(); ++i) {
|
||||
ASSERT_GT(gpairs.HostSpan()[i].GetHess(), 0);
|
||||
}
|
||||
ASSERT_LT(gpairs.HostSpan()[1].GetGrad(), 0);
|
||||
ASSERT_LT(gpairs.HostSpan()[3].GetGrad(), 0);
|
||||
|
||||
ASSERT_GT(gpairs.HostSpan()[0].GetGrad(), 0);
|
||||
ASSERT_GT(gpairs.HostSpan()[2].GetGrad(), 0);
|
||||
|
||||
info.weights_ = {2, 3};
|
||||
HostDeviceVector<GradientPair> weighted_gpairs;
|
||||
obj->GetGradient(predts, info, 0, &weighted_gpairs);
|
||||
auto const& h_gpairs = gpairs.ConstHostSpan();
|
||||
auto const& h_weighted_gpairs = weighted_gpairs.ConstHostSpan();
|
||||
for (size_t i : {0ul, 1ul}) {
|
||||
ASSERT_FLOAT_EQ(h_weighted_gpairs[i].GetGrad(), h_gpairs[i].GetGrad() * 2.0f);
|
||||
ASSERT_FLOAT_EQ(h_weighted_gpairs[i].GetHess(), h_gpairs[i].GetHess() * 2.0f);
|
||||
}
|
||||
for (size_t i : {2ul, 3ul}) {
|
||||
ASSERT_FLOAT_EQ(h_weighted_gpairs[i].GetGrad(), h_gpairs[i].GetGrad() * 3.0f);
|
||||
ASSERT_FLOAT_EQ(h_weighted_gpairs[i].GetHess(), h_gpairs[i].GetHess() * 3.0f);
|
||||
}
|
||||
}
|
||||
|
||||
ASSERT_NO_THROW(obj->DefaultEvalMetric());
|
||||
}
|
||||
|
||||
TEST(LambdaRank, NDCGGPair) {
|
||||
Context ctx;
|
||||
TestNDCGGPair(&ctx);
|
||||
}
|
||||
|
||||
void TestUnbiasedNDCG(Context const* ctx) {
|
||||
std::unique_ptr<xgboost::ObjFunction> obj{xgboost::ObjFunction::Create("rank:ndcg", ctx)};
|
||||
obj->Configure(Args{{"lambdarank_pair_method", "topk"},
|
||||
{"lambdarank_unbiased", "true"},
|
||||
{"lambdarank_bias_norm", "0"}});
|
||||
std::shared_ptr<DMatrix> p_fmat{RandomDataGenerator{10, 1, 0.0f}.GenerateDMatrix(true, false, 2)};
|
||||
auto h_label = p_fmat->Info().labels.HostView().Values();
|
||||
// Move clicked samples to the beginning.
|
||||
std::sort(h_label.begin(), h_label.end(), std::greater<>{});
|
||||
HostDeviceVector<float> predt(p_fmat->Info().num_row_, 1.0f);
|
||||
|
||||
HostDeviceVector<GradientPair> out_gpair;
|
||||
obj->GetGradient(predt, p_fmat->Info(), 0, &out_gpair);
|
||||
|
||||
Json config{Object{}};
|
||||
obj->SaveConfig(&config);
|
||||
auto ti_plus = get<F32Array const>(config["ti+"]);
|
||||
ASSERT_FLOAT_EQ(ti_plus[0], 1.0);
|
||||
// bias is non-increasing when prediction is constant. (constant cost on swapping documents)
|
||||
for (std::size_t i = 1; i < ti_plus.size(); ++i) {
|
||||
ASSERT_LE(ti_plus[i], ti_plus[i - 1]);
|
||||
}
|
||||
auto tj_minus = get<F32Array const>(config["tj-"]);
|
||||
ASSERT_FLOAT_EQ(tj_minus[0], 1.0);
|
||||
}
|
||||
|
||||
TEST(LambdaRank, UnbiasedNDCG) {
|
||||
Context ctx;
|
||||
TestUnbiasedNDCG(&ctx);
|
||||
}
|
||||
|
||||
void InitMakePairTest(Context const* ctx, MetaInfo* out_info, HostDeviceVector<float>* out_predt) {
|
||||
out_predt->SetDevice(ctx->gpu_id);
|
||||
MetaInfo& info = *out_info;
|
||||
|
||||
@ -12,6 +12,18 @@
|
||||
#include "test_lambdarank_obj.h"
|
||||
|
||||
namespace xgboost::obj {
|
||||
TEST(LambdaRank, GPUNDCGJsonIO) {
|
||||
Context ctx;
|
||||
ctx.gpu_id = 0;
|
||||
TestNDCGJsonIO(&ctx);
|
||||
}
|
||||
|
||||
TEST(LambdaRank, GPUNDCGGPair) {
|
||||
Context ctx;
|
||||
ctx.gpu_id = 0;
|
||||
TestNDCGGPair(&ctx);
|
||||
}
|
||||
|
||||
void TestGPUMakePair() {
|
||||
Context ctx;
|
||||
ctx.gpu_id = 0;
|
||||
@ -107,6 +119,12 @@ void TestGPUMakePair() {
|
||||
|
||||
TEST(LambdaRank, GPUMakePair) { TestGPUMakePair(); }
|
||||
|
||||
TEST(LambdaRank, GPUUnbiasedNDCG) {
|
||||
Context ctx;
|
||||
ctx.gpu_id = 0;
|
||||
TestUnbiasedNDCG(&ctx);
|
||||
}
|
||||
|
||||
template <typename CountFunctor>
|
||||
void RankItemCountImpl(std::vector<std::uint32_t> const &sorted_items, CountFunctor f,
|
||||
std::uint32_t find_val, std::uint32_t exp_val) {
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/**
|
||||
* Copyright 2023, XGBoost Contributors
|
||||
* Copyright (c) 2023, XGBoost Contributors
|
||||
*/
|
||||
#ifndef XGBOOST_OBJECTIVE_TEST_LAMBDARANK_OBJ_H_
|
||||
#define XGBOOST_OBJECTIVE_TEST_LAMBDARANK_OBJ_H_
|
||||
@ -18,6 +18,25 @@
|
||||
#include "../helpers.h" // for EmptyDMatrix
|
||||
|
||||
namespace xgboost::obj {
|
||||
inline void TestNDCGJsonIO(Context const* ctx) {
|
||||
std::unique_ptr<xgboost::ObjFunction> obj{ObjFunction::Create("rank:ndcg", ctx)};
|
||||
|
||||
obj->Configure(Args{});
|
||||
Json j_obj{Object()};
|
||||
obj->SaveConfig(&j_obj);
|
||||
|
||||
ASSERT_EQ(get<String>(j_obj["name"]), "rank:ndcg");
|
||||
auto const& j_param = j_obj["lambdarank_param"];
|
||||
|
||||
ASSERT_EQ(get<String>(j_param["ndcg_exp_gain"]), "1");
|
||||
ASSERT_EQ(get<String>(j_param["lambdarank_num_pair_per_sample"]),
|
||||
std::to_string(ltr::LambdaRankParam::NotSet()));
|
||||
}
|
||||
|
||||
void TestNDCGGPair(Context const* ctx);
|
||||
|
||||
void TestUnbiasedNDCG(Context const* ctx);
|
||||
|
||||
/**
|
||||
* \brief Initialize test data for make pair tests.
|
||||
*/
|
||||
|
||||
@ -35,24 +35,6 @@ TEST(Objective, DeclareUnifiedTest(PairwiseRankingGPair)) {
|
||||
ASSERT_NO_THROW(obj->DefaultEvalMetric());
|
||||
}
|
||||
|
||||
TEST(Objective, DeclareUnifiedTest(NDCG_JsonIO)) {
|
||||
xgboost::Context ctx;
|
||||
ctx.UpdateAllowUnknown(Args{});
|
||||
|
||||
std::unique_ptr<xgboost::ObjFunction> obj{xgboost::ObjFunction::Create("rank:ndcg", &ctx)};
|
||||
|
||||
obj->Configure(Args{});
|
||||
Json j_obj {Object()};
|
||||
obj->SaveConfig(&j_obj);
|
||||
|
||||
ASSERT_EQ(get<String>(j_obj["name"]), "rank:ndcg");;
|
||||
|
||||
auto const& j_param = j_obj["lambda_rank_param"];
|
||||
|
||||
ASSERT_EQ(get<String>(j_param["num_pairsample"]), "1");
|
||||
ASSERT_EQ(get<String>(j_param["fix_list_weight"]), "0");
|
||||
}
|
||||
|
||||
TEST(Objective, DeclareUnifiedTest(PairwiseRankingGPairSameLabels)) {
|
||||
std::vector<std::pair<std::string, std::string>> args;
|
||||
xgboost::Context ctx = xgboost::CreateEmptyGenericParam(GPUIDX);
|
||||
@ -71,33 +53,6 @@ TEST(Objective, DeclareUnifiedTest(PairwiseRankingGPairSameLabels)) {
|
||||
ASSERT_NO_THROW(obj->DefaultEvalMetric());
|
||||
}
|
||||
|
||||
TEST(Objective, DeclareUnifiedTest(NDCGRankingGPair)) {
|
||||
std::vector<std::pair<std::string, std::string>> args;
|
||||
xgboost::Context ctx = xgboost::CreateEmptyGenericParam(GPUIDX);
|
||||
|
||||
std::unique_ptr<xgboost::ObjFunction> obj{xgboost::ObjFunction::Create("rank:ndcg", &ctx)};
|
||||
obj->Configure(args);
|
||||
CheckConfigReload(obj, "rank:ndcg");
|
||||
|
||||
// Test with setting sample weight to second query group
|
||||
CheckRankingObjFunction(obj,
|
||||
{0, 0.1f, 0, 0.1f},
|
||||
{0, 1, 0, 1},
|
||||
{2.0f, 0.0f},
|
||||
{0, 2, 4},
|
||||
{0.7f, -0.7f, 0.0f, 0.0f},
|
||||
{0.74f, 0.74f, 0.0f, 0.0f});
|
||||
|
||||
CheckRankingObjFunction(obj,
|
||||
{0, 0.1f, 0, 0.1f},
|
||||
{0, 1, 0, 1},
|
||||
{1.0f, 1.0f},
|
||||
{0, 2, 4},
|
||||
{0.35f, -0.35f, 0.35f, -0.35f},
|
||||
{0.368f, 0.368f, 0.368f, 0.368f});
|
||||
ASSERT_NO_THROW(obj->DefaultEvalMetric());
|
||||
}
|
||||
|
||||
TEST(Objective, DeclareUnifiedTest(MAPRankingGPair)) {
|
||||
std::vector<std::pair<std::string, std::string>> args;
|
||||
xgboost::Context ctx = xgboost::CreateEmptyGenericParam(GPUIDX);
|
||||
|
||||
@ -89,62 +89,6 @@ TEST(Objective, RankSegmentSorterAscendingTest) {
|
||||
5, 4, 6});
|
||||
}
|
||||
|
||||
TEST(Objective, NDCGLambdaWeightComputerTest) {
|
||||
std::vector<float> hlabels = {3.1f, 1.2f, 2.3f, 4.4f, // Labels
|
||||
7.8f, 5.01f, 6.96f,
|
||||
10.3f, 8.7f, 11.4f, 9.45f, 11.4f};
|
||||
dh::device_vector<bst_float> dlabels(hlabels);
|
||||
|
||||
auto segment_label_sorter = RankSegmentSorterTestImpl<float>(
|
||||
{0, 4, 7, 12}, // Groups
|
||||
hlabels,
|
||||
{4.4f, 3.1f, 2.3f, 1.2f, // Expected sorted labels
|
||||
7.8f, 6.96f, 5.01f,
|
||||
11.4f, 11.4f, 10.3f, 9.45f, 8.7f},
|
||||
{3, 0, 2, 1, // Expected original positions
|
||||
4, 6, 5,
|
||||
9, 11, 7, 10, 8});
|
||||
|
||||
// Created segmented predictions for the labels from above
|
||||
std::vector<bst_float> hpreds{-9.78f, 24.367f, 0.908f, -11.47f,
|
||||
-1.03f, -2.79f, -3.1f,
|
||||
104.22f, 103.1f, -101.7f, 100.5f, 45.1f};
|
||||
dh::device_vector<bst_float> dpreds(hpreds);
|
||||
|
||||
xgboost::obj::NDCGLambdaWeightComputer ndcg_lw_computer(dpreds.data().get(),
|
||||
dlabels.data().get(),
|
||||
*segment_label_sorter);
|
||||
|
||||
// Where will the predictions move from its current position, if they were sorted
|
||||
// descendingly?
|
||||
auto dsorted_pred_pos = ndcg_lw_computer.GetPredictionSorter().GetIndexableSortedPositionsSpan();
|
||||
std::vector<uint32_t> hsorted_pred_pos(segment_label_sorter->GetNumItems());
|
||||
dh::CopyDeviceSpanToVector(&hsorted_pred_pos, dsorted_pred_pos);
|
||||
std::vector<uint32_t> expected_sorted_pred_pos{2, 0, 1, 3,
|
||||
4, 5, 6,
|
||||
7, 8, 11, 9, 10};
|
||||
EXPECT_EQ(expected_sorted_pred_pos, hsorted_pred_pos);
|
||||
|
||||
// Check group DCG values
|
||||
std::vector<float> hgroup_dcgs(segment_label_sorter->GetNumGroups());
|
||||
dh::CopyDeviceSpanToVector(&hgroup_dcgs, ndcg_lw_computer.GetGroupDcgsSpan());
|
||||
std::vector<uint32_t> hgroups(segment_label_sorter->GetNumGroups() + 1);
|
||||
dh::CopyDeviceSpanToVector(&hgroups, segment_label_sorter->GetGroupsSpan());
|
||||
EXPECT_EQ(hgroup_dcgs.size(), segment_label_sorter->GetNumGroups());
|
||||
std::vector<float> hsorted_labels(segment_label_sorter->GetNumItems());
|
||||
dh::CopyDeviceSpanToVector(&hsorted_labels, segment_label_sorter->GetItemsSpan());
|
||||
for (size_t i = 0; i < hgroup_dcgs.size(); ++i) {
|
||||
// Compute group DCG value on CPU and compare
|
||||
auto gbegin = hgroups[i];
|
||||
auto gend = hgroups[i + 1];
|
||||
EXPECT_NEAR(
|
||||
hgroup_dcgs[i],
|
||||
xgboost::obj::NDCGLambdaWeightComputer::ComputeGroupDCGWeight(&hsorted_labels[gbegin],
|
||||
gend - gbegin),
|
||||
0.01f);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(Objective, IndexableSortedItemsTest) {
|
||||
std::vector<float> hlabels = {3.1f, 1.2f, 2.3f, 4.4f, // Labels
|
||||
7.8f, 5.01f, 6.96f,
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import json
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
@ -36,19 +37,16 @@ class TestGPUEvalMetrics:
|
||||
|
||||
Xy = xgboost.DMatrix(X, y, group=group)
|
||||
|
||||
cpu = xgboost.train(
|
||||
booster = xgboost.train(
|
||||
{"tree_method": "hist", "eval_metric": "auc", "objective": "rank:ndcg"},
|
||||
Xy,
|
||||
num_boost_round=10,
|
||||
)
|
||||
cpu_auc = float(cpu.eval(Xy).split(":")[1])
|
||||
|
||||
gpu = xgboost.train(
|
||||
{"tree_method": "gpu_hist", "eval_metric": "auc", "objective": "rank:ndcg"},
|
||||
Xy,
|
||||
num_boost_round=10,
|
||||
)
|
||||
gpu_auc = float(gpu.eval(Xy).split(":")[1])
|
||||
cpu_auc = float(booster.eval(Xy).split(":")[1])
|
||||
booster.set_param({"gpu_id": "0"})
|
||||
assert json.loads(booster.save_config())["learner"]["generic_param"]["gpu_id"] == "0"
|
||||
gpu_auc = float(booster.eval(Xy).split(":")[1])
|
||||
assert json.loads(booster.save_config())["learner"]["generic_param"]["gpu_id"] == "0"
|
||||
|
||||
np.testing.assert_allclose(cpu_auc, gpu_auc)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user