Rework the NDCG objective. (#9015)

This commit is contained in:
Jiaming Yuan 2023-04-18 21:16:06 +08:00 committed by GitHub
parent ba9d24ff7b
commit ef13dd31b1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 1082 additions and 351 deletions

View File

@ -33,6 +33,7 @@ OBJECTS= \
$(PKGROOT)/src/objective/regression_obj.o \
$(PKGROOT)/src/objective/multiclass_obj.o \
$(PKGROOT)/src/objective/rank_obj.o \
$(PKGROOT)/src/objective/lambdarank_obj.o \
$(PKGROOT)/src/objective/hinge.o \
$(PKGROOT)/src/objective/aft_obj.o \
$(PKGROOT)/src/objective/adaptive.o \

View File

@ -33,6 +33,7 @@ OBJECTS= \
$(PKGROOT)/src/objective/regression_obj.o \
$(PKGROOT)/src/objective/multiclass_obj.o \
$(PKGROOT)/src/objective/rank_obj.o \
$(PKGROOT)/src/objective/lambdarank_obj.o \
$(PKGROOT)/src/objective/hinge.o \
$(PKGROOT)/src/objective/aft_obj.o \
$(PKGROOT)/src/objective/adaptive.o \

View File

@ -1,5 +1,5 @@
/*!
* Copyright 2015 by Contributors
/**
* Copyright 2015-2023 by XGBoost Contributors
* \file math.h
* \brief additional math utils
* \author Tianqi Chen
@ -7,16 +7,19 @@
#ifndef XGBOOST_COMMON_MATH_H_
#define XGBOOST_COMMON_MATH_H_
#include <xgboost/base.h>
#include <xgboost/base.h> // for XGBOOST_DEVICE
#include <algorithm>
#include <cmath>
#include <limits>
#include <utility>
#include <vector>
#include <algorithm> // for max
#include <cmath> // for exp, abs, log, lgamma
#include <limits> // for numeric_limits
#include <type_traits> // for is_floating_point, conditional, is_signed, is_same, declval, enable_if
#include <utility> // for pair
namespace xgboost {
namespace common {
template <typename T> XGBOOST_DEVICE T Sqr(T const &w) { return w * w; }
/*!
* \brief calculate the sigmoid of the input.
* \param x input parameter
@ -30,9 +33,11 @@ XGBOOST_DEVICE inline float Sigmoid(float x) {
return y;
}
template <typename T>
XGBOOST_DEVICE inline static T Sqr(T a) { return a * a; }
XGBOOST_DEVICE inline double Sigmoid(double x) {
auto denom = std::exp(-x) + 1.0;
auto y = 1.0 / denom;
return y;
}
/*!
* \brief Equality test for both integer and floating point.
*/
@ -134,10 +139,6 @@ inline static bool CmpFirst(const std::pair<float, unsigned> &a,
const std::pair<float, unsigned> &b) {
return a.first > b.first;
}
inline static bool CmpSecond(const std::pair<float, unsigned> &a,
const std::pair<float, unsigned> &b) {
return a.second > b.second;
}
// Redefined here to workaround a VC bug that doesn't support overloading for integer
// types.

View File

@ -70,7 +70,7 @@ struct LambdaRankParam : public XGBoostParameter<LambdaRankParam> {
// pairs
// should be accessed by getter for auto configuration.
// nolint so that we can keep the string name.
PairMethod lambdarank_pair_method{PairMethod::kMean}; // NOLINT
PairMethod lambdarank_pair_method{PairMethod::kTopK}; // NOLINT
std::size_t lambdarank_num_pair_per_sample{NotSet()}; // NOLINT
public:
@ -78,7 +78,7 @@ struct LambdaRankParam : public XGBoostParameter<LambdaRankParam> {
// unbiased
bool lambdarank_unbiased{false};
double lambdarank_bias_norm{2.0};
double lambdarank_bias_norm{1.0};
// ndcg
bool ndcg_exp_gain{true};
@ -135,7 +135,7 @@ struct LambdaRankParam : public XGBoostParameter<LambdaRankParam> {
.set_default(false)
.describe("Unbiased lambda mart. Use extended IPW to debias click position");
DMLC_DECLARE_FIELD(lambdarank_bias_norm)
.set_default(2.0)
.set_default(1.0)
.set_lower_bound(0.0)
.describe("Lp regularization for unbiased lambdarank.");
DMLC_DECLARE_FIELD(ndcg_exp_gain)

View File

@ -0,0 +1,440 @@
/**
* Copyright (c) 2023, XGBoost contributors
*/
#include "lambdarank_obj.h"
#include <dmlc/registry.h> // for DMLC_REGISTRY_FILE_TAG
#include <algorithm> // for transform, copy, fill_n, min, max
#include <cmath> // for pow, log2
#include <cstddef> // for size_t
#include <cstdint> // for int32_t
#include <map> // for operator!=
#include <memory> // for shared_ptr, __shared_ptr_access, allocator
#include <ostream> // for operator<<, basic_ostream
#include <string> // for char_traits, operator<, basic_string, string
#include <tuple> // for apply, make_tuple
#include <type_traits> // for is_floating_point
#include <utility> // for pair, swap
#include <vector> // for vector
#include "../common/error_msg.h" // for GroupWeight, LabelScoreSize
#include "../common/linalg_op.h" // for begin, cbegin, cend
#include "../common/optional_weight.h" // for MakeOptionalWeights, OptionalWeights
#include "../common/ranking_utils.h" // for RankingCache, LambdaRankParam, MAPCache, NDCGC...
#include "../common/threading_utils.h" // for ParallelFor, Sched
#include "../common/transform_iterator.h" // for IndexTransformIter
#include "init_estimation.h" // for FitIntercept
#include "xgboost/base.h" // for bst_group_t, GradientPair, kRtEps, GradientPai...
#include "xgboost/context.h" // for Context
#include "xgboost/data.h" // for MetaInfo
#include "xgboost/host_device_vector.h" // for HostDeviceVector
#include "xgboost/json.h" // for Json, get, Value, ToJson, F32Array, FromJson, IsA
#include "xgboost/linalg.h" // for Vector, Range, TensorView, VectorView, All
#include "xgboost/logging.h" // for LogCheck_EQ, CHECK_EQ, CHECK, LogCheck_LE, CHE...
#include "xgboost/objective.h" // for ObjFunctionReg, XGBOOST_REGISTER_OBJECTIVE
#include "xgboost/span.h" // for Span, operator!=
#include "xgboost/string_view.h" // for operator<<, StringView
#include "xgboost/task.h" // for ObjInfo
namespace xgboost::obj {
namespace cpu_impl {
void LambdaRankUpdatePositionBias(Context const* ctx, linalg::VectorView<double const> li_full,
linalg::VectorView<double const> lj_full,
linalg::Vector<double>* p_ti_plus,
linalg::Vector<double>* p_tj_minus, linalg::Vector<double>* p_li,
linalg::Vector<double>* p_lj,
std::shared_ptr<ltr::RankingCache> p_cache) {
auto ti_plus = p_ti_plus->HostView();
auto tj_minus = p_tj_minus->HostView();
auto li = p_li->HostView();
auto lj = p_lj->HostView();
auto gptr = p_cache->DataGroupPtr(ctx);
auto n_groups = p_cache->Groups();
auto regularizer = p_cache->Param().Regularizer();
// Aggregate over query groups
for (bst_group_t g{0}; g < n_groups; ++g) {
auto begin = gptr[g];
auto end = gptr[g + 1];
std::size_t group_size = end - begin;
auto n = std::min(group_size, p_cache->MaxPositionSize());
auto g_li = li_full.Slice(linalg::Range(begin, end));
auto g_lj = lj_full.Slice(linalg::Range(begin, end));
for (std::size_t i{0}; i < n; ++i) {
li(i) += g_li(i);
lj(i) += g_lj(i);
}
}
// The ti+ is not guaranteed to decrease since it depends on the |\delta Z|
//
// The update normalizes the ti+ to make ti+(0) equal to 1, which breaks the probability
// meaning. The reasoning behind the normalization is not clear, here we are just
// following the authors.
for (std::size_t i = 0; i < ti_plus.Size(); ++i) {
if (li(0) >= Eps64()) {
ti_plus(i) = std::pow(li(i) / li(0), regularizer); // eq.30
}
if (lj(0) >= Eps64()) {
tj_minus(i) = std::pow(lj(i) / lj(0), regularizer); // eq.31
}
assert(!std::isinf(ti_plus(i)));
assert(!std::isinf(tj_minus(i)));
}
}
} // namespace cpu_impl
/**
* \brief Base class for pair-wise learning to rank.
*
* See `From RankNet to LambdaRank to LambdaMART: An Overview` for a description of the
* algorithm.
*
* In addition to ranking, this also implements `Unbiased LambdaMART: An Unbiased
* Pairwise Learning-to-Rank Algorithm`.
*/
template <typename Loss, typename Cache>
class LambdaRankObj : public FitIntercept {
MetaInfo const* p_info_{nullptr};
// Update position biased for unbiased click data
void UpdatePositionBias() {
li_full_.SetDevice(ctx_->gpu_id);
lj_full_.SetDevice(ctx_->gpu_id);
li_.SetDevice(ctx_->gpu_id);
lj_.SetDevice(ctx_->gpu_id);
if (ctx_->IsCPU()) {
cpu_impl::LambdaRankUpdatePositionBias(ctx_, li_full_.View(ctx_->gpu_id),
lj_full_.View(ctx_->gpu_id), &ti_plus_, &tj_minus_,
&li_, &lj_, p_cache_);
} else {
cuda_impl::LambdaRankUpdatePositionBias(ctx_, li_full_.View(ctx_->gpu_id),
lj_full_.View(ctx_->gpu_id), &ti_plus_, &tj_minus_,
&li_, &lj_, p_cache_);
}
li_full_.Data()->Fill(0.0);
lj_full_.Data()->Fill(0.0);
li_.Data()->Fill(0.0);
lj_.Data()->Fill(0.0);
}
protected:
// L / tj-* (eq. 30)
linalg::Vector<double> li_;
// L / ti+* (eq. 31)
linalg::Vector<double> lj_;
// position bias ratio for relevant doc, ti+ (eq. 30)
linalg::Vector<double> ti_plus_;
// position bias ratio for irrelevant doc, tj- (eq. 31)
linalg::Vector<double> tj_minus_;
// li buffer for all samples
linalg::Vector<double> li_full_;
// lj buffer for all samples
linalg::Vector<double> lj_full_;
ltr::LambdaRankParam param_;
// cache
std::shared_ptr<ltr::RankingCache> p_cache_;
[[nodiscard]] std::shared_ptr<Cache> GetCache() const {
auto ptr = std::static_pointer_cast<Cache>(p_cache_);
CHECK(ptr);
return ptr;
}
// get group view for li/lj
linalg::VectorView<double> GroupLoss(bst_group_t g, linalg::Vector<double>* v) const {
auto gptr = p_cache_->DataGroupPtr(ctx_);
auto begin = gptr[g];
auto end = gptr[g + 1];
if (param_.lambdarank_unbiased) {
return v->HostView().Slice(linalg::Range(begin, end));
}
return v->HostView();
}
// Calculate lambda gradient for each group on CPU.
template <bool unbiased, typename Delta>
void CalcLambdaForGroup(std::int32_t iter, common::Span<float const> g_predt,
linalg::VectorView<float const> g_label, float w,
common::Span<std::size_t const> g_rank, bst_group_t g, Delta delta,
common::Span<GradientPair> g_gpair) {
std::fill_n(g_gpair.data(), g_gpair.size(), GradientPair{});
auto p_gpair = g_gpair.data();
auto ti_plus = ti_plus_.HostView();
auto tj_minus = tj_minus_.HostView();
auto li = GroupLoss(g, &li_full_);
auto lj = GroupLoss(g, &lj_full_);
// Normalization, first used by LightGBM.
// https://github.com/microsoft/LightGBM/pull/2331#issuecomment-523259298
double sum_lambda{0.0};
auto delta_op = [&](auto const&... args) { return delta(args..., g); };
auto loop = [&](std::size_t i, std::size_t j) {
// higher/lower on the target ranked list
std::size_t rank_high = i, rank_low = j;
if (g_label(g_rank[rank_high]) == g_label(g_rank[rank_low])) {
return;
}
if (g_label(g_rank[rank_high]) < g_label(g_rank[rank_low])) {
std::swap(rank_high, rank_low);
}
double cost;
auto pg = LambdaGrad<unbiased>(g_label, g_predt, g_rank, rank_high, rank_low, delta_op,
ti_plus, tj_minus, &cost);
auto ng = Repulse(pg);
std::size_t idx_high = g_rank[rank_high];
std::size_t idx_low = g_rank[rank_low];
p_gpair[idx_high] += pg;
p_gpair[idx_low] += ng;
if (unbiased) {
auto k = ti_plus.Size();
// We can probably use all the positions. If we skip the update due to having
// high/low > k, we might be losing out too many pairs. On the other hand, if we
// cap the position, then we might be accumulating too many tail bias into the
// last tracked position.
// We use `idx_high` since it represents the original position from the label
// list, and label list is assumed to be sorted.
if (idx_high < k && idx_low < k) {
if (tj_minus(idx_low) >= Eps64()) {
li(idx_high) += cost / tj_minus(idx_low); // eq.30
}
if (ti_plus(idx_high) >= Eps64()) {
lj(idx_low) += cost / ti_plus(idx_high); // eq.31
}
}
}
sum_lambda += -2.0 * static_cast<double>(pg.GetGrad());
};
MakePairs(ctx_, iter, p_cache_, g, g_label, g_rank, loop);
if (sum_lambda > 0.0) {
double norm = std::log2(1.0 + sum_lambda) / sum_lambda;
std::transform(g_gpair.data(), g_gpair.data() + g_gpair.size(), g_gpair.data(),
[norm](GradientPair const& g) { return g * norm; });
}
auto w_norm = p_cache_->WeightNorm();
std::transform(g_gpair.begin(), g_gpair.end(), g_gpair.begin(),
[&](GradientPair const& gpair) { return gpair * w * w_norm; });
}
public:
void Configure(Args const& args) override { param_.UpdateAllowUnknown(args); }
void SaveConfig(Json* p_out) const override {
auto& out = *p_out;
out["name"] = String(Loss::Name());
out["lambdarank_param"] = ToJson(param_);
auto save_bias = [](linalg::Vector<double> const& in, Json out) {
auto& out_array = get<F32Array>(out);
out_array.resize(in.Size());
auto h_in = in.HostView();
std::copy(linalg::cbegin(h_in), linalg::cend(h_in), out_array.begin());
};
if (param_.lambdarank_unbiased) {
out["ti+"] = F32Array();
save_bias(ti_plus_, out["ti+"]);
out["tj-"] = F32Array();
save_bias(tj_minus_, out["tj-"]);
}
}
void LoadConfig(Json const& in) override {
auto const& obj = get<Object const>(in);
if (obj.find("lambdarank_param") != obj.cend()) {
FromJson(in["lambdarank_param"], &param_);
}
if (param_.lambdarank_unbiased) {
auto load_bias = [](Json in, linalg::Vector<double>* out) {
if (IsA<F32Array>(in)) {
// JSON
auto const& array = get<F32Array>(in);
out->Reshape(array.size());
auto h_out = out->HostView();
std::copy(array.cbegin(), array.cend(), linalg::begin(h_out));
} else {
// UBJSON
auto const& array = get<Array>(in);
out->Reshape(array.size());
auto h_out = out->HostView();
std::transform(array.cbegin(), array.cend(), linalg::begin(h_out),
[](Json const& v) { return get<Number const>(v); });
}
};
load_bias(in["ti+"], &ti_plus_);
load_bias(in["tj-"], &tj_minus_);
}
}
[[nodiscard]] ObjInfo Task() const override { return ObjInfo{ObjInfo::kRanking}; }
[[nodiscard]] bst_target_t Targets(MetaInfo const& info) const override {
CHECK_LE(info.labels.Shape(1), 1) << "multi-output for LTR is not yet supported.";
return 1;
}
[[nodiscard]] const char* RankEvalMetric(StringView metric) const {
static thread_local std::string name;
if (param_.HasTruncation()) {
name = ltr::MakeMetricName(metric, param_.NumPair(), false);
} else {
name = ltr::MakeMetricName(metric, param_.NotSet(), false);
}
return name.c_str();
}
void GetGradient(HostDeviceVector<float> const& predt, MetaInfo const& info, std::int32_t iter,
HostDeviceVector<GradientPair>* out_gpair) override {
CHECK_EQ(info.labels.Size(), predt.Size()) << error::LabelScoreSize();
// init/renew cache
if (!p_cache_ || p_info_ != &info || p_cache_->Param() != param_) {
p_cache_ = std::make_shared<Cache>(ctx_, info, param_);
p_info_ = &info;
}
auto n_groups = p_cache_->Groups();
if (!info.weights_.Empty()) {
CHECK_EQ(info.weights_.Size(), n_groups) << error::GroupWeight();
}
if (ti_plus_.Size() == 0 && param_.lambdarank_unbiased) {
CHECK_EQ(iter, 0);
ti_plus_ = linalg::Constant<double>(ctx_, 1.0, p_cache_->MaxPositionSize());
tj_minus_ = linalg::Constant<double>(ctx_, 1.0, p_cache_->MaxPositionSize());
li_ = linalg::Zeros<double>(ctx_, p_cache_->MaxPositionSize());
lj_ = linalg::Zeros<double>(ctx_, p_cache_->MaxPositionSize());
li_full_ = linalg::Zeros<double>(ctx_, info.num_row_);
lj_full_ = linalg::Zeros<double>(ctx_, info.num_row_);
}
static_cast<Loss*>(this)->GetGradientImpl(iter, predt, info, out_gpair);
if (param_.lambdarank_unbiased) {
this->UpdatePositionBias();
}
}
};
class LambdaRankNDCG : public LambdaRankObj<LambdaRankNDCG, ltr::NDCGCache> {
public:
template <bool unbiased, bool exp_gain>
void CalcLambdaForGroupNDCG(std::int32_t iter, common::Span<float const> g_predt,
linalg::VectorView<float const> g_label, float w,
common::Span<std::size_t const> g_rank,
common::Span<GradientPair> g_gpair,
linalg::VectorView<double const> inv_IDCG,
common::Span<double const> discount, bst_group_t g) {
auto delta = [&](auto y_high, auto y_low, std::size_t rank_high, std::size_t rank_low,
bst_group_t g) {
static_assert(std::is_floating_point<decltype(y_high)>::value);
return DeltaNDCG<exp_gain>(y_high, y_low, rank_high, rank_low, inv_IDCG(g), discount);
};
this->CalcLambdaForGroup<unbiased>(iter, g_predt, g_label, w, g_rank, g, delta, g_gpair);
}
void GetGradientImpl(std::int32_t iter, const HostDeviceVector<float>& predt,
const MetaInfo& info, HostDeviceVector<GradientPair>* out_gpair) {
if (ctx_->IsCUDA()) {
cuda_impl::LambdaRankGetGradientNDCG(
ctx_, iter, predt, info, GetCache(), ti_plus_.View(ctx_->gpu_id),
tj_minus_.View(ctx_->gpu_id), li_full_.View(ctx_->gpu_id), lj_full_.View(ctx_->gpu_id),
out_gpair);
return;
}
bst_group_t n_groups = p_cache_->Groups();
auto gptr = p_cache_->DataGroupPtr(ctx_);
out_gpair->Resize(info.num_row_);
auto h_gpair = out_gpair->HostSpan();
auto h_predt = predt.ConstHostSpan();
auto h_label = info.labels.HostView();
auto h_weight = common::MakeOptionalWeights(ctx_, info.weights_);
auto make_range = [&](bst_group_t g) { return linalg::Range(gptr[g], gptr[g + 1]); };
auto dct = GetCache()->Discount(ctx_);
auto rank_idx = p_cache_->SortedIdx(ctx_, h_predt);
auto inv_IDCG = GetCache()->InvIDCG(ctx_);
common::ParallelFor(n_groups, ctx_->Threads(), common::Sched::Guided(), [&](auto g) {
std::size_t cnt = gptr[g + 1] - gptr[g];
auto w = h_weight[g];
auto g_predt = h_predt.subspan(gptr[g], cnt);
auto g_gpair = h_gpair.subspan(gptr[g], cnt);
auto g_label = h_label.Slice(make_range(g), 0);
auto g_rank = rank_idx.subspan(gptr[g], cnt);
auto args =
std::make_tuple(this, iter, g_predt, g_label, w, g_rank, g_gpair, inv_IDCG, dct, g);
if (param_.lambdarank_unbiased) {
if (param_.ndcg_exp_gain) {
std::apply(&LambdaRankNDCG::CalcLambdaForGroupNDCG<true, true>, args);
} else {
std::apply(&LambdaRankNDCG::CalcLambdaForGroupNDCG<true, false>, args);
}
} else {
if (param_.ndcg_exp_gain) {
std::apply(&LambdaRankNDCG::CalcLambdaForGroupNDCG<false, true>, args);
} else {
std::apply(&LambdaRankNDCG::CalcLambdaForGroupNDCG<false, false>, args);
}
}
});
}
static char const* Name() { return "rank:ndcg"; }
[[nodiscard]] const char* DefaultEvalMetric() const override {
return this->RankEvalMetric("ndcg");
}
[[nodiscard]] Json DefaultMetricConfig() const override {
Json config{Object{}};
config["name"] = String{DefaultEvalMetric()};
config["lambdarank_param"] = ToJson(param_);
return config;
}
};
namespace cuda_impl {
#if !defined(XGBOOST_USE_CUDA)
void LambdaRankGetGradientNDCG(Context const*, std::int32_t, HostDeviceVector<float> const&,
const MetaInfo&, std::shared_ptr<ltr::NDCGCache>,
linalg::VectorView<double const>, // input bias ratio
linalg::VectorView<double const>, // input bias ratio
linalg::VectorView<double>, linalg::VectorView<double>,
HostDeviceVector<GradientPair>*) {
common::AssertGPUSupport();
}
void LambdaRankUpdatePositionBias(Context const*, linalg::VectorView<double const>,
linalg::VectorView<double const>, linalg::Vector<double>*,
linalg::Vector<double>*, linalg::Vector<double>*,
linalg::Vector<double>*, std::shared_ptr<ltr::RankingCache>) {
common::AssertGPUSupport();
}
#endif // !defined(XGBOOST_USE_CUDA)
} // namespace cuda_impl
XGBOOST_REGISTER_OBJECTIVE(LambdaRankNDCG, LambdaRankNDCG::Name())
.describe("LambdaRank with NDCG loss as objective")
.set_body([]() { return new LambdaRankNDCG{}; });
DMLC_REGISTRY_FILE_TAG(lambdarank_obj);
} // namespace xgboost::obj

View File

@ -37,6 +37,312 @@ namespace xgboost::obj {
DMLC_REGISTRY_FILE_TAG(lambdarank_obj_cu);
namespace cuda_impl {
namespace {
/**
* \brief Calculate minimum value of bias for floating point truncation.
*/
void MinBias(Context const* ctx, std::shared_ptr<ltr::RankingCache> p_cache,
linalg::VectorView<double const> t_plus, linalg::VectorView<double const> tj_minus,
common::Span<double> d_min) {
CHECK_EQ(d_min.size(), 2);
auto cuctx = ctx->CUDACtx();
auto k = t_plus.Size();
auto const& p = p_cache->Param();
CHECK_GT(k, 0);
CHECK_EQ(k, p_cache->MaxPositionSize());
auto key_it = dh::MakeTransformIterator<std::size_t>(
thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(std::size_t i) { return i * k; });
auto val_it = dh::MakeTransformIterator<double>(thrust::make_counting_iterator(0ul),
[=] XGBOOST_DEVICE(std::size_t i) {
if (i >= k) {
return std::abs(tj_minus(i - k));
}
return std::abs(t_plus(i));
});
std::size_t bytes;
cub::DeviceSegmentedReduce::Min(nullptr, bytes, val_it, d_min.data(), 2, key_it, key_it + 1,
cuctx->Stream());
dh::TemporaryArray<char> temp(bytes);
cub::DeviceSegmentedReduce::Min(temp.data().get(), bytes, val_it, d_min.data(), 2, key_it,
key_it + 1, cuctx->Stream());
}
/**
* \brief Type for gradient statistic. (Gradient, cost for unbiased LTR, normalization factor)
*/
using GradCostNorm = thrust::tuple<GradientPair, double, double>;
/**
* \brief Obtain and update the gradient for one pair.
*/
template <bool unbiased, bool has_truncation, typename Delta>
struct GetGradOp {
MakePairsOp<has_truncation> make_pair;
Delta delta;
bool need_update;
auto __device__ operator()(std::size_t idx) -> GradCostNorm {
auto const& args = make_pair.args;
auto g = dh::SegmentId(args.d_threads_group_ptr, idx);
auto data_group_begin = static_cast<std::size_t>(args.d_group_ptr[g]);
std::size_t n_data = args.d_group_ptr[g + 1] - data_group_begin;
// obtain group segment data.
auto g_label = args.labels.Slice(linalg::Range(data_group_begin, data_group_begin + n_data), 0);
auto g_predt = args.predts.subspan(data_group_begin, n_data);
auto g_gpair = args.gpairs.subspan(data_group_begin, n_data).data();
auto g_rank = args.d_sorted_idx.subspan(data_group_begin, n_data);
auto [i, j] = make_pair(idx, g);
std::size_t rank_high = i, rank_low = j;
if (g_label(g_rank[i]) == g_label(g_rank[j])) {
return thrust::make_tuple(GradientPair{}, 0.0, 0.0);
}
if (g_label(g_rank[i]) < g_label(g_rank[j])) {
thrust::swap(rank_high, rank_low);
}
double cost{0};
auto delta_op = [&](auto const&... args) { return delta(args..., g); };
GradientPair pg = LambdaGrad<unbiased>(g_label, g_predt, g_rank, rank_high, rank_low, delta_op,
args.ti_plus, args.tj_minus, &cost);
std::size_t idx_high = g_rank[rank_high];
std::size_t idx_low = g_rank[rank_low];
if (need_update) {
// second run, update the gradient
auto ng = Repulse(pg);
auto gr = args.d_roundings(g);
// positive gradient truncated
auto pgt = GradientPair{common::TruncateWithRounding(gr.GetGrad(), pg.GetGrad()),
common::TruncateWithRounding(gr.GetHess(), pg.GetHess())};
// negative gradient truncated
auto ngt = GradientPair{common::TruncateWithRounding(gr.GetGrad(), ng.GetGrad()),
common::TruncateWithRounding(gr.GetHess(), ng.GetHess())};
dh::AtomicAddGpair(g_gpair + idx_high, pgt);
dh::AtomicAddGpair(g_gpair + idx_low, ngt);
}
if (unbiased && need_update) {
// second run, update the cost
assert(args.tj_minus.Size() == args.ti_plus.Size() && "Invalid size of position bias");
auto g_li = args.li.Slice(linalg::Range(data_group_begin, data_group_begin + n_data));
auto g_lj = args.lj.Slice(linalg::Range(data_group_begin, data_group_begin + n_data));
if (idx_high < args.ti_plus.Size() && idx_low < args.ti_plus.Size()) {
if (args.tj_minus(idx_low) >= Eps64()) {
// eq.30
atomicAdd(&g_li(idx_high), common::TruncateWithRounding(args.d_cost_rounding[0],
cost / args.tj_minus(idx_low)));
}
if (args.ti_plus(idx_high) >= Eps64()) {
// eq.31
atomicAdd(&g_lj(idx_low), common::TruncateWithRounding(args.d_cost_rounding[0],
cost / args.ti_plus(idx_high)));
}
}
}
return thrust::make_tuple(GradientPair{std::abs(pg.GetGrad()), std::abs(pg.GetHess())},
std::abs(cost), -2.0 * static_cast<double>(pg.GetGrad()));
}
};
template <bool unbiased, bool has_truncation, typename Delta>
struct MakeGetGrad {
MakePairsOp<has_truncation> make_pair;
Delta delta;
[[nodiscard]] KernelInputs const& Args() const { return make_pair.args; }
MakeGetGrad(KernelInputs args, Delta d) : make_pair{args}, delta{std::move(d)} {}
GetGradOp<unbiased, has_truncation, Delta> operator()(bool need_update) {
return GetGradOp<unbiased, has_truncation, Delta>{make_pair, delta, need_update};
}
};
/**
* \brief Calculate gradient for all pairs using update op created by make_get_grad.
*
* We need to run gradient calculation twice, the first time gathers infomation like
* maximum gradient, maximum cost, and the normalization term using reduction. The second
* time performs the actual update.
*
* Without normalization, we only need to run it once since we can manually calculate
* the bounds of gradient (NDCG \in [0, 1], delta_NDCG \in [0, 1], ti+/tj- are from the
* previous iteration so the bound can be calculated for current iteration). However, if
* normalization is used, the delta score is un-bounded and we need to obtain the sum
* gradient. As a tradeoff, we simply run the kernel twice, once as reduction, second
* one as for_each.
*
* Alternatively, we can bound the delta score by limiting the output of the model using
* sigmoid for binary output and some normalization for multi-level. But effect to the
* accuracy is not known yet, and it's only used by GPU.
*
* For performance, the segmented sort for sorted scores is the bottleneck and takes up
* about half of the time, while the reduction and for_each takes up the second half.
*/
template <bool unbiased, bool has_truncation, typename Delta>
void CalcGrad(Context const* ctx, MetaInfo const& info, std::shared_ptr<ltr::RankingCache> p_cache,
MakeGetGrad<unbiased, has_truncation, Delta> make_get_grad) {
auto n_groups = p_cache->Groups();
auto d_threads_group_ptr = p_cache->CUDAThreadsGroupPtr();
auto d_gptr = p_cache->DataGroupPtr(ctx);
auto d_gpair = make_get_grad.Args().gpairs;
/**
* First pass, gather info for normalization and rounding factor.
*/
auto val_it = dh::MakeTransformIterator<GradCostNorm>(thrust::make_counting_iterator(0ul),
make_get_grad(false));
auto reduction_op = [] XGBOOST_DEVICE(GradCostNorm const& l,
GradCostNorm const& r) -> GradCostNorm {
// get maximum gradient for each group, along with cost and the normalization term
auto const& lg = thrust::get<0>(l);
auto const& rg = thrust::get<0>(r);
auto grad = std::max(lg.GetGrad(), rg.GetGrad());
auto hess = std::max(lg.GetHess(), rg.GetHess());
auto cost = std::max(thrust::get<1>(l), thrust::get<1>(r));
double sum_lambda = thrust::get<2>(l) + thrust::get<2>(r);
return thrust::make_tuple(GradientPair{std::abs(grad), std::abs(hess)}, cost, sum_lambda);
};
auto init = thrust::make_tuple(GradientPair{0.0f, 0.0f}, 0.0, 0.0);
common::Span<GradCostNorm> d_max_lambdas = p_cache->MaxLambdas<GradCostNorm>(ctx, n_groups);
CHECK_EQ(n_groups * sizeof(GradCostNorm), d_max_lambdas.size_bytes());
std::size_t bytes;
cub::DeviceSegmentedReduce::Reduce(nullptr, bytes, val_it, d_max_lambdas.data(), n_groups,
d_threads_group_ptr.data(), d_threads_group_ptr.data() + 1,
reduction_op, init, ctx->CUDACtx()->Stream());
dh::TemporaryArray<char> temp(bytes);
cub::DeviceSegmentedReduce::Reduce(
temp.data().get(), bytes, val_it, d_max_lambdas.data(), n_groups, d_threads_group_ptr.data(),
d_threads_group_ptr.data() + 1, reduction_op, init, ctx->CUDACtx()->Stream());
dh::TemporaryArray<double> min_bias(2);
auto d_min_bias = dh::ToSpan(min_bias);
if (unbiased) {
MinBias(ctx, p_cache, make_get_grad.Args().ti_plus, make_get_grad.Args().tj_minus, d_min_bias);
}
/**
* Create rounding factors
*/
auto d_cost_rounding = p_cache->CUDACostRounding(ctx);
auto d_rounding = p_cache->CUDARounding(ctx);
dh::LaunchN(n_groups, ctx->CUDACtx()->Stream(), [=] XGBOOST_DEVICE(std::size_t g) mutable {
auto group_size = d_gptr[g + 1] - d_gptr[g];
auto const& max_grad = thrust::get<0>(d_max_lambdas[g]);
// float group size
auto fgs = static_cast<float>(group_size);
auto grad = common::CreateRoundingFactor(fgs * max_grad.GetGrad(), group_size);
auto hess = common::CreateRoundingFactor(fgs * max_grad.GetHess(), group_size);
d_rounding(g) = GradientPair{grad, hess};
auto cost = thrust::get<1>(d_max_lambdas[g]);
if (unbiased) {
cost /= std::min(d_min_bias[0], d_min_bias[1]);
d_cost_rounding[0] = common::CreateRoundingFactor(fgs * cost, group_size);
}
});
/**
* Second pass, actual update to gradient and bias.
*/
thrust::for_each_n(ctx->CUDACtx()->CTP(), thrust::make_counting_iterator(0ul),
p_cache->CUDAThreads(), make_get_grad(true));
/**
* Lastly, normalization and weight.
*/
auto d_weights = common::MakeOptionalWeights(ctx, info.weights_);
auto w_norm = p_cache->WeightNorm();
thrust::for_each_n(ctx->CUDACtx()->CTP(), thrust::make_counting_iterator(0ul), d_gpair.size(),
[=] XGBOOST_DEVICE(std::size_t i) {
auto g = dh::SegmentId(d_gptr, i);
auto sum_lambda = thrust::get<2>(d_max_lambdas[g]);
// Normalization
if (sum_lambda > 0.0) {
double norm = std::log2(1.0 + sum_lambda) / sum_lambda;
d_gpair[i] *= norm;
}
d_gpair[i] *= (d_weights[g] * w_norm);
});
}
/**
* \brief Handles boilerplate code like getting device span.
*/
template <typename Delta>
void Launch(Context const* ctx, std::int32_t iter, HostDeviceVector<float> const& preds,
const MetaInfo& info, std::shared_ptr<ltr::RankingCache> p_cache, Delta delta,
linalg::VectorView<double const> ti_plus, // input bias ratio
linalg::VectorView<double const> tj_minus, // input bias ratio
linalg::VectorView<double> li, linalg::VectorView<double> lj,
HostDeviceVector<GradientPair>* out_gpair) {
// boilerplate
std::int32_t device_id = ctx->gpu_id;
dh::safe_cuda(cudaSetDevice(device_id));
auto n_groups = p_cache->Groups();
info.labels.SetDevice(device_id);
preds.SetDevice(device_id);
out_gpair->SetDevice(device_id);
out_gpair->Resize(preds.Size());
CHECK(p_cache);
auto d_rounding = p_cache->CUDARounding(ctx);
auto d_cost_rounding = p_cache->CUDACostRounding(ctx);
CHECK_NE(d_rounding.Size(), 0);
auto label = info.labels.View(ctx->gpu_id);
auto predts = preds.ConstDeviceSpan();
auto gpairs = out_gpair->DeviceSpan();
thrust::fill_n(ctx->CUDACtx()->CTP(), gpairs.data(), gpairs.size(), GradientPair{0.0f, 0.0f});
auto const d_threads_group_ptr = p_cache->CUDAThreadsGroupPtr();
auto const d_gptr = p_cache->DataGroupPtr(ctx);
auto const rank_idx = p_cache->SortedIdx(ctx, predts);
auto const unbiased = p_cache->Param().lambdarank_unbiased;
common::Span<std::size_t const> d_y_sorted_idx;
if (!p_cache->Param().HasTruncation()) {
d_y_sorted_idx = SortY(ctx, info, rank_idx, p_cache);
}
KernelInputs args{ti_plus, tj_minus, li, lj, d_gptr, d_threads_group_ptr,
rank_idx, label, predts, gpairs, d_rounding, d_cost_rounding.data(),
d_y_sorted_idx, iter};
// dispatch based on unbiased and truncation
if (p_cache->Param().HasTruncation()) {
if (unbiased) {
CalcGrad(ctx, info, p_cache, MakeGetGrad<true, true, Delta>{args, delta});
} else {
CalcGrad(ctx, info, p_cache, MakeGetGrad<false, true, Delta>{args, delta});
}
} else {
if (unbiased) {
CalcGrad(ctx, info, p_cache, MakeGetGrad<true, false, Delta>{args, delta});
} else {
CalcGrad(ctx, info, p_cache, MakeGetGrad<false, false, Delta>{args, delta});
}
}
}
} // anonymous namespace
common::Span<std::size_t const> SortY(Context const* ctx, MetaInfo const& info,
common::Span<std::size_t const> d_rank,
std::shared_ptr<ltr::RankingCache> p_cache) {
@ -58,5 +364,116 @@ common::Span<std::size_t const> SortY(Context const* ctx, MetaInfo const& info,
common::SegmentedArgSort<false, true>(ctx, d_y_ranked, d_group_ptr, d_y_sorted_idx);
return d_y_sorted_idx;
}
void LambdaRankGetGradientNDCG(Context const* ctx, std::int32_t iter,
const HostDeviceVector<float>& preds, const MetaInfo& info,
std::shared_ptr<ltr::NDCGCache> p_cache,
linalg::VectorView<double const> ti_plus, // input bias ratio
linalg::VectorView<double const> tj_minus, // input bias ratio
linalg::VectorView<double> li, linalg::VectorView<double> lj,
HostDeviceVector<GradientPair>* out_gpair) {
// boilerplate
std::int32_t device_id = ctx->gpu_id;
dh::safe_cuda(cudaSetDevice(device_id));
auto const d_inv_IDCG = p_cache->InvIDCG(ctx);
auto const discount = p_cache->Discount(ctx);
info.labels.SetDevice(device_id);
preds.SetDevice(device_id);
auto const exp_gain = p_cache->Param().ndcg_exp_gain;
auto delta_ndcg = [=] XGBOOST_DEVICE(float y_high, float y_low, std::size_t rank_high,
std::size_t rank_low, bst_group_t g) {
return exp_gain ? DeltaNDCG<true>(y_high, y_low, rank_high, rank_low, d_inv_IDCG(g), discount)
: DeltaNDCG<false>(y_high, y_low, rank_high, rank_low, d_inv_IDCG(g), discount);
};
Launch(ctx, iter, preds, info, p_cache, delta_ndcg, ti_plus, tj_minus, li, lj, out_gpair);
}
namespace {
struct ReduceOp {
template <typename Tup>
Tup XGBOOST_DEVICE operator()(Tup const& l, Tup const& r) {
return thrust::make_tuple(thrust::get<0>(l) + thrust::get<0>(r),
thrust::get<1>(l) + thrust::get<1>(r));
}
};
} // namespace
void LambdaRankUpdatePositionBias(Context const* ctx, linalg::VectorView<double const> li_full,
linalg::VectorView<double const> lj_full,
linalg::Vector<double>* p_ti_plus,
linalg::Vector<double>* p_tj_minus,
linalg::Vector<double>* p_li, // loss
linalg::Vector<double>* p_lj,
std::shared_ptr<ltr::RankingCache> p_cache) {
auto const d_group_ptr = p_cache->DataGroupPtr(ctx);
auto n_groups = d_group_ptr.size() - 1;
auto ti_plus = p_ti_plus->View(ctx->gpu_id);
auto tj_minus = p_tj_minus->View(ctx->gpu_id);
auto li = p_li->View(ctx->gpu_id);
auto lj = p_lj->View(ctx->gpu_id);
CHECK_EQ(li.Size(), ti_plus.Size());
auto const& param = p_cache->Param();
auto regularizer = param.Regularizer();
std::size_t k = p_cache->MaxPositionSize();
CHECK_EQ(li.Size(), k);
CHECK_EQ(lj.Size(), k);
// reduce li_full to li for each group.
auto make_iter = [&](linalg::VectorView<double const> l_full) {
auto l_it = [=] XGBOOST_DEVICE(std::size_t i) {
// group index
auto g = i % n_groups;
// rank is the position within a group, also the segment index
auto r = i / n_groups;
auto begin = d_group_ptr[g];
std::size_t group_size = d_group_ptr[g + 1] - begin;
auto n = std::min(group_size, k);
// r can be greater than n since we allocate threads based on truncation level
// instead of actual group size.
if (r >= n) {
return 0.0;
}
return l_full(r + begin);
};
return l_it;
};
auto li_it =
dh::MakeTransformIterator<double>(thrust::make_counting_iterator(0ul), make_iter(li_full));
auto lj_it =
dh::MakeTransformIterator<double>(thrust::make_counting_iterator(0ul), make_iter(lj_full));
// k segments, each segment has size n_groups.
auto key_it = dh::MakeTransformIterator<std::size_t>(
thrust::make_counting_iterator(0ul),
[=] XGBOOST_DEVICE(std::size_t i) { return i * n_groups; });
auto val_it = thrust::make_zip_iterator(thrust::make_tuple(li_it, lj_it));
auto out_it =
thrust::make_zip_iterator(thrust::make_tuple(li.Values().data(), lj.Values().data()));
auto init = thrust::make_tuple(0.0, 0.0);
std::size_t bytes;
cub::DeviceSegmentedReduce::Reduce(nullptr, bytes, val_it, out_it, k, key_it, key_it + 1,
ReduceOp{}, init, ctx->CUDACtx()->Stream());
dh::TemporaryArray<char> temp(bytes);
cub::DeviceSegmentedReduce::Reduce(temp.data().get(), bytes, val_it, out_it, k, key_it,
key_it + 1, ReduceOp{}, init, ctx->CUDACtx()->Stream());
thrust::for_each_n(ctx->CUDACtx()->CTP(), thrust::make_counting_iterator(0ul), li.Size(),
[=] XGBOOST_DEVICE(std::size_t i) mutable {
if (li(0) >= Eps64()) {
ti_plus(i) = std::pow(li(i) / li(0), regularizer);
}
if (lj(0) >= Eps64()) {
tj_minus(i) = std::pow(lj(i) / lj(0), regularizer);
}
assert(!std::isinf(ti_plus(i)));
assert(!std::isinf(tj_minus(i)));
});
}
} // namespace cuda_impl
} // namespace xgboost::obj

View File

@ -1,5 +1,15 @@
/**
* Copyright 2023 XGBoost contributors
* Copyright 2023, XGBoost contributors
*
* Vocabulary explanation:
*
* There are two different lists we need to handle in the objective, first is the list of
* labels (relevance degree) provided by the user. Its order has no particular meaning
* when bias estimation is NOT used. Another one is generated by our model, sorted index
* based on prediction scores. `rank_high` refers to the position index of the model rank
* list that is higher than `rank_low`, while `idx_high` refers to where does the
* `rank_high` sample comes from. Simply put, `rank_high` indexes into the rank list
* obtained from the model, while `idx_high` indexes into the user provided sample list.
*/
#ifndef XGBOOST_OBJECTIVE_LAMBDARANK_OBJ_H_
#define XGBOOST_OBJECTIVE_LAMBDARANK_OBJ_H_
@ -25,14 +35,19 @@
#include "xgboost/span.h" // for Span
namespace xgboost::obj {
double constexpr Eps64() { return 1e-16; }
template <bool exp>
XGBOOST_DEVICE double DeltaNDCG(float y_high, float y_low, std::size_t r_high, std::size_t r_low,
double inv_IDCG, common::Span<double const> discount) {
XGBOOST_DEVICE double DeltaNDCG(float y_high, float y_low, std::size_t rank_high,
std::size_t rank_low, double inv_IDCG,
common::Span<double const> discount) {
// Use rank_high instead of idx_high as we are calculating discount based on ranks
// provided by the model.
double gain_high = exp ? ltr::CalcDCGGain(y_high) : y_high;
double discount_high = discount[r_high];
double discount_high = discount[rank_high];
double gain_low = exp ? ltr::CalcDCGGain(y_low) : y_low;
double discount_low = discount[r_low];
double discount_low = discount[rank_low];
double original = gain_high * discount_high + gain_low * discount_low;
double changed = gain_low * discount_high + gain_high * discount_low;
@ -70,9 +85,9 @@ template <bool unbiased, typename Delta>
XGBOOST_DEVICE GradientPair
LambdaGrad(linalg::VectorView<float const> labels, common::Span<float const> predts,
common::Span<size_t const> sorted_idx,
std::size_t rank_high, // cordiniate
std::size_t rank_low, // cordiniate
Delta delta, // delta score
std::size_t rank_high, // higher index on the model rank list
std::size_t rank_low, // lower index on the model rank list
Delta delta, // function to calculate delta score
linalg::VectorView<double const> t_plus, // input bias ratio
linalg::VectorView<double const> t_minus, // input bias ratio
double* p_cost) {
@ -95,30 +110,34 @@ LambdaGrad(linalg::VectorView<float const> labels, common::Span<float const> pre
// Use double whenever possible as we are working on the exp space.
double delta_score = std::abs(s_high - s_low);
double sigmoid = common::Sigmoid(s_high - s_low);
double const sigmoid = common::Sigmoid(s_high - s_low);
// Change in metric score like \delta NDCG or \delta MAP
double delta_metric = std::abs(delta(y_high, y_low, rank_high, rank_low));
if (best_score != worst_score) {
delta_metric /= (delta_score + kRtEps);
delta_metric /= (delta_score + 0.01);
}
if (unbiased) {
*p_cost = std::log(1.0 / (1.0 - sigmoid)) * delta_metric;
}
constexpr double kEps = 1e-16;
auto lambda_ij = (sigmoid - 1.0) * delta_metric;
auto hessian_ij = std::max(sigmoid * (1.0 - sigmoid), kEps) * delta_metric * 2.0;
auto hessian_ij = std::max(sigmoid * (1.0 - sigmoid), Eps64()) * delta_metric * 2.0;
auto k = t_plus.Size();
assert(t_minus.Size() == k && "Invalid size of position bias");
if (unbiased && idx_high < k && idx_low < k) {
lambda_ij /= (t_minus(idx_low) * t_plus(idx_high) + kRtEps);
hessian_ij /= (t_minus(idx_low) * t_plus(idx_high) + kRtEps);
// We need to skip samples that exceed the maximum number of tracked positions, and
// samples that have low probability and might bring us floating point issues.
if (unbiased && idx_high < k && idx_low < k && t_minus(idx_low) >= Eps64() &&
t_plus(idx_high) >= Eps64()) {
// The index should be ranks[idx_low], since we assume label is sorted, this reduces
// to `idx_low`, which represents the position on the input list, as explained in the
// file header.
lambda_ij /= (t_plus(idx_high) * t_minus(idx_low));
hessian_ij /= (t_plus(idx_high) * t_minus(idx_low));
}
auto pg = GradientPair{static_cast<float>(lambda_ij), static_cast<float>(hessian_ij)};
return pg;
}
@ -137,27 +156,6 @@ void LambdaRankGetGradientNDCG(Context const* ctx, std::int32_t iter,
linalg::VectorView<double> li, linalg::VectorView<double> lj,
HostDeviceVector<GradientPair>* out_gpair);
/**
* \brief Generate statistic for MAP used for calculating \Delta Z in lambda mart.
*/
void MAPStat(Context const* ctx, MetaInfo const& info, common::Span<std::size_t const> d_rank_idx,
std::shared_ptr<ltr::MAPCache> p_cache);
void LambdaRankGetGradientMAP(Context const* ctx, std::int32_t iter,
HostDeviceVector<float> const& predt, MetaInfo const& info,
std::shared_ptr<ltr::MAPCache> p_cache,
linalg::VectorView<double const> t_plus, // input bias ratio
linalg::VectorView<double const> t_minus, // input bias ratio
linalg::VectorView<double> li, linalg::VectorView<double> lj,
HostDeviceVector<GradientPair>* out_gpair);
void LambdaRankGetGradientPairwise(Context const* ctx, std::int32_t iter,
HostDeviceVector<float> const& predt, const MetaInfo& info,
std::shared_ptr<ltr::RankingCache> p_cache,
linalg::VectorView<double const> ti_plus, // input bias ratio
linalg::VectorView<double const> tj_minus, // input bias ratio
linalg::VectorView<double> li, linalg::VectorView<double> lj,
HostDeviceVector<GradientPair>* out_gpair);
void LambdaRankUpdatePositionBias(Context const* ctx, linalg::VectorView<double const> li_full,
linalg::VectorView<double const> lj_full,
@ -167,18 +165,6 @@ void LambdaRankUpdatePositionBias(Context const* ctx, linalg::VectorView<double
std::shared_ptr<ltr::RankingCache> p_cache);
} // namespace cuda_impl
namespace cpu_impl {
/**
* \brief Generate statistic for MAP used for calculating \Delta Z in lambda mart.
*
* \param label Ground truth relevance label.
* \param rank_idx Sorted index of prediction.
* \param p_cache An initialized MAPCache.
*/
void MAPStat(Context const* ctx, linalg::VectorView<float const> label,
common::Span<std::size_t const> rank_idx, std::shared_ptr<ltr::MAPCache> p_cache);
} // namespace cpu_impl
/**
* \param Construct pairs on CPU
*

View File

@ -48,12 +48,15 @@ DMLC_REGISTRY_LINK_TAG(quantile_obj_gpu);
DMLC_REGISTRY_LINK_TAG(hinge_obj_gpu);
DMLC_REGISTRY_LINK_TAG(multiclass_obj_gpu);
DMLC_REGISTRY_LINK_TAG(rank_obj_gpu);
DMLC_REGISTRY_LINK_TAG(lambdarank_obj);
DMLC_REGISTRY_LINK_TAG(lambdarank_obj_cu);
#else
DMLC_REGISTRY_LINK_TAG(regression_obj);
DMLC_REGISTRY_LINK_TAG(quantile_obj);
DMLC_REGISTRY_LINK_TAG(hinge_obj);
DMLC_REGISTRY_LINK_TAG(multiclass_obj);
DMLC_REGISTRY_LINK_TAG(rank_obj);
DMLC_REGISTRY_LINK_TAG(lambdarank_obj);
#endif // XGBOOST_USE_CUDA
} // namespace obj
} // namespace xgboost

View File

@ -207,174 +207,6 @@ class IndexablePredictionSorter {
};
#endif
// beta version: NDCG lambda rank
class NDCGLambdaWeightComputer
#if defined(__CUDACC__)
: public IndexablePredictionSorter
#endif
{
public:
#if defined(__CUDACC__)
// This function object computes the item's DCG value
class ComputeItemDCG : public thrust::unary_function<uint32_t, float> {
public:
XGBOOST_DEVICE ComputeItemDCG(const common::Span<const float> &dsorted_labels,
const common::Span<const uint32_t> &dgroups,
const common::Span<const uint32_t> &gidxs)
: dsorted_labels_(dsorted_labels),
dgroups_(dgroups),
dgidxs_(gidxs) {}
// Compute DCG for the item at 'idx'
__device__ __forceinline__ float operator()(uint32_t idx) const {
return ComputeItemDCGWeight(dsorted_labels_[idx], idx - dgroups_[dgidxs_[idx]]);
}
private:
const common::Span<const float> dsorted_labels_; // Labels sorted within a group
const common::Span<const uint32_t> dgroups_; // The group indices - where each group
// begins and ends
const common::Span<const uint32_t> dgidxs_; // The group each items belongs to
};
// Type containing device pointers that can be cheaply copied on the kernel
class NDCGLambdaWeightMultiplier : public BaseLambdaWeightMultiplier {
public:
NDCGLambdaWeightMultiplier(const dh::SegmentSorter<float> &segment_label_sorter,
const NDCGLambdaWeightComputer &lwc)
: BaseLambdaWeightMultiplier(segment_label_sorter, lwc.GetPredictionSorter()),
dgroup_dcgs_(lwc.GetGroupDcgsSpan()) {}
// Adjust the items weight by this value
__device__ __forceinline__ bst_float GetWeight(uint32_t gidx, int pidx, int nidx) const {
if (dgroup_dcgs_[gidx] == 0.0) return 0.0f;
uint32_t group_begin = dgroups_[gidx];
auto pos_lab_orig_posn = dorig_pos_[pidx];
auto neg_lab_orig_posn = dorig_pos_[nidx];
KERNEL_CHECK(pos_lab_orig_posn != neg_lab_orig_posn);
// Note: the label positive and negative indices are relative to the entire dataset.
// Hence, scale them back to an index within the group
auto pos_pred_pos = dindexable_sorted_preds_pos_[pos_lab_orig_posn] - group_begin;
auto neg_pred_pos = dindexable_sorted_preds_pos_[neg_lab_orig_posn] - group_begin;
return NDCGLambdaWeightComputer::ComputeDeltaWeight(
pos_pred_pos, neg_pred_pos,
static_cast<int>(dsorted_labels_[pidx]), static_cast<int>(dsorted_labels_[nidx]),
dgroup_dcgs_[gidx]);
}
private:
const common::Span<const float> dgroup_dcgs_; // Group DCG values
};
NDCGLambdaWeightComputer(const bst_float *dpreds,
const bst_float*,
const dh::SegmentSorter<float> &segment_label_sorter)
: IndexablePredictionSorter(dpreds, segment_label_sorter),
dgroup_dcg_(segment_label_sorter.GetNumGroups(), 0.0f),
weight_multiplier_(segment_label_sorter, *this) {
const auto &group_segments = segment_label_sorter.GetGroupSegmentsSpan();
// Allocator to be used for managing space overhead while performing transformed reductions
dh::XGBCachingDeviceAllocator<char> alloc;
// Compute each elements DCG values and reduce them across groups concurrently.
auto end_range =
thrust::reduce_by_key(thrust::cuda::par(alloc),
dh::tcbegin(group_segments), dh::tcend(group_segments),
thrust::make_transform_iterator(
// The indices need not be sequential within a group, as we care only
// about the sum of items DCG values within a group
dh::tcbegin(segment_label_sorter.GetOriginalPositionsSpan()),
ComputeItemDCG(segment_label_sorter.GetItemsSpan(),
segment_label_sorter.GetGroupsSpan(),
group_segments)),
thrust::make_discard_iterator(), // We don't care for the group indices
dgroup_dcg_.begin()); // Sum of the item's DCG values in the group
CHECK_EQ(static_cast<unsigned>(end_range.second - dgroup_dcg_.begin()), dgroup_dcg_.size());
}
inline const common::Span<const float> GetGroupDcgsSpan() const {
return { dgroup_dcg_.data().get(), dgroup_dcg_.size() };
}
inline const NDCGLambdaWeightMultiplier GetWeightMultiplier() const {
return weight_multiplier_;
}
#endif
static void GetLambdaWeight(const std::vector<ListEntry> &sorted_list,
std::vector<LambdaPair> *io_pairs) {
std::vector<LambdaPair> &pairs = *io_pairs;
float IDCG; // NOLINT
{
std::vector<bst_float> labels(sorted_list.size());
for (size_t i = 0; i < sorted_list.size(); ++i) {
labels[i] = sorted_list[i].label;
}
std::stable_sort(labels.begin(), labels.end(), std::greater<>());
IDCG = ComputeGroupDCGWeight(&labels[0], labels.size());
}
if (IDCG == 0.0) {
for (auto & pair : pairs) {
pair.weight = 0.0f;
}
} else {
for (auto & pair : pairs) {
unsigned pos_idx = pair.pos_index;
unsigned neg_idx = pair.neg_index;
pair.weight *= ComputeDeltaWeight(pos_idx, neg_idx,
sorted_list[pos_idx].label, sorted_list[neg_idx].label,
IDCG);
}
}
}
static char const* Name() {
return "rank:ndcg";
}
inline static bst_float ComputeGroupDCGWeight(const float *sorted_labels, uint32_t size) {
double sumdcg = 0.0;
for (uint32_t i = 0; i < size; ++i) {
sumdcg += ComputeItemDCGWeight(sorted_labels[i], i);
}
return static_cast<bst_float>(sumdcg);
}
private:
XGBOOST_DEVICE inline static bst_float ComputeItemDCGWeight(unsigned label, uint32_t idx) {
return (label != 0) ? (((1 << label) - 1) / std::log2(static_cast<bst_float>(idx + 2))) : 0;
}
// Compute the weight adjustment for an item within a group:
// pos_pred_pos => Where does the positive label live, had the list been sorted by prediction
// neg_pred_pos => Where does the negative label live, had the list been sorted by prediction
// pos_label => positive label value from sorted label list
// neg_label => negative label value from sorted label list
XGBOOST_DEVICE inline static bst_float ComputeDeltaWeight(uint32_t pos_pred_pos,
uint32_t neg_pred_pos,
int pos_label, int neg_label,
float idcg) {
float pos_loginv = 1.0f / std::log2(pos_pred_pos + 2.0f);
float neg_loginv = 1.0f / std::log2(neg_pred_pos + 2.0f);
bst_float original = ((1 << pos_label) - 1) * pos_loginv + ((1 << neg_label) - 1) * neg_loginv;
float changed = ((1 << neg_label) - 1) * pos_loginv + ((1 << pos_label) - 1) * neg_loginv;
bst_float delta = (original - changed) * (1.0f / idcg);
if (delta < 0.0f) delta = - delta;
return delta;
}
#if defined(__CUDACC__)
dh::caching_device_vector<float> dgroup_dcg_;
// This computes the adjustment to the weight
const NDCGLambdaWeightMultiplier weight_multiplier_;
#endif
};
class MAPLambdaWeightComputer
#if defined(__CUDACC__)
: public IndexablePredictionSorter
@ -948,10 +780,6 @@ XGBOOST_REGISTER_OBJECTIVE(PairwiseRankObj, PairwiseLambdaWeightComputer::Name()
.describe("Pairwise rank objective.")
.set_body([]() { return new LambdaRankObj<PairwiseLambdaWeightComputer>(); });
XGBOOST_REGISTER_OBJECTIVE(LambdaRankNDCG, NDCGLambdaWeightComputer::Name())
.describe("LambdaRank with NDCG as objective.")
.set_body([]() { return new LambdaRankObj<NDCGLambdaWeightComputer>(); });
XGBOOST_REGISTER_OBJECTIVE(LambdaRankObjMAP, MAPLambdaWeightComputer::Name())
.describe("LambdaRank with MAP as objective.")
.set_body([]() { return new LambdaRankObj<MAPLambdaWeightComputer>(); });

View File

@ -5,6 +5,7 @@
#include <gtest/gtest.h> // for Test, Message, TestPartResult, CmpHel...
#include <algorithm> // for sort
#include <cstddef> // for size_t
#include <initializer_list> // for initializer_list
#include <map> // for map
@ -13,7 +14,6 @@
#include <string> // for char_traits, basic_string, string
#include <vector> // for vector
#include "../../../src/common/ranking_utils.h" // for LambdaRankParam
#include "../../../src/common/ranking_utils.h" // for NDCGCache, LambdaRankParam
#include "../helpers.h" // for CheckRankingObjFunction, CheckConfigReload
#include "xgboost/base.h" // for GradientPair, bst_group_t, Args
@ -25,6 +25,126 @@
#include "xgboost/span.h" // for Span
namespace xgboost::obj {
TEST(LambdaRank, NDCGJsonIO) {
Context ctx;
TestNDCGJsonIO(&ctx);
}
void TestNDCGGPair(Context const* ctx) {
{
std::unique_ptr<xgboost::ObjFunction> obj{xgboost::ObjFunction::Create("rank:ndcg", ctx)};
obj->Configure(Args{{"lambdarank_pair_method", "topk"}});
CheckConfigReload(obj, "rank:ndcg");
// No gain in swapping 2 documents.
CheckRankingObjFunction(obj,
{1, 1, 1, 1},
{1, 1, 1, 1},
{1.0f, 1.0f},
{0, 2, 4},
{0.0f, -0.0f, 0.0f, 0.0f},
{0.0f, 0.0f, 0.0f, 0.0f});
}
{
std::unique_ptr<xgboost::ObjFunction> obj{xgboost::ObjFunction::Create("rank:ndcg", ctx)};
obj->Configure(Args{{"lambdarank_pair_method", "topk"}});
// Test with setting sample weight to second query group
CheckRankingObjFunction(obj,
{0, 0.1f, 0, 0.1f},
{0, 1, 0, 1},
{2.0f, 0.0f},
{0, 2, 4},
{2.06611f, -2.06611f, 0.0f, 0.0f},
{2.169331f, 2.169331f, 0.0f, 0.0f});
CheckRankingObjFunction(obj,
{0, 0.1f, 0, 0.1f},
{0, 1, 0, 1},
{2.0f, 2.0f},
{0, 2, 4},
{2.06611f, -2.06611f, 2.06611f, -2.06611f},
{2.169331f, 2.169331f, 2.169331f, 2.169331f});
}
std::unique_ptr<xgboost::ObjFunction> obj{xgboost::ObjFunction::Create("rank:ndcg", ctx)};
obj->Configure(Args{{"lambdarank_pair_method", "topk"}});
HostDeviceVector<float> predts{0, 1, 0, 1};
MetaInfo info;
info.labels = linalg::Tensor<float, 2>{{0, 1, 0, 1}, {4, 1}, GPUIDX};
info.group_ptr_ = {0, 2, 4};
info.num_row_ = 4;
HostDeviceVector<GradientPair> gpairs;
obj->GetGradient(predts, info, 0, &gpairs);
ASSERT_EQ(gpairs.Size(), predts.Size());
{
predts = {1, 0, 1, 0};
HostDeviceVector<GradientPair> gpairs;
obj->GetGradient(predts, info, 0, &gpairs);
for (size_t i = 0; i < gpairs.Size(); ++i) {
ASSERT_GT(gpairs.HostSpan()[i].GetHess(), 0);
}
ASSERT_LT(gpairs.HostSpan()[1].GetGrad(), 0);
ASSERT_LT(gpairs.HostSpan()[3].GetGrad(), 0);
ASSERT_GT(gpairs.HostSpan()[0].GetGrad(), 0);
ASSERT_GT(gpairs.HostSpan()[2].GetGrad(), 0);
info.weights_ = {2, 3};
HostDeviceVector<GradientPair> weighted_gpairs;
obj->GetGradient(predts, info, 0, &weighted_gpairs);
auto const& h_gpairs = gpairs.ConstHostSpan();
auto const& h_weighted_gpairs = weighted_gpairs.ConstHostSpan();
for (size_t i : {0ul, 1ul}) {
ASSERT_FLOAT_EQ(h_weighted_gpairs[i].GetGrad(), h_gpairs[i].GetGrad() * 2.0f);
ASSERT_FLOAT_EQ(h_weighted_gpairs[i].GetHess(), h_gpairs[i].GetHess() * 2.0f);
}
for (size_t i : {2ul, 3ul}) {
ASSERT_FLOAT_EQ(h_weighted_gpairs[i].GetGrad(), h_gpairs[i].GetGrad() * 3.0f);
ASSERT_FLOAT_EQ(h_weighted_gpairs[i].GetHess(), h_gpairs[i].GetHess() * 3.0f);
}
}
ASSERT_NO_THROW(obj->DefaultEvalMetric());
}
TEST(LambdaRank, NDCGGPair) {
Context ctx;
TestNDCGGPair(&ctx);
}
void TestUnbiasedNDCG(Context const* ctx) {
std::unique_ptr<xgboost::ObjFunction> obj{xgboost::ObjFunction::Create("rank:ndcg", ctx)};
obj->Configure(Args{{"lambdarank_pair_method", "topk"},
{"lambdarank_unbiased", "true"},
{"lambdarank_bias_norm", "0"}});
std::shared_ptr<DMatrix> p_fmat{RandomDataGenerator{10, 1, 0.0f}.GenerateDMatrix(true, false, 2)};
auto h_label = p_fmat->Info().labels.HostView().Values();
// Move clicked samples to the beginning.
std::sort(h_label.begin(), h_label.end(), std::greater<>{});
HostDeviceVector<float> predt(p_fmat->Info().num_row_, 1.0f);
HostDeviceVector<GradientPair> out_gpair;
obj->GetGradient(predt, p_fmat->Info(), 0, &out_gpair);
Json config{Object{}};
obj->SaveConfig(&config);
auto ti_plus = get<F32Array const>(config["ti+"]);
ASSERT_FLOAT_EQ(ti_plus[0], 1.0);
// bias is non-increasing when prediction is constant. (constant cost on swapping documents)
for (std::size_t i = 1; i < ti_plus.size(); ++i) {
ASSERT_LE(ti_plus[i], ti_plus[i - 1]);
}
auto tj_minus = get<F32Array const>(config["tj-"]);
ASSERT_FLOAT_EQ(tj_minus[0], 1.0);
}
TEST(LambdaRank, UnbiasedNDCG) {
Context ctx;
TestUnbiasedNDCG(&ctx);
}
void InitMakePairTest(Context const* ctx, MetaInfo* out_info, HostDeviceVector<float>* out_predt) {
out_predt->SetDevice(ctx->gpu_id);
MetaInfo& info = *out_info;

View File

@ -12,6 +12,18 @@
#include "test_lambdarank_obj.h"
namespace xgboost::obj {
TEST(LambdaRank, GPUNDCGJsonIO) {
Context ctx;
ctx.gpu_id = 0;
TestNDCGJsonIO(&ctx);
}
TEST(LambdaRank, GPUNDCGGPair) {
Context ctx;
ctx.gpu_id = 0;
TestNDCGGPair(&ctx);
}
void TestGPUMakePair() {
Context ctx;
ctx.gpu_id = 0;
@ -107,6 +119,12 @@ void TestGPUMakePair() {
TEST(LambdaRank, GPUMakePair) { TestGPUMakePair(); }
TEST(LambdaRank, GPUUnbiasedNDCG) {
Context ctx;
ctx.gpu_id = 0;
TestUnbiasedNDCG(&ctx);
}
template <typename CountFunctor>
void RankItemCountImpl(std::vector<std::uint32_t> const &sorted_items, CountFunctor f,
std::uint32_t find_val, std::uint32_t exp_val) {

View File

@ -1,5 +1,5 @@
/**
* Copyright 2023, XGBoost Contributors
* Copyright (c) 2023, XGBoost Contributors
*/
#ifndef XGBOOST_OBJECTIVE_TEST_LAMBDARANK_OBJ_H_
#define XGBOOST_OBJECTIVE_TEST_LAMBDARANK_OBJ_H_
@ -18,6 +18,25 @@
#include "../helpers.h" // for EmptyDMatrix
namespace xgboost::obj {
inline void TestNDCGJsonIO(Context const* ctx) {
std::unique_ptr<xgboost::ObjFunction> obj{ObjFunction::Create("rank:ndcg", ctx)};
obj->Configure(Args{});
Json j_obj{Object()};
obj->SaveConfig(&j_obj);
ASSERT_EQ(get<String>(j_obj["name"]), "rank:ndcg");
auto const& j_param = j_obj["lambdarank_param"];
ASSERT_EQ(get<String>(j_param["ndcg_exp_gain"]), "1");
ASSERT_EQ(get<String>(j_param["lambdarank_num_pair_per_sample"]),
std::to_string(ltr::LambdaRankParam::NotSet()));
}
void TestNDCGGPair(Context const* ctx);
void TestUnbiasedNDCG(Context const* ctx);
/**
* \brief Initialize test data for make pair tests.
*/

View File

@ -35,24 +35,6 @@ TEST(Objective, DeclareUnifiedTest(PairwiseRankingGPair)) {
ASSERT_NO_THROW(obj->DefaultEvalMetric());
}
TEST(Objective, DeclareUnifiedTest(NDCG_JsonIO)) {
xgboost::Context ctx;
ctx.UpdateAllowUnknown(Args{});
std::unique_ptr<xgboost::ObjFunction> obj{xgboost::ObjFunction::Create("rank:ndcg", &ctx)};
obj->Configure(Args{});
Json j_obj {Object()};
obj->SaveConfig(&j_obj);
ASSERT_EQ(get<String>(j_obj["name"]), "rank:ndcg");;
auto const& j_param = j_obj["lambda_rank_param"];
ASSERT_EQ(get<String>(j_param["num_pairsample"]), "1");
ASSERT_EQ(get<String>(j_param["fix_list_weight"]), "0");
}
TEST(Objective, DeclareUnifiedTest(PairwiseRankingGPairSameLabels)) {
std::vector<std::pair<std::string, std::string>> args;
xgboost::Context ctx = xgboost::CreateEmptyGenericParam(GPUIDX);
@ -71,33 +53,6 @@ TEST(Objective, DeclareUnifiedTest(PairwiseRankingGPairSameLabels)) {
ASSERT_NO_THROW(obj->DefaultEvalMetric());
}
TEST(Objective, DeclareUnifiedTest(NDCGRankingGPair)) {
std::vector<std::pair<std::string, std::string>> args;
xgboost::Context ctx = xgboost::CreateEmptyGenericParam(GPUIDX);
std::unique_ptr<xgboost::ObjFunction> obj{xgboost::ObjFunction::Create("rank:ndcg", &ctx)};
obj->Configure(args);
CheckConfigReload(obj, "rank:ndcg");
// Test with setting sample weight to second query group
CheckRankingObjFunction(obj,
{0, 0.1f, 0, 0.1f},
{0, 1, 0, 1},
{2.0f, 0.0f},
{0, 2, 4},
{0.7f, -0.7f, 0.0f, 0.0f},
{0.74f, 0.74f, 0.0f, 0.0f});
CheckRankingObjFunction(obj,
{0, 0.1f, 0, 0.1f},
{0, 1, 0, 1},
{1.0f, 1.0f},
{0, 2, 4},
{0.35f, -0.35f, 0.35f, -0.35f},
{0.368f, 0.368f, 0.368f, 0.368f});
ASSERT_NO_THROW(obj->DefaultEvalMetric());
}
TEST(Objective, DeclareUnifiedTest(MAPRankingGPair)) {
std::vector<std::pair<std::string, std::string>> args;
xgboost::Context ctx = xgboost::CreateEmptyGenericParam(GPUIDX);

View File

@ -89,62 +89,6 @@ TEST(Objective, RankSegmentSorterAscendingTest) {
5, 4, 6});
}
TEST(Objective, NDCGLambdaWeightComputerTest) {
std::vector<float> hlabels = {3.1f, 1.2f, 2.3f, 4.4f, // Labels
7.8f, 5.01f, 6.96f,
10.3f, 8.7f, 11.4f, 9.45f, 11.4f};
dh::device_vector<bst_float> dlabels(hlabels);
auto segment_label_sorter = RankSegmentSorterTestImpl<float>(
{0, 4, 7, 12}, // Groups
hlabels,
{4.4f, 3.1f, 2.3f, 1.2f, // Expected sorted labels
7.8f, 6.96f, 5.01f,
11.4f, 11.4f, 10.3f, 9.45f, 8.7f},
{3, 0, 2, 1, // Expected original positions
4, 6, 5,
9, 11, 7, 10, 8});
// Created segmented predictions for the labels from above
std::vector<bst_float> hpreds{-9.78f, 24.367f, 0.908f, -11.47f,
-1.03f, -2.79f, -3.1f,
104.22f, 103.1f, -101.7f, 100.5f, 45.1f};
dh::device_vector<bst_float> dpreds(hpreds);
xgboost::obj::NDCGLambdaWeightComputer ndcg_lw_computer(dpreds.data().get(),
dlabels.data().get(),
*segment_label_sorter);
// Where will the predictions move from its current position, if they were sorted
// descendingly?
auto dsorted_pred_pos = ndcg_lw_computer.GetPredictionSorter().GetIndexableSortedPositionsSpan();
std::vector<uint32_t> hsorted_pred_pos(segment_label_sorter->GetNumItems());
dh::CopyDeviceSpanToVector(&hsorted_pred_pos, dsorted_pred_pos);
std::vector<uint32_t> expected_sorted_pred_pos{2, 0, 1, 3,
4, 5, 6,
7, 8, 11, 9, 10};
EXPECT_EQ(expected_sorted_pred_pos, hsorted_pred_pos);
// Check group DCG values
std::vector<float> hgroup_dcgs(segment_label_sorter->GetNumGroups());
dh::CopyDeviceSpanToVector(&hgroup_dcgs, ndcg_lw_computer.GetGroupDcgsSpan());
std::vector<uint32_t> hgroups(segment_label_sorter->GetNumGroups() + 1);
dh::CopyDeviceSpanToVector(&hgroups, segment_label_sorter->GetGroupsSpan());
EXPECT_EQ(hgroup_dcgs.size(), segment_label_sorter->GetNumGroups());
std::vector<float> hsorted_labels(segment_label_sorter->GetNumItems());
dh::CopyDeviceSpanToVector(&hsorted_labels, segment_label_sorter->GetItemsSpan());
for (size_t i = 0; i < hgroup_dcgs.size(); ++i) {
// Compute group DCG value on CPU and compare
auto gbegin = hgroups[i];
auto gend = hgroups[i + 1];
EXPECT_NEAR(
hgroup_dcgs[i],
xgboost::obj::NDCGLambdaWeightComputer::ComputeGroupDCGWeight(&hsorted_labels[gbegin],
gend - gbegin),
0.01f);
}
}
TEST(Objective, IndexableSortedItemsTest) {
std::vector<float> hlabels = {3.1f, 1.2f, 2.3f, 4.4f, // Labels
7.8f, 5.01f, 6.96f,

View File

@ -1,3 +1,4 @@
import json
import sys
import pytest
@ -36,19 +37,16 @@ class TestGPUEvalMetrics:
Xy = xgboost.DMatrix(X, y, group=group)
cpu = xgboost.train(
booster = xgboost.train(
{"tree_method": "hist", "eval_metric": "auc", "objective": "rank:ndcg"},
Xy,
num_boost_round=10,
)
cpu_auc = float(cpu.eval(Xy).split(":")[1])
gpu = xgboost.train(
{"tree_method": "gpu_hist", "eval_metric": "auc", "objective": "rank:ndcg"},
Xy,
num_boost_round=10,
)
gpu_auc = float(gpu.eval(Xy).split(":")[1])
cpu_auc = float(booster.eval(Xy).split(":")[1])
booster.set_param({"gpu_id": "0"})
assert json.loads(booster.save_config())["learner"]["generic_param"]["gpu_id"] == "0"
gpu_auc = float(booster.eval(Xy).split(":")[1])
assert json.loads(booster.save_config())["learner"]["generic_param"]["gpu_id"] == "0"
np.testing.assert_allclose(cpu_auc, gpu_auc)