diff --git a/src/common/ranking_utils.h b/src/common/ranking_utils.h index 727f918f2..bc071c2d6 100644 --- a/src/common/ranking_utils.h +++ b/src/common/ranking_utils.h @@ -123,7 +123,7 @@ struct LambdaRankParam : public XGBoostParameter { DMLC_DECLARE_PARAMETER(LambdaRankParam) { DMLC_DECLARE_FIELD(lambdarank_pair_method) - .set_default(PairMethod::kMean) + .set_default(PairMethod::kTopK) .add_enum("mean", PairMethod::kMean) .add_enum("topk", PairMethod::kTopK) .describe("Method for constructing pairs."); diff --git a/src/metric/rank_metric.cc b/src/metric/rank_metric.cc index 3a1416b0f..a84d0edb1 100644 --- a/src/metric/rank_metric.cc +++ b/src/metric/rank_metric.cc @@ -112,7 +112,6 @@ class PerGroupWeightPolicy { return info.GetWeight(group_id); } }; - } // anonymous namespace namespace xgboost::metric { diff --git a/src/objective/init_estimation.cc b/src/objective/init_estimation.cc index 938ceb59d..834c052f5 100644 --- a/src/objective/init_estimation.cc +++ b/src/objective/init_estimation.cc @@ -14,8 +14,7 @@ #include "xgboost/linalg.h" // Tensor,Vector #include "xgboost/task.h" // ObjInfo -namespace xgboost { -namespace obj { +namespace xgboost::obj { void FitIntercept::InitEstimation(MetaInfo const& info, linalg::Vector* base_score) const { if (this->Task().task == ObjInfo::kRegression) { CheckInitInputs(info); @@ -31,14 +30,13 @@ void FitIntercept::InitEstimation(MetaInfo const& info, linalg::Vector* b ObjFunction::Create(get(config["name"]), this->ctx_)}; new_obj->LoadConfig(config); new_obj->GetGradient(dummy_predt, info, 0, &gpair); + bst_target_t n_targets = this->Targets(info); linalg::Vector leaf_weight; tree::FitStump(this->ctx_, info, gpair, n_targets, &leaf_weight); - // workaround, we don't support multi-target due to binary model serialization for // base margin. common::Mean(this->ctx_, leaf_weight, base_score); this->PredTransform(base_score->Data()); } -} // namespace obj -} // namespace xgboost +} // namespace xgboost::obj diff --git a/src/objective/init_estimation.h b/src/objective/init_estimation.h index b0a91d8c3..0ac5c5206 100644 --- a/src/objective/init_estimation.h +++ b/src/objective/init_estimation.h @@ -7,8 +7,7 @@ #include "xgboost/linalg.h" // Tensor #include "xgboost/objective.h" // ObjFunction -namespace xgboost { -namespace obj { +namespace xgboost::obj { class FitIntercept : public ObjFunction { void InitEstimation(MetaInfo const& info, linalg::Vector* base_score) const override; }; @@ -20,6 +19,5 @@ inline void CheckInitInputs(MetaInfo const& info) { << "Number of weights should be equal to number of data points."; } } -} // namespace obj -} // namespace xgboost +} // namespace xgboost::obj #endif // XGBOOST_OBJECTIVE_INIT_ESTIMATION_H_ diff --git a/src/objective/lambdarank_obj.cu b/src/objective/lambdarank_obj.cu new file mode 100644 index 000000000..eb82b17b4 --- /dev/null +++ b/src/objective/lambdarank_obj.cu @@ -0,0 +1,62 @@ +/** + * Copyright 2015-2023 by XGBoost contributors + * + * \brief CUDA implementation of lambdarank. + */ +#include // for fill_n +#include // for for_each_n +#include // for make_counting_iterator +#include // for make_zip_iterator +#include // for make_tuple, tuple, tie, get + +#include // for min +#include // for assert +#include // for abs, log2, isinf +#include // for size_t +#include // for int32_t +#include // for shared_ptr +#include + +#include "../common/algorithm.cuh" // for SegmentedArgSort +#include "../common/cuda_context.cuh" // for CUDAContext +#include "../common/deterministic.cuh" // for CreateRoundingFactor, TruncateWithRounding +#include "../common/device_helpers.cuh" // for SegmentId, TemporaryArray, AtomicAddGpair +#include "../common/optional_weight.h" // for MakeOptionalWeights +#include "../common/ranking_utils.h" // for NDCGCache, LambdaRankParam, rel_degree_t +#include "lambdarank_obj.cuh" +#include "lambdarank_obj.h" +#include "xgboost/base.h" // for bst_group_t, XGBOOST_DEVICE, GradientPair +#include "xgboost/context.h" // for Context +#include "xgboost/data.h" // for MetaInfo +#include "xgboost/host_device_vector.h" // for HostDeviceVector +#include "xgboost/linalg.h" // for VectorView, Range, Vector +#include "xgboost/logging.h" +#include "xgboost/span.h" // for Span + +namespace xgboost::obj { +DMLC_REGISTRY_FILE_TAG(lambdarank_obj_cu); + +namespace cuda_impl { +common::Span SortY(Context const* ctx, MetaInfo const& info, + common::Span d_rank, + std::shared_ptr p_cache) { + auto const d_group_ptr = p_cache->DataGroupPtr(ctx); + auto label = info.labels.View(ctx->gpu_id); + // The buffer for ranked y is necessary as cub segmented sort accepts only pointer. + auto d_y_ranked = p_cache->RankedY(ctx, info.num_row_); + thrust::for_each_n(ctx->CUDACtx()->CTP(), thrust::make_counting_iterator(0ul), d_y_ranked.size(), + [=] XGBOOST_DEVICE(std::size_t i) { + auto g = dh::SegmentId(d_group_ptr, i); + auto g_label = + label.Slice(linalg::Range(d_group_ptr[g], d_group_ptr[g + 1]), 0); + auto g_rank_idx = d_rank.subspan(d_group_ptr[g], g_label.Size()); + i -= d_group_ptr[g]; + auto g_y_ranked = d_y_ranked.subspan(d_group_ptr[g], g_label.Size()); + g_y_ranked[i] = g_label(g_rank_idx[i]); + }); + auto d_y_sorted_idx = p_cache->SortedIdxY(ctx, info.num_row_); + common::SegmentedArgSort(ctx, d_y_ranked, d_group_ptr, d_y_sorted_idx); + return d_y_sorted_idx; +} +} // namespace cuda_impl +} // namespace xgboost::obj diff --git a/src/objective/lambdarank_obj.cuh b/src/objective/lambdarank_obj.cuh new file mode 100644 index 000000000..be9f479ce --- /dev/null +++ b/src/objective/lambdarank_obj.cuh @@ -0,0 +1,172 @@ +/** + * Copyright 2023 XGBoost contributors + */ +#ifndef XGBOOST_OBJECTIVE_LAMBDARANK_OBJ_CUH_ +#define XGBOOST_OBJECTIVE_LAMBDARANK_OBJ_CUH_ + +#include // for lower_bound, upper_bound +#include // for greater +#include // for make_counting_iterator +#include // for minstd_rand +#include // for uniform_int_distribution + +#include // for cassert +#include // for size_t +#include // for int32_t +#include // for make_tuple, tuple + +#include "../common/device_helpers.cuh" // for MakeTransformIterator +#include "../common/ranking_utils.cuh" // for PairsForGroup +#include "../common/ranking_utils.h" // for RankingCache +#include "../common/threading_utils.cuh" // for UnravelTrapeziodIdx +#include "xgboost/base.h" // for bst_group_t, GradientPair, XGBOOST_DEVICE +#include "xgboost/data.h" // for MetaInfo +#include "xgboost/linalg.h" // for VectorView, Range, UnravelIndex +#include "xgboost/span.h" // for Span + +namespace xgboost::obj::cuda_impl { +/** + * \brief Find number of elements left to the label bucket + */ +template ::value_type> +XGBOOST_DEVICE __forceinline__ std::size_t CountNumItemsToTheLeftOf(It items, std::size_t n, T v) { + return thrust::lower_bound(thrust::seq, items, items + n, v, thrust::greater{}) - items; +} +/** + * \brief Find number of elements right to the label bucket + */ +template ::value_type> +XGBOOST_DEVICE __forceinline__ std::size_t CountNumItemsToTheRightOf(It items, std::size_t n, T v) { + return n - (thrust::upper_bound(thrust::seq, items, items + n, v, thrust::greater{}) - items); +} +/** + * \brief Sort labels according to rank list for making pairs. + */ +common::Span SortY(Context const *ctx, MetaInfo const &info, + common::Span d_rank, + std::shared_ptr p_cache); + +/** + * \brief Parameters needed for calculating gradient + */ +struct KernelInputs { + linalg::VectorView ti_plus; // input bias ratio + linalg::VectorView tj_minus; // input bias ratio + linalg::VectorView li; + linalg::VectorView lj; + + common::Span d_group_ptr; + common::Span d_threads_group_ptr; + common::Span d_sorted_idx; + + linalg::MatrixView labels; + common::Span predts; + common::Span gpairs; + + linalg::VectorView d_roundings; + double const *d_cost_rounding; + + common::Span d_y_sorted_idx; + + std::int32_t iter; +}; +/** + * \brief Functor for generating pairs + */ +template +struct MakePairsOp { + KernelInputs args; + /** + * \brief Make pair for the topk pair method. + */ + XGBOOST_DEVICE std::tuple WithTruncation(std::size_t idx, + bst_group_t g) const { + auto thread_group_begin = args.d_threads_group_ptr[g]; + auto idx_in_thread_group = idx - thread_group_begin; + + auto data_group_begin = static_cast(args.d_group_ptr[g]); + std::size_t n_data = args.d_group_ptr[g + 1] - data_group_begin; + // obtain group segment data. + auto g_label = args.labels.Slice(linalg::Range(data_group_begin, data_group_begin + n_data), 0); + auto g_sorted_idx = args.d_sorted_idx.subspan(data_group_begin, n_data); + + std::size_t i = 0, j = 0; + common::UnravelTrapeziodIdx(idx_in_thread_group, n_data, &i, &j); + + std::size_t rank_high = i, rank_low = j; + return std::make_tuple(rank_high, rank_low); + } + /** + * \brief Make pair for the mean pair method + */ + XGBOOST_DEVICE std::tuple WithSampling(std::size_t idx, + bst_group_t g) const { + std::size_t n_samples = args.labels.Size(); + assert(n_samples == args.predts.size()); + // Constructed from ranking cache. + std::size_t n_pairs = + ltr::cuda_impl::PairsForGroup(args.d_threads_group_ptr[g + 1] - args.d_threads_group_ptr[g], + args.d_group_ptr[g + 1] - args.d_group_ptr[g]); + + assert(n_pairs > 0); + auto [sample_idx, sample_pair_idx] = linalg::UnravelIndex(idx, {n_samples, n_pairs}); + + auto g_begin = static_cast(args.d_group_ptr[g]); + std::size_t n_data = args.d_group_ptr[g + 1] - g_begin; + + auto g_label = args.labels.Slice(linalg::Range(g_begin, g_begin + n_data)); + auto g_rank_idx = args.d_sorted_idx.subspan(args.d_group_ptr[g], n_data); + auto g_y_sorted_idx = args.d_y_sorted_idx.subspan(g_begin, n_data); + + std::size_t const i = sample_idx - g_begin; + assert(sample_pair_idx < n_samples); + assert(i <= sample_idx); + + auto g_sorted_label = dh::MakeTransformIterator( + thrust::make_counting_iterator(0ul), + [&](std::size_t i) { return g_label(g_rank_idx[g_y_sorted_idx[i]]); }); + + // Are the labels diverse enough? If they are all the same, then there is nothing to pick + // from another group - bail sooner + if (g_label.Size() == 0 || g_sorted_label[0] == g_sorted_label[n_data - 1]) { + auto z = static_cast(0ul); + return std::make_tuple(z, z); + } + + std::size_t n_lefts = CountNumItemsToTheLeftOf(g_sorted_label, i + 1, g_sorted_label[i]); + std::size_t n_rights = + CountNumItemsToTheRightOf(g_sorted_label + i, n_data - i, g_sorted_label[i]); + // The index pointing to the first element of the next bucket + std::size_t right_bound = n_data - n_rights; + + thrust::minstd_rand rng(args.iter); + auto pair_idx = i; + rng.discard(sample_pair_idx * n_data + g + pair_idx); // fixme + thrust::uniform_int_distribution dist(0, n_lefts + n_rights - 1); + auto ridx = dist(rng); + SPAN_CHECK(ridx < n_lefts + n_rights); + if (ridx >= n_lefts) { + ridx = ridx - n_lefts + right_bound; // fixme + } + + auto idx0 = g_y_sorted_idx[pair_idx]; + auto idx1 = g_y_sorted_idx[ridx]; + + return std::make_tuple(idx0, idx1); + } + /** + * \brief Generate a single pair. + * + * \param idx Pair index (CUDA thread index). + * \param g Query group index. + */ + XGBOOST_DEVICE auto operator()(std::size_t idx, bst_group_t g) const { + if (has_truncation) { + return this->WithTruncation(idx, g); + } else { + return this->WithSampling(idx, g); + } + } +}; +} // namespace xgboost::obj::cuda_impl +#endif // XGBOOST_OBJECTIVE_LAMBDARANK_OBJ_CUH_ diff --git a/src/objective/lambdarank_obj.h b/src/objective/lambdarank_obj.h new file mode 100644 index 000000000..3adb27a2e --- /dev/null +++ b/src/objective/lambdarank_obj.h @@ -0,0 +1,260 @@ +/** + * Copyright 2023 XGBoost contributors + */ +#ifndef XGBOOST_OBJECTIVE_LAMBDARANK_OBJ_H_ +#define XGBOOST_OBJECTIVE_LAMBDARANK_OBJ_H_ +#include // for min, max +#include // for assert +#include // for log, abs +#include // for size_t +#include // for greater +#include // for shared_ptr +#include // for minstd_rand, uniform_int_distribution +#include // for vector + +#include "../common/algorithm.h" // for ArgSort +#include "../common/math.h" // for Sigmoid +#include "../common/ranking_utils.h" // for CalcDCGGain +#include "../common/transform_iterator.h" // for MakeIndexTransformIter +#include "xgboost/base.h" // for GradientPair, XGBOOST_DEVICE, kRtEps +#include "xgboost/context.h" // for Context +#include "xgboost/data.h" // for MetaInfo +#include "xgboost/host_device_vector.h" // for HostDeviceVector +#include "xgboost/linalg.h" // for VectorView, Vector +#include "xgboost/logging.h" // for CHECK_EQ +#include "xgboost/span.h" // for Span + +namespace xgboost::obj { +template +XGBOOST_DEVICE double DeltaNDCG(float y_high, float y_low, std::size_t r_high, std::size_t r_low, + double inv_IDCG, common::Span discount) { + double gain_high = exp ? ltr::CalcDCGGain(y_high) : y_high; + double discount_high = discount[r_high]; + + double gain_low = exp ? ltr::CalcDCGGain(y_low) : y_low; + double discount_low = discount[r_low]; + + double original = gain_high * discount_high + gain_low * discount_low; + double changed = gain_low * discount_high + gain_high * discount_low; + + double delta_NDCG = (original - changed) * inv_IDCG; + assert(delta_NDCG >= -1.0); + assert(delta_NDCG <= 1.0); + return delta_NDCG; +} + +XGBOOST_DEVICE inline double DeltaMAP(float y_high, float y_low, std::size_t rank_high, + std::size_t rank_low, common::Span n_rel, + common::Span acc) { + double r_h = static_cast(rank_high) + 1.0; + double r_l = static_cast(rank_low) + 1.0; + double delta{0.0}; + double n_total_relevances = n_rel.back(); + assert(n_total_relevances > 0.0); + auto m = n_rel[rank_low]; + double n = n_rel[rank_high]; + + if (y_high < y_low) { + auto a = m / r_l - (n + 1.0) / r_h; + auto b = acc[rank_low - 1] - acc[rank_high]; + delta = (a - b) / n_total_relevances; + } else { + auto a = n / r_h - m / r_l; + auto b = acc[rank_low - 1] - acc[rank_high]; + delta = (a + b) / n_total_relevances; + } + return delta; +} + +template +XGBOOST_DEVICE GradientPair +LambdaGrad(linalg::VectorView labels, common::Span predts, + common::Span sorted_idx, + std::size_t rank_high, // cordiniate + std::size_t rank_low, // cordiniate + Delta delta, // delta score + linalg::VectorView t_plus, // input bias ratio + linalg::VectorView t_minus, // input bias ratio + double* p_cost) { + assert(sorted_idx.size() > 0 && "Empty sorted idx for a group."); + std::size_t idx_high = sorted_idx[rank_high]; + std::size_t idx_low = sorted_idx[rank_low]; + + if (labels(idx_high) == labels(idx_low)) { + *p_cost = 0; + return {0.0f, 0.0f}; + } + + auto best_score = predts[sorted_idx.front()]; + auto worst_score = predts[sorted_idx.back()]; + + auto y_high = labels(idx_high); + float s_high = predts[idx_high]; + auto y_low = labels(idx_low); + float s_low = predts[idx_low]; + + // Use double whenever possible as we are working on the exp space. + double delta_score = std::abs(s_high - s_low); + double sigmoid = common::Sigmoid(s_high - s_low); + // Change in metric score like \delta NDCG or \delta MAP + double delta_metric = std::abs(delta(y_high, y_low, rank_high, rank_low)); + + if (best_score != worst_score) { + delta_metric /= (delta_score + kRtEps); + } + + if (unbiased) { + *p_cost = std::log(1.0 / (1.0 - sigmoid)) * delta_metric; + } + + constexpr double kEps = 1e-16; + auto lambda_ij = (sigmoid - 1.0) * delta_metric; + auto hessian_ij = std::max(sigmoid * (1.0 - sigmoid), kEps) * delta_metric * 2.0; + + auto k = t_plus.Size(); + assert(t_minus.Size() == k && "Invalid size of position bias"); + + if (unbiased && idx_high < k && idx_low < k) { + lambda_ij /= (t_minus(idx_low) * t_plus(idx_high) + kRtEps); + hessian_ij /= (t_minus(idx_low) * t_plus(idx_high) + kRtEps); + } + + auto pg = GradientPair{static_cast(lambda_ij), static_cast(hessian_ij)}; + return pg; +} + +XGBOOST_DEVICE inline GradientPair Repulse(GradientPair pg) { + auto ng = GradientPair{-pg.GetGrad(), pg.GetHess()}; + return ng; +} + +namespace cuda_impl { +void LambdaRankGetGradientNDCG(Context const* ctx, std::int32_t iter, + HostDeviceVector const& preds, MetaInfo const& info, + std::shared_ptr p_cache, + linalg::VectorView t_plus, // input bias ratio + linalg::VectorView t_minus, // input bias ratio + linalg::VectorView li, linalg::VectorView lj, + HostDeviceVector* out_gpair); + +/** + * \brief Generate statistic for MAP used for calculating \Delta Z in lambda mart. + */ +void MAPStat(Context const* ctx, MetaInfo const& info, common::Span d_rank_idx, + std::shared_ptr p_cache); + +void LambdaRankGetGradientMAP(Context const* ctx, std::int32_t iter, + HostDeviceVector const& predt, MetaInfo const& info, + std::shared_ptr p_cache, + linalg::VectorView t_plus, // input bias ratio + linalg::VectorView t_minus, // input bias ratio + linalg::VectorView li, linalg::VectorView lj, + HostDeviceVector* out_gpair); + +void LambdaRankGetGradientPairwise(Context const* ctx, std::int32_t iter, + HostDeviceVector const& predt, const MetaInfo& info, + std::shared_ptr p_cache, + linalg::VectorView ti_plus, // input bias ratio + linalg::VectorView tj_minus, // input bias ratio + linalg::VectorView li, linalg::VectorView lj, + HostDeviceVector* out_gpair); + +void LambdaRankUpdatePositionBias(Context const* ctx, linalg::VectorView li_full, + linalg::VectorView lj_full, + linalg::Vector* p_ti_plus, + linalg::Vector* p_tj_minus, linalg::Vector* p_li, + linalg::Vector* p_lj, + std::shared_ptr p_cache); +} // namespace cuda_impl + +namespace cpu_impl { +/** + * \brief Generate statistic for MAP used for calculating \Delta Z in lambda mart. + * + * \param label Ground truth relevance label. + * \param rank_idx Sorted index of prediction. + * \param p_cache An initialized MAPCache. + */ +void MAPStat(Context const* ctx, linalg::VectorView label, + common::Span rank_idx, std::shared_ptr p_cache); +} // namespace cpu_impl + +/** + * \param Construct pairs on CPU + * + * \tparam Op Functor for upgrading a pair of gradients. + * + * \param ctx The global context. + * \param iter The boosting iteration. + * \param cache ltr cache. + * \param g The current query group + * \param g_label label The labels for the current query group + * \param g_rank Sorted index of model scores for the current query group. + * \param op A callable that accepts two index for a pair of documents. The index is for + * the ranked list (labels sorted according to model scores). + */ +template +void MakePairs(Context const* ctx, std::int32_t iter, + std::shared_ptr const cache, bst_group_t g, + linalg::VectorView g_label, common::Span g_rank, + Op op) { + auto group_ptr = cache->DataGroupPtr(ctx); + ltr::position_t cnt = group_ptr[g + 1] - group_ptr[g]; + + if (cache->Param().HasTruncation()) { + for (std::size_t i = 0; i < std::min(cnt, cache->Param().NumPair()); ++i) { + for (std::size_t j = i + 1; j < cnt; ++j) { + op(i, j); + } + } + } else { + CHECK_EQ(g_rank.size(), g_label.Size()); + std::minstd_rand rnd(iter); + rnd.discard(g); // fixme(jiamingy): honor the global seed + // sort label according to the rank list + auto it = common::MakeIndexTransformIter( + [&g_rank, &g_label](std::size_t idx) { return g_label(g_rank[idx]); }); + std::vector y_sorted_idx = + common::ArgSort(ctx, it, it + cnt, std::greater<>{}); + // permutation iterator to get the original label + auto rev_it = common::MakeIndexTransformIter( + [&](std::size_t idx) { return g_label(g_rank[y_sorted_idx[idx]]); }); + + for (std::size_t i = 0; i < cnt;) { + std::size_t j = i + 1; + // find the bucket boundary + while (j < cnt && rev_it[i] == rev_it[j]) { + ++j; + } + // Bucket [i,j), construct n_samples pairs for each sample inside the bucket with + // another sample outside the bucket. + // + // n elements left to the bucket, and n elements right to the bucket + std::size_t n_lefts = i, n_rights = static_cast(cnt - j); + if (n_lefts + n_rights == 0) { + i = j; + continue; + } + + auto n_samples = cache->Param().NumPair(); + // for each pair specifed by the user + while (n_samples--) { + // for each sample in the bucket + for (std::size_t pair_idx = i; pair_idx < j; ++pair_idx) { + std::size_t ridx = std::uniform_int_distribution( + static_cast(0), n_lefts + n_rights - 1)(rnd); + if (ridx >= n_lefts) { + ridx = ridx - i + j; // shift to the right of the bucket + } + // index that points to the rank list. + auto idx0 = y_sorted_idx[pair_idx]; + auto idx1 = y_sorted_idx[ridx]; + op(idx0, idx1); + } + } + i = j; + } + } +} +} // namespace xgboost::obj +#endif // XGBOOST_OBJECTIVE_LAMBDARANK_OBJ_H_ diff --git a/tests/cpp/objective/test_lambdarank_obj.cc b/tests/cpp/objective/test_lambdarank_obj.cc new file mode 100644 index 000000000..11cbf6bec --- /dev/null +++ b/tests/cpp/objective/test_lambdarank_obj.cc @@ -0,0 +1,106 @@ +/** + * Copyright 2023 by XGBoost Contributors + */ +#include "test_lambdarank_obj.h" + +#include // for Test, Message, TestPartResult, CmpHel... + +#include // for size_t +#include // for initializer_list +#include // for map +#include // for unique_ptr, shared_ptr, make_shared +#include // for iota +#include // for char_traits, basic_string, string +#include // for vector + +#include "../../../src/common/ranking_utils.h" // for LambdaRankParam +#include "../../../src/common/ranking_utils.h" // for NDCGCache, LambdaRankParam +#include "../helpers.h" // for CheckRankingObjFunction, CheckConfigReload +#include "xgboost/base.h" // for GradientPair, bst_group_t, Args +#include "xgboost/context.h" // for Context +#include "xgboost/data.h" // for MetaInfo, DMatrix +#include "xgboost/host_device_vector.h" // for HostDeviceVector +#include "xgboost/linalg.h" // for Tensor, All, TensorView +#include "xgboost/objective.h" // for ObjFunction +#include "xgboost/span.h" // for Span + +namespace xgboost::obj { +void InitMakePairTest(Context const* ctx, MetaInfo* out_info, HostDeviceVector* out_predt) { + out_predt->SetDevice(ctx->gpu_id); + MetaInfo& info = *out_info; + info.num_row_ = 128; + info.labels.ModifyInplace([&](HostDeviceVector* data, common::Span shape) { + shape[0] = info.num_row_; + shape[1] = 1; + auto& h_data = data->HostVector(); + h_data.resize(shape[0]); + for (std::size_t i = 0; i < h_data.size(); ++i) { + h_data[i] = i % 2; + } + }); + std::vector predt(info.num_row_); + std::iota(predt.rbegin(), predt.rend(), 0.0f); + out_predt->HostVector() = predt; +} + +TEST(LambdaRank, MakePair) { + Context ctx; + MetaInfo info; + HostDeviceVector predt; + + InitMakePairTest(&ctx, &info, &predt); + + ltr::LambdaRankParam param; + param.UpdateAllowUnknown(Args{{"lambdarank_pair_method", "topk"}}); + ASSERT_TRUE(param.HasTruncation()); + + std::shared_ptr p_cache = std::make_shared(&ctx, info, param); + auto const& h_predt = predt.ConstHostVector(); + { + auto rank_idx = p_cache->SortedIdx(&ctx, h_predt); + for (std::size_t i = 0; i < h_predt.size(); ++i) { + ASSERT_EQ(rank_idx[i], static_cast(*(h_predt.crbegin() + i))); + } + std::int32_t n_pairs{0}; + MakePairs(&ctx, 0, p_cache, 0, info.labels.HostView().Slice(linalg::All(), 0), rank_idx, + [&](auto i, auto j) { + ASSERT_GT(j, i); + ASSERT_LT(i, p_cache->Param().NumPair()); + ++n_pairs; + }); + ASSERT_EQ(n_pairs, 3568); + } + + auto const h_label = info.labels.HostView(); + + { + param.UpdateAllowUnknown(Args{{"lambdarank_pair_method", "mean"}}); + auto p_cache = std::make_shared(&ctx, info, param); + ASSERT_FALSE(param.HasTruncation()); + std::int32_t n_pairs = 0; + auto rank_idx = p_cache->SortedIdx(&ctx, h_predt); + MakePairs(&ctx, 0, p_cache, 0, info.labels.HostView().Slice(linalg::All(), 0), rank_idx, + [&](auto i, auto j) { + ++n_pairs; + // Not in the same bucket + ASSERT_NE(h_label(rank_idx[i]), h_label(rank_idx[j])); + }); + ASSERT_EQ(n_pairs, info.num_row_ * param.NumPair()); + } + + { + param.UpdateAllowUnknown(Args{{"lambdarank_num_pair_per_sample", "2"}}); + auto p_cache = std::make_shared(&ctx, info, param); + auto rank_idx = p_cache->SortedIdx(&ctx, h_predt); + std::int32_t n_pairs = 0; + MakePairs(&ctx, 0, p_cache, 0, info.labels.HostView().Slice(linalg::All(), 0), rank_idx, + [&](auto i, auto j) { + ++n_pairs; + // Not in the same bucket + ASSERT_NE(h_label(rank_idx[i]), h_label(rank_idx[j])); + }); + ASSERT_EQ(param.NumPair(), 2); + ASSERT_EQ(n_pairs, info.num_row_ * param.NumPair()); + } +} +} // namespace xgboost::obj diff --git a/tests/cpp/objective/test_lambdarank_obj.cu b/tests/cpp/objective/test_lambdarank_obj.cu new file mode 100644 index 000000000..03ccdef8b --- /dev/null +++ b/tests/cpp/objective/test_lambdarank_obj.cu @@ -0,0 +1,138 @@ +/** + * Copyright 2023 by XGBoost Contributors + */ +#include +#include // for Context + +#include // for uint32_t +#include // for vector + +#include "../../../src/common/cuda_context.cuh" // for CUDAContext +#include "../../../src/objective/lambdarank_obj.cuh" +#include "test_lambdarank_obj.h" + +namespace xgboost::obj { +void TestGPUMakePair() { + Context ctx; + ctx.gpu_id = 0; + + MetaInfo info; + HostDeviceVector predt; + InitMakePairTest(&ctx, &info, &predt); + + ltr::LambdaRankParam param; + + auto make_args = [&](std::shared_ptr p_cache, auto rank_idx, + common::Span y_sorted_idx) { + linalg::Vector dummy; + auto d = dummy.View(ctx.gpu_id); + linalg::Vector dgpair; + auto dg = dgpair.View(ctx.gpu_id); + cuda_impl::KernelInputs args{d, + d, + d, + d, + p_cache->DataGroupPtr(&ctx), + p_cache->CUDAThreadsGroupPtr(), + rank_idx, + info.labels.View(ctx.gpu_id), + predt.ConstDeviceSpan(), + {}, + dg, + nullptr, + y_sorted_idx, + 0}; + return args; + }; + + { + param.UpdateAllowUnknown(Args{{"lambdarank_pair_method", "topk"}}); + auto p_cache = std::make_shared(&ctx, info, param); + auto rank_idx = p_cache->SortedIdx(&ctx, predt.ConstDeviceSpan()); + + ASSERT_EQ(p_cache->CUDAThreads(), 3568); + + auto args = make_args(p_cache, rank_idx, {}); + auto n_pairs = p_cache->Param().NumPair(); + auto make_pair = cuda_impl::MakePairsOp{args}; + + dh::LaunchN(p_cache->CUDAThreads(), ctx.CUDACtx()->Stream(), + [=] XGBOOST_DEVICE(std::size_t idx) { + auto [i, j] = make_pair(idx, 0); + SPAN_CHECK(j > i); + SPAN_CHECK(i < n_pairs); + }); + } + { + param.UpdateAllowUnknown(Args{{"lambdarank_pair_method", "mean"}}); + auto p_cache = std::make_shared(&ctx, info, param); + auto rank_idx = p_cache->SortedIdx(&ctx, predt.ConstDeviceSpan()); + auto y_sorted_idx = cuda_impl::SortY(&ctx, info, rank_idx, p_cache); + + ASSERT_FALSE(param.HasTruncation()); + ASSERT_EQ(p_cache->CUDAThreads(), info.num_row_ * param.NumPair()); + + auto args = make_args(p_cache, rank_idx, y_sorted_idx); + auto make_pair = cuda_impl::MakePairsOp{args}; + auto n_pairs = p_cache->Param().NumPair(); + ASSERT_EQ(n_pairs, 1); + + dh::LaunchN( + p_cache->CUDAThreads(), ctx.CUDACtx()->Stream(), [=] XGBOOST_DEVICE(std::size_t idx) { + idx = 97; + auto [i, j] = make_pair(idx, 0); + // Not in the same bucket + SPAN_CHECK(make_pair.args.labels(rank_idx[i]) != make_pair.args.labels(rank_idx[j])); + }); + } + { + param.UpdateAllowUnknown(Args{{"lambdarank_num_pair_per_sample", "2"}}); + auto p_cache = std::make_shared(&ctx, info, param); + auto rank_idx = p_cache->SortedIdx(&ctx, predt.ConstDeviceSpan()); + auto y_sorted_idx = cuda_impl::SortY(&ctx, info, rank_idx, p_cache); + + auto args = make_args(p_cache, rank_idx, y_sorted_idx); + auto make_pair = cuda_impl::MakePairsOp{args}; + + dh::LaunchN( + p_cache->CUDAThreads(), ctx.CUDACtx()->Stream(), [=] XGBOOST_DEVICE(std::size_t idx) { + auto [i, j] = make_pair(idx, 0); + // Not in the same bucket + SPAN_CHECK(make_pair.args.labels(rank_idx[i]) != make_pair.args.labels(rank_idx[j])); + }); + ASSERT_EQ(param.NumPair(), 2); + ASSERT_EQ(p_cache->CUDAThreads(), info.num_row_ * param.NumPair()); + } +} + +TEST(LambdaRank, GPUMakePair) { TestGPUMakePair(); } + +template +void RankItemCountImpl(std::vector const &sorted_items, CountFunctor f, + std::uint32_t find_val, std::uint32_t exp_val) { + EXPECT_NE(std::find(sorted_items.begin(), sorted_items.end(), find_val), sorted_items.end()); + EXPECT_EQ(f(&sorted_items[0], sorted_items.size(), find_val), exp_val); +} + +TEST(LambdaRank, RankItemCountOnLeft) { + // Items sorted descendingly + std::vector sorted_items{10, 10, 6, 4, 4, 4, 4, 1, 1, 1, 1, 1, 0}; + auto wrapper = [](auto const &...args) { return cuda_impl::CountNumItemsToTheLeftOf(args...); }; + RankItemCountImpl(sorted_items, wrapper, 10, static_cast(0)); + RankItemCountImpl(sorted_items, wrapper, 6, static_cast(2)); + RankItemCountImpl(sorted_items, wrapper, 4, static_cast(3)); + RankItemCountImpl(sorted_items, wrapper, 1, static_cast(7)); + RankItemCountImpl(sorted_items, wrapper, 0, static_cast(12)); +} + +TEST(LambdaRank, RankItemCountOnRight) { + // Items sorted descendingly + std::vector sorted_items{10, 10, 6, 4, 4, 4, 4, 1, 1, 1, 1, 1, 0}; + auto wrapper = [](auto const &...args) { return cuda_impl::CountNumItemsToTheRightOf(args...); }; + RankItemCountImpl(sorted_items, wrapper, 10, static_cast(11)); + RankItemCountImpl(sorted_items, wrapper, 6, static_cast(10)); + RankItemCountImpl(sorted_items, wrapper, 4, static_cast(6)); + RankItemCountImpl(sorted_items, wrapper, 1, static_cast(1)); + RankItemCountImpl(sorted_items, wrapper, 0, static_cast(0)); +} +} // namespace xgboost::obj diff --git a/tests/cpp/objective/test_lambdarank_obj.h b/tests/cpp/objective/test_lambdarank_obj.h new file mode 100644 index 000000000..8dd238d2b --- /dev/null +++ b/tests/cpp/objective/test_lambdarank_obj.h @@ -0,0 +1,26 @@ +/** + * Copyright 2023, XGBoost Contributors + */ +#ifndef XGBOOST_OBJECTIVE_TEST_LAMBDARANK_OBJ_H_ +#define XGBOOST_OBJECTIVE_TEST_LAMBDARANK_OBJ_H_ +#include +#include // for MetaInfo +#include // for HostDeviceVector +#include // for All +#include // for ObjFunction + +#include // for shared_ptr, make_shared +#include // for iota +#include // for vector + +#include "../../../src/common/ranking_utils.h" // for LambdaRankParam, MAPCache +#include "../../../src/objective/lambdarank_obj.h" // for MAPStat +#include "../helpers.h" // for EmptyDMatrix + +namespace xgboost::obj { +/** + * \brief Initialize test data for make pair tests. + */ +void InitMakePairTest(Context const* ctx, MetaInfo* out_info, HostDeviceVector* out_predt); +} // namespace xgboost::obj +#endif // XGBOOST_OBJECTIVE_TEST_LAMBDARANK_OBJ_H_ diff --git a/tests/cpp/objective/test_ranking_obj_gpu.cu b/tests/cpp/objective/test_ranking_obj_gpu.cu index 02286ab46..540560c1f 100644 --- a/tests/cpp/objective/test_ranking_obj_gpu.cu +++ b/tests/cpp/objective/test_ranking_obj_gpu.cu @@ -89,43 +89,6 @@ TEST(Objective, RankSegmentSorterAscendingTest) { 5, 4, 6}); } -using CountFunctor = uint32_t (*)(const int *, uint32_t, int); -void RankItemCountImpl(const std::vector &sorted_items, CountFunctor f, - int find_val, uint32_t exp_val) { - EXPECT_NE(std::find(sorted_items.begin(), sorted_items.end(), find_val), sorted_items.end()); - EXPECT_EQ(f(&sorted_items[0], sorted_items.size(), find_val), exp_val); -} - -TEST(Objective, RankItemCountOnLeft) { - // Items sorted descendingly - std::vector sorted_items{10, 10, 6, 4, 4, 4, 4, 1, 1, 1, 1, 1, 0}; - RankItemCountImpl(sorted_items, &xgboost::obj::CountNumItemsToTheLeftOf, - 10, static_cast(0)); - RankItemCountImpl(sorted_items, &xgboost::obj::CountNumItemsToTheLeftOf, - 6, static_cast(2)); - RankItemCountImpl(sorted_items, &xgboost::obj::CountNumItemsToTheLeftOf, - 4, static_cast(3)); - RankItemCountImpl(sorted_items, &xgboost::obj::CountNumItemsToTheLeftOf, - 1, static_cast(7)); - RankItemCountImpl(sorted_items, &xgboost::obj::CountNumItemsToTheLeftOf, - 0, static_cast(12)); -} - -TEST(Objective, RankItemCountOnRight) { - // Items sorted descendingly - std::vector sorted_items{10, 10, 6, 4, 4, 4, 4, 1, 1, 1, 1, 1, 0}; - RankItemCountImpl(sorted_items, &xgboost::obj::CountNumItemsToTheRightOf, - 10, static_cast(11)); - RankItemCountImpl(sorted_items, &xgboost::obj::CountNumItemsToTheRightOf, - 6, static_cast(10)); - RankItemCountImpl(sorted_items, &xgboost::obj::CountNumItemsToTheRightOf, - 4, static_cast(6)); - RankItemCountImpl(sorted_items, &xgboost::obj::CountNumItemsToTheRightOf, - 1, static_cast(1)); - RankItemCountImpl(sorted_items, &xgboost::obj::CountNumItemsToTheRightOf, - 0, static_cast(0)); -} - TEST(Objective, NDCGLambdaWeightComputerTest) { std::vector hlabels = {3.1f, 1.2f, 2.3f, 4.4f, // Labels 7.8f, 5.01f, 6.96f,