Rework MAP and Pairwise for LTR. (#9075)

This commit is contained in:
Jiaming Yuan
2023-04-28 02:39:12 +08:00
committed by GitHub
parent 0e470ef606
commit e206b899ef
19 changed files with 612 additions and 1135 deletions

View File

@@ -223,4 +223,125 @@ TEST(LambdaRank, MakePair) {
ASSERT_EQ(n_pairs, info.num_row_ * param.NumPair());
}
}
void TestMAPStat(Context const* ctx) {
auto p_fmat = EmptyDMatrix();
MetaInfo& info = p_fmat->Info();
ltr::LambdaRankParam param;
param.UpdateAllowUnknown(Args{});
{
std::vector<float> h_data{1.0f, 1.0f, 0.0f, 1.0f, 1.0f, 1.0f};
info.labels.Reshape(h_data.size(), 1);
info.labels.Data()->HostVector() = h_data;
info.num_row_ = h_data.size();
HostDeviceVector<float> predt;
auto& h_predt = predt.HostVector();
h_predt.resize(h_data.size());
std::iota(h_predt.rbegin(), h_predt.rend(), 0.0f);
auto p_cache = std::make_shared<ltr::MAPCache>(ctx, info, param);
predt.SetDevice(ctx->gpu_id);
auto rank_idx =
p_cache->SortedIdx(ctx, ctx->IsCPU() ? predt.ConstHostSpan() : predt.ConstDeviceSpan());
if (ctx->IsCPU()) {
obj::cpu_impl::MAPStat(ctx, info.labels.HostView().Slice(linalg::All(), 0), rank_idx,
p_cache);
} else {
obj::cuda_impl::MAPStat(ctx, info, rank_idx, p_cache);
}
Context cpu_ctx;
auto n_rel = p_cache->NumRelevant(&cpu_ctx);
auto acc = p_cache->Acc(&cpu_ctx);
ASSERT_EQ(n_rel[0], 1.0);
ASSERT_EQ(acc[0], 1.0);
ASSERT_EQ(n_rel.back(), h_data.size() - 1.0);
ASSERT_NEAR(acc.back(), 1.95 + (1.0 / h_data.size()), kRtEps);
}
{
info.labels.Reshape(16);
auto& h_label = info.labels.Data()->HostVector();
info.group_ptr_ = {0, 8, 16};
info.num_row_ = info.labels.Shape(0);
std::fill_n(h_label.begin(), 8, 1.0f);
std::fill_n(h_label.begin() + 8, 8, 0.0f);
HostDeviceVector<float> predt;
auto& h_predt = predt.HostVector();
h_predt.resize(h_label.size());
std::iota(h_predt.rbegin(), h_predt.rbegin() + 8, 0.0f);
std::iota(h_predt.rbegin() + 8, h_predt.rend(), 0.0f);
auto p_cache = std::make_shared<ltr::MAPCache>(ctx, info, param);
predt.SetDevice(ctx->gpu_id);
auto rank_idx =
p_cache->SortedIdx(ctx, ctx->IsCPU() ? predt.ConstHostSpan() : predt.ConstDeviceSpan());
if (ctx->IsCPU()) {
obj::cpu_impl::MAPStat(ctx, info.labels.HostView().Slice(linalg::All(), 0), rank_idx,
p_cache);
} else {
obj::cuda_impl::MAPStat(ctx, info, rank_idx, p_cache);
}
Context cpu_ctx;
auto n_rel = p_cache->NumRelevant(&cpu_ctx);
ASSERT_EQ(n_rel[7], 8); // first group
ASSERT_EQ(n_rel.back(), 0); // second group
}
}
TEST(LambdaRank, MAPStat) {
Context ctx;
TestMAPStat(&ctx);
}
void TestMAPGPair(Context const* ctx) {
std::unique_ptr<xgboost::ObjFunction> obj{xgboost::ObjFunction::Create("rank:map", ctx)};
Args args;
obj->Configure(args);
CheckConfigReload(obj, "rank:map");
CheckRankingObjFunction(obj, // obj
{0, 0.1f, 0, 0.1f}, // score
{0, 1, 0, 1}, // label
{2.0f, 2.0f}, // weight
{0, 2, 4}, // group
{1.2054923f, -1.2054923f, 1.2054923f, -1.2054923f}, // out grad
{1.2657166f, 1.2657166f, 1.2657166f, 1.2657166f});
// disable the second query group with 0 weight
CheckRankingObjFunction(obj, // obj
{0, 0.1f, 0, 0.1f}, // score
{0, 1, 0, 1}, // label
{2.0f, 0.0f}, // weight
{0, 2, 4}, // group
{1.2054923f, -1.2054923f, .0f, .0f}, // out grad
{1.2657166f, 1.2657166f, .0f, .0f});
}
TEST(LambdaRank, MAPGPair) {
Context ctx;
TestMAPGPair(&ctx);
}
void TestPairWiseGPair(Context const* ctx) {
std::unique_ptr<xgboost::ObjFunction> obj{xgboost::ObjFunction::Create("rank:pairwise", ctx)};
Args args;
obj->Configure(args);
args.emplace_back("lambdarank_unbiased", "true");
}
TEST(LambdaRank, Pairwise) {
Context ctx;
TestPairWiseGPair(&ctx);
}
} // namespace xgboost::obj