/** * Copyright 2023 XGBoost contributors */ #ifndef XGBOOST_OBJECTIVE_LAMBDARANK_OBJ_CUH_ #define XGBOOST_OBJECTIVE_LAMBDARANK_OBJ_CUH_ #include // for lower_bound, upper_bound #include // for greater #include // for make_counting_iterator #include // for minstd_rand #include // for uniform_int_distribution #include // for cassert #include // for size_t #include // for int32_t #include // for make_tuple, tuple #include "../common/device_helpers.cuh" // for MakeTransformIterator #include "../common/ranking_utils.cuh" // for PairsForGroup #include "../common/ranking_utils.h" // for RankingCache #include "../common/threading_utils.cuh" // for UnravelTrapeziodIdx #include "xgboost/base.h" // for bst_group_t, GradientPair, XGBOOST_DEVICE #include "xgboost/data.h" // for MetaInfo #include "xgboost/linalg.h" // for VectorView, Range, UnravelIndex #include "xgboost/span.h" // for Span namespace xgboost::obj::cuda_impl { /** * \brief Find number of elements left to the label bucket */ template ::value_type> XGBOOST_DEVICE __forceinline__ std::size_t CountNumItemsToTheLeftOf(It items, std::size_t n, T v) { return thrust::lower_bound(thrust::seq, items, items + n, v, thrust::greater{}) - items; } /** * \brief Find number of elements right to the label bucket */ template ::value_type> XGBOOST_DEVICE __forceinline__ std::size_t CountNumItemsToTheRightOf(It items, std::size_t n, T v) { return n - (thrust::upper_bound(thrust::seq, items, items + n, v, thrust::greater{}) - items); } /** * \brief Sort labels according to rank list for making pairs. */ common::Span SortY(Context const *ctx, MetaInfo const &info, common::Span d_rank, std::shared_ptr p_cache); /** * \brief Parameters needed for calculating gradient */ struct KernelInputs { linalg::VectorView ti_plus; // input bias ratio linalg::VectorView tj_minus; // input bias ratio linalg::VectorView li; linalg::VectorView lj; common::Span d_group_ptr; common::Span d_threads_group_ptr; common::Span d_sorted_idx; linalg::MatrixView labels; common::Span predts; linalg::MatrixView gpairs; linalg::VectorView d_roundings; double const *d_cost_rounding; common::Span d_y_sorted_idx; std::int32_t iter; }; /** * \brief Functor for generating pairs */ template struct MakePairsOp { KernelInputs args; /** * \brief Make pair for the topk pair method. */ [[nodiscard]] XGBOOST_DEVICE std::tuple WithTruncation( std::size_t idx, bst_group_t g) const { auto thread_group_begin = args.d_threads_group_ptr[g]; auto idx_in_thread_group = idx - thread_group_begin; auto data_group_begin = static_cast(args.d_group_ptr[g]); std::size_t n_data = args.d_group_ptr[g + 1] - data_group_begin; // obtain group segment data. auto g_label = args.labels.Slice(linalg::Range(data_group_begin, data_group_begin + n_data), 0); auto g_sorted_idx = args.d_sorted_idx.subspan(data_group_begin, n_data); std::size_t i = 0, j = 0; common::UnravelTrapeziodIdx(idx_in_thread_group, n_data, &i, &j); std::size_t rank_high = i, rank_low = j; return std::make_tuple(rank_high, rank_low); } /** * \brief Make pair for the mean pair method */ XGBOOST_DEVICE std::tuple WithSampling(std::size_t idx, bst_group_t g) const { std::size_t n_samples = args.labels.Size(); assert(n_samples == args.predts.size()); // Constructed from ranking cache. std::size_t n_pairs = ltr::cuda_impl::PairsForGroup(args.d_threads_group_ptr[g + 1] - args.d_threads_group_ptr[g], args.d_group_ptr[g + 1] - args.d_group_ptr[g]); assert(n_pairs > 0); auto [sample_idx, sample_pair_idx] = linalg::UnravelIndex(idx, {n_samples, n_pairs}); auto g_begin = static_cast(args.d_group_ptr[g]); std::size_t n_data = args.d_group_ptr[g + 1] - g_begin; auto g_label = args.labels.Slice(linalg::Range(g_begin, g_begin + n_data)); auto g_rank_idx = args.d_sorted_idx.subspan(args.d_group_ptr[g], n_data); auto g_y_sorted_idx = args.d_y_sorted_idx.subspan(g_begin, n_data); std::size_t const i = sample_idx - g_begin; assert(sample_pair_idx < n_samples); assert(i <= sample_idx); auto g_sorted_label = dh::MakeTransformIterator( thrust::make_counting_iterator(0ul), [&](std::size_t i) { return g_label(g_rank_idx[g_y_sorted_idx[i]]); }); // Are the labels diverse enough? If they are all the same, then there is nothing to pick // from another group - bail sooner if (g_label.Size() == 0 || g_sorted_label[0] == g_sorted_label[n_data - 1]) { auto z = static_cast(0ul); return std::make_tuple(z, z); } std::size_t n_lefts = CountNumItemsToTheLeftOf(g_sorted_label, i + 1, g_sorted_label[i]); std::size_t n_rights = CountNumItemsToTheRightOf(g_sorted_label + i, n_data - i, g_sorted_label[i]); // The index pointing to the first element of the next bucket std::size_t right_bound = n_data - n_rights; thrust::minstd_rand rng(args.iter); auto pair_idx = i; rng.discard(sample_pair_idx * n_data + g + pair_idx); // fixme thrust::uniform_int_distribution dist(0, n_lefts + n_rights - 1); auto ridx = dist(rng); SPAN_CHECK(ridx < n_lefts + n_rights); if (ridx >= n_lefts) { ridx = ridx - n_lefts + right_bound; // fixme } auto idx0 = g_y_sorted_idx[pair_idx]; auto idx1 = g_y_sorted_idx[ridx]; return std::make_tuple(idx0, idx1); } /** * \brief Generate a single pair. * * \param idx Pair index (CUDA thread index). * \param g Query group index. */ XGBOOST_DEVICE auto operator()(std::size_t idx, bst_group_t g) const { if (has_truncation) { return this->WithTruncation(idx, g); } else { return this->WithSampling(idx, g); } } }; } // namespace xgboost::obj::cuda_impl #endif // XGBOOST_OBJECTIVE_LAMBDARANK_OBJ_CUH_