Rework the MAP metric. (#8931)

- The new implementation is more strict as only binary labels are accepted. The previous implementation converts values greater than 1 to 1.
- Deterministic GPU. (no atomic add).
- Fix top-k handling.
- Precise definition of MAP. (There are other variants on how to handle top-k).
- Refactor GPU ranking tests.
This commit is contained in:
Jiaming Yuan 2023-03-22 17:45:20 +08:00 committed by GitHub
parent b240f055d3
commit 5891f752c8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 458 additions and 323 deletions

View File

@ -408,8 +408,17 @@ Specify the learning task and the corresponding learning objective. The objectiv
- ``ndcg``: `Normalized Discounted Cumulative Gain <http://en.wikipedia.org/wiki/NDCG>`_ - ``ndcg``: `Normalized Discounted Cumulative Gain <http://en.wikipedia.org/wiki/NDCG>`_
- ``map``: `Mean Average Precision <http://en.wikipedia.org/wiki/Mean_average_precision#Mean_average_precision>`_ - ``map``: `Mean Average Precision <http://en.wikipedia.org/wiki/Mean_average_precision#Mean_average_precision>`_
- ``ndcg@n``, ``map@n``: 'n' can be assigned as an integer to cut off the top positions in the lists for evaluation.
- ``ndcg-``, ``map-``, ``ndcg@n-``, ``map@n-``: In XGBoost, NDCG and MAP will evaluate the score of a list without any positive samples as 1. By adding "-" in the evaluation metric XGBoost will evaluate these score as 0 to be consistent under some conditions. The `average precision` is defined as:
.. math::
AP@l = \frac{1}{min{(l, N)}}\sum^l_{k=1}P@k \cdot I_{(k)}
where :math:`I_{(k)}` is an indicator function that equals to :math:`1` when the document at :math:`k` is relevant and :math:`0` otherwise. The :math:`P@k` is the precision at :math:`k`, and :math:`N` is the total number of relevant documents. Lastly, the `mean average precision` is defined as the weighted average across all queries.
- ``ndcg@n``, ``map@n``: :math:`n` can be assigned as an integer to cut off the top positions in the lists for evaluation.
- ``ndcg-``, ``map-``, ``ndcg@n-``, ``map@n-``: In XGBoost, the NDCG and MAP evaluate the score of a list without any positive samples as :math:`1`. By appending "-" to the evaluation metric name, we can ask XGBoost to evaluate these scores as :math:`0` to be consistent under some conditions.
- ``poisson-nloglik``: negative log-likelihood for Poisson regression - ``poisson-nloglik``: negative log-likelihood for Poisson regression
- ``gamma-nloglik``: negative log-likelihood for gamma regression - ``gamma-nloglik``: negative log-likelihood for gamma regression
- ``cox-nloglik``: negative partial log-likelihood for Cox proportional hazards regression - ``cox-nloglik``: negative partial log-likelihood for Cox proportional hazards regression

View File

@ -14,6 +14,7 @@ import zipfile
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager from contextlib import contextmanager
from io import StringIO from io import StringIO
from pathlib import Path
from platform import system from platform import system
from typing import ( from typing import (
Any, Any,
@ -443,7 +444,7 @@ def get_mq2008(
from sklearn.datasets import load_svmlight_files from sklearn.datasets import load_svmlight_files
src = "https://s3-us-west-2.amazonaws.com/xgboost-examples/MQ2008.zip" src = "https://s3-us-west-2.amazonaws.com/xgboost-examples/MQ2008.zip"
target = dpath + "/MQ2008.zip" target = os.path.join(os.path.expanduser(dpath), "MQ2008.zip")
if not os.path.exists(target): if not os.path.exists(target):
request.urlretrieve(url=src, filename=target) request.urlretrieve(url=src, filename=target)
@ -462,9 +463,9 @@ def get_mq2008(
qid_valid, qid_valid,
) = load_svmlight_files( ) = load_svmlight_files(
( (
dpath + "MQ2008/Fold1/train.txt", Path(dpath) / "MQ2008" / "Fold1" / "train.txt",
dpath + "MQ2008/Fold1/test.txt", Path(dpath) / "MQ2008" / "Fold1" / "test.txt",
dpath + "MQ2008/Fold1/vali.txt", Path(dpath) / "MQ2008" / "Fold1" / "vali.txt",
), ),
query_id=True, query_id=True,
zero_based=False, zero_based=False,

View File

@ -48,7 +48,12 @@ def run_ranking_qid_df(impl: ModuleType, tree_method: str) -> None:
def neg_mse(*args: Any, **kwargs: Any) -> float: def neg_mse(*args: Any, **kwargs: Any) -> float:
return -float(mean_squared_error(*args, **kwargs)) return -float(mean_squared_error(*args, **kwargs))
ranker = xgb.XGBRanker(n_estimators=3, eval_metric=neg_mse, tree_method=tree_method) ranker = xgb.XGBRanker(
n_estimators=3,
eval_metric=neg_mse,
tree_method=tree_method,
disable_default_eval_metric=True,
)
ranker.fit(df, y, eval_set=[(valid_df, y)]) ranker.fit(df, y, eval_set=[(valid_df, y)])
score = ranker.score(valid_df, y) score = ranker.score(valid_df, y)
assert np.isclose(score, ranker.evals_result()["validation_0"]["neg_mse"][-1]) assert np.isclose(score, ranker.evals_result()["validation_0"]["neg_mse"][-1])

View File

@ -22,7 +22,7 @@ constexpr StringView LabelScoreSize() {
} }
constexpr StringView InfInData() { constexpr StringView InfInData() {
return "Input data contains `inf` while `missing` is not set to `inf`"; return "Input data contains `inf` or a value too large, while `missing` is not set to `inf`";
} }
} // namespace xgboost::error } // namespace xgboost::error
#endif // XGBOOST_COMMON_ERROR_MSG_H_ #endif // XGBOOST_COMMON_ERROR_MSG_H_

View File

@ -1,13 +1,15 @@
/*! /**
* Copyright 2022, XGBoost contributors. * Copyright 2022-2023 by XGBoost contributors.
*/ */
#ifndef XGBOOST_COMMON_NUMERIC_H_ #ifndef XGBOOST_COMMON_NUMERIC_H_
#define XGBOOST_COMMON_NUMERIC_H_ #define XGBOOST_COMMON_NUMERIC_H_
#include <dmlc/common.h> // OMPException #include <dmlc/common.h> // OMPException
#include <algorithm> // std::max #include <algorithm> // for std::max
#include <iterator> // std::iterator_traits #include <cstddef> // for size_t
#include <cstdint> // for int32_t
#include <iterator> // for iterator_traits
#include <vector> #include <vector>
#include "common.h" // AssertGPUSupport #include "common.h" // AssertGPUSupport
@ -15,8 +17,7 @@
#include "xgboost/context.h" // Context #include "xgboost/context.h" // Context
#include "xgboost/host_device_vector.h" // HostDeviceVector #include "xgboost/host_device_vector.h" // HostDeviceVector
namespace xgboost { namespace xgboost::common {
namespace common {
/** /**
* \brief Run length encode on CPU, input must be sorted. * \brief Run length encode on CPU, input must be sorted.
@ -111,11 +112,11 @@ inline double Reduce(Context const*, HostDeviceVector<float> const&) {
namespace cpu_impl { namespace cpu_impl {
template <typename It, typename V = typename It::value_type> template <typename It, typename V = typename It::value_type>
V Reduce(Context const* ctx, It first, It second, V const& init) { V Reduce(Context const* ctx, It first, It second, V const& init) {
size_t n = std::distance(first, second); std::size_t n = std::distance(first, second);
common::MemStackAllocator<V, common::DefaultMaxThreads()> result_tloc(ctx->Threads(), init); auto n_threads = static_cast<std::size_t>(std::min(n, static_cast<std::size_t>(ctx->Threads())));
common::ParallelFor(n, ctx->Threads(), common::MemStackAllocator<V, common::DefaultMaxThreads()> result_tloc(n_threads, init);
[&](auto i) { result_tloc[omp_get_thread_num()] += first[i]; }); common::ParallelFor(n, n_threads, [&](auto i) { result_tloc[omp_get_thread_num()] += first[i]; });
auto result = std::accumulate(result_tloc.cbegin(), result_tloc.cbegin() + ctx->Threads(), init); auto result = std::accumulate(result_tloc.cbegin(), result_tloc.cbegin() + n_threads, init);
return result; return result;
} }
} // namespace cpu_impl } // namespace cpu_impl
@ -144,7 +145,6 @@ void Iota(Context const* ctx, It first, It last,
}); });
} }
} }
} // namespace common } // namespace xgboost::common
} // namespace xgboost
#endif // XGBOOST_COMMON_NUMERIC_H_ #endif // XGBOOST_COMMON_NUMERIC_H_

View File

@ -114,6 +114,15 @@ void NDCGCache::InitOnCUDA(Context const*, MetaInfo const&) { common::AssertGPUS
DMLC_REGISTER_PARAMETER(LambdaRankParam); DMLC_REGISTER_PARAMETER(LambdaRankParam);
void MAPCache::InitOnCPU(Context const*, MetaInfo const& info) {
auto const& h_label = info.labels.HostView().Slice(linalg::All(), 0);
CheckMapLabels(h_label, [](auto beg, auto end, auto op) { return std::all_of(beg, end, op); });
}
#if !defined(XGBOOST_USE_CUDA)
void MAPCache::InitOnCUDA(Context const*, MetaInfo const&) { common::AssertGPUSupport(); }
#endif // !defined(XGBOOST_USE_CUDA)
std::string ParseMetricName(StringView name, StringView param, position_t* topn, bool* minus) { std::string ParseMetricName(StringView name, StringView param, position_t* topn, bool* minus) {
std::string out_name; std::string out_name;
if (!param.empty()) { if (!param.empty()) {

View File

@ -204,4 +204,9 @@ void NDCGCache::InitOnCUDA(Context const* ctx, MetaInfo const& info) {
dh::LaunchN(MaxGroupSize(), cuctx->Stream(), dh::LaunchN(MaxGroupSize(), cuctx->Stream(),
[=] XGBOOST_DEVICE(std::size_t i) { d_discount[i] = CalcDCGDiscount(i); }); [=] XGBOOST_DEVICE(std::size_t i) { d_discount[i] = CalcDCGDiscount(i); });
} }
void MAPCache::InitOnCUDA(Context const* ctx, MetaInfo const& info) {
auto const d_label = info.labels.View(ctx->gpu_id).Slice(linalg::All(), 0);
CheckMapLabels(d_label, CheckMAPOp{ctx->CUDACtx()});
}
} // namespace xgboost::ltr } // namespace xgboost::ltr

View File

@ -358,6 +358,71 @@ void CheckNDCGLabels(ltr::LambdaRankParam const& p, linalg::VectorView<float con
} }
} }
template <typename AllOf>
bool IsBinaryRel(linalg::VectorView<float const> label, AllOf all_of) {
auto s_label = label.Values();
return all_of(s_label.data(), s_label.data() + s_label.size(), [] XGBOOST_DEVICE(float y) {
return std::abs(y - 1.0f) < kRtEps || std::abs(y - 0.0f) < kRtEps;
});
}
/**
* \brief Validate label for MAP
*
* \tparam Implementation of std::all_of. Specified as a parameter to reuse the check for
* both CPU and GPU.
*/
template <typename AllOf>
void CheckMapLabels(linalg::VectorView<float const> label, AllOf all_of) {
auto s_label = label.Values();
auto is_binary = IsBinaryRel(label, all_of);
CHECK(is_binary) << "MAP can only be used with binary labels.";
}
class MAPCache : public RankingCache {
// Total number of relevant documents for each group
HostDeviceVector<double> n_rel_;
// \sum l_k/k
HostDeviceVector<double> acc_;
HostDeviceVector<double> map_;
// Number of samples in this dataset.
std::size_t n_samples_{0};
void InitOnCPU(Context const* ctx, MetaInfo const& info);
void InitOnCUDA(Context const* ctx, MetaInfo const& info);
public:
MAPCache(Context const* ctx, MetaInfo const& info, LambdaRankParam const& p)
: RankingCache{ctx, info, p}, n_samples_{static_cast<std::size_t>(info.num_row_)} {
if (ctx->IsCPU()) {
this->InitOnCPU(ctx, info);
} else {
this->InitOnCUDA(ctx, info);
}
}
common::Span<double> NumRelevant(Context const* ctx) {
if (n_rel_.Empty()) {
n_rel_.SetDevice(ctx->gpu_id);
n_rel_.Resize(n_samples_);
}
return ctx->IsCPU() ? n_rel_.HostSpan() : n_rel_.DeviceSpan();
}
common::Span<double> Acc(Context const* ctx) {
if (acc_.Empty()) {
acc_.SetDevice(ctx->gpu_id);
acc_.Resize(n_samples_);
}
return ctx->IsCPU() ? acc_.HostSpan() : acc_.DeviceSpan();
}
common::Span<double> Map(Context const* ctx) {
if (map_.Empty()) {
map_.SetDevice(ctx->gpu_id);
map_.Resize(this->Groups());
}
return ctx->IsCPU() ? map_.HostSpan() : map_.DeviceSpan();
}
};
/** /**
* \brief Parse name for ranking metric given parameters. * \brief Parse name for ranking metric given parameters.
* *

View File

@ -8,9 +8,11 @@
#include <dmlc/omp.h> #include <dmlc/omp.h>
#include <algorithm> #include <algorithm>
#include <cstdint> // std::int32_t #include <cstdint> // for int32_t
#include <cstdlib> // for malloc, free
#include <limits> #include <limits>
#include <type_traits> // std::is_signed #include <new> // for bad_alloc
#include <type_traits> // for is_signed
#include <vector> #include <vector>
#include "xgboost/logging.h" #include "xgboost/logging.h"
@ -266,7 +268,7 @@ class MemStackAllocator {
if (MaxStackSize >= required_size_) { if (MaxStackSize >= required_size_) {
ptr_ = stack_mem_; ptr_ = stack_mem_;
} else { } else {
ptr_ = reinterpret_cast<T*>(malloc(required_size_ * sizeof(T))); ptr_ = reinterpret_cast<T*>(std::malloc(required_size_ * sizeof(T)));
} }
if (!ptr_) { if (!ptr_) {
throw std::bad_alloc{}; throw std::bad_alloc{};
@ -278,7 +280,7 @@ class MemStackAllocator {
~MemStackAllocator() { ~MemStackAllocator() {
if (required_size_ > MaxStackSize) { if (required_size_ > MaxStackSize) {
free(ptr_); std::free(ptr_);
} }
} }
T& operator[](size_t i) { return ptr_[i]; } T& operator[](size_t i) { return ptr_[i]; }

View File

@ -284,37 +284,6 @@ struct EvalPrecision : public EvalRank {
} }
}; };
/*! \brief Mean Average Precision at N, for both classification and rank */
struct EvalMAP : public EvalRank {
public:
explicit EvalMAP(const char* name, const char* param) : EvalRank(name, param) {}
double EvalGroup(PredIndPairContainer *recptr) const override {
PredIndPairContainer &rec(*recptr);
std::stable_sort(rec.begin(), rec.end(), common::CmpFirst);
unsigned nhits = 0;
double sumap = 0.0;
for (size_t i = 0; i < rec.size(); ++i) {
if (rec[i].second != 0) {
nhits += 1;
if (i < this->topn) {
sumap += static_cast<double>(nhits) / (i + 1);
}
}
}
if (nhits != 0) {
sumap /= nhits;
return sumap;
} else {
if (this->minus) {
return 0.0;
} else {
return 1.0;
}
}
}
};
/*! \brief Cox: Partial likelihood of the Cox proportional hazards model */ /*! \brief Cox: Partial likelihood of the Cox proportional hazards model */
struct EvalCox : public MetricNoCache { struct EvalCox : public MetricNoCache {
public: public:
@ -370,10 +339,6 @@ XGBOOST_REGISTER_METRIC(Precision, "pre")
.describe("precision@k for rank.") .describe("precision@k for rank.")
.set_body([](const char* param) { return new EvalPrecision("pre", param); }); .set_body([](const char* param) { return new EvalPrecision("pre", param); });
XGBOOST_REGISTER_METRIC(MAP, "map")
.describe("map@k for rank.")
.set_body([](const char* param) { return new EvalMAP("map", param); });
XGBOOST_REGISTER_METRIC(Cox, "cox-nloglik") XGBOOST_REGISTER_METRIC(Cox, "cox-nloglik")
.describe("Negative log partial likelihood of Cox proportional hazards model.") .describe("Negative log partial likelihood of Cox proportional hazards model.")
.set_body([](const char*) { return new EvalCox(); }); .set_body([](const char*) { return new EvalCox(); });
@ -516,6 +481,68 @@ class EvalNDCG : public EvalRankWithCache<ltr::NDCGCache> {
} }
}; };
class EvalMAPScore : public EvalRankWithCache<ltr::MAPCache> {
public:
using EvalRankWithCache::EvalRankWithCache;
const char* Name() const override { return name_.c_str(); }
double Eval(HostDeviceVector<float> const& predt, MetaInfo const& info,
std::shared_ptr<ltr::MAPCache> p_cache) override {
if (ctx_->IsCUDA()) {
auto map = cuda_impl::MAPScore(ctx_, info, predt, minus_, p_cache);
return Finalize(map.Residue(), map.Weights());
}
auto gptr = p_cache->DataGroupPtr(ctx_);
auto h_label = info.labels.HostView().Slice(linalg::All(), 0);
auto h_predt = linalg::MakeTensorView(ctx_, &predt, predt.Size());
auto map_gloc = p_cache->Map(ctx_);
std::fill_n(map_gloc.data(), map_gloc.size(), 0.0);
auto rank_idx = p_cache->SortedIdx(ctx_, predt.ConstHostSpan());
common::ParallelFor(p_cache->Groups(), ctx_->Threads(), [&](auto g) {
auto g_predt = h_predt.Slice(linalg::Range(gptr[g], gptr[g + 1]));
auto g_label = h_label.Slice(linalg::Range(gptr[g], gptr[g + 1]));
auto g_rank = rank_idx.subspan(gptr[g]);
auto n = std::min(static_cast<std::size_t>(param_.TopK()), g_label.Size());
double n_hits{0.0};
for (std::size_t i = 0; i < n; ++i) {
auto p = g_label(g_rank[i]);
n_hits += p;
map_gloc[g] += n_hits / static_cast<double>((i + 1)) * p;
}
for (std::size_t i = n; i < g_label.Size(); ++i) {
n_hits += g_label(g_rank[i]);
}
if (n_hits > 0.0) {
map_gloc[g] /= std::min(n_hits, static_cast<double>(param_.TopK()));
} else {
map_gloc[g] = minus_ ? 0.0 : 1.0;
}
});
auto sw = 0.0;
auto weight = common::MakeOptionalWeights(ctx_, info.weights_);
if (!weight.Empty()) {
CHECK_EQ(weight.weights.size(), p_cache->Groups());
}
for (std::size_t i = 0; i < map_gloc.size(); ++i) {
map_gloc[i] = map_gloc[i] * weight[i];
sw += weight[i];
}
auto sum = std::accumulate(map_gloc.cbegin(), map_gloc.cend(), 0.0);
return Finalize(sum, sw);
}
};
XGBOOST_REGISTER_METRIC(EvalMAP, "map")
.describe("map@k for ranking.")
.set_body([](char const* param) {
return new EvalMAPScore{"map", param};
});
XGBOOST_REGISTER_METRIC(EvalNDCG, "ndcg") XGBOOST_REGISTER_METRIC(EvalNDCG, "ndcg")
.describe("ndcg@k for ranking.") .describe("ndcg@k for ranking.")
.set_body([](char const* param) { .set_body([](char const* param) {

View File

@ -125,89 +125,10 @@ struct EvalPrecisionGpu {
}; };
/*! \brief Mean Average Precision at N, for both classification and rank */
struct EvalMAPGpu {
public:
static double EvalMetric(const dh::SegmentSorter<float> &pred_sorter,
const float *dlabels,
const EvalRankConfig &ecfg) {
// Group info on device
const auto &dgroups = pred_sorter.GetGroupsSpan();
const auto ngroups = pred_sorter.GetNumGroups();
const auto &dgroup_idx = pred_sorter.GetGroupSegmentsSpan();
// Original positions of the predictions after they have been sorted
const auto &dpreds_orig_pos = pred_sorter.GetOriginalPositionsSpan();
// First, determine non zero labels in the dataset individually
const auto nitems = pred_sorter.GetNumItems();
dh::caching_device_vector<uint32_t> hits(nitems, 0);
auto DetermineNonTrivialLabelLambda = [=] __device__(uint32_t idx) {
return (static_cast<unsigned>(dlabels[dpreds_orig_pos[idx]]) != 0) ? 1 : 0;
}; // NOLINT
thrust::transform(thrust::make_counting_iterator(static_cast<uint32_t>(0)),
thrust::make_counting_iterator(nitems),
hits.begin(),
DetermineNonTrivialLabelLambda);
// Allocator to be used by sort for managing space overhead while performing prefix scans
dh::XGBCachingDeviceAllocator<char> alloc;
// Next, prefix scan the nontrivial labels that are segmented to accumulate them.
// This is required for computing the metric sum
// Data segmented into different groups...
thrust::inclusive_scan_by_key(thrust::cuda::par(alloc),
dh::tcbegin(dgroup_idx), dh::tcend(dgroup_idx),
hits.begin(), // Input value
hits.begin()); // In-place scan
// Find each group's metric sum
dh::caching_device_vector<double> sumap(ngroups, 0);
auto *dsumap = sumap.data().get();
const auto *dhits = hits.data().get();
int device_id = -1;
dh::safe_cuda(cudaGetDevice(&device_id));
// For each group item compute the aggregated precision
dh::LaunchN(nitems, nullptr, [=] __device__(uint32_t idx) {
if (DetermineNonTrivialLabelLambda(idx)) {
const auto group_idx = dgroup_idx[idx];
const auto group_begin = dgroups[group_idx];
const auto ridx = idx - group_begin;
if (ridx < ecfg.topn) {
atomicAdd(&dsumap[group_idx],
static_cast<double>(dhits[idx]) / (ridx + 1));
}
}
});
// Aggregate the group's item precisions
dh::LaunchN(ngroups, nullptr, [=] __device__(uint32_t gidx) {
auto nhits = dgroups[gidx + 1] ? dhits[dgroups[gidx + 1] - 1] : 0;
if (nhits != 0) {
dsumap[gidx] /= nhits;
} else {
if (ecfg.minus) {
dsumap[gidx] = 0;
} else {
dsumap[gidx] = 1;
}
}
});
return thrust::reduce(thrust::cuda::par(alloc), sumap.begin(), sumap.end());
}
};
XGBOOST_REGISTER_GPU_METRIC(PrecisionGpu, "pre") XGBOOST_REGISTER_GPU_METRIC(PrecisionGpu, "pre")
.describe("precision@k for rank computed on GPU.") .describe("precision@k for rank computed on GPU.")
.set_body([](const char* param) { return new EvalRankGpu<EvalPrecisionGpu>("pre", param); }); .set_body([](const char* param) { return new EvalRankGpu<EvalPrecisionGpu>("pre", param); });
XGBOOST_REGISTER_GPU_METRIC(MAPGpu, "map")
.describe("map@k for rank computed on GPU.")
.set_body([](const char* param) { return new EvalRankGpu<EvalMAPGpu>("map", param); });
namespace cuda_impl { namespace cuda_impl {
PackedReduceResult NDCGScore(Context const *ctx, MetaInfo const &info, PackedReduceResult NDCGScore(Context const *ctx, MetaInfo const &info,
HostDeviceVector<float> const &predt, bool minus, HostDeviceVector<float> const &predt, bool minus,
@ -245,5 +166,87 @@ PackedReduceResult NDCGScore(Context const *ctx, MetaInfo const &info,
PackedReduceResult{0.0, 0.0}); PackedReduceResult{0.0, 0.0});
return pair; return pair;
} }
PackedReduceResult MAPScore(Context const *ctx, MetaInfo const &info,
HostDeviceVector<float> const &predt, bool minus,
std::shared_ptr<ltr::MAPCache> p_cache) {
auto d_group_ptr = p_cache->DataGroupPtr(ctx);
auto n_groups = info.group_ptr_.size() - 1;
auto d_label = info.labels.View(ctx->gpu_id).Slice(linalg::All(), 0);
predt.SetDevice(ctx->gpu_id);
auto d_rank_idx = p_cache->SortedIdx(ctx, predt.ConstDeviceSpan());
auto key_it = dh::MakeTransformIterator<std::size_t>(
thrust::make_counting_iterator(0ul),
[=] XGBOOST_DEVICE(std::size_t i) { return dh::SegmentId(d_group_ptr, i); });
auto get_label = [=] XGBOOST_DEVICE(std::size_t i) {
auto g = key_it[i];
auto g_begin = d_group_ptr[g];
auto g_end = d_group_ptr[g + 1];
i -= g_begin;
auto g_label = d_label.Slice(linalg::Range(g_begin, g_end));
auto g_rank = d_rank_idx.subspan(g_begin, g_end - g_begin);
return g_label(g_rank[i]);
};
auto it = dh::MakeTransformIterator<double>(thrust::make_counting_iterator(0ul), get_label);
auto cuctx = ctx->CUDACtx();
auto n_rel = p_cache->NumRelevant(ctx);
thrust::inclusive_scan_by_key(cuctx->CTP(), key_it, key_it + d_label.Size(), it, n_rel.data());
double topk = p_cache->Param().TopK();
auto map = p_cache->Map(ctx);
thrust::fill_n(cuctx->CTP(), map.data(), map.size(), 0.0);
{
auto val_it = dh::MakeTransformIterator<double>(
thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(std::size_t i) {
auto g = key_it[i];
auto g_begin = d_group_ptr[g];
auto g_end = d_group_ptr[g + 1];
i -= g_begin;
if (i >= topk) {
return 0.0;
}
auto g_label = d_label.Slice(linalg::Range(g_begin, g_end));
auto g_rank = d_rank_idx.subspan(g_begin, g_end - g_begin);
auto label = g_label(g_rank[i]);
auto g_n_rel = n_rel.subspan(g_begin, g_end - g_begin);
auto nhits = g_n_rel[i];
return nhits / static_cast<double>(i + 1) * label;
});
std::size_t bytes;
cub::DeviceSegmentedReduce::Sum(nullptr, bytes, val_it, map.data(), p_cache->Groups(),
d_group_ptr.data(), d_group_ptr.data() + 1, cuctx->Stream());
dh::TemporaryArray<char> temp(bytes);
cub::DeviceSegmentedReduce::Sum(temp.data().get(), bytes, val_it, map.data(), p_cache->Groups(),
d_group_ptr.data(), d_group_ptr.data() + 1, cuctx->Stream());
}
PackedReduceResult result{0.0, 0.0};
{
auto d_weight = common::MakeOptionalWeights(ctx, info.weights_);
if (!d_weight.Empty()) {
CHECK_EQ(d_weight.weights.size(), p_cache->Groups());
}
auto val_it = dh::MakeTransformIterator<PackedReduceResult>(
thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(std::size_t g) {
auto g_begin = d_group_ptr[g];
auto g_end = d_group_ptr[g + 1];
auto g_n_rel = n_rel.subspan(g_begin, g_end - g_begin);
if (!g_n_rel.empty() && g_n_rel.back() > 0.0) {
return PackedReduceResult{map[g] * d_weight[g] / std::min(g_n_rel.back(), topk),
static_cast<double>(d_weight[g])};
}
return PackedReduceResult{minus ? 0.0 : 1.0, static_cast<double>(d_weight[g])};
});
result =
thrust::reduce(cuctx->CTP(), val_it, val_it + map.size(), PackedReduceResult{0.0, 0.0});
}
return result;
}
} // namespace cuda_impl } // namespace cuda_impl
} // namespace xgboost::metric } // namespace xgboost::metric

View File

@ -6,7 +6,7 @@
#include <memory> // for shared_ptr #include <memory> // for shared_ptr
#include "../common/common.h" // for AssertGPUSupport #include "../common/common.h" // for AssertGPUSupport
#include "../common/ranking_utils.h" // for NDCGCache #include "../common/ranking_utils.h" // for NDCGCache, MAPCache
#include "metric_common.h" // for PackedReduceResult #include "metric_common.h" // for PackedReduceResult
#include "xgboost/context.h" // for Context #include "xgboost/context.h" // for Context
#include "xgboost/data.h" // for MetaInfo #include "xgboost/data.h" // for MetaInfo
@ -19,6 +19,10 @@ PackedReduceResult NDCGScore(Context const *ctx, MetaInfo const &info,
HostDeviceVector<float> const &predt, bool minus, HostDeviceVector<float> const &predt, bool minus,
std::shared_ptr<ltr::NDCGCache> p_cache); std::shared_ptr<ltr::NDCGCache> p_cache);
PackedReduceResult MAPScore(Context const *ctx, MetaInfo const &info,
HostDeviceVector<float> const &predt, bool minus,
std::shared_ptr<ltr::MAPCache> p_cache);
#if !defined(XGBOOST_USE_CUDA) #if !defined(XGBOOST_USE_CUDA)
inline PackedReduceResult NDCGScore(Context const *, MetaInfo const &, inline PackedReduceResult NDCGScore(Context const *, MetaInfo const &,
HostDeviceVector<float> const &, bool, HostDeviceVector<float> const &, bool,
@ -26,6 +30,13 @@ inline PackedReduceResult NDCGScore(Context const *, MetaInfo const &,
common::AssertGPUSupport(); common::AssertGPUSupport();
return {}; return {};
} }
inline PackedReduceResult MAPScore(Context const *, MetaInfo const &,
HostDeviceVector<float> const &, bool,
std::shared_ptr<ltr::MAPCache>) {
common::AssertGPUSupport();
return {};
}
#endif #endif
} // namespace cuda_impl } // namespace cuda_impl
} // namespace metric } // namespace metric

View File

@ -177,4 +177,36 @@ TEST(NDCGCache, InitFromCPU) {
Context ctx; Context ctx;
TestNDCGCache(&ctx); TestNDCGCache(&ctx);
} }
void TestMAPCache(Context const* ctx) {
auto p_fmat = EmptyDMatrix();
MetaInfo& info = p_fmat->Info();
LambdaRankParam param;
param.UpdateAllowUnknown(Args{});
std::vector<float> h_data(32);
common::Iota(ctx, h_data.begin(), h_data.end(), 0.0f);
info.labels.Reshape(h_data.size());
info.num_row_ = h_data.size();
info.labels.Data()->HostVector() = std::move(h_data);
auto fail = [&]() { std::make_shared<MAPCache>(ctx, info, param); };
// binary label
ASSERT_THROW(fail(), dmlc::Error);
h_data = std::vector<float>(32, 0.0f);
h_data[1] = 1.0f;
info.labels.Data()->HostVector() = h_data;
auto p_cache = std::make_shared<MAPCache>(ctx, info, param);
ASSERT_EQ(p_cache->Acc(ctx).size(), info.num_row_);
ASSERT_EQ(p_cache->NumRelevant(ctx).size(), info.num_row_);
}
TEST(MAPCache, InitFromCPU) {
Context ctx;
ctx.Init(Args{});
TestMAPCache(&ctx);
}
} // namespace xgboost::ltr } // namespace xgboost::ltr

View File

@ -95,4 +95,10 @@ TEST(NDCGCache, InitFromGPU) {
ctx.UpdateAllowUnknown(Args{{"gpu_id", "0"}}); ctx.UpdateAllowUnknown(Args{{"gpu_id", "0"}});
TestNDCGCache(&ctx); TestNDCGCache(&ctx);
} }
TEST(MAPCache, InitFromGPU) {
Context ctx;
ctx.UpdateAllowUnknown(Args{{"gpu_id", "0"}});
TestMAPCache(&ctx);
}
} // namespace xgboost::ltr } // namespace xgboost::ltr

View File

@ -6,4 +6,6 @@
namespace xgboost::ltr { namespace xgboost::ltr {
void TestNDCGCache(Context const* ctx); void TestNDCGCache(Context const* ctx);
void TestMAPCache(Context const* ctx);
} // namespace xgboost::ltr } // namespace xgboost::ltr

View File

@ -141,7 +141,7 @@ TEST(Metric, DeclareUnifiedTest(MAP)) {
// Rank metric with group info // Rank metric with group info
EXPECT_NEAR(GetMetricEval(metric, EXPECT_NEAR(GetMetricEval(metric,
{0.1f, 0.9f, 0.2f, 0.8f, 0.4f, 1.7f}, {0.1f, 0.9f, 0.2f, 0.8f, 0.4f, 1.7f},
{2, 7, 1, 0, 5, 0}, // Labels {1, 1, 1, 0, 1, 0}, // Labels
{}, // Weights {}, // Weights
{0, 2, 5, 6}), // Group info {0, 2, 5, 6}), // Group info
0.8611f, 0.001f); 0.8611f, 0.001f);

View File

@ -1,194 +1,130 @@
import itertools
import os import os
import shutil from typing import Dict
import urllib.request
import zipfile
import numpy as np import numpy as np
import pytest
import xgboost import xgboost
from xgboost import testing as tm from xgboost import testing as tm
pytestmark = tm.timeout(10) pytestmark = tm.timeout(30)
class TestRanking: def comp_training_with_rank_objective(
@classmethod dtrain: xgboost.DMatrix,
def setup_class(cls): dtest: xgboost.DMatrix,
""" rank_objective: str,
Download and setup the test fixtures metric_name: str,
""" tolerance: float = 1e-02,
from sklearn.datasets import load_svmlight_files ) -> None:
"""Internal method that trains the dataset using the rank objective on GPU and CPU,
evaluates the metric and determines if the delta between the metric is within the
tolerance level.
# download the test data """
cls.dpath = os.path.join(tm.demo_dir(__file__), "rank/") # specify validations set to watch performance
src = 'https://s3-us-west-2.amazonaws.com/xgboost-examples/MQ2008.zip' watchlist = [(dtest, "eval"), (dtrain, "train")]
target = os.path.join(cls.dpath, "MQ2008.zip")
if os.path.exists(cls.dpath) and os.path.exists(target): params = {
print("Skipping dataset download...") "booster": "gbtree",
else: "tree_method": "gpu_hist",
urllib.request.urlretrieve(url=src, filename=target) "gpu_id": 0,
with zipfile.ZipFile(target, 'r') as f: "predictor": "gpu_predictor",
f.extractall(path=cls.dpath) }
(x_train, y_train, qid_train, x_test, y_test, qid_test, num_trees = 100
x_valid, y_valid, qid_valid) = load_svmlight_files( check_metric_improvement_rounds = 10
(cls.dpath + "MQ2008/Fold1/train.txt",
cls.dpath + "MQ2008/Fold1/test.txt",
cls.dpath + "MQ2008/Fold1/vali.txt"),
query_id=True, zero_based=False)
# instantiate the matrices
cls.dtrain = xgboost.DMatrix(x_train, y_train)
cls.dvalid = xgboost.DMatrix(x_valid, y_valid)
cls.dtest = xgboost.DMatrix(x_test, y_test)
# set the group counts from the query IDs
cls.dtrain.set_group([len(list(items))
for _key, items in itertools.groupby(qid_train)])
cls.dtest.set_group([len(list(items))
for _key, items in itertools.groupby(qid_test)])
cls.dvalid.set_group([len(list(items))
for _key, items in itertools.groupby(qid_valid)])
# save the query IDs for testing
cls.qid_train = qid_train
cls.qid_test = qid_test
cls.qid_valid = qid_valid
def setup_weighted(x, y, groups): evals_result: Dict[str, Dict] = {}
# Setup weighted data params["objective"] = rank_objective
data = xgboost.DMatrix(x, y) params["eval_metric"] = metric_name
groups_segment = [len(list(items)) bst = xgboost.train(
for _key, items in itertools.groupby(groups)] params,
data.set_group(groups_segment) dtrain,
n_groups = len(groups_segment) num_boost_round=num_trees,
weights = np.ones((n_groups,)) early_stopping_rounds=check_metric_improvement_rounds,
data.set_weight(weights) evals=watchlist,
return data evals_result=evals_result,
)
gpu_scores = evals_result["train"][metric_name][-1]
cls.dtrain_w = setup_weighted(x_train, y_train, qid_train) evals_result = {}
cls.dtest_w = setup_weighted(x_test, y_test, qid_test)
cls.dvalid_w = setup_weighted(x_valid, y_valid, qid_valid)
# model training parameters cpu_params = {
cls.params = {'booster': 'gbtree', "booster": "gbtree",
'tree_method': 'gpu_hist', "tree_method": "hist",
'gpu_id': 0, "gpu_id": -1,
'predictor': 'gpu_predictor'} "predictor": "cpu_predictor",
cls.cpu_params = {'booster': 'gbtree', }
'tree_method': 'hist', cpu_params["objective"] = rank_objective
'gpu_id': -1, cpu_params["eval_metric"] = metric_name
'predictor': 'cpu_predictor'} bstc = xgboost.train(
cpu_params,
dtrain,
num_boost_round=num_trees,
early_stopping_rounds=check_metric_improvement_rounds,
evals=watchlist,
evals_result=evals_result,
)
cpu_scores = evals_result["train"][metric_name][-1]
@classmethod info = (rank_objective, metric_name)
def teardown_class(cls): assert np.allclose(gpu_scores, cpu_scores, tolerance, tolerance), info
""" assert np.allclose(bst.best_score, bstc.best_score, tolerance, tolerance), info
Cleanup test artifacts from download and unpacking
:return:
"""
os.remove(os.path.join(cls.dpath, "MQ2008.zip"))
shutil.rmtree(os.path.join(cls.dpath, "MQ2008"))
@classmethod evals_result_weighted: Dict[str, Dict] = {}
def __test_training_with_rank_objective(cls, rank_objective, metric_name, tolerance=1e-02): dtest.set_weight(np.ones((dtest.get_group().size,)))
""" dtrain.set_weight(np.ones((dtrain.get_group().size,)))
Internal method that trains the dataset using the rank objective on GPU and CPU, evaluates watchlist = [(dtest, "eval"), (dtrain, "train")]
the metric and determines if the delta between the metric is within the tolerance level bst_w = xgboost.train(
:return: params,
""" dtrain,
# specify validations set to watch performance num_boost_round=num_trees,
watchlist = [(cls.dtest, 'eval'), (cls.dtrain, 'train')] early_stopping_rounds=check_metric_improvement_rounds,
evals=watchlist,
evals_result=evals_result_weighted,
)
weighted_metric = evals_result_weighted["train"][metric_name][-1]
num_trees = 100 tolerance = 1e-5
check_metric_improvement_rounds = 10 assert np.allclose(bst_w.best_score, bst.best_score, tolerance, tolerance)
assert np.allclose(weighted_metric, gpu_scores, tolerance, tolerance)
evals_result = {}
cls.params['objective'] = rank_objective
cls.params['eval_metric'] = metric_name
bst = xgboost.train(
cls.params, cls.dtrain, num_boost_round=num_trees,
early_stopping_rounds=check_metric_improvement_rounds,
evals=watchlist, evals_result=evals_result)
gpu_map_metric = evals_result['train'][metric_name][-1]
evals_result = {} @pytest.mark.parametrize(
cls.cpu_params['objective'] = rank_objective "objective,metric",
cls.cpu_params['eval_metric'] = metric_name [
bstc = xgboost.train( ("rank:pairwise", "auc"),
cls.cpu_params, cls.dtrain, num_boost_round=num_trees, ("rank:pairwise", "ndcg"),
early_stopping_rounds=check_metric_improvement_rounds, ("rank:pairwise", "map"),
evals=watchlist, evals_result=evals_result) ("rank:ndcg", "auc"),
cpu_map_metric = evals_result['train'][metric_name][-1] ("rank:ndcg", "ndcg"),
("rank:ndcg", "map"),
("rank:map", "auc"),
("rank:map", "ndcg"),
("rank:map", "map"),
],
)
def test_with_mq2008(objective, metric) -> None:
(
x_train,
y_train,
qid_train,
x_test,
y_test,
qid_test,
x_valid,
y_valid,
qid_valid,
) = tm.get_mq2008(os.path.join(os.path.join(tm.demo_dir(__file__), "rank")))
assert np.allclose(gpu_map_metric, cpu_map_metric, tolerance, if metric.find("map") != -1 or objective.find("map") != -1:
tolerance) y_train[y_train <= 1] = 0.0
assert np.allclose(bst.best_score, bstc.best_score, tolerance, y_train[y_train > 1] = 1.0
tolerance) y_test[y_test <= 1] = 0.0
y_test[y_test > 1] = 1.0
evals_result_weighted = {} dtrain = xgboost.DMatrix(x_train, y_train, qid=qid_train)
watchlist = [(cls.dtest_w, 'eval'), (cls.dtrain_w, 'train')] dtest = xgboost.DMatrix(x_test, y_test, qid=qid_test)
bst_w = xgboost.train(
cls.params, cls.dtrain_w, num_boost_round=num_trees,
early_stopping_rounds=check_metric_improvement_rounds,
evals=watchlist, evals_result=evals_result_weighted)
weighted_metric = evals_result_weighted['train'][metric_name][-1]
# GPU Ranking is not deterministic due to `AtomicAddGpair`,
# remove tolerance once the issue is resolved.
# https://github.com/dmlc/xgboost/issues/5561
assert np.allclose(bst_w.best_score, bst.best_score,
tolerance, tolerance)
assert np.allclose(weighted_metric, gpu_map_metric,
tolerance, tolerance)
def test_training_rank_pairwise_map_metric(self): comp_training_with_rank_objective(dtrain, dtest, objective, metric)
"""
Train an XGBoost ranking model with pairwise objective function and compare map metric
"""
self.__test_training_with_rank_objective('rank:pairwise', 'map')
def test_training_rank_pairwise_auc_metric(self):
"""
Train an XGBoost ranking model with pairwise objective function and compare auc metric
"""
self.__test_training_with_rank_objective('rank:pairwise', 'auc')
def test_training_rank_pairwise_ndcg_metric(self):
"""
Train an XGBoost ranking model with pairwise objective function and compare ndcg metric
"""
self.__test_training_with_rank_objective('rank:pairwise', 'ndcg')
def test_training_rank_ndcg_map(self):
"""
Train an XGBoost ranking model with ndcg objective function and compare map metric
"""
self.__test_training_with_rank_objective('rank:ndcg', 'map')
def test_training_rank_ndcg_auc(self):
"""
Train an XGBoost ranking model with ndcg objective function and compare auc metric
"""
self.__test_training_with_rank_objective('rank:ndcg', 'auc')
def test_training_rank_ndcg_ndcg(self):
"""
Train an XGBoost ranking model with ndcg objective function and compare ndcg metric
"""
self.__test_training_with_rank_objective('rank:ndcg', 'ndcg')
def test_training_rank_map_map(self):
"""
Train an XGBoost ranking model with map objective function and compare map metric
"""
self.__test_training_with_rank_objective('rank:map', 'map')
def test_training_rank_map_auc(self):
"""
Train an XGBoost ranking model with map objective function and compare auc metric
"""
self.__test_training_with_rank_objective('rank:map', 'auc')
def test_training_rank_map_ndcg(self):
"""
Train an XGBoost ranking model with map objective function and compare ndcg metric
"""
self.__test_training_with_rank_objective('rank:map', 'ndcg')

View File

@ -128,12 +128,23 @@ def test_ranking():
x_test = np.random.rand(100, 10) x_test = np.random.rand(100, 10)
params = {'tree_method': 'exact', 'objective': 'rank:pairwise', params = {
'learning_rate': 0.1, 'gamma': 1.0, 'min_child_weight': 0.1, "tree_method": "exact",
'max_depth': 6, 'n_estimators': 4} "learning_rate": 0.1,
"gamma": 1.0,
"min_child_weight": 0.1,
"max_depth": 6,
"eval_metric": "ndcg",
"n_estimators": 4,
}
model = xgb.sklearn.XGBRanker(**params) model = xgb.sklearn.XGBRanker(**params)
model.fit(x_train, y_train, group=train_group, model.fit(
eval_set=[(x_valid, y_valid)], eval_group=[valid_group]) x_train,
y_train,
group=train_group,
eval_set=[(x_valid, y_valid)],
eval_group=[valid_group],
)
assert model.evals_result() assert model.evals_result()
pred = model.predict(x_test) pred = model.predict(x_test)
@ -145,11 +156,18 @@ def test_ranking():
assert train_data.get_label().shape[0] == x_train.shape[0] assert train_data.get_label().shape[0] == x_train.shape[0]
valid_data.set_group(valid_group) valid_data.set_group(valid_group)
params_orig = {'tree_method': 'exact', 'objective': 'rank:pairwise', params_orig = {
'eta': 0.1, 'gamma': 1.0, "tree_method": "exact",
'min_child_weight': 0.1, 'max_depth': 6} "objective": "rank:pairwise",
xgb_model_orig = xgb.train(params_orig, train_data, num_boost_round=4, "eta": 0.1,
evals=[(valid_data, 'validation')]) "gamma": 1.0,
"min_child_weight": 0.1,
"max_depth": 6,
"eval_metric": "ndcg",
}
xgb_model_orig = xgb.train(
params_orig, train_data, num_boost_round=4, evals=[(valid_data, "validation")]
)
pred_orig = xgb_model_orig.predict(test_data) pred_orig = xgb_model_orig.predict(test_data)
np.testing.assert_almost_equal(pred, pred_orig) np.testing.assert_almost_equal(pred, pred_orig)
@ -165,7 +183,11 @@ def test_ranking_metric() -> None:
# sklearn compares the number of mis-classified docs, while the one in xgboost # sklearn compares the number of mis-classified docs, while the one in xgboost
# compares the number of mis-classified pairs. # compares the number of mis-classified pairs.
ltr = xgb.XGBRanker( ltr = xgb.XGBRanker(
eval_metric=roc_auc_score, n_estimators=10, tree_method="hist", max_depth=2 eval_metric=roc_auc_score,
n_estimators=10,
tree_method="hist",
max_depth=2,
objective="rank:pairwise",
) )
ltr.fit( ltr.fit(
X, X,