diff --git a/doc/parameter.rst b/doc/parameter.rst index e5cb13abf..abd5f39d5 100644 --- a/doc/parameter.rst +++ b/doc/parameter.rst @@ -500,7 +500,11 @@ These are parameters specific to learning to rank task. See :doc:`Learning to Ra It specifies the number of pairs sampled for each document when pair method is ``mean``, or the truncation level for queries when the pair method is ``topk``. For example, to train with ``ndcg@6``, set ``lambdarank_num_pair_per_sample`` to :math:`6` and ``lambdarank_pair_method`` to ``topk``. -* ``lambdarank_unbiased`` [default = ``false``] +* ``lambdarank_normalization`` [default = ``true``] + + Whether to normalize the leaf value by lambda gradient. This can sometimes stagnate the training progress. + +* ``lambdarank_unbiased`` [default = ``false``] Specify whether do we need to debias input click data. diff --git a/doc/tutorials/learning_to_rank.rst b/doc/tutorials/learning_to_rank.rst index 015f736e0..bfc727ed7 100644 --- a/doc/tutorials/learning_to_rank.rst +++ b/doc/tutorials/learning_to_rank.rst @@ -48,7 +48,7 @@ Notice that the samples are sorted based on their query index in a non-decreasin import xgboost as xgb # Make a synthetic ranking dataset for demonstration - seed = 1994 + seed = 1994 X, y = make_classification(random_state=seed) rng = np.random.default_rng(seed) n_query_groups = 3 @@ -146,7 +146,8 @@ The consideration of effective pairs also applies to the choice of pair method ( When using the mean strategy for generating pairs, where the target metric (like ``NDCG``) is computed over the whole query list, users can specify how many pairs should be generated per each document, by setting the ``lambdarank_num_pair_per_sample``. XGBoost will randomly sample ``lambdarank_num_pair_per_sample`` pairs for each element in the query group (:math:`|pairs| = |query| \times num\_pairsample`). Often, setting it to 1 can produce reasonable results. In cases where performance is inadequate due to insufficient number of effective pairs being generated, set ``lambdarank_num_pair_per_sample`` to a higher value. As more document pairs are generated, more effective pairs will be generated as well. -On the other hand, if you are prioritizing the top :math:`k` documents, the ``lambdarank_num_pair_per_sample`` should be set slightly higher than :math:`k` (with a few more documents) to obtain a good training result. +On the other hand, if you are prioritizing the top :math:`k` documents, the ``lambdarank_num_pair_per_sample`` should be set slightly higher than :math:`k` (with a few more documents) to obtain a good training result. Lastly, XGBoost employs additional regularization for learning to rank objectives, which can be disabled by setting the ``lambdarank_normalization`` to ``False``. + **Summary** If you have large amount of training data: diff --git a/python-package/xgboost/testing/ranking.py b/python-package/xgboost/testing/ranking.py index a11eb3e03..72cf37aeb 100644 --- a/python-package/xgboost/testing/ranking.py +++ b/python-package/xgboost/testing/ranking.py @@ -100,3 +100,21 @@ def run_ranking_categorical(device: str) -> None: scores = cross_val_score(ltr, X, y) for s in scores: assert s > 0.7 + + +def run_normalization(device: str) -> None: + """Test normalization.""" + X, y, qid, _ = tm.make_ltr(2048, 4, 64, 3) + ltr = xgb.XGBRanker(objective="rank:pairwise", n_estimators=4, device=device) + ltr.fit(X, y, qid=qid, eval_set=[(X, y)], eval_qid=[qid]) + e0 = ltr.evals_result() + + ltr = xgb.XGBRanker( + objective="rank:pairwise", + n_estimators=4, + device=device, + lambdarank_normalization=False, + ) + ltr.fit(X, y, qid=qid, eval_set=[(X, y)], eval_qid=[qid]) + e1 = ltr.evals_result() + assert e1["validation_0"]["ndcg@32"][-1] > e0["validation_0"]["ndcg@32"][-1] diff --git a/src/common/ranking_utils.h b/src/common/ranking_utils.h index e6b87ed4b..acba0feeb 100644 --- a/src/common/ranking_utils.h +++ b/src/common/ranking_utils.h @@ -78,6 +78,7 @@ struct LambdaRankParam : public XGBoostParameter { // unbiased bool lambdarank_unbiased{false}; + bool lambdarank_normalization{true}; double lambdarank_bias_norm{1.0}; // ndcg bool ndcg_exp_gain{true}; @@ -86,6 +87,7 @@ struct LambdaRankParam : public XGBoostParameter { return lambdarank_pair_method == that.lambdarank_pair_method && lambdarank_num_pair_per_sample == that.lambdarank_num_pair_per_sample && lambdarank_unbiased == that.lambdarank_unbiased && + lambdarank_normalization == that.lambdarank_normalization && lambdarank_bias_norm == that.lambdarank_bias_norm && ndcg_exp_gain == that.ndcg_exp_gain; } bool operator!=(LambdaRankParam const& that) const { return !(*this == that); } @@ -134,6 +136,9 @@ struct LambdaRankParam : public XGBoostParameter { DMLC_DECLARE_FIELD(lambdarank_unbiased) .set_default(false) .describe("Unbiased lambda mart. Use extended IPW to debias click position"); + DMLC_DECLARE_FIELD(lambdarank_normalization) + .set_default(true) + .describe("Whether to normalize the leaf value for lambda rank."); DMLC_DECLARE_FIELD(lambdarank_bias_norm) .set_default(1.0) .set_lower_bound(0.0) diff --git a/src/objective/lambdarank_obj.cc b/src/objective/lambdarank_obj.cc index 0c9d1262a..b7e290d41 100644 --- a/src/objective/lambdarank_obj.cc +++ b/src/objective/lambdarank_obj.cc @@ -222,7 +222,7 @@ class LambdaRankObj : public FitIntercept { }; MakePairs(ctx_, iter, p_cache_, g, g_label, g_rank, loop); - if (sum_lambda > 0.0) { + if (sum_lambda > 0.0 && param_.lambdarank_normalization) { double norm = std::log2(1.0 + sum_lambda) / sum_lambda; std::transform(g_gpair.Values().data(), g_gpair.Values().data() + g_gpair.Size(), g_gpair.Values().data(), [norm](GradientPair const& g) { return g * norm; }); diff --git a/src/objective/lambdarank_obj.cu b/src/objective/lambdarank_obj.cu index 30eba2fdc..25c5d138c 100644 --- a/src/objective/lambdarank_obj.cu +++ b/src/objective/lambdarank_obj.cu @@ -266,12 +266,13 @@ void CalcGrad(Context const* ctx, MetaInfo const& info, std::shared_ptrWeightNorm(); + auto norm = p_cache->Param().lambdarank_normalization; thrust::for_each_n(ctx->CUDACtx()->CTP(), thrust::make_counting_iterator(0ul), d_gpair.Size(), [=] XGBOOST_DEVICE(std::size_t i) mutable { auto g = dh::SegmentId(d_gptr, i); auto sum_lambda = thrust::get<2>(d_max_lambdas[g]); // Normalization - if (sum_lambda > 0.0) { + if (sum_lambda > 0.0 && norm) { double norm = std::log2(1.0 + sum_lambda) / sum_lambda; d_gpair(i, 0) *= norm; } diff --git a/tests/python-gpu/test_gpu_ranking.py b/tests/python-gpu/test_gpu_ranking.py index 2579b17de..b7c5c3adb 100644 --- a/tests/python-gpu/test_gpu_ranking.py +++ b/tests/python-gpu/test_gpu_ranking.py @@ -6,6 +6,7 @@ import pytest import xgboost from xgboost import testing as tm +from xgboost.testing.ranking import run_normalization pytestmark = tm.timeout(30) @@ -126,3 +127,7 @@ def test_with_mq2008(objective, metric) -> None: dtest = xgboost.DMatrix(x_test, y_test, qid=qid_test) comp_training_with_rank_objective(dtrain, dtest, objective, metric) + + +def test_normalization() -> None: + run_normalization("cuda") diff --git a/tests/python/test_ranking.py b/tests/python/test_ranking.py index 8bdeb070f..f09ceceac 100644 --- a/tests/python/test_ranking.py +++ b/tests/python/test_ranking.py @@ -13,6 +13,7 @@ import xgboost from xgboost import testing as tm from xgboost.testing.data import RelDataCV, simulate_clicks, sort_ltr_samples from xgboost.testing.params import lambdarank_parameter_strategy +from xgboost.testing.ranking import run_normalization def test_ndcg_custom_gain(): @@ -188,6 +189,10 @@ def test_unbiased() -> None: assert df["ti+"].iloc[-1] < df["ti+"].iloc[0] +def test_normalization() -> None: + run_normalization("cpu") + + class TestRanking: @classmethod def setup_class(cls):