Rework the MAP metric. (#8931)
- The new implementation is more strict as only binary labels are accepted. The previous implementation converts values greater than 1 to 1. - Deterministic GPU. (no atomic add). - Fix top-k handling. - Precise definition of MAP. (There are other variants on how to handle top-k). - Refactor GPU ranking tests.
This commit is contained in:
parent
b240f055d3
commit
5891f752c8
@ -408,8 +408,17 @@ Specify the learning task and the corresponding learning objective. The objectiv
|
||||
|
||||
- ``ndcg``: `Normalized Discounted Cumulative Gain <http://en.wikipedia.org/wiki/NDCG>`_
|
||||
- ``map``: `Mean Average Precision <http://en.wikipedia.org/wiki/Mean_average_precision#Mean_average_precision>`_
|
||||
- ``ndcg@n``, ``map@n``: 'n' can be assigned as an integer to cut off the top positions in the lists for evaluation.
|
||||
- ``ndcg-``, ``map-``, ``ndcg@n-``, ``map@n-``: In XGBoost, NDCG and MAP will evaluate the score of a list without any positive samples as 1. By adding "-" in the evaluation metric XGBoost will evaluate these score as 0 to be consistent under some conditions.
|
||||
|
||||
The `average precision` is defined as:
|
||||
|
||||
.. math::
|
||||
|
||||
AP@l = \frac{1}{min{(l, N)}}\sum^l_{k=1}P@k \cdot I_{(k)}
|
||||
|
||||
where :math:`I_{(k)}` is an indicator function that equals to :math:`1` when the document at :math:`k` is relevant and :math:`0` otherwise. The :math:`P@k` is the precision at :math:`k`, and :math:`N` is the total number of relevant documents. Lastly, the `mean average precision` is defined as the weighted average across all queries.
|
||||
|
||||
- ``ndcg@n``, ``map@n``: :math:`n` can be assigned as an integer to cut off the top positions in the lists for evaluation.
|
||||
- ``ndcg-``, ``map-``, ``ndcg@n-``, ``map@n-``: In XGBoost, the NDCG and MAP evaluate the score of a list without any positive samples as :math:`1`. By appending "-" to the evaluation metric name, we can ask XGBoost to evaluate these scores as :math:`0` to be consistent under some conditions.
|
||||
- ``poisson-nloglik``: negative log-likelihood for Poisson regression
|
||||
- ``gamma-nloglik``: negative log-likelihood for gamma regression
|
||||
- ``cox-nloglik``: negative partial log-likelihood for Cox proportional hazards regression
|
||||
|
||||
@ -14,6 +14,7 @@ import zipfile
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from contextlib import contextmanager
|
||||
from io import StringIO
|
||||
from pathlib import Path
|
||||
from platform import system
|
||||
from typing import (
|
||||
Any,
|
||||
@ -443,7 +444,7 @@ def get_mq2008(
|
||||
from sklearn.datasets import load_svmlight_files
|
||||
|
||||
src = "https://s3-us-west-2.amazonaws.com/xgboost-examples/MQ2008.zip"
|
||||
target = dpath + "/MQ2008.zip"
|
||||
target = os.path.join(os.path.expanduser(dpath), "MQ2008.zip")
|
||||
if not os.path.exists(target):
|
||||
request.urlretrieve(url=src, filename=target)
|
||||
|
||||
@ -462,9 +463,9 @@ def get_mq2008(
|
||||
qid_valid,
|
||||
) = load_svmlight_files(
|
||||
(
|
||||
dpath + "MQ2008/Fold1/train.txt",
|
||||
dpath + "MQ2008/Fold1/test.txt",
|
||||
dpath + "MQ2008/Fold1/vali.txt",
|
||||
Path(dpath) / "MQ2008" / "Fold1" / "train.txt",
|
||||
Path(dpath) / "MQ2008" / "Fold1" / "test.txt",
|
||||
Path(dpath) / "MQ2008" / "Fold1" / "vali.txt",
|
||||
),
|
||||
query_id=True,
|
||||
zero_based=False,
|
||||
|
||||
@ -48,7 +48,12 @@ def run_ranking_qid_df(impl: ModuleType, tree_method: str) -> None:
|
||||
def neg_mse(*args: Any, **kwargs: Any) -> float:
|
||||
return -float(mean_squared_error(*args, **kwargs))
|
||||
|
||||
ranker = xgb.XGBRanker(n_estimators=3, eval_metric=neg_mse, tree_method=tree_method)
|
||||
ranker = xgb.XGBRanker(
|
||||
n_estimators=3,
|
||||
eval_metric=neg_mse,
|
||||
tree_method=tree_method,
|
||||
disable_default_eval_metric=True,
|
||||
)
|
||||
ranker.fit(df, y, eval_set=[(valid_df, y)])
|
||||
score = ranker.score(valid_df, y)
|
||||
assert np.isclose(score, ranker.evals_result()["validation_0"]["neg_mse"][-1])
|
||||
|
||||
@ -22,7 +22,7 @@ constexpr StringView LabelScoreSize() {
|
||||
}
|
||||
|
||||
constexpr StringView InfInData() {
|
||||
return "Input data contains `inf` while `missing` is not set to `inf`";
|
||||
return "Input data contains `inf` or a value too large, while `missing` is not set to `inf`";
|
||||
}
|
||||
} // namespace xgboost::error
|
||||
#endif // XGBOOST_COMMON_ERROR_MSG_H_
|
||||
|
||||
@ -1,13 +1,15 @@
|
||||
/*!
|
||||
* Copyright 2022, XGBoost contributors.
|
||||
/**
|
||||
* Copyright 2022-2023 by XGBoost contributors.
|
||||
*/
|
||||
#ifndef XGBOOST_COMMON_NUMERIC_H_
|
||||
#define XGBOOST_COMMON_NUMERIC_H_
|
||||
|
||||
#include <dmlc/common.h> // OMPException
|
||||
|
||||
#include <algorithm> // std::max
|
||||
#include <iterator> // std::iterator_traits
|
||||
#include <algorithm> // for std::max
|
||||
#include <cstddef> // for size_t
|
||||
#include <cstdint> // for int32_t
|
||||
#include <iterator> // for iterator_traits
|
||||
#include <vector>
|
||||
|
||||
#include "common.h" // AssertGPUSupport
|
||||
@ -15,8 +17,7 @@
|
||||
#include "xgboost/context.h" // Context
|
||||
#include "xgboost/host_device_vector.h" // HostDeviceVector
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
namespace xgboost::common {
|
||||
|
||||
/**
|
||||
* \brief Run length encode on CPU, input must be sorted.
|
||||
@ -111,11 +112,11 @@ inline double Reduce(Context const*, HostDeviceVector<float> const&) {
|
||||
namespace cpu_impl {
|
||||
template <typename It, typename V = typename It::value_type>
|
||||
V Reduce(Context const* ctx, It first, It second, V const& init) {
|
||||
size_t n = std::distance(first, second);
|
||||
common::MemStackAllocator<V, common::DefaultMaxThreads()> result_tloc(ctx->Threads(), init);
|
||||
common::ParallelFor(n, ctx->Threads(),
|
||||
[&](auto i) { result_tloc[omp_get_thread_num()] += first[i]; });
|
||||
auto result = std::accumulate(result_tloc.cbegin(), result_tloc.cbegin() + ctx->Threads(), init);
|
||||
std::size_t n = std::distance(first, second);
|
||||
auto n_threads = static_cast<std::size_t>(std::min(n, static_cast<std::size_t>(ctx->Threads())));
|
||||
common::MemStackAllocator<V, common::DefaultMaxThreads()> result_tloc(n_threads, init);
|
||||
common::ParallelFor(n, n_threads, [&](auto i) { result_tloc[omp_get_thread_num()] += first[i]; });
|
||||
auto result = std::accumulate(result_tloc.cbegin(), result_tloc.cbegin() + n_threads, init);
|
||||
return result;
|
||||
}
|
||||
} // namespace cpu_impl
|
||||
@ -144,7 +145,6 @@ void Iota(Context const* ctx, It first, It last,
|
||||
});
|
||||
}
|
||||
}
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
} // namespace xgboost::common
|
||||
|
||||
#endif // XGBOOST_COMMON_NUMERIC_H_
|
||||
|
||||
@ -114,6 +114,15 @@ void NDCGCache::InitOnCUDA(Context const*, MetaInfo const&) { common::AssertGPUS
|
||||
|
||||
DMLC_REGISTER_PARAMETER(LambdaRankParam);
|
||||
|
||||
void MAPCache::InitOnCPU(Context const*, MetaInfo const& info) {
|
||||
auto const& h_label = info.labels.HostView().Slice(linalg::All(), 0);
|
||||
CheckMapLabels(h_label, [](auto beg, auto end, auto op) { return std::all_of(beg, end, op); });
|
||||
}
|
||||
|
||||
#if !defined(XGBOOST_USE_CUDA)
|
||||
void MAPCache::InitOnCUDA(Context const*, MetaInfo const&) { common::AssertGPUSupport(); }
|
||||
#endif // !defined(XGBOOST_USE_CUDA)
|
||||
|
||||
std::string ParseMetricName(StringView name, StringView param, position_t* topn, bool* minus) {
|
||||
std::string out_name;
|
||||
if (!param.empty()) {
|
||||
|
||||
@ -204,4 +204,9 @@ void NDCGCache::InitOnCUDA(Context const* ctx, MetaInfo const& info) {
|
||||
dh::LaunchN(MaxGroupSize(), cuctx->Stream(),
|
||||
[=] XGBOOST_DEVICE(std::size_t i) { d_discount[i] = CalcDCGDiscount(i); });
|
||||
}
|
||||
|
||||
void MAPCache::InitOnCUDA(Context const* ctx, MetaInfo const& info) {
|
||||
auto const d_label = info.labels.View(ctx->gpu_id).Slice(linalg::All(), 0);
|
||||
CheckMapLabels(d_label, CheckMAPOp{ctx->CUDACtx()});
|
||||
}
|
||||
} // namespace xgboost::ltr
|
||||
|
||||
@ -358,6 +358,71 @@ void CheckNDCGLabels(ltr::LambdaRankParam const& p, linalg::VectorView<float con
|
||||
}
|
||||
}
|
||||
|
||||
template <typename AllOf>
|
||||
bool IsBinaryRel(linalg::VectorView<float const> label, AllOf all_of) {
|
||||
auto s_label = label.Values();
|
||||
return all_of(s_label.data(), s_label.data() + s_label.size(), [] XGBOOST_DEVICE(float y) {
|
||||
return std::abs(y - 1.0f) < kRtEps || std::abs(y - 0.0f) < kRtEps;
|
||||
});
|
||||
}
|
||||
/**
|
||||
* \brief Validate label for MAP
|
||||
*
|
||||
* \tparam Implementation of std::all_of. Specified as a parameter to reuse the check for
|
||||
* both CPU and GPU.
|
||||
*/
|
||||
template <typename AllOf>
|
||||
void CheckMapLabels(linalg::VectorView<float const> label, AllOf all_of) {
|
||||
auto s_label = label.Values();
|
||||
auto is_binary = IsBinaryRel(label, all_of);
|
||||
CHECK(is_binary) << "MAP can only be used with binary labels.";
|
||||
}
|
||||
|
||||
class MAPCache : public RankingCache {
|
||||
// Total number of relevant documents for each group
|
||||
HostDeviceVector<double> n_rel_;
|
||||
// \sum l_k/k
|
||||
HostDeviceVector<double> acc_;
|
||||
HostDeviceVector<double> map_;
|
||||
// Number of samples in this dataset.
|
||||
std::size_t n_samples_{0};
|
||||
|
||||
void InitOnCPU(Context const* ctx, MetaInfo const& info);
|
||||
void InitOnCUDA(Context const* ctx, MetaInfo const& info);
|
||||
|
||||
public:
|
||||
MAPCache(Context const* ctx, MetaInfo const& info, LambdaRankParam const& p)
|
||||
: RankingCache{ctx, info, p}, n_samples_{static_cast<std::size_t>(info.num_row_)} {
|
||||
if (ctx->IsCPU()) {
|
||||
this->InitOnCPU(ctx, info);
|
||||
} else {
|
||||
this->InitOnCUDA(ctx, info);
|
||||
}
|
||||
}
|
||||
|
||||
common::Span<double> NumRelevant(Context const* ctx) {
|
||||
if (n_rel_.Empty()) {
|
||||
n_rel_.SetDevice(ctx->gpu_id);
|
||||
n_rel_.Resize(n_samples_);
|
||||
}
|
||||
return ctx->IsCPU() ? n_rel_.HostSpan() : n_rel_.DeviceSpan();
|
||||
}
|
||||
common::Span<double> Acc(Context const* ctx) {
|
||||
if (acc_.Empty()) {
|
||||
acc_.SetDevice(ctx->gpu_id);
|
||||
acc_.Resize(n_samples_);
|
||||
}
|
||||
return ctx->IsCPU() ? acc_.HostSpan() : acc_.DeviceSpan();
|
||||
}
|
||||
common::Span<double> Map(Context const* ctx) {
|
||||
if (map_.Empty()) {
|
||||
map_.SetDevice(ctx->gpu_id);
|
||||
map_.Resize(this->Groups());
|
||||
}
|
||||
return ctx->IsCPU() ? map_.HostSpan() : map_.DeviceSpan();
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* \brief Parse name for ranking metric given parameters.
|
||||
*
|
||||
|
||||
@ -8,9 +8,11 @@
|
||||
#include <dmlc/omp.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdint> // std::int32_t
|
||||
#include <cstdint> // for int32_t
|
||||
#include <cstdlib> // for malloc, free
|
||||
#include <limits>
|
||||
#include <type_traits> // std::is_signed
|
||||
#include <new> // for bad_alloc
|
||||
#include <type_traits> // for is_signed
|
||||
#include <vector>
|
||||
|
||||
#include "xgboost/logging.h"
|
||||
@ -266,7 +268,7 @@ class MemStackAllocator {
|
||||
if (MaxStackSize >= required_size_) {
|
||||
ptr_ = stack_mem_;
|
||||
} else {
|
||||
ptr_ = reinterpret_cast<T*>(malloc(required_size_ * sizeof(T)));
|
||||
ptr_ = reinterpret_cast<T*>(std::malloc(required_size_ * sizeof(T)));
|
||||
}
|
||||
if (!ptr_) {
|
||||
throw std::bad_alloc{};
|
||||
@ -278,7 +280,7 @@ class MemStackAllocator {
|
||||
|
||||
~MemStackAllocator() {
|
||||
if (required_size_ > MaxStackSize) {
|
||||
free(ptr_);
|
||||
std::free(ptr_);
|
||||
}
|
||||
}
|
||||
T& operator[](size_t i) { return ptr_[i]; }
|
||||
|
||||
@ -284,37 +284,6 @@ struct EvalPrecision : public EvalRank {
|
||||
}
|
||||
};
|
||||
|
||||
/*! \brief Mean Average Precision at N, for both classification and rank */
|
||||
struct EvalMAP : public EvalRank {
|
||||
public:
|
||||
explicit EvalMAP(const char* name, const char* param) : EvalRank(name, param) {}
|
||||
|
||||
double EvalGroup(PredIndPairContainer *recptr) const override {
|
||||
PredIndPairContainer &rec(*recptr);
|
||||
std::stable_sort(rec.begin(), rec.end(), common::CmpFirst);
|
||||
unsigned nhits = 0;
|
||||
double sumap = 0.0;
|
||||
for (size_t i = 0; i < rec.size(); ++i) {
|
||||
if (rec[i].second != 0) {
|
||||
nhits += 1;
|
||||
if (i < this->topn) {
|
||||
sumap += static_cast<double>(nhits) / (i + 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (nhits != 0) {
|
||||
sumap /= nhits;
|
||||
return sumap;
|
||||
} else {
|
||||
if (this->minus) {
|
||||
return 0.0;
|
||||
} else {
|
||||
return 1.0;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/*! \brief Cox: Partial likelihood of the Cox proportional hazards model */
|
||||
struct EvalCox : public MetricNoCache {
|
||||
public:
|
||||
@ -370,10 +339,6 @@ XGBOOST_REGISTER_METRIC(Precision, "pre")
|
||||
.describe("precision@k for rank.")
|
||||
.set_body([](const char* param) { return new EvalPrecision("pre", param); });
|
||||
|
||||
XGBOOST_REGISTER_METRIC(MAP, "map")
|
||||
.describe("map@k for rank.")
|
||||
.set_body([](const char* param) { return new EvalMAP("map", param); });
|
||||
|
||||
XGBOOST_REGISTER_METRIC(Cox, "cox-nloglik")
|
||||
.describe("Negative log partial likelihood of Cox proportional hazards model.")
|
||||
.set_body([](const char*) { return new EvalCox(); });
|
||||
@ -516,6 +481,68 @@ class EvalNDCG : public EvalRankWithCache<ltr::NDCGCache> {
|
||||
}
|
||||
};
|
||||
|
||||
class EvalMAPScore : public EvalRankWithCache<ltr::MAPCache> {
|
||||
public:
|
||||
using EvalRankWithCache::EvalRankWithCache;
|
||||
const char* Name() const override { return name_.c_str(); }
|
||||
|
||||
double Eval(HostDeviceVector<float> const& predt, MetaInfo const& info,
|
||||
std::shared_ptr<ltr::MAPCache> p_cache) override {
|
||||
if (ctx_->IsCUDA()) {
|
||||
auto map = cuda_impl::MAPScore(ctx_, info, predt, minus_, p_cache);
|
||||
return Finalize(map.Residue(), map.Weights());
|
||||
}
|
||||
|
||||
auto gptr = p_cache->DataGroupPtr(ctx_);
|
||||
auto h_label = info.labels.HostView().Slice(linalg::All(), 0);
|
||||
auto h_predt = linalg::MakeTensorView(ctx_, &predt, predt.Size());
|
||||
|
||||
auto map_gloc = p_cache->Map(ctx_);
|
||||
std::fill_n(map_gloc.data(), map_gloc.size(), 0.0);
|
||||
auto rank_idx = p_cache->SortedIdx(ctx_, predt.ConstHostSpan());
|
||||
|
||||
common::ParallelFor(p_cache->Groups(), ctx_->Threads(), [&](auto g) {
|
||||
auto g_predt = h_predt.Slice(linalg::Range(gptr[g], gptr[g + 1]));
|
||||
auto g_label = h_label.Slice(linalg::Range(gptr[g], gptr[g + 1]));
|
||||
auto g_rank = rank_idx.subspan(gptr[g]);
|
||||
|
||||
auto n = std::min(static_cast<std::size_t>(param_.TopK()), g_label.Size());
|
||||
double n_hits{0.0};
|
||||
for (std::size_t i = 0; i < n; ++i) {
|
||||
auto p = g_label(g_rank[i]);
|
||||
n_hits += p;
|
||||
map_gloc[g] += n_hits / static_cast<double>((i + 1)) * p;
|
||||
}
|
||||
for (std::size_t i = n; i < g_label.Size(); ++i) {
|
||||
n_hits += g_label(g_rank[i]);
|
||||
}
|
||||
if (n_hits > 0.0) {
|
||||
map_gloc[g] /= std::min(n_hits, static_cast<double>(param_.TopK()));
|
||||
} else {
|
||||
map_gloc[g] = minus_ ? 0.0 : 1.0;
|
||||
}
|
||||
});
|
||||
|
||||
auto sw = 0.0;
|
||||
auto weight = common::MakeOptionalWeights(ctx_, info.weights_);
|
||||
if (!weight.Empty()) {
|
||||
CHECK_EQ(weight.weights.size(), p_cache->Groups());
|
||||
}
|
||||
for (std::size_t i = 0; i < map_gloc.size(); ++i) {
|
||||
map_gloc[i] = map_gloc[i] * weight[i];
|
||||
sw += weight[i];
|
||||
}
|
||||
auto sum = std::accumulate(map_gloc.cbegin(), map_gloc.cend(), 0.0);
|
||||
return Finalize(sum, sw);
|
||||
}
|
||||
};
|
||||
|
||||
XGBOOST_REGISTER_METRIC(EvalMAP, "map")
|
||||
.describe("map@k for ranking.")
|
||||
.set_body([](char const* param) {
|
||||
return new EvalMAPScore{"map", param};
|
||||
});
|
||||
|
||||
XGBOOST_REGISTER_METRIC(EvalNDCG, "ndcg")
|
||||
.describe("ndcg@k for ranking.")
|
||||
.set_body([](char const* param) {
|
||||
|
||||
@ -125,89 +125,10 @@ struct EvalPrecisionGpu {
|
||||
};
|
||||
|
||||
|
||||
/*! \brief Mean Average Precision at N, for both classification and rank */
|
||||
struct EvalMAPGpu {
|
||||
public:
|
||||
static double EvalMetric(const dh::SegmentSorter<float> &pred_sorter,
|
||||
const float *dlabels,
|
||||
const EvalRankConfig &ecfg) {
|
||||
// Group info on device
|
||||
const auto &dgroups = pred_sorter.GetGroupsSpan();
|
||||
const auto ngroups = pred_sorter.GetNumGroups();
|
||||
const auto &dgroup_idx = pred_sorter.GetGroupSegmentsSpan();
|
||||
|
||||
// Original positions of the predictions after they have been sorted
|
||||
const auto &dpreds_orig_pos = pred_sorter.GetOriginalPositionsSpan();
|
||||
|
||||
// First, determine non zero labels in the dataset individually
|
||||
const auto nitems = pred_sorter.GetNumItems();
|
||||
dh::caching_device_vector<uint32_t> hits(nitems, 0);
|
||||
auto DetermineNonTrivialLabelLambda = [=] __device__(uint32_t idx) {
|
||||
return (static_cast<unsigned>(dlabels[dpreds_orig_pos[idx]]) != 0) ? 1 : 0;
|
||||
}; // NOLINT
|
||||
|
||||
thrust::transform(thrust::make_counting_iterator(static_cast<uint32_t>(0)),
|
||||
thrust::make_counting_iterator(nitems),
|
||||
hits.begin(),
|
||||
DetermineNonTrivialLabelLambda);
|
||||
|
||||
// Allocator to be used by sort for managing space overhead while performing prefix scans
|
||||
dh::XGBCachingDeviceAllocator<char> alloc;
|
||||
|
||||
// Next, prefix scan the nontrivial labels that are segmented to accumulate them.
|
||||
// This is required for computing the metric sum
|
||||
// Data segmented into different groups...
|
||||
thrust::inclusive_scan_by_key(thrust::cuda::par(alloc),
|
||||
dh::tcbegin(dgroup_idx), dh::tcend(dgroup_idx),
|
||||
hits.begin(), // Input value
|
||||
hits.begin()); // In-place scan
|
||||
|
||||
// Find each group's metric sum
|
||||
dh::caching_device_vector<double> sumap(ngroups, 0);
|
||||
auto *dsumap = sumap.data().get();
|
||||
const auto *dhits = hits.data().get();
|
||||
|
||||
int device_id = -1;
|
||||
dh::safe_cuda(cudaGetDevice(&device_id));
|
||||
// For each group item compute the aggregated precision
|
||||
dh::LaunchN(nitems, nullptr, [=] __device__(uint32_t idx) {
|
||||
if (DetermineNonTrivialLabelLambda(idx)) {
|
||||
const auto group_idx = dgroup_idx[idx];
|
||||
const auto group_begin = dgroups[group_idx];
|
||||
const auto ridx = idx - group_begin;
|
||||
if (ridx < ecfg.topn) {
|
||||
atomicAdd(&dsumap[group_idx],
|
||||
static_cast<double>(dhits[idx]) / (ridx + 1));
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Aggregate the group's item precisions
|
||||
dh::LaunchN(ngroups, nullptr, [=] __device__(uint32_t gidx) {
|
||||
auto nhits = dgroups[gidx + 1] ? dhits[dgroups[gidx + 1] - 1] : 0;
|
||||
if (nhits != 0) {
|
||||
dsumap[gidx] /= nhits;
|
||||
} else {
|
||||
if (ecfg.minus) {
|
||||
dsumap[gidx] = 0;
|
||||
} else {
|
||||
dsumap[gidx] = 1;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
return thrust::reduce(thrust::cuda::par(alloc), sumap.begin(), sumap.end());
|
||||
}
|
||||
};
|
||||
|
||||
XGBOOST_REGISTER_GPU_METRIC(PrecisionGpu, "pre")
|
||||
.describe("precision@k for rank computed on GPU.")
|
||||
.set_body([](const char* param) { return new EvalRankGpu<EvalPrecisionGpu>("pre", param); });
|
||||
|
||||
XGBOOST_REGISTER_GPU_METRIC(MAPGpu, "map")
|
||||
.describe("map@k for rank computed on GPU.")
|
||||
.set_body([](const char* param) { return new EvalRankGpu<EvalMAPGpu>("map", param); });
|
||||
|
||||
namespace cuda_impl {
|
||||
PackedReduceResult NDCGScore(Context const *ctx, MetaInfo const &info,
|
||||
HostDeviceVector<float> const &predt, bool minus,
|
||||
@ -245,5 +166,87 @@ PackedReduceResult NDCGScore(Context const *ctx, MetaInfo const &info,
|
||||
PackedReduceResult{0.0, 0.0});
|
||||
return pair;
|
||||
}
|
||||
|
||||
PackedReduceResult MAPScore(Context const *ctx, MetaInfo const &info,
|
||||
HostDeviceVector<float> const &predt, bool minus,
|
||||
std::shared_ptr<ltr::MAPCache> p_cache) {
|
||||
auto d_group_ptr = p_cache->DataGroupPtr(ctx);
|
||||
auto n_groups = info.group_ptr_.size() - 1;
|
||||
auto d_label = info.labels.View(ctx->gpu_id).Slice(linalg::All(), 0);
|
||||
|
||||
predt.SetDevice(ctx->gpu_id);
|
||||
auto d_rank_idx = p_cache->SortedIdx(ctx, predt.ConstDeviceSpan());
|
||||
auto key_it = dh::MakeTransformIterator<std::size_t>(
|
||||
thrust::make_counting_iterator(0ul),
|
||||
[=] XGBOOST_DEVICE(std::size_t i) { return dh::SegmentId(d_group_ptr, i); });
|
||||
|
||||
auto get_label = [=] XGBOOST_DEVICE(std::size_t i) {
|
||||
auto g = key_it[i];
|
||||
auto g_begin = d_group_ptr[g];
|
||||
auto g_end = d_group_ptr[g + 1];
|
||||
i -= g_begin;
|
||||
auto g_label = d_label.Slice(linalg::Range(g_begin, g_end));
|
||||
auto g_rank = d_rank_idx.subspan(g_begin, g_end - g_begin);
|
||||
return g_label(g_rank[i]);
|
||||
};
|
||||
auto it = dh::MakeTransformIterator<double>(thrust::make_counting_iterator(0ul), get_label);
|
||||
|
||||
auto cuctx = ctx->CUDACtx();
|
||||
auto n_rel = p_cache->NumRelevant(ctx);
|
||||
thrust::inclusive_scan_by_key(cuctx->CTP(), key_it, key_it + d_label.Size(), it, n_rel.data());
|
||||
|
||||
double topk = p_cache->Param().TopK();
|
||||
auto map = p_cache->Map(ctx);
|
||||
thrust::fill_n(cuctx->CTP(), map.data(), map.size(), 0.0);
|
||||
{
|
||||
auto val_it = dh::MakeTransformIterator<double>(
|
||||
thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(std::size_t i) {
|
||||
auto g = key_it[i];
|
||||
auto g_begin = d_group_ptr[g];
|
||||
auto g_end = d_group_ptr[g + 1];
|
||||
i -= g_begin;
|
||||
if (i >= topk) {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
auto g_label = d_label.Slice(linalg::Range(g_begin, g_end));
|
||||
auto g_rank = d_rank_idx.subspan(g_begin, g_end - g_begin);
|
||||
auto label = g_label(g_rank[i]);
|
||||
|
||||
auto g_n_rel = n_rel.subspan(g_begin, g_end - g_begin);
|
||||
auto nhits = g_n_rel[i];
|
||||
return nhits / static_cast<double>(i + 1) * label;
|
||||
});
|
||||
|
||||
std::size_t bytes;
|
||||
cub::DeviceSegmentedReduce::Sum(nullptr, bytes, val_it, map.data(), p_cache->Groups(),
|
||||
d_group_ptr.data(), d_group_ptr.data() + 1, cuctx->Stream());
|
||||
dh::TemporaryArray<char> temp(bytes);
|
||||
cub::DeviceSegmentedReduce::Sum(temp.data().get(), bytes, val_it, map.data(), p_cache->Groups(),
|
||||
d_group_ptr.data(), d_group_ptr.data() + 1, cuctx->Stream());
|
||||
}
|
||||
|
||||
PackedReduceResult result{0.0, 0.0};
|
||||
{
|
||||
auto d_weight = common::MakeOptionalWeights(ctx, info.weights_);
|
||||
if (!d_weight.Empty()) {
|
||||
CHECK_EQ(d_weight.weights.size(), p_cache->Groups());
|
||||
}
|
||||
auto val_it = dh::MakeTransformIterator<PackedReduceResult>(
|
||||
thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(std::size_t g) {
|
||||
auto g_begin = d_group_ptr[g];
|
||||
auto g_end = d_group_ptr[g + 1];
|
||||
auto g_n_rel = n_rel.subspan(g_begin, g_end - g_begin);
|
||||
if (!g_n_rel.empty() && g_n_rel.back() > 0.0) {
|
||||
return PackedReduceResult{map[g] * d_weight[g] / std::min(g_n_rel.back(), topk),
|
||||
static_cast<double>(d_weight[g])};
|
||||
}
|
||||
return PackedReduceResult{minus ? 0.0 : 1.0, static_cast<double>(d_weight[g])};
|
||||
});
|
||||
result =
|
||||
thrust::reduce(cuctx->CTP(), val_it, val_it + map.size(), PackedReduceResult{0.0, 0.0});
|
||||
}
|
||||
return result;
|
||||
}
|
||||
} // namespace cuda_impl
|
||||
} // namespace xgboost::metric
|
||||
|
||||
@ -6,7 +6,7 @@
|
||||
#include <memory> // for shared_ptr
|
||||
|
||||
#include "../common/common.h" // for AssertGPUSupport
|
||||
#include "../common/ranking_utils.h" // for NDCGCache
|
||||
#include "../common/ranking_utils.h" // for NDCGCache, MAPCache
|
||||
#include "metric_common.h" // for PackedReduceResult
|
||||
#include "xgboost/context.h" // for Context
|
||||
#include "xgboost/data.h" // for MetaInfo
|
||||
@ -19,6 +19,10 @@ PackedReduceResult NDCGScore(Context const *ctx, MetaInfo const &info,
|
||||
HostDeviceVector<float> const &predt, bool minus,
|
||||
std::shared_ptr<ltr::NDCGCache> p_cache);
|
||||
|
||||
PackedReduceResult MAPScore(Context const *ctx, MetaInfo const &info,
|
||||
HostDeviceVector<float> const &predt, bool minus,
|
||||
std::shared_ptr<ltr::MAPCache> p_cache);
|
||||
|
||||
#if !defined(XGBOOST_USE_CUDA)
|
||||
inline PackedReduceResult NDCGScore(Context const *, MetaInfo const &,
|
||||
HostDeviceVector<float> const &, bool,
|
||||
@ -26,6 +30,13 @@ inline PackedReduceResult NDCGScore(Context const *, MetaInfo const &,
|
||||
common::AssertGPUSupport();
|
||||
return {};
|
||||
}
|
||||
|
||||
inline PackedReduceResult MAPScore(Context const *, MetaInfo const &,
|
||||
HostDeviceVector<float> const &, bool,
|
||||
std::shared_ptr<ltr::MAPCache>) {
|
||||
common::AssertGPUSupport();
|
||||
return {};
|
||||
}
|
||||
#endif
|
||||
} // namespace cuda_impl
|
||||
} // namespace metric
|
||||
|
||||
@ -177,4 +177,36 @@ TEST(NDCGCache, InitFromCPU) {
|
||||
Context ctx;
|
||||
TestNDCGCache(&ctx);
|
||||
}
|
||||
|
||||
void TestMAPCache(Context const* ctx) {
|
||||
auto p_fmat = EmptyDMatrix();
|
||||
MetaInfo& info = p_fmat->Info();
|
||||
LambdaRankParam param;
|
||||
param.UpdateAllowUnknown(Args{});
|
||||
|
||||
std::vector<float> h_data(32);
|
||||
|
||||
common::Iota(ctx, h_data.begin(), h_data.end(), 0.0f);
|
||||
info.labels.Reshape(h_data.size());
|
||||
info.num_row_ = h_data.size();
|
||||
info.labels.Data()->HostVector() = std::move(h_data);
|
||||
|
||||
auto fail = [&]() { std::make_shared<MAPCache>(ctx, info, param); };
|
||||
// binary label
|
||||
ASSERT_THROW(fail(), dmlc::Error);
|
||||
|
||||
h_data = std::vector<float>(32, 0.0f);
|
||||
h_data[1] = 1.0f;
|
||||
info.labels.Data()->HostVector() = h_data;
|
||||
auto p_cache = std::make_shared<MAPCache>(ctx, info, param);
|
||||
|
||||
ASSERT_EQ(p_cache->Acc(ctx).size(), info.num_row_);
|
||||
ASSERT_EQ(p_cache->NumRelevant(ctx).size(), info.num_row_);
|
||||
}
|
||||
|
||||
TEST(MAPCache, InitFromCPU) {
|
||||
Context ctx;
|
||||
ctx.Init(Args{});
|
||||
TestMAPCache(&ctx);
|
||||
}
|
||||
} // namespace xgboost::ltr
|
||||
|
||||
@ -95,4 +95,10 @@ TEST(NDCGCache, InitFromGPU) {
|
||||
ctx.UpdateAllowUnknown(Args{{"gpu_id", "0"}});
|
||||
TestNDCGCache(&ctx);
|
||||
}
|
||||
|
||||
TEST(MAPCache, InitFromGPU) {
|
||||
Context ctx;
|
||||
ctx.UpdateAllowUnknown(Args{{"gpu_id", "0"}});
|
||||
TestMAPCache(&ctx);
|
||||
}
|
||||
} // namespace xgboost::ltr
|
||||
|
||||
@ -6,4 +6,6 @@
|
||||
|
||||
namespace xgboost::ltr {
|
||||
void TestNDCGCache(Context const* ctx);
|
||||
|
||||
void TestMAPCache(Context const* ctx);
|
||||
} // namespace xgboost::ltr
|
||||
|
||||
@ -141,7 +141,7 @@ TEST(Metric, DeclareUnifiedTest(MAP)) {
|
||||
// Rank metric with group info
|
||||
EXPECT_NEAR(GetMetricEval(metric,
|
||||
{0.1f, 0.9f, 0.2f, 0.8f, 0.4f, 1.7f},
|
||||
{2, 7, 1, 0, 5, 0}, // Labels
|
||||
{1, 1, 1, 0, 1, 0}, // Labels
|
||||
{}, // Weights
|
||||
{0, 2, 5, 6}), // Group info
|
||||
0.8611f, 0.001f);
|
||||
|
||||
@ -1,194 +1,130 @@
|
||||
import itertools
|
||||
import os
|
||||
import shutil
|
||||
import urllib.request
|
||||
import zipfile
|
||||
from typing import Dict
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import xgboost
|
||||
from xgboost import testing as tm
|
||||
|
||||
pytestmark = tm.timeout(10)
|
||||
pytestmark = tm.timeout(30)
|
||||
|
||||
|
||||
class TestRanking:
|
||||
@classmethod
|
||||
def setup_class(cls):
|
||||
"""
|
||||
Download and setup the test fixtures
|
||||
"""
|
||||
from sklearn.datasets import load_svmlight_files
|
||||
def comp_training_with_rank_objective(
|
||||
dtrain: xgboost.DMatrix,
|
||||
dtest: xgboost.DMatrix,
|
||||
rank_objective: str,
|
||||
metric_name: str,
|
||||
tolerance: float = 1e-02,
|
||||
) -> None:
|
||||
"""Internal method that trains the dataset using the rank objective on GPU and CPU,
|
||||
evaluates the metric and determines if the delta between the metric is within the
|
||||
tolerance level.
|
||||
|
||||
# download the test data
|
||||
cls.dpath = os.path.join(tm.demo_dir(__file__), "rank/")
|
||||
src = 'https://s3-us-west-2.amazonaws.com/xgboost-examples/MQ2008.zip'
|
||||
target = os.path.join(cls.dpath, "MQ2008.zip")
|
||||
"""
|
||||
# specify validations set to watch performance
|
||||
watchlist = [(dtest, "eval"), (dtrain, "train")]
|
||||
|
||||
if os.path.exists(cls.dpath) and os.path.exists(target):
|
||||
print("Skipping dataset download...")
|
||||
else:
|
||||
urllib.request.urlretrieve(url=src, filename=target)
|
||||
with zipfile.ZipFile(target, 'r') as f:
|
||||
f.extractall(path=cls.dpath)
|
||||
params = {
|
||||
"booster": "gbtree",
|
||||
"tree_method": "gpu_hist",
|
||||
"gpu_id": 0,
|
||||
"predictor": "gpu_predictor",
|
||||
}
|
||||
|
||||
(x_train, y_train, qid_train, x_test, y_test, qid_test,
|
||||
x_valid, y_valid, qid_valid) = load_svmlight_files(
|
||||
(cls.dpath + "MQ2008/Fold1/train.txt",
|
||||
cls.dpath + "MQ2008/Fold1/test.txt",
|
||||
cls.dpath + "MQ2008/Fold1/vali.txt"),
|
||||
query_id=True, zero_based=False)
|
||||
# instantiate the matrices
|
||||
cls.dtrain = xgboost.DMatrix(x_train, y_train)
|
||||
cls.dvalid = xgboost.DMatrix(x_valid, y_valid)
|
||||
cls.dtest = xgboost.DMatrix(x_test, y_test)
|
||||
# set the group counts from the query IDs
|
||||
cls.dtrain.set_group([len(list(items))
|
||||
for _key, items in itertools.groupby(qid_train)])
|
||||
cls.dtest.set_group([len(list(items))
|
||||
for _key, items in itertools.groupby(qid_test)])
|
||||
cls.dvalid.set_group([len(list(items))
|
||||
for _key, items in itertools.groupby(qid_valid)])
|
||||
# save the query IDs for testing
|
||||
cls.qid_train = qid_train
|
||||
cls.qid_test = qid_test
|
||||
cls.qid_valid = qid_valid
|
||||
num_trees = 100
|
||||
check_metric_improvement_rounds = 10
|
||||
|
||||
def setup_weighted(x, y, groups):
|
||||
# Setup weighted data
|
||||
data = xgboost.DMatrix(x, y)
|
||||
groups_segment = [len(list(items))
|
||||
for _key, items in itertools.groupby(groups)]
|
||||
data.set_group(groups_segment)
|
||||
n_groups = len(groups_segment)
|
||||
weights = np.ones((n_groups,))
|
||||
data.set_weight(weights)
|
||||
return data
|
||||
evals_result: Dict[str, Dict] = {}
|
||||
params["objective"] = rank_objective
|
||||
params["eval_metric"] = metric_name
|
||||
bst = xgboost.train(
|
||||
params,
|
||||
dtrain,
|
||||
num_boost_round=num_trees,
|
||||
early_stopping_rounds=check_metric_improvement_rounds,
|
||||
evals=watchlist,
|
||||
evals_result=evals_result,
|
||||
)
|
||||
gpu_scores = evals_result["train"][metric_name][-1]
|
||||
|
||||
cls.dtrain_w = setup_weighted(x_train, y_train, qid_train)
|
||||
cls.dtest_w = setup_weighted(x_test, y_test, qid_test)
|
||||
cls.dvalid_w = setup_weighted(x_valid, y_valid, qid_valid)
|
||||
evals_result = {}
|
||||
|
||||
# model training parameters
|
||||
cls.params = {'booster': 'gbtree',
|
||||
'tree_method': 'gpu_hist',
|
||||
'gpu_id': 0,
|
||||
'predictor': 'gpu_predictor'}
|
||||
cls.cpu_params = {'booster': 'gbtree',
|
||||
'tree_method': 'hist',
|
||||
'gpu_id': -1,
|
||||
'predictor': 'cpu_predictor'}
|
||||
cpu_params = {
|
||||
"booster": "gbtree",
|
||||
"tree_method": "hist",
|
||||
"gpu_id": -1,
|
||||
"predictor": "cpu_predictor",
|
||||
}
|
||||
cpu_params["objective"] = rank_objective
|
||||
cpu_params["eval_metric"] = metric_name
|
||||
bstc = xgboost.train(
|
||||
cpu_params,
|
||||
dtrain,
|
||||
num_boost_round=num_trees,
|
||||
early_stopping_rounds=check_metric_improvement_rounds,
|
||||
evals=watchlist,
|
||||
evals_result=evals_result,
|
||||
)
|
||||
cpu_scores = evals_result["train"][metric_name][-1]
|
||||
|
||||
@classmethod
|
||||
def teardown_class(cls):
|
||||
"""
|
||||
Cleanup test artifacts from download and unpacking
|
||||
:return:
|
||||
"""
|
||||
os.remove(os.path.join(cls.dpath, "MQ2008.zip"))
|
||||
shutil.rmtree(os.path.join(cls.dpath, "MQ2008"))
|
||||
info = (rank_objective, metric_name)
|
||||
assert np.allclose(gpu_scores, cpu_scores, tolerance, tolerance), info
|
||||
assert np.allclose(bst.best_score, bstc.best_score, tolerance, tolerance), info
|
||||
|
||||
@classmethod
|
||||
def __test_training_with_rank_objective(cls, rank_objective, metric_name, tolerance=1e-02):
|
||||
"""
|
||||
Internal method that trains the dataset using the rank objective on GPU and CPU, evaluates
|
||||
the metric and determines if the delta between the metric is within the tolerance level
|
||||
:return:
|
||||
"""
|
||||
# specify validations set to watch performance
|
||||
watchlist = [(cls.dtest, 'eval'), (cls.dtrain, 'train')]
|
||||
evals_result_weighted: Dict[str, Dict] = {}
|
||||
dtest.set_weight(np.ones((dtest.get_group().size,)))
|
||||
dtrain.set_weight(np.ones((dtrain.get_group().size,)))
|
||||
watchlist = [(dtest, "eval"), (dtrain, "train")]
|
||||
bst_w = xgboost.train(
|
||||
params,
|
||||
dtrain,
|
||||
num_boost_round=num_trees,
|
||||
early_stopping_rounds=check_metric_improvement_rounds,
|
||||
evals=watchlist,
|
||||
evals_result=evals_result_weighted,
|
||||
)
|
||||
weighted_metric = evals_result_weighted["train"][metric_name][-1]
|
||||
|
||||
num_trees = 100
|
||||
check_metric_improvement_rounds = 10
|
||||
tolerance = 1e-5
|
||||
assert np.allclose(bst_w.best_score, bst.best_score, tolerance, tolerance)
|
||||
assert np.allclose(weighted_metric, gpu_scores, tolerance, tolerance)
|
||||
|
||||
evals_result = {}
|
||||
cls.params['objective'] = rank_objective
|
||||
cls.params['eval_metric'] = metric_name
|
||||
bst = xgboost.train(
|
||||
cls.params, cls.dtrain, num_boost_round=num_trees,
|
||||
early_stopping_rounds=check_metric_improvement_rounds,
|
||||
evals=watchlist, evals_result=evals_result)
|
||||
gpu_map_metric = evals_result['train'][metric_name][-1]
|
||||
|
||||
evals_result = {}
|
||||
cls.cpu_params['objective'] = rank_objective
|
||||
cls.cpu_params['eval_metric'] = metric_name
|
||||
bstc = xgboost.train(
|
||||
cls.cpu_params, cls.dtrain, num_boost_round=num_trees,
|
||||
early_stopping_rounds=check_metric_improvement_rounds,
|
||||
evals=watchlist, evals_result=evals_result)
|
||||
cpu_map_metric = evals_result['train'][metric_name][-1]
|
||||
@pytest.mark.parametrize(
|
||||
"objective,metric",
|
||||
[
|
||||
("rank:pairwise", "auc"),
|
||||
("rank:pairwise", "ndcg"),
|
||||
("rank:pairwise", "map"),
|
||||
("rank:ndcg", "auc"),
|
||||
("rank:ndcg", "ndcg"),
|
||||
("rank:ndcg", "map"),
|
||||
("rank:map", "auc"),
|
||||
("rank:map", "ndcg"),
|
||||
("rank:map", "map"),
|
||||
],
|
||||
)
|
||||
def test_with_mq2008(objective, metric) -> None:
|
||||
(
|
||||
x_train,
|
||||
y_train,
|
||||
qid_train,
|
||||
x_test,
|
||||
y_test,
|
||||
qid_test,
|
||||
x_valid,
|
||||
y_valid,
|
||||
qid_valid,
|
||||
) = tm.get_mq2008(os.path.join(os.path.join(tm.demo_dir(__file__), "rank")))
|
||||
|
||||
assert np.allclose(gpu_map_metric, cpu_map_metric, tolerance,
|
||||
tolerance)
|
||||
assert np.allclose(bst.best_score, bstc.best_score, tolerance,
|
||||
tolerance)
|
||||
if metric.find("map") != -1 or objective.find("map") != -1:
|
||||
y_train[y_train <= 1] = 0.0
|
||||
y_train[y_train > 1] = 1.0
|
||||
y_test[y_test <= 1] = 0.0
|
||||
y_test[y_test > 1] = 1.0
|
||||
|
||||
evals_result_weighted = {}
|
||||
watchlist = [(cls.dtest_w, 'eval'), (cls.dtrain_w, 'train')]
|
||||
bst_w = xgboost.train(
|
||||
cls.params, cls.dtrain_w, num_boost_round=num_trees,
|
||||
early_stopping_rounds=check_metric_improvement_rounds,
|
||||
evals=watchlist, evals_result=evals_result_weighted)
|
||||
weighted_metric = evals_result_weighted['train'][metric_name][-1]
|
||||
# GPU Ranking is not deterministic due to `AtomicAddGpair`,
|
||||
# remove tolerance once the issue is resolved.
|
||||
# https://github.com/dmlc/xgboost/issues/5561
|
||||
assert np.allclose(bst_w.best_score, bst.best_score,
|
||||
tolerance, tolerance)
|
||||
assert np.allclose(weighted_metric, gpu_map_metric,
|
||||
tolerance, tolerance)
|
||||
dtrain = xgboost.DMatrix(x_train, y_train, qid=qid_train)
|
||||
dtest = xgboost.DMatrix(x_test, y_test, qid=qid_test)
|
||||
|
||||
def test_training_rank_pairwise_map_metric(self):
|
||||
"""
|
||||
Train an XGBoost ranking model with pairwise objective function and compare map metric
|
||||
"""
|
||||
self.__test_training_with_rank_objective('rank:pairwise', 'map')
|
||||
|
||||
def test_training_rank_pairwise_auc_metric(self):
|
||||
"""
|
||||
Train an XGBoost ranking model with pairwise objective function and compare auc metric
|
||||
"""
|
||||
self.__test_training_with_rank_objective('rank:pairwise', 'auc')
|
||||
|
||||
def test_training_rank_pairwise_ndcg_metric(self):
|
||||
"""
|
||||
Train an XGBoost ranking model with pairwise objective function and compare ndcg metric
|
||||
"""
|
||||
self.__test_training_with_rank_objective('rank:pairwise', 'ndcg')
|
||||
|
||||
def test_training_rank_ndcg_map(self):
|
||||
"""
|
||||
Train an XGBoost ranking model with ndcg objective function and compare map metric
|
||||
"""
|
||||
self.__test_training_with_rank_objective('rank:ndcg', 'map')
|
||||
|
||||
def test_training_rank_ndcg_auc(self):
|
||||
"""
|
||||
Train an XGBoost ranking model with ndcg objective function and compare auc metric
|
||||
"""
|
||||
self.__test_training_with_rank_objective('rank:ndcg', 'auc')
|
||||
|
||||
def test_training_rank_ndcg_ndcg(self):
|
||||
"""
|
||||
Train an XGBoost ranking model with ndcg objective function and compare ndcg metric
|
||||
"""
|
||||
self.__test_training_with_rank_objective('rank:ndcg', 'ndcg')
|
||||
|
||||
def test_training_rank_map_map(self):
|
||||
"""
|
||||
Train an XGBoost ranking model with map objective function and compare map metric
|
||||
"""
|
||||
self.__test_training_with_rank_objective('rank:map', 'map')
|
||||
|
||||
def test_training_rank_map_auc(self):
|
||||
"""
|
||||
Train an XGBoost ranking model with map objective function and compare auc metric
|
||||
"""
|
||||
self.__test_training_with_rank_objective('rank:map', 'auc')
|
||||
|
||||
def test_training_rank_map_ndcg(self):
|
||||
"""
|
||||
Train an XGBoost ranking model with map objective function and compare ndcg metric
|
||||
"""
|
||||
self.__test_training_with_rank_objective('rank:map', 'ndcg')
|
||||
comp_training_with_rank_objective(dtrain, dtest, objective, metric)
|
||||
|
||||
@ -128,12 +128,23 @@ def test_ranking():
|
||||
|
||||
x_test = np.random.rand(100, 10)
|
||||
|
||||
params = {'tree_method': 'exact', 'objective': 'rank:pairwise',
|
||||
'learning_rate': 0.1, 'gamma': 1.0, 'min_child_weight': 0.1,
|
||||
'max_depth': 6, 'n_estimators': 4}
|
||||
params = {
|
||||
"tree_method": "exact",
|
||||
"learning_rate": 0.1,
|
||||
"gamma": 1.0,
|
||||
"min_child_weight": 0.1,
|
||||
"max_depth": 6,
|
||||
"eval_metric": "ndcg",
|
||||
"n_estimators": 4,
|
||||
}
|
||||
model = xgb.sklearn.XGBRanker(**params)
|
||||
model.fit(x_train, y_train, group=train_group,
|
||||
eval_set=[(x_valid, y_valid)], eval_group=[valid_group])
|
||||
model.fit(
|
||||
x_train,
|
||||
y_train,
|
||||
group=train_group,
|
||||
eval_set=[(x_valid, y_valid)],
|
||||
eval_group=[valid_group],
|
||||
)
|
||||
assert model.evals_result()
|
||||
|
||||
pred = model.predict(x_test)
|
||||
@ -145,11 +156,18 @@ def test_ranking():
|
||||
assert train_data.get_label().shape[0] == x_train.shape[0]
|
||||
valid_data.set_group(valid_group)
|
||||
|
||||
params_orig = {'tree_method': 'exact', 'objective': 'rank:pairwise',
|
||||
'eta': 0.1, 'gamma': 1.0,
|
||||
'min_child_weight': 0.1, 'max_depth': 6}
|
||||
xgb_model_orig = xgb.train(params_orig, train_data, num_boost_round=4,
|
||||
evals=[(valid_data, 'validation')])
|
||||
params_orig = {
|
||||
"tree_method": "exact",
|
||||
"objective": "rank:pairwise",
|
||||
"eta": 0.1,
|
||||
"gamma": 1.0,
|
||||
"min_child_weight": 0.1,
|
||||
"max_depth": 6,
|
||||
"eval_metric": "ndcg",
|
||||
}
|
||||
xgb_model_orig = xgb.train(
|
||||
params_orig, train_data, num_boost_round=4, evals=[(valid_data, "validation")]
|
||||
)
|
||||
pred_orig = xgb_model_orig.predict(test_data)
|
||||
|
||||
np.testing.assert_almost_equal(pred, pred_orig)
|
||||
@ -165,7 +183,11 @@ def test_ranking_metric() -> None:
|
||||
# sklearn compares the number of mis-classified docs, while the one in xgboost
|
||||
# compares the number of mis-classified pairs.
|
||||
ltr = xgb.XGBRanker(
|
||||
eval_metric=roc_auc_score, n_estimators=10, tree_method="hist", max_depth=2
|
||||
eval_metric=roc_auc_score,
|
||||
n_estimators=10,
|
||||
tree_method="hist",
|
||||
max_depth=2,
|
||||
objective="rank:pairwise",
|
||||
)
|
||||
ltr.fit(
|
||||
X,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user