Rework the MAP metric. (#8931)

- The new implementation is more strict as only binary labels are accepted. The previous implementation converts values greater than 1 to 1.
- Deterministic GPU. (no atomic add).
- Fix top-k handling.
- Precise definition of MAP. (There are other variants on how to handle top-k).
- Refactor GPU ranking tests.
This commit is contained in:
Jiaming Yuan
2023-03-22 17:45:20 +08:00
committed by GitHub
parent b240f055d3
commit 5891f752c8
18 changed files with 458 additions and 323 deletions

View File

@@ -284,37 +284,6 @@ struct EvalPrecision : public EvalRank {
}
};
/*! \brief Mean Average Precision at N, for both classification and rank */
struct EvalMAP : public EvalRank {
public:
explicit EvalMAP(const char* name, const char* param) : EvalRank(name, param) {}
double EvalGroup(PredIndPairContainer *recptr) const override {
PredIndPairContainer &rec(*recptr);
std::stable_sort(rec.begin(), rec.end(), common::CmpFirst);
unsigned nhits = 0;
double sumap = 0.0;
for (size_t i = 0; i < rec.size(); ++i) {
if (rec[i].second != 0) {
nhits += 1;
if (i < this->topn) {
sumap += static_cast<double>(nhits) / (i + 1);
}
}
}
if (nhits != 0) {
sumap /= nhits;
return sumap;
} else {
if (this->minus) {
return 0.0;
} else {
return 1.0;
}
}
}
};
/*! \brief Cox: Partial likelihood of the Cox proportional hazards model */
struct EvalCox : public MetricNoCache {
public:
@@ -370,10 +339,6 @@ XGBOOST_REGISTER_METRIC(Precision, "pre")
.describe("precision@k for rank.")
.set_body([](const char* param) { return new EvalPrecision("pre", param); });
XGBOOST_REGISTER_METRIC(MAP, "map")
.describe("map@k for rank.")
.set_body([](const char* param) { return new EvalMAP("map", param); });
XGBOOST_REGISTER_METRIC(Cox, "cox-nloglik")
.describe("Negative log partial likelihood of Cox proportional hazards model.")
.set_body([](const char*) { return new EvalCox(); });
@@ -516,6 +481,68 @@ class EvalNDCG : public EvalRankWithCache<ltr::NDCGCache> {
}
};
class EvalMAPScore : public EvalRankWithCache<ltr::MAPCache> {
public:
using EvalRankWithCache::EvalRankWithCache;
const char* Name() const override { return name_.c_str(); }
double Eval(HostDeviceVector<float> const& predt, MetaInfo const& info,
std::shared_ptr<ltr::MAPCache> p_cache) override {
if (ctx_->IsCUDA()) {
auto map = cuda_impl::MAPScore(ctx_, info, predt, minus_, p_cache);
return Finalize(map.Residue(), map.Weights());
}
auto gptr = p_cache->DataGroupPtr(ctx_);
auto h_label = info.labels.HostView().Slice(linalg::All(), 0);
auto h_predt = linalg::MakeTensorView(ctx_, &predt, predt.Size());
auto map_gloc = p_cache->Map(ctx_);
std::fill_n(map_gloc.data(), map_gloc.size(), 0.0);
auto rank_idx = p_cache->SortedIdx(ctx_, predt.ConstHostSpan());
common::ParallelFor(p_cache->Groups(), ctx_->Threads(), [&](auto g) {
auto g_predt = h_predt.Slice(linalg::Range(gptr[g], gptr[g + 1]));
auto g_label = h_label.Slice(linalg::Range(gptr[g], gptr[g + 1]));
auto g_rank = rank_idx.subspan(gptr[g]);
auto n = std::min(static_cast<std::size_t>(param_.TopK()), g_label.Size());
double n_hits{0.0};
for (std::size_t i = 0; i < n; ++i) {
auto p = g_label(g_rank[i]);
n_hits += p;
map_gloc[g] += n_hits / static_cast<double>((i + 1)) * p;
}
for (std::size_t i = n; i < g_label.Size(); ++i) {
n_hits += g_label(g_rank[i]);
}
if (n_hits > 0.0) {
map_gloc[g] /= std::min(n_hits, static_cast<double>(param_.TopK()));
} else {
map_gloc[g] = minus_ ? 0.0 : 1.0;
}
});
auto sw = 0.0;
auto weight = common::MakeOptionalWeights(ctx_, info.weights_);
if (!weight.Empty()) {
CHECK_EQ(weight.weights.size(), p_cache->Groups());
}
for (std::size_t i = 0; i < map_gloc.size(); ++i) {
map_gloc[i] = map_gloc[i] * weight[i];
sw += weight[i];
}
auto sum = std::accumulate(map_gloc.cbegin(), map_gloc.cend(), 0.0);
return Finalize(sum, sw);
}
};
XGBOOST_REGISTER_METRIC(EvalMAP, "map")
.describe("map@k for ranking.")
.set_body([](char const* param) {
return new EvalMAPScore{"map", param};
});
XGBOOST_REGISTER_METRIC(EvalNDCG, "ndcg")
.describe("ndcg@k for ranking.")
.set_body([](char const* param) {

View File

@@ -125,89 +125,10 @@ struct EvalPrecisionGpu {
};
/*! \brief Mean Average Precision at N, for both classification and rank */
struct EvalMAPGpu {
public:
static double EvalMetric(const dh::SegmentSorter<float> &pred_sorter,
const float *dlabels,
const EvalRankConfig &ecfg) {
// Group info on device
const auto &dgroups = pred_sorter.GetGroupsSpan();
const auto ngroups = pred_sorter.GetNumGroups();
const auto &dgroup_idx = pred_sorter.GetGroupSegmentsSpan();
// Original positions of the predictions after they have been sorted
const auto &dpreds_orig_pos = pred_sorter.GetOriginalPositionsSpan();
// First, determine non zero labels in the dataset individually
const auto nitems = pred_sorter.GetNumItems();
dh::caching_device_vector<uint32_t> hits(nitems, 0);
auto DetermineNonTrivialLabelLambda = [=] __device__(uint32_t idx) {
return (static_cast<unsigned>(dlabels[dpreds_orig_pos[idx]]) != 0) ? 1 : 0;
}; // NOLINT
thrust::transform(thrust::make_counting_iterator(static_cast<uint32_t>(0)),
thrust::make_counting_iterator(nitems),
hits.begin(),
DetermineNonTrivialLabelLambda);
// Allocator to be used by sort for managing space overhead while performing prefix scans
dh::XGBCachingDeviceAllocator<char> alloc;
// Next, prefix scan the nontrivial labels that are segmented to accumulate them.
// This is required for computing the metric sum
// Data segmented into different groups...
thrust::inclusive_scan_by_key(thrust::cuda::par(alloc),
dh::tcbegin(dgroup_idx), dh::tcend(dgroup_idx),
hits.begin(), // Input value
hits.begin()); // In-place scan
// Find each group's metric sum
dh::caching_device_vector<double> sumap(ngroups, 0);
auto *dsumap = sumap.data().get();
const auto *dhits = hits.data().get();
int device_id = -1;
dh::safe_cuda(cudaGetDevice(&device_id));
// For each group item compute the aggregated precision
dh::LaunchN(nitems, nullptr, [=] __device__(uint32_t idx) {
if (DetermineNonTrivialLabelLambda(idx)) {
const auto group_idx = dgroup_idx[idx];
const auto group_begin = dgroups[group_idx];
const auto ridx = idx - group_begin;
if (ridx < ecfg.topn) {
atomicAdd(&dsumap[group_idx],
static_cast<double>(dhits[idx]) / (ridx + 1));
}
}
});
// Aggregate the group's item precisions
dh::LaunchN(ngroups, nullptr, [=] __device__(uint32_t gidx) {
auto nhits = dgroups[gidx + 1] ? dhits[dgroups[gidx + 1] - 1] : 0;
if (nhits != 0) {
dsumap[gidx] /= nhits;
} else {
if (ecfg.minus) {
dsumap[gidx] = 0;
} else {
dsumap[gidx] = 1;
}
}
});
return thrust::reduce(thrust::cuda::par(alloc), sumap.begin(), sumap.end());
}
};
XGBOOST_REGISTER_GPU_METRIC(PrecisionGpu, "pre")
.describe("precision@k for rank computed on GPU.")
.set_body([](const char* param) { return new EvalRankGpu<EvalPrecisionGpu>("pre", param); });
XGBOOST_REGISTER_GPU_METRIC(MAPGpu, "map")
.describe("map@k for rank computed on GPU.")
.set_body([](const char* param) { return new EvalRankGpu<EvalMAPGpu>("map", param); });
namespace cuda_impl {
PackedReduceResult NDCGScore(Context const *ctx, MetaInfo const &info,
HostDeviceVector<float> const &predt, bool minus,
@@ -245,5 +166,87 @@ PackedReduceResult NDCGScore(Context const *ctx, MetaInfo const &info,
PackedReduceResult{0.0, 0.0});
return pair;
}
PackedReduceResult MAPScore(Context const *ctx, MetaInfo const &info,
HostDeviceVector<float> const &predt, bool minus,
std::shared_ptr<ltr::MAPCache> p_cache) {
auto d_group_ptr = p_cache->DataGroupPtr(ctx);
auto n_groups = info.group_ptr_.size() - 1;
auto d_label = info.labels.View(ctx->gpu_id).Slice(linalg::All(), 0);
predt.SetDevice(ctx->gpu_id);
auto d_rank_idx = p_cache->SortedIdx(ctx, predt.ConstDeviceSpan());
auto key_it = dh::MakeTransformIterator<std::size_t>(
thrust::make_counting_iterator(0ul),
[=] XGBOOST_DEVICE(std::size_t i) { return dh::SegmentId(d_group_ptr, i); });
auto get_label = [=] XGBOOST_DEVICE(std::size_t i) {
auto g = key_it[i];
auto g_begin = d_group_ptr[g];
auto g_end = d_group_ptr[g + 1];
i -= g_begin;
auto g_label = d_label.Slice(linalg::Range(g_begin, g_end));
auto g_rank = d_rank_idx.subspan(g_begin, g_end - g_begin);
return g_label(g_rank[i]);
};
auto it = dh::MakeTransformIterator<double>(thrust::make_counting_iterator(0ul), get_label);
auto cuctx = ctx->CUDACtx();
auto n_rel = p_cache->NumRelevant(ctx);
thrust::inclusive_scan_by_key(cuctx->CTP(), key_it, key_it + d_label.Size(), it, n_rel.data());
double topk = p_cache->Param().TopK();
auto map = p_cache->Map(ctx);
thrust::fill_n(cuctx->CTP(), map.data(), map.size(), 0.0);
{
auto val_it = dh::MakeTransformIterator<double>(
thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(std::size_t i) {
auto g = key_it[i];
auto g_begin = d_group_ptr[g];
auto g_end = d_group_ptr[g + 1];
i -= g_begin;
if (i >= topk) {
return 0.0;
}
auto g_label = d_label.Slice(linalg::Range(g_begin, g_end));
auto g_rank = d_rank_idx.subspan(g_begin, g_end - g_begin);
auto label = g_label(g_rank[i]);
auto g_n_rel = n_rel.subspan(g_begin, g_end - g_begin);
auto nhits = g_n_rel[i];
return nhits / static_cast<double>(i + 1) * label;
});
std::size_t bytes;
cub::DeviceSegmentedReduce::Sum(nullptr, bytes, val_it, map.data(), p_cache->Groups(),
d_group_ptr.data(), d_group_ptr.data() + 1, cuctx->Stream());
dh::TemporaryArray<char> temp(bytes);
cub::DeviceSegmentedReduce::Sum(temp.data().get(), bytes, val_it, map.data(), p_cache->Groups(),
d_group_ptr.data(), d_group_ptr.data() + 1, cuctx->Stream());
}
PackedReduceResult result{0.0, 0.0};
{
auto d_weight = common::MakeOptionalWeights(ctx, info.weights_);
if (!d_weight.Empty()) {
CHECK_EQ(d_weight.weights.size(), p_cache->Groups());
}
auto val_it = dh::MakeTransformIterator<PackedReduceResult>(
thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(std::size_t g) {
auto g_begin = d_group_ptr[g];
auto g_end = d_group_ptr[g + 1];
auto g_n_rel = n_rel.subspan(g_begin, g_end - g_begin);
if (!g_n_rel.empty() && g_n_rel.back() > 0.0) {
return PackedReduceResult{map[g] * d_weight[g] / std::min(g_n_rel.back(), topk),
static_cast<double>(d_weight[g])};
}
return PackedReduceResult{minus ? 0.0 : 1.0, static_cast<double>(d_weight[g])};
});
result =
thrust::reduce(cuctx->CTP(), val_it, val_it + map.size(), PackedReduceResult{0.0, 0.0});
}
return result;
}
} // namespace cuda_impl
} // namespace xgboost::metric

View File

@@ -6,7 +6,7 @@
#include <memory> // for shared_ptr
#include "../common/common.h" // for AssertGPUSupport
#include "../common/ranking_utils.h" // for NDCGCache
#include "../common/ranking_utils.h" // for NDCGCache, MAPCache
#include "metric_common.h" // for PackedReduceResult
#include "xgboost/context.h" // for Context
#include "xgboost/data.h" // for MetaInfo
@@ -19,6 +19,10 @@ PackedReduceResult NDCGScore(Context const *ctx, MetaInfo const &info,
HostDeviceVector<float> const &predt, bool minus,
std::shared_ptr<ltr::NDCGCache> p_cache);
PackedReduceResult MAPScore(Context const *ctx, MetaInfo const &info,
HostDeviceVector<float> const &predt, bool minus,
std::shared_ptr<ltr::MAPCache> p_cache);
#if !defined(XGBOOST_USE_CUDA)
inline PackedReduceResult NDCGScore(Context const *, MetaInfo const &,
HostDeviceVector<float> const &, bool,
@@ -26,6 +30,13 @@ inline PackedReduceResult NDCGScore(Context const *, MetaInfo const &,
common::AssertGPUSupport();
return {};
}
inline PackedReduceResult MAPScore(Context const *, MetaInfo const &,
HostDeviceVector<float> const &, bool,
std::shared_ptr<ltr::MAPCache>) {
common::AssertGPUSupport();
return {};
}
#endif
} // namespace cuda_impl
} // namespace metric