Rework the MAP metric. (#8931)

- The new implementation is more strict as only binary labels are accepted. The previous implementation converts values greater than 1 to 1.
- Deterministic GPU. (no atomic add).
- Fix top-k handling.
- Precise definition of MAP. (There are other variants on how to handle top-k).
- Refactor GPU ranking tests.
This commit is contained in:
Jiaming Yuan
2023-03-22 17:45:20 +08:00
committed by GitHub
parent b240f055d3
commit 5891f752c8
18 changed files with 458 additions and 323 deletions

View File

@@ -177,4 +177,36 @@ TEST(NDCGCache, InitFromCPU) {
Context ctx;
TestNDCGCache(&ctx);
}
void TestMAPCache(Context const* ctx) {
auto p_fmat = EmptyDMatrix();
MetaInfo& info = p_fmat->Info();
LambdaRankParam param;
param.UpdateAllowUnknown(Args{});
std::vector<float> h_data(32);
common::Iota(ctx, h_data.begin(), h_data.end(), 0.0f);
info.labels.Reshape(h_data.size());
info.num_row_ = h_data.size();
info.labels.Data()->HostVector() = std::move(h_data);
auto fail = [&]() { std::make_shared<MAPCache>(ctx, info, param); };
// binary label
ASSERT_THROW(fail(), dmlc::Error);
h_data = std::vector<float>(32, 0.0f);
h_data[1] = 1.0f;
info.labels.Data()->HostVector() = h_data;
auto p_cache = std::make_shared<MAPCache>(ctx, info, param);
ASSERT_EQ(p_cache->Acc(ctx).size(), info.num_row_);
ASSERT_EQ(p_cache->NumRelevant(ctx).size(), info.num_row_);
}
TEST(MAPCache, InitFromCPU) {
Context ctx;
ctx.Init(Args{});
TestMAPCache(&ctx);
}
} // namespace xgboost::ltr

View File

@@ -95,4 +95,10 @@ TEST(NDCGCache, InitFromGPU) {
ctx.UpdateAllowUnknown(Args{{"gpu_id", "0"}});
TestNDCGCache(&ctx);
}
TEST(MAPCache, InitFromGPU) {
Context ctx;
ctx.UpdateAllowUnknown(Args{{"gpu_id", "0"}});
TestMAPCache(&ctx);
}
} // namespace xgboost::ltr

View File

@@ -6,4 +6,6 @@
namespace xgboost::ltr {
void TestNDCGCache(Context const* ctx);
void TestMAPCache(Context const* ctx);
} // namespace xgboost::ltr

View File

@@ -141,7 +141,7 @@ TEST(Metric, DeclareUnifiedTest(MAP)) {
// Rank metric with group info
EXPECT_NEAR(GetMetricEval(metric,
{0.1f, 0.9f, 0.2f, 0.8f, 0.4f, 1.7f},
{2, 7, 1, 0, 5, 0}, // Labels
{1, 1, 1, 0, 1, 0}, // Labels
{}, // Weights
{0, 2, 5, 6}), // Group info
0.8611f, 0.001f);