Optional normalization for learning to rank. (#10094)

This commit is contained in:
Jiaming Yuan
2024-03-08 12:41:21 +08:00
committed by GitHub
parent bc516198dc
commit e14c3b9325
8 changed files with 44 additions and 5 deletions

View File

@@ -222,7 +222,7 @@ class LambdaRankObj : public FitIntercept {
};
MakePairs(ctx_, iter, p_cache_, g, g_label, g_rank, loop);
if (sum_lambda > 0.0) {
if (sum_lambda > 0.0 && param_.lambdarank_normalization) {
double norm = std::log2(1.0 + sum_lambda) / sum_lambda;
std::transform(g_gpair.Values().data(), g_gpair.Values().data() + g_gpair.Size(),
g_gpair.Values().data(), [norm](GradientPair const& g) { return g * norm; });

View File

@@ -266,12 +266,13 @@ void CalcGrad(Context const* ctx, MetaInfo const& info, std::shared_ptr<ltr::Ran
*/
auto d_weights = common::MakeOptionalWeights(ctx, info.weights_);
auto w_norm = p_cache->WeightNorm();
auto norm = p_cache->Param().lambdarank_normalization;
thrust::for_each_n(ctx->CUDACtx()->CTP(), thrust::make_counting_iterator(0ul), d_gpair.Size(),
[=] XGBOOST_DEVICE(std::size_t i) mutable {
auto g = dh::SegmentId(d_gptr, i);
auto sum_lambda = thrust::get<2>(d_max_lambdas[g]);
// Normalization
if (sum_lambda > 0.0) {
if (sum_lambda > 0.0 && norm) {
double norm = std::log2(1.0 + sum_lambda) / sum_lambda;
d_gpair(i, 0) *= norm;
}