Define pair generation strategies for LTR. (#8984)

This commit is contained in:
Jiaming Yuan 2023-03-30 12:00:35 +08:00 committed by GitHub
parent d385cc64e2
commit d062a9e009
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 770 additions and 48 deletions

View File

@ -123,7 +123,7 @@ struct LambdaRankParam : public XGBoostParameter<LambdaRankParam> {
DMLC_DECLARE_PARAMETER(LambdaRankParam) { DMLC_DECLARE_PARAMETER(LambdaRankParam) {
DMLC_DECLARE_FIELD(lambdarank_pair_method) DMLC_DECLARE_FIELD(lambdarank_pair_method)
.set_default(PairMethod::kMean) .set_default(PairMethod::kTopK)
.add_enum("mean", PairMethod::kMean) .add_enum("mean", PairMethod::kMean)
.add_enum("topk", PairMethod::kTopK) .add_enum("topk", PairMethod::kTopK)
.describe("Method for constructing pairs."); .describe("Method for constructing pairs.");

View File

@ -112,7 +112,6 @@ class PerGroupWeightPolicy {
return info.GetWeight(group_id); return info.GetWeight(group_id);
} }
}; };
} // anonymous namespace } // anonymous namespace
namespace xgboost::metric { namespace xgboost::metric {

View File

@ -14,8 +14,7 @@
#include "xgboost/linalg.h" // Tensor,Vector #include "xgboost/linalg.h" // Tensor,Vector
#include "xgboost/task.h" // ObjInfo #include "xgboost/task.h" // ObjInfo
namespace xgboost { namespace xgboost::obj {
namespace obj {
void FitIntercept::InitEstimation(MetaInfo const& info, linalg::Vector<float>* base_score) const { void FitIntercept::InitEstimation(MetaInfo const& info, linalg::Vector<float>* base_score) const {
if (this->Task().task == ObjInfo::kRegression) { if (this->Task().task == ObjInfo::kRegression) {
CheckInitInputs(info); CheckInitInputs(info);
@ -31,14 +30,13 @@ void FitIntercept::InitEstimation(MetaInfo const& info, linalg::Vector<float>* b
ObjFunction::Create(get<String const>(config["name"]), this->ctx_)}; ObjFunction::Create(get<String const>(config["name"]), this->ctx_)};
new_obj->LoadConfig(config); new_obj->LoadConfig(config);
new_obj->GetGradient(dummy_predt, info, 0, &gpair); new_obj->GetGradient(dummy_predt, info, 0, &gpair);
bst_target_t n_targets = this->Targets(info); bst_target_t n_targets = this->Targets(info);
linalg::Vector<float> leaf_weight; linalg::Vector<float> leaf_weight;
tree::FitStump(this->ctx_, info, gpair, n_targets, &leaf_weight); tree::FitStump(this->ctx_, info, gpair, n_targets, &leaf_weight);
// workaround, we don't support multi-target due to binary model serialization for // workaround, we don't support multi-target due to binary model serialization for
// base margin. // base margin.
common::Mean(this->ctx_, leaf_weight, base_score); common::Mean(this->ctx_, leaf_weight, base_score);
this->PredTransform(base_score->Data()); this->PredTransform(base_score->Data());
} }
} // namespace obj } // namespace xgboost::obj
} // namespace xgboost

View File

@ -7,8 +7,7 @@
#include "xgboost/linalg.h" // Tensor #include "xgboost/linalg.h" // Tensor
#include "xgboost/objective.h" // ObjFunction #include "xgboost/objective.h" // ObjFunction
namespace xgboost { namespace xgboost::obj {
namespace obj {
class FitIntercept : public ObjFunction { class FitIntercept : public ObjFunction {
void InitEstimation(MetaInfo const& info, linalg::Vector<float>* base_score) const override; void InitEstimation(MetaInfo const& info, linalg::Vector<float>* base_score) const override;
}; };
@ -20,6 +19,5 @@ inline void CheckInitInputs(MetaInfo const& info) {
<< "Number of weights should be equal to number of data points."; << "Number of weights should be equal to number of data points.";
} }
} }
} // namespace obj } // namespace xgboost::obj
} // namespace xgboost
#endif // XGBOOST_OBJECTIVE_INIT_ESTIMATION_H_ #endif // XGBOOST_OBJECTIVE_INIT_ESTIMATION_H_

View File

@ -0,0 +1,62 @@
/**
* Copyright 2015-2023 by XGBoost contributors
*
* \brief CUDA implementation of lambdarank.
*/
#include <thrust/fill.h> // for fill_n
#include <thrust/for_each.h> // for for_each_n
#include <thrust/iterator/counting_iterator.h> // for make_counting_iterator
#include <thrust/iterator/zip_iterator.h> // for make_zip_iterator
#include <thrust/tuple.h> // for make_tuple, tuple, tie, get
#include <algorithm> // for min
#include <cassert> // for assert
#include <cmath> // for abs, log2, isinf
#include <cstddef> // for size_t
#include <cstdint> // for int32_t
#include <memory> // for shared_ptr
#include <utility>
#include "../common/algorithm.cuh" // for SegmentedArgSort
#include "../common/cuda_context.cuh" // for CUDAContext
#include "../common/deterministic.cuh" // for CreateRoundingFactor, TruncateWithRounding
#include "../common/device_helpers.cuh" // for SegmentId, TemporaryArray, AtomicAddGpair
#include "../common/optional_weight.h" // for MakeOptionalWeights
#include "../common/ranking_utils.h" // for NDCGCache, LambdaRankParam, rel_degree_t
#include "lambdarank_obj.cuh"
#include "lambdarank_obj.h"
#include "xgboost/base.h" // for bst_group_t, XGBOOST_DEVICE, GradientPair
#include "xgboost/context.h" // for Context
#include "xgboost/data.h" // for MetaInfo
#include "xgboost/host_device_vector.h" // for HostDeviceVector
#include "xgboost/linalg.h" // for VectorView, Range, Vector
#include "xgboost/logging.h"
#include "xgboost/span.h" // for Span
namespace xgboost::obj {
DMLC_REGISTRY_FILE_TAG(lambdarank_obj_cu);
namespace cuda_impl {
common::Span<std::size_t const> SortY(Context const* ctx, MetaInfo const& info,
common::Span<std::size_t const> d_rank,
std::shared_ptr<ltr::RankingCache> p_cache) {
auto const d_group_ptr = p_cache->DataGroupPtr(ctx);
auto label = info.labels.View(ctx->gpu_id);
// The buffer for ranked y is necessary as cub segmented sort accepts only pointer.
auto d_y_ranked = p_cache->RankedY(ctx, info.num_row_);
thrust::for_each_n(ctx->CUDACtx()->CTP(), thrust::make_counting_iterator(0ul), d_y_ranked.size(),
[=] XGBOOST_DEVICE(std::size_t i) {
auto g = dh::SegmentId(d_group_ptr, i);
auto g_label =
label.Slice(linalg::Range(d_group_ptr[g], d_group_ptr[g + 1]), 0);
auto g_rank_idx = d_rank.subspan(d_group_ptr[g], g_label.Size());
i -= d_group_ptr[g];
auto g_y_ranked = d_y_ranked.subspan(d_group_ptr[g], g_label.Size());
g_y_ranked[i] = g_label(g_rank_idx[i]);
});
auto d_y_sorted_idx = p_cache->SortedIdxY(ctx, info.num_row_);
common::SegmentedArgSort<false, true>(ctx, d_y_ranked, d_group_ptr, d_y_sorted_idx);
return d_y_sorted_idx;
}
} // namespace cuda_impl
} // namespace xgboost::obj

View File

@ -0,0 +1,172 @@
/**
* Copyright 2023 XGBoost contributors
*/
#ifndef XGBOOST_OBJECTIVE_LAMBDARANK_OBJ_CUH_
#define XGBOOST_OBJECTIVE_LAMBDARANK_OBJ_CUH_
#include <thrust/binary_search.h> // for lower_bound, upper_bound
#include <thrust/functional.h> // for greater
#include <thrust/iterator/counting_iterator.h> // for make_counting_iterator
#include <thrust/random/linear_congruential_engine.h> // for minstd_rand
#include <thrust/random/uniform_int_distribution.h> // for uniform_int_distribution
#include <cassert> // for cassert
#include <cstddef> // for size_t
#include <cstdint> // for int32_t
#include <tuple> // for make_tuple, tuple
#include "../common/device_helpers.cuh" // for MakeTransformIterator
#include "../common/ranking_utils.cuh" // for PairsForGroup
#include "../common/ranking_utils.h" // for RankingCache
#include "../common/threading_utils.cuh" // for UnravelTrapeziodIdx
#include "xgboost/base.h" // for bst_group_t, GradientPair, XGBOOST_DEVICE
#include "xgboost/data.h" // for MetaInfo
#include "xgboost/linalg.h" // for VectorView, Range, UnravelIndex
#include "xgboost/span.h" // for Span
namespace xgboost::obj::cuda_impl {
/**
* \brief Find number of elements left to the label bucket
*/
template <typename It, typename T = typename std::iterator_traits<It>::value_type>
XGBOOST_DEVICE __forceinline__ std::size_t CountNumItemsToTheLeftOf(It items, std::size_t n, T v) {
return thrust::lower_bound(thrust::seq, items, items + n, v, thrust::greater<T>{}) - items;
}
/**
* \brief Find number of elements right to the label bucket
*/
template <typename It, typename T = typename std::iterator_traits<It>::value_type>
XGBOOST_DEVICE __forceinline__ std::size_t CountNumItemsToTheRightOf(It items, std::size_t n, T v) {
return n - (thrust::upper_bound(thrust::seq, items, items + n, v, thrust::greater<T>{}) - items);
}
/**
* \brief Sort labels according to rank list for making pairs.
*/
common::Span<std::size_t const> SortY(Context const *ctx, MetaInfo const &info,
common::Span<std::size_t const> d_rank,
std::shared_ptr<ltr::RankingCache> p_cache);
/**
* \brief Parameters needed for calculating gradient
*/
struct KernelInputs {
linalg::VectorView<double const> ti_plus; // input bias ratio
linalg::VectorView<double const> tj_minus; // input bias ratio
linalg::VectorView<double> li;
linalg::VectorView<double> lj;
common::Span<bst_group_t const> d_group_ptr;
common::Span<std::size_t const> d_threads_group_ptr;
common::Span<std::size_t const> d_sorted_idx;
linalg::MatrixView<float const> labels;
common::Span<float const> predts;
common::Span<GradientPair> gpairs;
linalg::VectorView<GradientPair const> d_roundings;
double const *d_cost_rounding;
common::Span<std::size_t const> d_y_sorted_idx;
std::int32_t iter;
};
/**
* \brief Functor for generating pairs
*/
template <bool has_truncation>
struct MakePairsOp {
KernelInputs args;
/**
* \brief Make pair for the topk pair method.
*/
XGBOOST_DEVICE std::tuple<std::size_t, std::size_t> WithTruncation(std::size_t idx,
bst_group_t g) const {
auto thread_group_begin = args.d_threads_group_ptr[g];
auto idx_in_thread_group = idx - thread_group_begin;
auto data_group_begin = static_cast<std::size_t>(args.d_group_ptr[g]);
std::size_t n_data = args.d_group_ptr[g + 1] - data_group_begin;
// obtain group segment data.
auto g_label = args.labels.Slice(linalg::Range(data_group_begin, data_group_begin + n_data), 0);
auto g_sorted_idx = args.d_sorted_idx.subspan(data_group_begin, n_data);
std::size_t i = 0, j = 0;
common::UnravelTrapeziodIdx(idx_in_thread_group, n_data, &i, &j);
std::size_t rank_high = i, rank_low = j;
return std::make_tuple(rank_high, rank_low);
}
/**
* \brief Make pair for the mean pair method
*/
XGBOOST_DEVICE std::tuple<std::size_t, std::size_t> WithSampling(std::size_t idx,
bst_group_t g) const {
std::size_t n_samples = args.labels.Size();
assert(n_samples == args.predts.size());
// Constructed from ranking cache.
std::size_t n_pairs =
ltr::cuda_impl::PairsForGroup(args.d_threads_group_ptr[g + 1] - args.d_threads_group_ptr[g],
args.d_group_ptr[g + 1] - args.d_group_ptr[g]);
assert(n_pairs > 0);
auto [sample_idx, sample_pair_idx] = linalg::UnravelIndex(idx, {n_samples, n_pairs});
auto g_begin = static_cast<std::size_t>(args.d_group_ptr[g]);
std::size_t n_data = args.d_group_ptr[g + 1] - g_begin;
auto g_label = args.labels.Slice(linalg::Range(g_begin, g_begin + n_data));
auto g_rank_idx = args.d_sorted_idx.subspan(args.d_group_ptr[g], n_data);
auto g_y_sorted_idx = args.d_y_sorted_idx.subspan(g_begin, n_data);
std::size_t const i = sample_idx - g_begin;
assert(sample_pair_idx < n_samples);
assert(i <= sample_idx);
auto g_sorted_label = dh::MakeTransformIterator<float>(
thrust::make_counting_iterator(0ul),
[&](std::size_t i) { return g_label(g_rank_idx[g_y_sorted_idx[i]]); });
// Are the labels diverse enough? If they are all the same, then there is nothing to pick
// from another group - bail sooner
if (g_label.Size() == 0 || g_sorted_label[0] == g_sorted_label[n_data - 1]) {
auto z = static_cast<std::size_t>(0ul);
return std::make_tuple(z, z);
}
std::size_t n_lefts = CountNumItemsToTheLeftOf(g_sorted_label, i + 1, g_sorted_label[i]);
std::size_t n_rights =
CountNumItemsToTheRightOf(g_sorted_label + i, n_data - i, g_sorted_label[i]);
// The index pointing to the first element of the next bucket
std::size_t right_bound = n_data - n_rights;
thrust::minstd_rand rng(args.iter);
auto pair_idx = i;
rng.discard(sample_pair_idx * n_data + g + pair_idx); // fixme
thrust::uniform_int_distribution<std::size_t> dist(0, n_lefts + n_rights - 1);
auto ridx = dist(rng);
SPAN_CHECK(ridx < n_lefts + n_rights);
if (ridx >= n_lefts) {
ridx = ridx - n_lefts + right_bound; // fixme
}
auto idx0 = g_y_sorted_idx[pair_idx];
auto idx1 = g_y_sorted_idx[ridx];
return std::make_tuple(idx0, idx1);
}
/**
* \brief Generate a single pair.
*
* \param idx Pair index (CUDA thread index).
* \param g Query group index.
*/
XGBOOST_DEVICE auto operator()(std::size_t idx, bst_group_t g) const {
if (has_truncation) {
return this->WithTruncation(idx, g);
} else {
return this->WithSampling(idx, g);
}
}
};
} // namespace xgboost::obj::cuda_impl
#endif // XGBOOST_OBJECTIVE_LAMBDARANK_OBJ_CUH_

View File

@ -0,0 +1,260 @@
/**
* Copyright 2023 XGBoost contributors
*/
#ifndef XGBOOST_OBJECTIVE_LAMBDARANK_OBJ_H_
#define XGBOOST_OBJECTIVE_LAMBDARANK_OBJ_H_
#include <algorithm> // for min, max
#include <cassert> // for assert
#include <cmath> // for log, abs
#include <cstddef> // for size_t
#include <functional> // for greater
#include <memory> // for shared_ptr
#include <random> // for minstd_rand, uniform_int_distribution
#include <vector> // for vector
#include "../common/algorithm.h" // for ArgSort
#include "../common/math.h" // for Sigmoid
#include "../common/ranking_utils.h" // for CalcDCGGain
#include "../common/transform_iterator.h" // for MakeIndexTransformIter
#include "xgboost/base.h" // for GradientPair, XGBOOST_DEVICE, kRtEps
#include "xgboost/context.h" // for Context
#include "xgboost/data.h" // for MetaInfo
#include "xgboost/host_device_vector.h" // for HostDeviceVector
#include "xgboost/linalg.h" // for VectorView, Vector
#include "xgboost/logging.h" // for CHECK_EQ
#include "xgboost/span.h" // for Span
namespace xgboost::obj {
template <bool exp>
XGBOOST_DEVICE double DeltaNDCG(float y_high, float y_low, std::size_t r_high, std::size_t r_low,
double inv_IDCG, common::Span<double const> discount) {
double gain_high = exp ? ltr::CalcDCGGain(y_high) : y_high;
double discount_high = discount[r_high];
double gain_low = exp ? ltr::CalcDCGGain(y_low) : y_low;
double discount_low = discount[r_low];
double original = gain_high * discount_high + gain_low * discount_low;
double changed = gain_low * discount_high + gain_high * discount_low;
double delta_NDCG = (original - changed) * inv_IDCG;
assert(delta_NDCG >= -1.0);
assert(delta_NDCG <= 1.0);
return delta_NDCG;
}
XGBOOST_DEVICE inline double DeltaMAP(float y_high, float y_low, std::size_t rank_high,
std::size_t rank_low, common::Span<double const> n_rel,
common::Span<double const> acc) {
double r_h = static_cast<double>(rank_high) + 1.0;
double r_l = static_cast<double>(rank_low) + 1.0;
double delta{0.0};
double n_total_relevances = n_rel.back();
assert(n_total_relevances > 0.0);
auto m = n_rel[rank_low];
double n = n_rel[rank_high];
if (y_high < y_low) {
auto a = m / r_l - (n + 1.0) / r_h;
auto b = acc[rank_low - 1] - acc[rank_high];
delta = (a - b) / n_total_relevances;
} else {
auto a = n / r_h - m / r_l;
auto b = acc[rank_low - 1] - acc[rank_high];
delta = (a + b) / n_total_relevances;
}
return delta;
}
template <bool unbiased, typename Delta>
XGBOOST_DEVICE GradientPair
LambdaGrad(linalg::VectorView<float const> labels, common::Span<float const> predts,
common::Span<size_t const> sorted_idx,
std::size_t rank_high, // cordiniate
std::size_t rank_low, // cordiniate
Delta delta, // delta score
linalg::VectorView<double const> t_plus, // input bias ratio
linalg::VectorView<double const> t_minus, // input bias ratio
double* p_cost) {
assert(sorted_idx.size() > 0 && "Empty sorted idx for a group.");
std::size_t idx_high = sorted_idx[rank_high];
std::size_t idx_low = sorted_idx[rank_low];
if (labels(idx_high) == labels(idx_low)) {
*p_cost = 0;
return {0.0f, 0.0f};
}
auto best_score = predts[sorted_idx.front()];
auto worst_score = predts[sorted_idx.back()];
auto y_high = labels(idx_high);
float s_high = predts[idx_high];
auto y_low = labels(idx_low);
float s_low = predts[idx_low];
// Use double whenever possible as we are working on the exp space.
double delta_score = std::abs(s_high - s_low);
double sigmoid = common::Sigmoid(s_high - s_low);
// Change in metric score like \delta NDCG or \delta MAP
double delta_metric = std::abs(delta(y_high, y_low, rank_high, rank_low));
if (best_score != worst_score) {
delta_metric /= (delta_score + kRtEps);
}
if (unbiased) {
*p_cost = std::log(1.0 / (1.0 - sigmoid)) * delta_metric;
}
constexpr double kEps = 1e-16;
auto lambda_ij = (sigmoid - 1.0) * delta_metric;
auto hessian_ij = std::max(sigmoid * (1.0 - sigmoid), kEps) * delta_metric * 2.0;
auto k = t_plus.Size();
assert(t_minus.Size() == k && "Invalid size of position bias");
if (unbiased && idx_high < k && idx_low < k) {
lambda_ij /= (t_minus(idx_low) * t_plus(idx_high) + kRtEps);
hessian_ij /= (t_minus(idx_low) * t_plus(idx_high) + kRtEps);
}
auto pg = GradientPair{static_cast<float>(lambda_ij), static_cast<float>(hessian_ij)};
return pg;
}
XGBOOST_DEVICE inline GradientPair Repulse(GradientPair pg) {
auto ng = GradientPair{-pg.GetGrad(), pg.GetHess()};
return ng;
}
namespace cuda_impl {
void LambdaRankGetGradientNDCG(Context const* ctx, std::int32_t iter,
HostDeviceVector<float> const& preds, MetaInfo const& info,
std::shared_ptr<ltr::NDCGCache> p_cache,
linalg::VectorView<double const> t_plus, // input bias ratio
linalg::VectorView<double const> t_minus, // input bias ratio
linalg::VectorView<double> li, linalg::VectorView<double> lj,
HostDeviceVector<GradientPair>* out_gpair);
/**
* \brief Generate statistic for MAP used for calculating \Delta Z in lambda mart.
*/
void MAPStat(Context const* ctx, MetaInfo const& info, common::Span<std::size_t const> d_rank_idx,
std::shared_ptr<ltr::MAPCache> p_cache);
void LambdaRankGetGradientMAP(Context const* ctx, std::int32_t iter,
HostDeviceVector<float> const& predt, MetaInfo const& info,
std::shared_ptr<ltr::MAPCache> p_cache,
linalg::VectorView<double const> t_plus, // input bias ratio
linalg::VectorView<double const> t_minus, // input bias ratio
linalg::VectorView<double> li, linalg::VectorView<double> lj,
HostDeviceVector<GradientPair>* out_gpair);
void LambdaRankGetGradientPairwise(Context const* ctx, std::int32_t iter,
HostDeviceVector<float> const& predt, const MetaInfo& info,
std::shared_ptr<ltr::RankingCache> p_cache,
linalg::VectorView<double const> ti_plus, // input bias ratio
linalg::VectorView<double const> tj_minus, // input bias ratio
linalg::VectorView<double> li, linalg::VectorView<double> lj,
HostDeviceVector<GradientPair>* out_gpair);
void LambdaRankUpdatePositionBias(Context const* ctx, linalg::VectorView<double const> li_full,
linalg::VectorView<double const> lj_full,
linalg::Vector<double>* p_ti_plus,
linalg::Vector<double>* p_tj_minus, linalg::Vector<double>* p_li,
linalg::Vector<double>* p_lj,
std::shared_ptr<ltr::RankingCache> p_cache);
} // namespace cuda_impl
namespace cpu_impl {
/**
* \brief Generate statistic for MAP used for calculating \Delta Z in lambda mart.
*
* \param label Ground truth relevance label.
* \param rank_idx Sorted index of prediction.
* \param p_cache An initialized MAPCache.
*/
void MAPStat(Context const* ctx, linalg::VectorView<float const> label,
common::Span<std::size_t const> rank_idx, std::shared_ptr<ltr::MAPCache> p_cache);
} // namespace cpu_impl
/**
* \param Construct pairs on CPU
*
* \tparam Op Functor for upgrading a pair of gradients.
*
* \param ctx The global context.
* \param iter The boosting iteration.
* \param cache ltr cache.
* \param g The current query group
* \param g_label label The labels for the current query group
* \param g_rank Sorted index of model scores for the current query group.
* \param op A callable that accepts two index for a pair of documents. The index is for
* the ranked list (labels sorted according to model scores).
*/
template <typename Op>
void MakePairs(Context const* ctx, std::int32_t iter,
std::shared_ptr<ltr::RankingCache> const cache, bst_group_t g,
linalg::VectorView<float const> g_label, common::Span<std::size_t const> g_rank,
Op op) {
auto group_ptr = cache->DataGroupPtr(ctx);
ltr::position_t cnt = group_ptr[g + 1] - group_ptr[g];
if (cache->Param().HasTruncation()) {
for (std::size_t i = 0; i < std::min(cnt, cache->Param().NumPair()); ++i) {
for (std::size_t j = i + 1; j < cnt; ++j) {
op(i, j);
}
}
} else {
CHECK_EQ(g_rank.size(), g_label.Size());
std::minstd_rand rnd(iter);
rnd.discard(g); // fixme(jiamingy): honor the global seed
// sort label according to the rank list
auto it = common::MakeIndexTransformIter(
[&g_rank, &g_label](std::size_t idx) { return g_label(g_rank[idx]); });
std::vector<std::size_t> y_sorted_idx =
common::ArgSort<std::size_t>(ctx, it, it + cnt, std::greater<>{});
// permutation iterator to get the original label
auto rev_it = common::MakeIndexTransformIter(
[&](std::size_t idx) { return g_label(g_rank[y_sorted_idx[idx]]); });
for (std::size_t i = 0; i < cnt;) {
std::size_t j = i + 1;
// find the bucket boundary
while (j < cnt && rev_it[i] == rev_it[j]) {
++j;
}
// Bucket [i,j), construct n_samples pairs for each sample inside the bucket with
// another sample outside the bucket.
//
// n elements left to the bucket, and n elements right to the bucket
std::size_t n_lefts = i, n_rights = static_cast<std::size_t>(cnt - j);
if (n_lefts + n_rights == 0) {
i = j;
continue;
}
auto n_samples = cache->Param().NumPair();
// for each pair specifed by the user
while (n_samples--) {
// for each sample in the bucket
for (std::size_t pair_idx = i; pair_idx < j; ++pair_idx) {
std::size_t ridx = std::uniform_int_distribution<std::size_t>(
static_cast<std::size_t>(0), n_lefts + n_rights - 1)(rnd);
if (ridx >= n_lefts) {
ridx = ridx - i + j; // shift to the right of the bucket
}
// index that points to the rank list.
auto idx0 = y_sorted_idx[pair_idx];
auto idx1 = y_sorted_idx[ridx];
op(idx0, idx1);
}
}
i = j;
}
}
}
} // namespace xgboost::obj
#endif // XGBOOST_OBJECTIVE_LAMBDARANK_OBJ_H_

View File

@ -0,0 +1,106 @@
/**
* Copyright 2023 by XGBoost Contributors
*/
#include "test_lambdarank_obj.h"
#include <gtest/gtest.h> // for Test, Message, TestPartResult, CmpHel...
#include <cstddef> // for size_t
#include <initializer_list> // for initializer_list
#include <map> // for map
#include <memory> // for unique_ptr, shared_ptr, make_shared
#include <numeric> // for iota
#include <string> // for char_traits, basic_string, string
#include <vector> // for vector
#include "../../../src/common/ranking_utils.h" // for LambdaRankParam
#include "../../../src/common/ranking_utils.h" // for NDCGCache, LambdaRankParam
#include "../helpers.h" // for CheckRankingObjFunction, CheckConfigReload
#include "xgboost/base.h" // for GradientPair, bst_group_t, Args
#include "xgboost/context.h" // for Context
#include "xgboost/data.h" // for MetaInfo, DMatrix
#include "xgboost/host_device_vector.h" // for HostDeviceVector
#include "xgboost/linalg.h" // for Tensor, All, TensorView
#include "xgboost/objective.h" // for ObjFunction
#include "xgboost/span.h" // for Span
namespace xgboost::obj {
void InitMakePairTest(Context const* ctx, MetaInfo* out_info, HostDeviceVector<float>* out_predt) {
out_predt->SetDevice(ctx->gpu_id);
MetaInfo& info = *out_info;
info.num_row_ = 128;
info.labels.ModifyInplace([&](HostDeviceVector<float>* data, common::Span<std::size_t> shape) {
shape[0] = info.num_row_;
shape[1] = 1;
auto& h_data = data->HostVector();
h_data.resize(shape[0]);
for (std::size_t i = 0; i < h_data.size(); ++i) {
h_data[i] = i % 2;
}
});
std::vector<float> predt(info.num_row_);
std::iota(predt.rbegin(), predt.rend(), 0.0f);
out_predt->HostVector() = predt;
}
TEST(LambdaRank, MakePair) {
Context ctx;
MetaInfo info;
HostDeviceVector<float> predt;
InitMakePairTest(&ctx, &info, &predt);
ltr::LambdaRankParam param;
param.UpdateAllowUnknown(Args{{"lambdarank_pair_method", "topk"}});
ASSERT_TRUE(param.HasTruncation());
std::shared_ptr<ltr::RankingCache> p_cache = std::make_shared<ltr::NDCGCache>(&ctx, info, param);
auto const& h_predt = predt.ConstHostVector();
{
auto rank_idx = p_cache->SortedIdx(&ctx, h_predt);
for (std::size_t i = 0; i < h_predt.size(); ++i) {
ASSERT_EQ(rank_idx[i], static_cast<std::size_t>(*(h_predt.crbegin() + i)));
}
std::int32_t n_pairs{0};
MakePairs(&ctx, 0, p_cache, 0, info.labels.HostView().Slice(linalg::All(), 0), rank_idx,
[&](auto i, auto j) {
ASSERT_GT(j, i);
ASSERT_LT(i, p_cache->Param().NumPair());
++n_pairs;
});
ASSERT_EQ(n_pairs, 3568);
}
auto const h_label = info.labels.HostView();
{
param.UpdateAllowUnknown(Args{{"lambdarank_pair_method", "mean"}});
auto p_cache = std::make_shared<ltr::NDCGCache>(&ctx, info, param);
ASSERT_FALSE(param.HasTruncation());
std::int32_t n_pairs = 0;
auto rank_idx = p_cache->SortedIdx(&ctx, h_predt);
MakePairs(&ctx, 0, p_cache, 0, info.labels.HostView().Slice(linalg::All(), 0), rank_idx,
[&](auto i, auto j) {
++n_pairs;
// Not in the same bucket
ASSERT_NE(h_label(rank_idx[i]), h_label(rank_idx[j]));
});
ASSERT_EQ(n_pairs, info.num_row_ * param.NumPair());
}
{
param.UpdateAllowUnknown(Args{{"lambdarank_num_pair_per_sample", "2"}});
auto p_cache = std::make_shared<ltr::NDCGCache>(&ctx, info, param);
auto rank_idx = p_cache->SortedIdx(&ctx, h_predt);
std::int32_t n_pairs = 0;
MakePairs(&ctx, 0, p_cache, 0, info.labels.HostView().Slice(linalg::All(), 0), rank_idx,
[&](auto i, auto j) {
++n_pairs;
// Not in the same bucket
ASSERT_NE(h_label(rank_idx[i]), h_label(rank_idx[j]));
});
ASSERT_EQ(param.NumPair(), 2);
ASSERT_EQ(n_pairs, info.num_row_ * param.NumPair());
}
}
} // namespace xgboost::obj

View File

@ -0,0 +1,138 @@
/**
* Copyright 2023 by XGBoost Contributors
*/
#include <gtest/gtest.h>
#include <xgboost/context.h> // for Context
#include <cstdint> // for uint32_t
#include <vector> // for vector
#include "../../../src/common/cuda_context.cuh" // for CUDAContext
#include "../../../src/objective/lambdarank_obj.cuh"
#include "test_lambdarank_obj.h"
namespace xgboost::obj {
void TestGPUMakePair() {
Context ctx;
ctx.gpu_id = 0;
MetaInfo info;
HostDeviceVector<float> predt;
InitMakePairTest(&ctx, &info, &predt);
ltr::LambdaRankParam param;
auto make_args = [&](std::shared_ptr<ltr::RankingCache> p_cache, auto rank_idx,
common::Span<std::size_t const> y_sorted_idx) {
linalg::Vector<double> dummy;
auto d = dummy.View(ctx.gpu_id);
linalg::Vector<GradientPair> dgpair;
auto dg = dgpair.View(ctx.gpu_id);
cuda_impl::KernelInputs args{d,
d,
d,
d,
p_cache->DataGroupPtr(&ctx),
p_cache->CUDAThreadsGroupPtr(),
rank_idx,
info.labels.View(ctx.gpu_id),
predt.ConstDeviceSpan(),
{},
dg,
nullptr,
y_sorted_idx,
0};
return args;
};
{
param.UpdateAllowUnknown(Args{{"lambdarank_pair_method", "topk"}});
auto p_cache = std::make_shared<ltr::NDCGCache>(&ctx, info, param);
auto rank_idx = p_cache->SortedIdx(&ctx, predt.ConstDeviceSpan());
ASSERT_EQ(p_cache->CUDAThreads(), 3568);
auto args = make_args(p_cache, rank_idx, {});
auto n_pairs = p_cache->Param().NumPair();
auto make_pair = cuda_impl::MakePairsOp<true>{args};
dh::LaunchN(p_cache->CUDAThreads(), ctx.CUDACtx()->Stream(),
[=] XGBOOST_DEVICE(std::size_t idx) {
auto [i, j] = make_pair(idx, 0);
SPAN_CHECK(j > i);
SPAN_CHECK(i < n_pairs);
});
}
{
param.UpdateAllowUnknown(Args{{"lambdarank_pair_method", "mean"}});
auto p_cache = std::make_shared<ltr::NDCGCache>(&ctx, info, param);
auto rank_idx = p_cache->SortedIdx(&ctx, predt.ConstDeviceSpan());
auto y_sorted_idx = cuda_impl::SortY(&ctx, info, rank_idx, p_cache);
ASSERT_FALSE(param.HasTruncation());
ASSERT_EQ(p_cache->CUDAThreads(), info.num_row_ * param.NumPair());
auto args = make_args(p_cache, rank_idx, y_sorted_idx);
auto make_pair = cuda_impl::MakePairsOp<false>{args};
auto n_pairs = p_cache->Param().NumPair();
ASSERT_EQ(n_pairs, 1);
dh::LaunchN(
p_cache->CUDAThreads(), ctx.CUDACtx()->Stream(), [=] XGBOOST_DEVICE(std::size_t idx) {
idx = 97;
auto [i, j] = make_pair(idx, 0);
// Not in the same bucket
SPAN_CHECK(make_pair.args.labels(rank_idx[i]) != make_pair.args.labels(rank_idx[j]));
});
}
{
param.UpdateAllowUnknown(Args{{"lambdarank_num_pair_per_sample", "2"}});
auto p_cache = std::make_shared<ltr::NDCGCache>(&ctx, info, param);
auto rank_idx = p_cache->SortedIdx(&ctx, predt.ConstDeviceSpan());
auto y_sorted_idx = cuda_impl::SortY(&ctx, info, rank_idx, p_cache);
auto args = make_args(p_cache, rank_idx, y_sorted_idx);
auto make_pair = cuda_impl::MakePairsOp<false>{args};
dh::LaunchN(
p_cache->CUDAThreads(), ctx.CUDACtx()->Stream(), [=] XGBOOST_DEVICE(std::size_t idx) {
auto [i, j] = make_pair(idx, 0);
// Not in the same bucket
SPAN_CHECK(make_pair.args.labels(rank_idx[i]) != make_pair.args.labels(rank_idx[j]));
});
ASSERT_EQ(param.NumPair(), 2);
ASSERT_EQ(p_cache->CUDAThreads(), info.num_row_ * param.NumPair());
}
}
TEST(LambdaRank, GPUMakePair) { TestGPUMakePair(); }
template <typename CountFunctor>
void RankItemCountImpl(std::vector<std::uint32_t> const &sorted_items, CountFunctor f,
std::uint32_t find_val, std::uint32_t exp_val) {
EXPECT_NE(std::find(sorted_items.begin(), sorted_items.end(), find_val), sorted_items.end());
EXPECT_EQ(f(&sorted_items[0], sorted_items.size(), find_val), exp_val);
}
TEST(LambdaRank, RankItemCountOnLeft) {
// Items sorted descendingly
std::vector<std::uint32_t> sorted_items{10, 10, 6, 4, 4, 4, 4, 1, 1, 1, 1, 1, 0};
auto wrapper = [](auto const &...args) { return cuda_impl::CountNumItemsToTheLeftOf(args...); };
RankItemCountImpl(sorted_items, wrapper, 10, static_cast<uint32_t>(0));
RankItemCountImpl(sorted_items, wrapper, 6, static_cast<uint32_t>(2));
RankItemCountImpl(sorted_items, wrapper, 4, static_cast<uint32_t>(3));
RankItemCountImpl(sorted_items, wrapper, 1, static_cast<uint32_t>(7));
RankItemCountImpl(sorted_items, wrapper, 0, static_cast<uint32_t>(12));
}
TEST(LambdaRank, RankItemCountOnRight) {
// Items sorted descendingly
std::vector<std::uint32_t> sorted_items{10, 10, 6, 4, 4, 4, 4, 1, 1, 1, 1, 1, 0};
auto wrapper = [](auto const &...args) { return cuda_impl::CountNumItemsToTheRightOf(args...); };
RankItemCountImpl(sorted_items, wrapper, 10, static_cast<uint32_t>(11));
RankItemCountImpl(sorted_items, wrapper, 6, static_cast<uint32_t>(10));
RankItemCountImpl(sorted_items, wrapper, 4, static_cast<uint32_t>(6));
RankItemCountImpl(sorted_items, wrapper, 1, static_cast<uint32_t>(1));
RankItemCountImpl(sorted_items, wrapper, 0, static_cast<uint32_t>(0));
}
} // namespace xgboost::obj

View File

@ -0,0 +1,26 @@
/**
* Copyright 2023, XGBoost Contributors
*/
#ifndef XGBOOST_OBJECTIVE_TEST_LAMBDARANK_OBJ_H_
#define XGBOOST_OBJECTIVE_TEST_LAMBDARANK_OBJ_H_
#include <gtest/gtest.h>
#include <xgboost/data.h> // for MetaInfo
#include <xgboost/host_device_vector.h> // for HostDeviceVector
#include <xgboost/linalg.h> // for All
#include <xgboost/objective.h> // for ObjFunction
#include <memory> // for shared_ptr, make_shared
#include <numeric> // for iota
#include <vector> // for vector
#include "../../../src/common/ranking_utils.h" // for LambdaRankParam, MAPCache
#include "../../../src/objective/lambdarank_obj.h" // for MAPStat
#include "../helpers.h" // for EmptyDMatrix
namespace xgboost::obj {
/**
* \brief Initialize test data for make pair tests.
*/
void InitMakePairTest(Context const* ctx, MetaInfo* out_info, HostDeviceVector<float>* out_predt);
} // namespace xgboost::obj
#endif // XGBOOST_OBJECTIVE_TEST_LAMBDARANK_OBJ_H_

View File

@ -89,43 +89,6 @@ TEST(Objective, RankSegmentSorterAscendingTest) {
5, 4, 6}); 5, 4, 6});
} }
using CountFunctor = uint32_t (*)(const int *, uint32_t, int);
void RankItemCountImpl(const std::vector<int> &sorted_items, CountFunctor f,
int find_val, uint32_t exp_val) {
EXPECT_NE(std::find(sorted_items.begin(), sorted_items.end(), find_val), sorted_items.end());
EXPECT_EQ(f(&sorted_items[0], sorted_items.size(), find_val), exp_val);
}
TEST(Objective, RankItemCountOnLeft) {
// Items sorted descendingly
std::vector<int> sorted_items{10, 10, 6, 4, 4, 4, 4, 1, 1, 1, 1, 1, 0};
RankItemCountImpl(sorted_items, &xgboost::obj::CountNumItemsToTheLeftOf,
10, static_cast<uint32_t>(0));
RankItemCountImpl(sorted_items, &xgboost::obj::CountNumItemsToTheLeftOf,
6, static_cast<uint32_t>(2));
RankItemCountImpl(sorted_items, &xgboost::obj::CountNumItemsToTheLeftOf,
4, static_cast<uint32_t>(3));
RankItemCountImpl(sorted_items, &xgboost::obj::CountNumItemsToTheLeftOf,
1, static_cast<uint32_t>(7));
RankItemCountImpl(sorted_items, &xgboost::obj::CountNumItemsToTheLeftOf,
0, static_cast<uint32_t>(12));
}
TEST(Objective, RankItemCountOnRight) {
// Items sorted descendingly
std::vector<int> sorted_items{10, 10, 6, 4, 4, 4, 4, 1, 1, 1, 1, 1, 0};
RankItemCountImpl(sorted_items, &xgboost::obj::CountNumItemsToTheRightOf,
10, static_cast<uint32_t>(11));
RankItemCountImpl(sorted_items, &xgboost::obj::CountNumItemsToTheRightOf,
6, static_cast<uint32_t>(10));
RankItemCountImpl(sorted_items, &xgboost::obj::CountNumItemsToTheRightOf,
4, static_cast<uint32_t>(6));
RankItemCountImpl(sorted_items, &xgboost::obj::CountNumItemsToTheRightOf,
1, static_cast<uint32_t>(1));
RankItemCountImpl(sorted_items, &xgboost::obj::CountNumItemsToTheRightOf,
0, static_cast<uint32_t>(0));
}
TEST(Objective, NDCGLambdaWeightComputerTest) { TEST(Objective, NDCGLambdaWeightComputerTest) {
std::vector<float> hlabels = {3.1f, 1.2f, 2.3f, 4.4f, // Labels std::vector<float> hlabels = {3.1f, 1.2f, 2.3f, 4.4f, // Labels
7.8f, 5.01f, 6.96f, 7.8f, 5.01f, 6.96f,