Rework the MAP metric. (#8931)

- The new implementation is more strict as only binary labels are accepted. The previous implementation converts values greater than 1 to 1.
- Deterministic GPU. (no atomic add).
- Fix top-k handling.
- Precise definition of MAP. (There are other variants on how to handle top-k).
- Refactor GPU ranking tests.
This commit is contained in:
Jiaming Yuan 2023-03-22 17:45:20 +08:00 committed by GitHub
parent b240f055d3
commit 5891f752c8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 458 additions and 323 deletions

View File

@ -408,8 +408,17 @@ Specify the learning task and the corresponding learning objective. The objectiv
- ``ndcg``: `Normalized Discounted Cumulative Gain <http://en.wikipedia.org/wiki/NDCG>`_
- ``map``: `Mean Average Precision <http://en.wikipedia.org/wiki/Mean_average_precision#Mean_average_precision>`_
- ``ndcg@n``, ``map@n``: 'n' can be assigned as an integer to cut off the top positions in the lists for evaluation.
- ``ndcg-``, ``map-``, ``ndcg@n-``, ``map@n-``: In XGBoost, NDCG and MAP will evaluate the score of a list without any positive samples as 1. By adding "-" in the evaluation metric XGBoost will evaluate these score as 0 to be consistent under some conditions.
The `average precision` is defined as:
.. math::
AP@l = \frac{1}{min{(l, N)}}\sum^l_{k=1}P@k \cdot I_{(k)}
where :math:`I_{(k)}` is an indicator function that equals to :math:`1` when the document at :math:`k` is relevant and :math:`0` otherwise. The :math:`P@k` is the precision at :math:`k`, and :math:`N` is the total number of relevant documents. Lastly, the `mean average precision` is defined as the weighted average across all queries.
- ``ndcg@n``, ``map@n``: :math:`n` can be assigned as an integer to cut off the top positions in the lists for evaluation.
- ``ndcg-``, ``map-``, ``ndcg@n-``, ``map@n-``: In XGBoost, the NDCG and MAP evaluate the score of a list without any positive samples as :math:`1`. By appending "-" to the evaluation metric name, we can ask XGBoost to evaluate these scores as :math:`0` to be consistent under some conditions.
- ``poisson-nloglik``: negative log-likelihood for Poisson regression
- ``gamma-nloglik``: negative log-likelihood for gamma regression
- ``cox-nloglik``: negative partial log-likelihood for Cox proportional hazards regression

View File

@ -14,6 +14,7 @@ import zipfile
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
from io import StringIO
from pathlib import Path
from platform import system
from typing import (
Any,
@ -443,7 +444,7 @@ def get_mq2008(
from sklearn.datasets import load_svmlight_files
src = "https://s3-us-west-2.amazonaws.com/xgboost-examples/MQ2008.zip"
target = dpath + "/MQ2008.zip"
target = os.path.join(os.path.expanduser(dpath), "MQ2008.zip")
if not os.path.exists(target):
request.urlretrieve(url=src, filename=target)
@ -462,9 +463,9 @@ def get_mq2008(
qid_valid,
) = load_svmlight_files(
(
dpath + "MQ2008/Fold1/train.txt",
dpath + "MQ2008/Fold1/test.txt",
dpath + "MQ2008/Fold1/vali.txt",
Path(dpath) / "MQ2008" / "Fold1" / "train.txt",
Path(dpath) / "MQ2008" / "Fold1" / "test.txt",
Path(dpath) / "MQ2008" / "Fold1" / "vali.txt",
),
query_id=True,
zero_based=False,

View File

@ -48,7 +48,12 @@ def run_ranking_qid_df(impl: ModuleType, tree_method: str) -> None:
def neg_mse(*args: Any, **kwargs: Any) -> float:
return -float(mean_squared_error(*args, **kwargs))
ranker = xgb.XGBRanker(n_estimators=3, eval_metric=neg_mse, tree_method=tree_method)
ranker = xgb.XGBRanker(
n_estimators=3,
eval_metric=neg_mse,
tree_method=tree_method,
disable_default_eval_metric=True,
)
ranker.fit(df, y, eval_set=[(valid_df, y)])
score = ranker.score(valid_df, y)
assert np.isclose(score, ranker.evals_result()["validation_0"]["neg_mse"][-1])

View File

@ -22,7 +22,7 @@ constexpr StringView LabelScoreSize() {
}
constexpr StringView InfInData() {
return "Input data contains `inf` while `missing` is not set to `inf`";
return "Input data contains `inf` or a value too large, while `missing` is not set to `inf`";
}
} // namespace xgboost::error
#endif // XGBOOST_COMMON_ERROR_MSG_H_

View File

@ -1,13 +1,15 @@
/*!
* Copyright 2022, XGBoost contributors.
/**
* Copyright 2022-2023 by XGBoost contributors.
*/
#ifndef XGBOOST_COMMON_NUMERIC_H_
#define XGBOOST_COMMON_NUMERIC_H_
#include <dmlc/common.h> // OMPException
#include <algorithm> // std::max
#include <iterator> // std::iterator_traits
#include <algorithm> // for std::max
#include <cstddef> // for size_t
#include <cstdint> // for int32_t
#include <iterator> // for iterator_traits
#include <vector>
#include "common.h" // AssertGPUSupport
@ -15,8 +17,7 @@
#include "xgboost/context.h" // Context
#include "xgboost/host_device_vector.h" // HostDeviceVector
namespace xgboost {
namespace common {
namespace xgboost::common {
/**
* \brief Run length encode on CPU, input must be sorted.
@ -111,11 +112,11 @@ inline double Reduce(Context const*, HostDeviceVector<float> const&) {
namespace cpu_impl {
template <typename It, typename V = typename It::value_type>
V Reduce(Context const* ctx, It first, It second, V const& init) {
size_t n = std::distance(first, second);
common::MemStackAllocator<V, common::DefaultMaxThreads()> result_tloc(ctx->Threads(), init);
common::ParallelFor(n, ctx->Threads(),
[&](auto i) { result_tloc[omp_get_thread_num()] += first[i]; });
auto result = std::accumulate(result_tloc.cbegin(), result_tloc.cbegin() + ctx->Threads(), init);
std::size_t n = std::distance(first, second);
auto n_threads = static_cast<std::size_t>(std::min(n, static_cast<std::size_t>(ctx->Threads())));
common::MemStackAllocator<V, common::DefaultMaxThreads()> result_tloc(n_threads, init);
common::ParallelFor(n, n_threads, [&](auto i) { result_tloc[omp_get_thread_num()] += first[i]; });
auto result = std::accumulate(result_tloc.cbegin(), result_tloc.cbegin() + n_threads, init);
return result;
}
} // namespace cpu_impl
@ -144,7 +145,6 @@ void Iota(Context const* ctx, It first, It last,
});
}
}
} // namespace common
} // namespace xgboost
} // namespace xgboost::common
#endif // XGBOOST_COMMON_NUMERIC_H_

View File

@ -114,6 +114,15 @@ void NDCGCache::InitOnCUDA(Context const*, MetaInfo const&) { common::AssertGPUS
DMLC_REGISTER_PARAMETER(LambdaRankParam);
void MAPCache::InitOnCPU(Context const*, MetaInfo const& info) {
auto const& h_label = info.labels.HostView().Slice(linalg::All(), 0);
CheckMapLabels(h_label, [](auto beg, auto end, auto op) { return std::all_of(beg, end, op); });
}
#if !defined(XGBOOST_USE_CUDA)
void MAPCache::InitOnCUDA(Context const*, MetaInfo const&) { common::AssertGPUSupport(); }
#endif // !defined(XGBOOST_USE_CUDA)
std::string ParseMetricName(StringView name, StringView param, position_t* topn, bool* minus) {
std::string out_name;
if (!param.empty()) {

View File

@ -204,4 +204,9 @@ void NDCGCache::InitOnCUDA(Context const* ctx, MetaInfo const& info) {
dh::LaunchN(MaxGroupSize(), cuctx->Stream(),
[=] XGBOOST_DEVICE(std::size_t i) { d_discount[i] = CalcDCGDiscount(i); });
}
void MAPCache::InitOnCUDA(Context const* ctx, MetaInfo const& info) {
auto const d_label = info.labels.View(ctx->gpu_id).Slice(linalg::All(), 0);
CheckMapLabels(d_label, CheckMAPOp{ctx->CUDACtx()});
}
} // namespace xgboost::ltr

View File

@ -358,6 +358,71 @@ void CheckNDCGLabels(ltr::LambdaRankParam const& p, linalg::VectorView<float con
}
}
template <typename AllOf>
bool IsBinaryRel(linalg::VectorView<float const> label, AllOf all_of) {
auto s_label = label.Values();
return all_of(s_label.data(), s_label.data() + s_label.size(), [] XGBOOST_DEVICE(float y) {
return std::abs(y - 1.0f) < kRtEps || std::abs(y - 0.0f) < kRtEps;
});
}
/**
* \brief Validate label for MAP
*
* \tparam Implementation of std::all_of. Specified as a parameter to reuse the check for
* both CPU and GPU.
*/
template <typename AllOf>
void CheckMapLabels(linalg::VectorView<float const> label, AllOf all_of) {
auto s_label = label.Values();
auto is_binary = IsBinaryRel(label, all_of);
CHECK(is_binary) << "MAP can only be used with binary labels.";
}
class MAPCache : public RankingCache {
// Total number of relevant documents for each group
HostDeviceVector<double> n_rel_;
// \sum l_k/k
HostDeviceVector<double> acc_;
HostDeviceVector<double> map_;
// Number of samples in this dataset.
std::size_t n_samples_{0};
void InitOnCPU(Context const* ctx, MetaInfo const& info);
void InitOnCUDA(Context const* ctx, MetaInfo const& info);
public:
MAPCache(Context const* ctx, MetaInfo const& info, LambdaRankParam const& p)
: RankingCache{ctx, info, p}, n_samples_{static_cast<std::size_t>(info.num_row_)} {
if (ctx->IsCPU()) {
this->InitOnCPU(ctx, info);
} else {
this->InitOnCUDA(ctx, info);
}
}
common::Span<double> NumRelevant(Context const* ctx) {
if (n_rel_.Empty()) {
n_rel_.SetDevice(ctx->gpu_id);
n_rel_.Resize(n_samples_);
}
return ctx->IsCPU() ? n_rel_.HostSpan() : n_rel_.DeviceSpan();
}
common::Span<double> Acc(Context const* ctx) {
if (acc_.Empty()) {
acc_.SetDevice(ctx->gpu_id);
acc_.Resize(n_samples_);
}
return ctx->IsCPU() ? acc_.HostSpan() : acc_.DeviceSpan();
}
common::Span<double> Map(Context const* ctx) {
if (map_.Empty()) {
map_.SetDevice(ctx->gpu_id);
map_.Resize(this->Groups());
}
return ctx->IsCPU() ? map_.HostSpan() : map_.DeviceSpan();
}
};
/**
* \brief Parse name for ranking metric given parameters.
*

View File

@ -8,9 +8,11 @@
#include <dmlc/omp.h>
#include <algorithm>
#include <cstdint> // std::int32_t
#include <cstdint> // for int32_t
#include <cstdlib> // for malloc, free
#include <limits>
#include <type_traits> // std::is_signed
#include <new> // for bad_alloc
#include <type_traits> // for is_signed
#include <vector>
#include "xgboost/logging.h"
@ -266,7 +268,7 @@ class MemStackAllocator {
if (MaxStackSize >= required_size_) {
ptr_ = stack_mem_;
} else {
ptr_ = reinterpret_cast<T*>(malloc(required_size_ * sizeof(T)));
ptr_ = reinterpret_cast<T*>(std::malloc(required_size_ * sizeof(T)));
}
if (!ptr_) {
throw std::bad_alloc{};
@ -278,7 +280,7 @@ class MemStackAllocator {
~MemStackAllocator() {
if (required_size_ > MaxStackSize) {
free(ptr_);
std::free(ptr_);
}
}
T& operator[](size_t i) { return ptr_[i]; }

View File

@ -284,37 +284,6 @@ struct EvalPrecision : public EvalRank {
}
};
/*! \brief Mean Average Precision at N, for both classification and rank */
struct EvalMAP : public EvalRank {
public:
explicit EvalMAP(const char* name, const char* param) : EvalRank(name, param) {}
double EvalGroup(PredIndPairContainer *recptr) const override {
PredIndPairContainer &rec(*recptr);
std::stable_sort(rec.begin(), rec.end(), common::CmpFirst);
unsigned nhits = 0;
double sumap = 0.0;
for (size_t i = 0; i < rec.size(); ++i) {
if (rec[i].second != 0) {
nhits += 1;
if (i < this->topn) {
sumap += static_cast<double>(nhits) / (i + 1);
}
}
}
if (nhits != 0) {
sumap /= nhits;
return sumap;
} else {
if (this->minus) {
return 0.0;
} else {
return 1.0;
}
}
}
};
/*! \brief Cox: Partial likelihood of the Cox proportional hazards model */
struct EvalCox : public MetricNoCache {
public:
@ -370,10 +339,6 @@ XGBOOST_REGISTER_METRIC(Precision, "pre")
.describe("precision@k for rank.")
.set_body([](const char* param) { return new EvalPrecision("pre", param); });
XGBOOST_REGISTER_METRIC(MAP, "map")
.describe("map@k for rank.")
.set_body([](const char* param) { return new EvalMAP("map", param); });
XGBOOST_REGISTER_METRIC(Cox, "cox-nloglik")
.describe("Negative log partial likelihood of Cox proportional hazards model.")
.set_body([](const char*) { return new EvalCox(); });
@ -516,6 +481,68 @@ class EvalNDCG : public EvalRankWithCache<ltr::NDCGCache> {
}
};
class EvalMAPScore : public EvalRankWithCache<ltr::MAPCache> {
public:
using EvalRankWithCache::EvalRankWithCache;
const char* Name() const override { return name_.c_str(); }
double Eval(HostDeviceVector<float> const& predt, MetaInfo const& info,
std::shared_ptr<ltr::MAPCache> p_cache) override {
if (ctx_->IsCUDA()) {
auto map = cuda_impl::MAPScore(ctx_, info, predt, minus_, p_cache);
return Finalize(map.Residue(), map.Weights());
}
auto gptr = p_cache->DataGroupPtr(ctx_);
auto h_label = info.labels.HostView().Slice(linalg::All(), 0);
auto h_predt = linalg::MakeTensorView(ctx_, &predt, predt.Size());
auto map_gloc = p_cache->Map(ctx_);
std::fill_n(map_gloc.data(), map_gloc.size(), 0.0);
auto rank_idx = p_cache->SortedIdx(ctx_, predt.ConstHostSpan());
common::ParallelFor(p_cache->Groups(), ctx_->Threads(), [&](auto g) {
auto g_predt = h_predt.Slice(linalg::Range(gptr[g], gptr[g + 1]));
auto g_label = h_label.Slice(linalg::Range(gptr[g], gptr[g + 1]));
auto g_rank = rank_idx.subspan(gptr[g]);
auto n = std::min(static_cast<std::size_t>(param_.TopK()), g_label.Size());
double n_hits{0.0};
for (std::size_t i = 0; i < n; ++i) {
auto p = g_label(g_rank[i]);
n_hits += p;
map_gloc[g] += n_hits / static_cast<double>((i + 1)) * p;
}
for (std::size_t i = n; i < g_label.Size(); ++i) {
n_hits += g_label(g_rank[i]);
}
if (n_hits > 0.0) {
map_gloc[g] /= std::min(n_hits, static_cast<double>(param_.TopK()));
} else {
map_gloc[g] = minus_ ? 0.0 : 1.0;
}
});
auto sw = 0.0;
auto weight = common::MakeOptionalWeights(ctx_, info.weights_);
if (!weight.Empty()) {
CHECK_EQ(weight.weights.size(), p_cache->Groups());
}
for (std::size_t i = 0; i < map_gloc.size(); ++i) {
map_gloc[i] = map_gloc[i] * weight[i];
sw += weight[i];
}
auto sum = std::accumulate(map_gloc.cbegin(), map_gloc.cend(), 0.0);
return Finalize(sum, sw);
}
};
XGBOOST_REGISTER_METRIC(EvalMAP, "map")
.describe("map@k for ranking.")
.set_body([](char const* param) {
return new EvalMAPScore{"map", param};
});
XGBOOST_REGISTER_METRIC(EvalNDCG, "ndcg")
.describe("ndcg@k for ranking.")
.set_body([](char const* param) {

View File

@ -125,89 +125,10 @@ struct EvalPrecisionGpu {
};
/*! \brief Mean Average Precision at N, for both classification and rank */
struct EvalMAPGpu {
public:
static double EvalMetric(const dh::SegmentSorter<float> &pred_sorter,
const float *dlabels,
const EvalRankConfig &ecfg) {
// Group info on device
const auto &dgroups = pred_sorter.GetGroupsSpan();
const auto ngroups = pred_sorter.GetNumGroups();
const auto &dgroup_idx = pred_sorter.GetGroupSegmentsSpan();
// Original positions of the predictions after they have been sorted
const auto &dpreds_orig_pos = pred_sorter.GetOriginalPositionsSpan();
// First, determine non zero labels in the dataset individually
const auto nitems = pred_sorter.GetNumItems();
dh::caching_device_vector<uint32_t> hits(nitems, 0);
auto DetermineNonTrivialLabelLambda = [=] __device__(uint32_t idx) {
return (static_cast<unsigned>(dlabels[dpreds_orig_pos[idx]]) != 0) ? 1 : 0;
}; // NOLINT
thrust::transform(thrust::make_counting_iterator(static_cast<uint32_t>(0)),
thrust::make_counting_iterator(nitems),
hits.begin(),
DetermineNonTrivialLabelLambda);
// Allocator to be used by sort for managing space overhead while performing prefix scans
dh::XGBCachingDeviceAllocator<char> alloc;
// Next, prefix scan the nontrivial labels that are segmented to accumulate them.
// This is required for computing the metric sum
// Data segmented into different groups...
thrust::inclusive_scan_by_key(thrust::cuda::par(alloc),
dh::tcbegin(dgroup_idx), dh::tcend(dgroup_idx),
hits.begin(), // Input value
hits.begin()); // In-place scan
// Find each group's metric sum
dh::caching_device_vector<double> sumap(ngroups, 0);
auto *dsumap = sumap.data().get();
const auto *dhits = hits.data().get();
int device_id = -1;
dh::safe_cuda(cudaGetDevice(&device_id));
// For each group item compute the aggregated precision
dh::LaunchN(nitems, nullptr, [=] __device__(uint32_t idx) {
if (DetermineNonTrivialLabelLambda(idx)) {
const auto group_idx = dgroup_idx[idx];
const auto group_begin = dgroups[group_idx];
const auto ridx = idx - group_begin;
if (ridx < ecfg.topn) {
atomicAdd(&dsumap[group_idx],
static_cast<double>(dhits[idx]) / (ridx + 1));
}
}
});
// Aggregate the group's item precisions
dh::LaunchN(ngroups, nullptr, [=] __device__(uint32_t gidx) {
auto nhits = dgroups[gidx + 1] ? dhits[dgroups[gidx + 1] - 1] : 0;
if (nhits != 0) {
dsumap[gidx] /= nhits;
} else {
if (ecfg.minus) {
dsumap[gidx] = 0;
} else {
dsumap[gidx] = 1;
}
}
});
return thrust::reduce(thrust::cuda::par(alloc), sumap.begin(), sumap.end());
}
};
XGBOOST_REGISTER_GPU_METRIC(PrecisionGpu, "pre")
.describe("precision@k for rank computed on GPU.")
.set_body([](const char* param) { return new EvalRankGpu<EvalPrecisionGpu>("pre", param); });
XGBOOST_REGISTER_GPU_METRIC(MAPGpu, "map")
.describe("map@k for rank computed on GPU.")
.set_body([](const char* param) { return new EvalRankGpu<EvalMAPGpu>("map", param); });
namespace cuda_impl {
PackedReduceResult NDCGScore(Context const *ctx, MetaInfo const &info,
HostDeviceVector<float> const &predt, bool minus,
@ -245,5 +166,87 @@ PackedReduceResult NDCGScore(Context const *ctx, MetaInfo const &info,
PackedReduceResult{0.0, 0.0});
return pair;
}
PackedReduceResult MAPScore(Context const *ctx, MetaInfo const &info,
HostDeviceVector<float> const &predt, bool minus,
std::shared_ptr<ltr::MAPCache> p_cache) {
auto d_group_ptr = p_cache->DataGroupPtr(ctx);
auto n_groups = info.group_ptr_.size() - 1;
auto d_label = info.labels.View(ctx->gpu_id).Slice(linalg::All(), 0);
predt.SetDevice(ctx->gpu_id);
auto d_rank_idx = p_cache->SortedIdx(ctx, predt.ConstDeviceSpan());
auto key_it = dh::MakeTransformIterator<std::size_t>(
thrust::make_counting_iterator(0ul),
[=] XGBOOST_DEVICE(std::size_t i) { return dh::SegmentId(d_group_ptr, i); });
auto get_label = [=] XGBOOST_DEVICE(std::size_t i) {
auto g = key_it[i];
auto g_begin = d_group_ptr[g];
auto g_end = d_group_ptr[g + 1];
i -= g_begin;
auto g_label = d_label.Slice(linalg::Range(g_begin, g_end));
auto g_rank = d_rank_idx.subspan(g_begin, g_end - g_begin);
return g_label(g_rank[i]);
};
auto it = dh::MakeTransformIterator<double>(thrust::make_counting_iterator(0ul), get_label);
auto cuctx = ctx->CUDACtx();
auto n_rel = p_cache->NumRelevant(ctx);
thrust::inclusive_scan_by_key(cuctx->CTP(), key_it, key_it + d_label.Size(), it, n_rel.data());
double topk = p_cache->Param().TopK();
auto map = p_cache->Map(ctx);
thrust::fill_n(cuctx->CTP(), map.data(), map.size(), 0.0);
{
auto val_it = dh::MakeTransformIterator<double>(
thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(std::size_t i) {
auto g = key_it[i];
auto g_begin = d_group_ptr[g];
auto g_end = d_group_ptr[g + 1];
i -= g_begin;
if (i >= topk) {
return 0.0;
}
auto g_label = d_label.Slice(linalg::Range(g_begin, g_end));
auto g_rank = d_rank_idx.subspan(g_begin, g_end - g_begin);
auto label = g_label(g_rank[i]);
auto g_n_rel = n_rel.subspan(g_begin, g_end - g_begin);
auto nhits = g_n_rel[i];
return nhits / static_cast<double>(i + 1) * label;
});
std::size_t bytes;
cub::DeviceSegmentedReduce::Sum(nullptr, bytes, val_it, map.data(), p_cache->Groups(),
d_group_ptr.data(), d_group_ptr.data() + 1, cuctx->Stream());
dh::TemporaryArray<char> temp(bytes);
cub::DeviceSegmentedReduce::Sum(temp.data().get(), bytes, val_it, map.data(), p_cache->Groups(),
d_group_ptr.data(), d_group_ptr.data() + 1, cuctx->Stream());
}
PackedReduceResult result{0.0, 0.0};
{
auto d_weight = common::MakeOptionalWeights(ctx, info.weights_);
if (!d_weight.Empty()) {
CHECK_EQ(d_weight.weights.size(), p_cache->Groups());
}
auto val_it = dh::MakeTransformIterator<PackedReduceResult>(
thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(std::size_t g) {
auto g_begin = d_group_ptr[g];
auto g_end = d_group_ptr[g + 1];
auto g_n_rel = n_rel.subspan(g_begin, g_end - g_begin);
if (!g_n_rel.empty() && g_n_rel.back() > 0.0) {
return PackedReduceResult{map[g] * d_weight[g] / std::min(g_n_rel.back(), topk),
static_cast<double>(d_weight[g])};
}
return PackedReduceResult{minus ? 0.0 : 1.0, static_cast<double>(d_weight[g])};
});
result =
thrust::reduce(cuctx->CTP(), val_it, val_it + map.size(), PackedReduceResult{0.0, 0.0});
}
return result;
}
} // namespace cuda_impl
} // namespace xgboost::metric

View File

@ -6,7 +6,7 @@
#include <memory> // for shared_ptr
#include "../common/common.h" // for AssertGPUSupport
#include "../common/ranking_utils.h" // for NDCGCache
#include "../common/ranking_utils.h" // for NDCGCache, MAPCache
#include "metric_common.h" // for PackedReduceResult
#include "xgboost/context.h" // for Context
#include "xgboost/data.h" // for MetaInfo
@ -19,6 +19,10 @@ PackedReduceResult NDCGScore(Context const *ctx, MetaInfo const &info,
HostDeviceVector<float> const &predt, bool minus,
std::shared_ptr<ltr::NDCGCache> p_cache);
PackedReduceResult MAPScore(Context const *ctx, MetaInfo const &info,
HostDeviceVector<float> const &predt, bool minus,
std::shared_ptr<ltr::MAPCache> p_cache);
#if !defined(XGBOOST_USE_CUDA)
inline PackedReduceResult NDCGScore(Context const *, MetaInfo const &,
HostDeviceVector<float> const &, bool,
@ -26,6 +30,13 @@ inline PackedReduceResult NDCGScore(Context const *, MetaInfo const &,
common::AssertGPUSupport();
return {};
}
inline PackedReduceResult MAPScore(Context const *, MetaInfo const &,
HostDeviceVector<float> const &, bool,
std::shared_ptr<ltr::MAPCache>) {
common::AssertGPUSupport();
return {};
}
#endif
} // namespace cuda_impl
} // namespace metric

View File

@ -177,4 +177,36 @@ TEST(NDCGCache, InitFromCPU) {
Context ctx;
TestNDCGCache(&ctx);
}
void TestMAPCache(Context const* ctx) {
auto p_fmat = EmptyDMatrix();
MetaInfo& info = p_fmat->Info();
LambdaRankParam param;
param.UpdateAllowUnknown(Args{});
std::vector<float> h_data(32);
common::Iota(ctx, h_data.begin(), h_data.end(), 0.0f);
info.labels.Reshape(h_data.size());
info.num_row_ = h_data.size();
info.labels.Data()->HostVector() = std::move(h_data);
auto fail = [&]() { std::make_shared<MAPCache>(ctx, info, param); };
// binary label
ASSERT_THROW(fail(), dmlc::Error);
h_data = std::vector<float>(32, 0.0f);
h_data[1] = 1.0f;
info.labels.Data()->HostVector() = h_data;
auto p_cache = std::make_shared<MAPCache>(ctx, info, param);
ASSERT_EQ(p_cache->Acc(ctx).size(), info.num_row_);
ASSERT_EQ(p_cache->NumRelevant(ctx).size(), info.num_row_);
}
TEST(MAPCache, InitFromCPU) {
Context ctx;
ctx.Init(Args{});
TestMAPCache(&ctx);
}
} // namespace xgboost::ltr

View File

@ -95,4 +95,10 @@ TEST(NDCGCache, InitFromGPU) {
ctx.UpdateAllowUnknown(Args{{"gpu_id", "0"}});
TestNDCGCache(&ctx);
}
TEST(MAPCache, InitFromGPU) {
Context ctx;
ctx.UpdateAllowUnknown(Args{{"gpu_id", "0"}});
TestMAPCache(&ctx);
}
} // namespace xgboost::ltr

View File

@ -6,4 +6,6 @@
namespace xgboost::ltr {
void TestNDCGCache(Context const* ctx);
void TestMAPCache(Context const* ctx);
} // namespace xgboost::ltr

View File

@ -141,7 +141,7 @@ TEST(Metric, DeclareUnifiedTest(MAP)) {
// Rank metric with group info
EXPECT_NEAR(GetMetricEval(metric,
{0.1f, 0.9f, 0.2f, 0.8f, 0.4f, 1.7f},
{2, 7, 1, 0, 5, 0}, // Labels
{1, 1, 1, 0, 1, 0}, // Labels
{}, // Weights
{0, 2, 5, 6}), // Group info
0.8611f, 0.001f);

View File

@ -1,194 +1,130 @@
import itertools
import os
import shutil
import urllib.request
import zipfile
from typing import Dict
import numpy as np
import pytest
import xgboost
from xgboost import testing as tm
pytestmark = tm.timeout(10)
pytestmark = tm.timeout(30)
class TestRanking:
@classmethod
def setup_class(cls):
"""
Download and setup the test fixtures
"""
from sklearn.datasets import load_svmlight_files
def comp_training_with_rank_objective(
dtrain: xgboost.DMatrix,
dtest: xgboost.DMatrix,
rank_objective: str,
metric_name: str,
tolerance: float = 1e-02,
) -> None:
"""Internal method that trains the dataset using the rank objective on GPU and CPU,
evaluates the metric and determines if the delta between the metric is within the
tolerance level.
# download the test data
cls.dpath = os.path.join(tm.demo_dir(__file__), "rank/")
src = 'https://s3-us-west-2.amazonaws.com/xgboost-examples/MQ2008.zip'
target = os.path.join(cls.dpath, "MQ2008.zip")
"""
# specify validations set to watch performance
watchlist = [(dtest, "eval"), (dtrain, "train")]
if os.path.exists(cls.dpath) and os.path.exists(target):
print("Skipping dataset download...")
else:
urllib.request.urlretrieve(url=src, filename=target)
with zipfile.ZipFile(target, 'r') as f:
f.extractall(path=cls.dpath)
params = {
"booster": "gbtree",
"tree_method": "gpu_hist",
"gpu_id": 0,
"predictor": "gpu_predictor",
}
(x_train, y_train, qid_train, x_test, y_test, qid_test,
x_valid, y_valid, qid_valid) = load_svmlight_files(
(cls.dpath + "MQ2008/Fold1/train.txt",
cls.dpath + "MQ2008/Fold1/test.txt",
cls.dpath + "MQ2008/Fold1/vali.txt"),
query_id=True, zero_based=False)
# instantiate the matrices
cls.dtrain = xgboost.DMatrix(x_train, y_train)
cls.dvalid = xgboost.DMatrix(x_valid, y_valid)
cls.dtest = xgboost.DMatrix(x_test, y_test)
# set the group counts from the query IDs
cls.dtrain.set_group([len(list(items))
for _key, items in itertools.groupby(qid_train)])
cls.dtest.set_group([len(list(items))
for _key, items in itertools.groupby(qid_test)])
cls.dvalid.set_group([len(list(items))
for _key, items in itertools.groupby(qid_valid)])
# save the query IDs for testing
cls.qid_train = qid_train
cls.qid_test = qid_test
cls.qid_valid = qid_valid
num_trees = 100
check_metric_improvement_rounds = 10
def setup_weighted(x, y, groups):
# Setup weighted data
data = xgboost.DMatrix(x, y)
groups_segment = [len(list(items))
for _key, items in itertools.groupby(groups)]
data.set_group(groups_segment)
n_groups = len(groups_segment)
weights = np.ones((n_groups,))
data.set_weight(weights)
return data
evals_result: Dict[str, Dict] = {}
params["objective"] = rank_objective
params["eval_metric"] = metric_name
bst = xgboost.train(
params,
dtrain,
num_boost_round=num_trees,
early_stopping_rounds=check_metric_improvement_rounds,
evals=watchlist,
evals_result=evals_result,
)
gpu_scores = evals_result["train"][metric_name][-1]
cls.dtrain_w = setup_weighted(x_train, y_train, qid_train)
cls.dtest_w = setup_weighted(x_test, y_test, qid_test)
cls.dvalid_w = setup_weighted(x_valid, y_valid, qid_valid)
evals_result = {}
# model training parameters
cls.params = {'booster': 'gbtree',
'tree_method': 'gpu_hist',
'gpu_id': 0,
'predictor': 'gpu_predictor'}
cls.cpu_params = {'booster': 'gbtree',
'tree_method': 'hist',
'gpu_id': -1,
'predictor': 'cpu_predictor'}
cpu_params = {
"booster": "gbtree",
"tree_method": "hist",
"gpu_id": -1,
"predictor": "cpu_predictor",
}
cpu_params["objective"] = rank_objective
cpu_params["eval_metric"] = metric_name
bstc = xgboost.train(
cpu_params,
dtrain,
num_boost_round=num_trees,
early_stopping_rounds=check_metric_improvement_rounds,
evals=watchlist,
evals_result=evals_result,
)
cpu_scores = evals_result["train"][metric_name][-1]
@classmethod
def teardown_class(cls):
"""
Cleanup test artifacts from download and unpacking
:return:
"""
os.remove(os.path.join(cls.dpath, "MQ2008.zip"))
shutil.rmtree(os.path.join(cls.dpath, "MQ2008"))
info = (rank_objective, metric_name)
assert np.allclose(gpu_scores, cpu_scores, tolerance, tolerance), info
assert np.allclose(bst.best_score, bstc.best_score, tolerance, tolerance), info
@classmethod
def __test_training_with_rank_objective(cls, rank_objective, metric_name, tolerance=1e-02):
"""
Internal method that trains the dataset using the rank objective on GPU and CPU, evaluates
the metric and determines if the delta between the metric is within the tolerance level
:return:
"""
# specify validations set to watch performance
watchlist = [(cls.dtest, 'eval'), (cls.dtrain, 'train')]
evals_result_weighted: Dict[str, Dict] = {}
dtest.set_weight(np.ones((dtest.get_group().size,)))
dtrain.set_weight(np.ones((dtrain.get_group().size,)))
watchlist = [(dtest, "eval"), (dtrain, "train")]
bst_w = xgboost.train(
params,
dtrain,
num_boost_round=num_trees,
early_stopping_rounds=check_metric_improvement_rounds,
evals=watchlist,
evals_result=evals_result_weighted,
)
weighted_metric = evals_result_weighted["train"][metric_name][-1]
num_trees = 100
check_metric_improvement_rounds = 10
tolerance = 1e-5
assert np.allclose(bst_w.best_score, bst.best_score, tolerance, tolerance)
assert np.allclose(weighted_metric, gpu_scores, tolerance, tolerance)
evals_result = {}
cls.params['objective'] = rank_objective
cls.params['eval_metric'] = metric_name
bst = xgboost.train(
cls.params, cls.dtrain, num_boost_round=num_trees,
early_stopping_rounds=check_metric_improvement_rounds,
evals=watchlist, evals_result=evals_result)
gpu_map_metric = evals_result['train'][metric_name][-1]
evals_result = {}
cls.cpu_params['objective'] = rank_objective
cls.cpu_params['eval_metric'] = metric_name
bstc = xgboost.train(
cls.cpu_params, cls.dtrain, num_boost_round=num_trees,
early_stopping_rounds=check_metric_improvement_rounds,
evals=watchlist, evals_result=evals_result)
cpu_map_metric = evals_result['train'][metric_name][-1]
@pytest.mark.parametrize(
"objective,metric",
[
("rank:pairwise", "auc"),
("rank:pairwise", "ndcg"),
("rank:pairwise", "map"),
("rank:ndcg", "auc"),
("rank:ndcg", "ndcg"),
("rank:ndcg", "map"),
("rank:map", "auc"),
("rank:map", "ndcg"),
("rank:map", "map"),
],
)
def test_with_mq2008(objective, metric) -> None:
(
x_train,
y_train,
qid_train,
x_test,
y_test,
qid_test,
x_valid,
y_valid,
qid_valid,
) = tm.get_mq2008(os.path.join(os.path.join(tm.demo_dir(__file__), "rank")))
assert np.allclose(gpu_map_metric, cpu_map_metric, tolerance,
tolerance)
assert np.allclose(bst.best_score, bstc.best_score, tolerance,
tolerance)
if metric.find("map") != -1 or objective.find("map") != -1:
y_train[y_train <= 1] = 0.0
y_train[y_train > 1] = 1.0
y_test[y_test <= 1] = 0.0
y_test[y_test > 1] = 1.0
evals_result_weighted = {}
watchlist = [(cls.dtest_w, 'eval'), (cls.dtrain_w, 'train')]
bst_w = xgboost.train(
cls.params, cls.dtrain_w, num_boost_round=num_trees,
early_stopping_rounds=check_metric_improvement_rounds,
evals=watchlist, evals_result=evals_result_weighted)
weighted_metric = evals_result_weighted['train'][metric_name][-1]
# GPU Ranking is not deterministic due to `AtomicAddGpair`,
# remove tolerance once the issue is resolved.
# https://github.com/dmlc/xgboost/issues/5561
assert np.allclose(bst_w.best_score, bst.best_score,
tolerance, tolerance)
assert np.allclose(weighted_metric, gpu_map_metric,
tolerance, tolerance)
dtrain = xgboost.DMatrix(x_train, y_train, qid=qid_train)
dtest = xgboost.DMatrix(x_test, y_test, qid=qid_test)
def test_training_rank_pairwise_map_metric(self):
"""
Train an XGBoost ranking model with pairwise objective function and compare map metric
"""
self.__test_training_with_rank_objective('rank:pairwise', 'map')
def test_training_rank_pairwise_auc_metric(self):
"""
Train an XGBoost ranking model with pairwise objective function and compare auc metric
"""
self.__test_training_with_rank_objective('rank:pairwise', 'auc')
def test_training_rank_pairwise_ndcg_metric(self):
"""
Train an XGBoost ranking model with pairwise objective function and compare ndcg metric
"""
self.__test_training_with_rank_objective('rank:pairwise', 'ndcg')
def test_training_rank_ndcg_map(self):
"""
Train an XGBoost ranking model with ndcg objective function and compare map metric
"""
self.__test_training_with_rank_objective('rank:ndcg', 'map')
def test_training_rank_ndcg_auc(self):
"""
Train an XGBoost ranking model with ndcg objective function and compare auc metric
"""
self.__test_training_with_rank_objective('rank:ndcg', 'auc')
def test_training_rank_ndcg_ndcg(self):
"""
Train an XGBoost ranking model with ndcg objective function and compare ndcg metric
"""
self.__test_training_with_rank_objective('rank:ndcg', 'ndcg')
def test_training_rank_map_map(self):
"""
Train an XGBoost ranking model with map objective function and compare map metric
"""
self.__test_training_with_rank_objective('rank:map', 'map')
def test_training_rank_map_auc(self):
"""
Train an XGBoost ranking model with map objective function and compare auc metric
"""
self.__test_training_with_rank_objective('rank:map', 'auc')
def test_training_rank_map_ndcg(self):
"""
Train an XGBoost ranking model with map objective function and compare ndcg metric
"""
self.__test_training_with_rank_objective('rank:map', 'ndcg')
comp_training_with_rank_objective(dtrain, dtest, objective, metric)

View File

@ -128,12 +128,23 @@ def test_ranking():
x_test = np.random.rand(100, 10)
params = {'tree_method': 'exact', 'objective': 'rank:pairwise',
'learning_rate': 0.1, 'gamma': 1.0, 'min_child_weight': 0.1,
'max_depth': 6, 'n_estimators': 4}
params = {
"tree_method": "exact",
"learning_rate": 0.1,
"gamma": 1.0,
"min_child_weight": 0.1,
"max_depth": 6,
"eval_metric": "ndcg",
"n_estimators": 4,
}
model = xgb.sklearn.XGBRanker(**params)
model.fit(x_train, y_train, group=train_group,
eval_set=[(x_valid, y_valid)], eval_group=[valid_group])
model.fit(
x_train,
y_train,
group=train_group,
eval_set=[(x_valid, y_valid)],
eval_group=[valid_group],
)
assert model.evals_result()
pred = model.predict(x_test)
@ -145,11 +156,18 @@ def test_ranking():
assert train_data.get_label().shape[0] == x_train.shape[0]
valid_data.set_group(valid_group)
params_orig = {'tree_method': 'exact', 'objective': 'rank:pairwise',
'eta': 0.1, 'gamma': 1.0,
'min_child_weight': 0.1, 'max_depth': 6}
xgb_model_orig = xgb.train(params_orig, train_data, num_boost_round=4,
evals=[(valid_data, 'validation')])
params_orig = {
"tree_method": "exact",
"objective": "rank:pairwise",
"eta": 0.1,
"gamma": 1.0,
"min_child_weight": 0.1,
"max_depth": 6,
"eval_metric": "ndcg",
}
xgb_model_orig = xgb.train(
params_orig, train_data, num_boost_round=4, evals=[(valid_data, "validation")]
)
pred_orig = xgb_model_orig.predict(test_data)
np.testing.assert_almost_equal(pred, pred_orig)
@ -165,7 +183,11 @@ def test_ranking_metric() -> None:
# sklearn compares the number of mis-classified docs, while the one in xgboost
# compares the number of mis-classified pairs.
ltr = xgb.XGBRanker(
eval_metric=roc_auc_score, n_estimators=10, tree_method="hist", max_depth=2
eval_metric=roc_auc_score,
n_estimators=10,
tree_method="hist",
max_depth=2,
objective="rank:pairwise",
)
ltr.fit(
X,