Reimplement the NDCG metric. (#8906)

- Add support for non-exp gain.
- Cache the DMatrix object to avoid re-calculating the IDCG.
- Make GPU implementation deterministic. (no atomic add)
This commit is contained in:
Jiaming Yuan 2023-03-15 03:26:17 +08:00 committed by GitHub
parent 8685556af2
commit 72e8331eab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 363 additions and 160 deletions

View File

@ -161,6 +161,26 @@ class DMatrixCache {
} }
return container_.at(key).value; return container_.at(key).value;
} }
/**
* \brief Re-initialize the item in cache.
*
* Since the shared_ptr is used to hold the item, any reference that lives outside of
* the cache can no-longer be reached from the cache.
*
* We use reset instead of erase to avoid walking through the whole cache for renewing
* a single item. (the cache is FIFO, needs to maintain the order).
*/
template <typename... Args>
std::shared_ptr<CacheT> ResetItem(std::shared_ptr<DMatrix> m, Args const&... args) {
std::lock_guard<std::mutex> guard{lock_};
CheckConsistent();
auto key = Key{m.get(), std::this_thread::get_id()};
auto it = container_.find(key);
CHECK(it != container_.cend());
it->second = {m, std::make_shared<CacheT>(args...)};
CheckConsistent();
return it->second.value;
}
/** /**
* \brief Get a const reference to the underlying hash map. Clear expired caches before * \brief Get a const reference to the underlying hash map. Clear expired caches before
* returning. * returning.

View File

@ -20,23 +20,51 @@
// corresponding headers that brings in those function declaration can't be included with CUDA). // corresponding headers that brings in those function declaration can't be included with CUDA).
// This precludes the CPU and GPU logic to coexist inside a .cu file // This precludes the CPU and GPU logic to coexist inside a .cu file
#include "rank_metric.h"
#include <dmlc/omp.h>
#include <dmlc/registry.h> #include <dmlc/registry.h>
#include <xgboost/metric.h>
#include <cmath> #include <algorithm> // for stable_sort, copy, fill_n, min, max
#include <vector> #include <array> // for array
#include <cmath> // for log, sqrt
#include <cstddef> // for size_t, std
#include <cstdint> // for uint32_t
#include <functional> // for less, greater
#include <map> // for operator!=, _Rb_tree_const_iterator
#include <memory> // for allocator, unique_ptr, shared_ptr, __shared_...
#include <numeric> // for accumulate
#include <ostream> // for operator<<, basic_ostream, ostringstream
#include <string> // for char_traits, operator<, basic_string, to_string
#include <utility> // for pair, make_pair
#include <vector> // for vector
#include "../collective/communicator-inl.h" #include "../collective/communicator-inl.h" // for IsDistributed, Allreduce
#include "../common/algorithm.h" // Sort #include "../collective/communicator.h" // for Operation
#include "../common/math.h" #include "../common/algorithm.h" // for ArgSort, Sort
#include "../common/ranking_utils.h" // MakeMetricName #include "../common/linalg_op.h" // for cbegin, cend
#include "../common/threading_utils.h" #include "../common/math.h" // for CmpFirst
#include "metric_common.h" #include "../common/optional_weight.h" // for OptionalWeights, MakeOptionalWeights
#include "xgboost/host_device_vector.h" #include "../common/ranking_utils.h" // for LambdaRankParam, NDCGCache, ParseMetricName
#include "../common/threading_utils.h" // for ParallelFor
#include "../common/transform_iterator.h" // for IndexTransformIter
#include "dmlc/common.h" // for OMPException
#include "metric_common.h" // for MetricNoCache, GPUMetric, PackedReduceResult
#include "xgboost/base.h" // for bst_float, bst_omp_uint, bst_group_t, Args
#include "xgboost/cache.h" // for DMatrixCache
#include "xgboost/context.h" // for Context
#include "xgboost/data.h" // for MetaInfo, DMatrix
#include "xgboost/host_device_vector.h" // for HostDeviceVector
#include "xgboost/json.h" // for Json, FromJson, IsA, ToJson, get, Null, Object
#include "xgboost/linalg.h" // for Tensor, TensorView, Range, VectorView, MakeT...
#include "xgboost/logging.h" // for CHECK, ConsoleLogger, LOG_INFO, CHECK_EQ
#include "xgboost/metric.h" // for MetricReg, XGBOOST_REGISTER_METRIC, Metric
#include "xgboost/span.h" // for Span, operator!=
#include "xgboost/string_view.h" // for StringView
namespace { namespace {
using PredIndPair = std::pair<xgboost::bst_float, uint32_t>; using PredIndPair = std::pair<xgboost::bst_float, xgboost::ltr::rel_degree_t>;
using PredIndPairContainer = std::vector<PredIndPair>; using PredIndPairContainer = std::vector<PredIndPair>;
/* /*
@ -87,8 +115,7 @@ class PerGroupWeightPolicy {
} // anonymous namespace } // anonymous namespace
namespace xgboost { namespace xgboost::metric {
namespace metric {
// tag the this file, used by force static link later. // tag the this file, used by force static link later.
DMLC_REGISTRY_FILE_TAG(rank_metric); DMLC_REGISTRY_FILE_TAG(rank_metric);
@ -257,40 +284,6 @@ struct EvalPrecision : public EvalRank {
} }
}; };
/*! \brief NDCG: Normalized Discounted Cumulative Gain at N */
struct EvalNDCG : public EvalRank {
private:
double CalcDCG(const PredIndPairContainer &rec) const {
double sumdcg = 0.0;
for (size_t i = 0; i < rec.size() && i < this->topn; ++i) {
const unsigned rel = rec[i].second;
if (rel != 0) {
sumdcg += ((1 << rel) - 1) / std::log2(i + 2.0);
}
}
return sumdcg;
}
public:
explicit EvalNDCG(const char* name, const char* param) : EvalRank(name, param) {}
double EvalGroup(PredIndPairContainer *recptr) const override {
PredIndPairContainer &rec(*recptr);
std::stable_sort(rec.begin(), rec.end(), common::CmpFirst);
double dcg = CalcDCG(rec);
std::stable_sort(rec.begin(), rec.end(), common::CmpSecond);
double idcg = CalcDCG(rec);
if (idcg == 0.0f) {
if (this->minus) {
return 0.0f;
} else {
return 1.0f;
}
}
return dcg/idcg;
}
};
/*! \brief Mean Average Precision at N, for both classification and rank */ /*! \brief Mean Average Precision at N, for both classification and rank */
struct EvalMAP : public EvalRank { struct EvalMAP : public EvalRank {
public: public:
@ -377,10 +370,6 @@ XGBOOST_REGISTER_METRIC(Precision, "pre")
.describe("precision@k for rank.") .describe("precision@k for rank.")
.set_body([](const char* param) { return new EvalPrecision("pre", param); }); .set_body([](const char* param) { return new EvalPrecision("pre", param); });
XGBOOST_REGISTER_METRIC(NDCG, "ndcg")
.describe("ndcg@k for rank.")
.set_body([](const char* param) { return new EvalNDCG("ndcg", param); });
XGBOOST_REGISTER_METRIC(MAP, "map") XGBOOST_REGISTER_METRIC(MAP, "map")
.describe("map@k for rank.") .describe("map@k for rank.")
.set_body([](const char* param) { return new EvalMAP("map", param); }); .set_body([](const char* param) { return new EvalMAP("map", param); });
@ -388,5 +377,148 @@ XGBOOST_REGISTER_METRIC(MAP, "map")
XGBOOST_REGISTER_METRIC(Cox, "cox-nloglik") XGBOOST_REGISTER_METRIC(Cox, "cox-nloglik")
.describe("Negative log partial likelihood of Cox proportional hazards model.") .describe("Negative log partial likelihood of Cox proportional hazards model.")
.set_body([](const char*) { return new EvalCox(); }); .set_body([](const char*) { return new EvalCox(); });
} // namespace metric
} // namespace xgboost // ranking metrics that requires cache
template <typename Cache>
class EvalRankWithCache : public Metric {
protected:
ltr::LambdaRankParam param_;
bool minus_{false};
std::string name_;
DMatrixCache<Cache> cache_{DMatrixCache<Cache>::DefaultSize()};
public:
EvalRankWithCache(StringView name, const char* param) {
auto constexpr kMax = ltr::LambdaRankParam::NotSet();
std::uint32_t topn{kMax};
this->name_ = ltr::ParseMetricName(name, param, &topn, &minus_);
if (topn != kMax) {
param_.UpdateAllowUnknown(Args{{"lambdarank_num_pair_per_sample", std::to_string(topn)},
{"lambdarank_pair_method", "topk"}});
}
param_.UpdateAllowUnknown(Args{});
}
void Configure(Args const&) override {
// do not configure, otherwise the ndcg param will be forced into the same as the one in
// objective.
}
void LoadConfig(Json const& in) override {
if (IsA<Null>(in)) {
return;
}
auto const& obj = get<Object const>(in);
auto it = obj.find("lambdarank_param");
if (it != obj.cend()) {
FromJson(it->second, &param_);
}
}
void SaveConfig(Json* p_out) const override {
auto& out = *p_out;
out["name"] = String{this->Name()};
out["lambdarank_param"] = ToJson(param_);
}
double Evaluate(HostDeviceVector<float> const& preds, std::shared_ptr<DMatrix> p_fmat) override {
auto const& info = p_fmat->Info();
auto p_cache = cache_.CacheItem(p_fmat, ctx_, info, param_);
if (p_cache->Param() != param_) {
p_cache = cache_.ResetItem(p_fmat, ctx_, info, param_);
}
CHECK(p_cache->Param() == param_);
CHECK_EQ(preds.Size(), info.labels.Size());
return this->Eval(preds, info, p_cache);
}
virtual double Eval(HostDeviceVector<float> const& preds, MetaInfo const& info,
std::shared_ptr<Cache> p_cache) = 0;
};
namespace {
double Finalize(double score, double sw) {
std::array<double, 2> dat{score, sw};
collective::Allreduce<collective::Operation::kSum>(dat.data(), dat.size());
if (sw > 0.0) {
score = score / sw;
}
CHECK_LE(score, 1.0 + kRtEps)
<< "Invalid output score, might be caused by invalid query group weight.";
score = std::min(1.0, score);
return score;
}
} // namespace
/**
* \brief Implement the NDCG score function for learning to rank.
*
* Ties are ignored, which can lead to different result with other implementations.
*/
class EvalNDCG : public EvalRankWithCache<ltr::NDCGCache> {
public:
using EvalRankWithCache::EvalRankWithCache;
const char* Name() const override { return name_.c_str(); }
double Eval(HostDeviceVector<float> const& preds, MetaInfo const& info,
std::shared_ptr<ltr::NDCGCache> p_cache) override {
if (ctx_->IsCUDA()) {
auto ndcg = cuda_impl::NDCGScore(ctx_, info, preds, minus_, p_cache);
return Finalize(ndcg.Residue(), ndcg.Weights());
}
// group local ndcg
auto group_ptr = p_cache->DataGroupPtr(ctx_);
bst_group_t n_groups = group_ptr.size() - 1;
auto ndcg_gloc = p_cache->Dcg(ctx_);
std::fill_n(ndcg_gloc.Values().data(), ndcg_gloc.Size(), 0.0);
auto h_inv_idcg = p_cache->InvIDCG(ctx_);
auto p_discount = p_cache->Discount(ctx_).data();
auto h_label = info.labels.HostView();
auto h_predt = linalg::MakeTensorView(ctx_, &preds, preds.Size());
auto weights = common::MakeOptionalWeights(ctx_, info.weights_);
common::ParallelFor(n_groups, ctx_->Threads(), [&](auto g) {
auto g_predt = h_predt.Slice(linalg::Range(group_ptr[g], group_ptr[g + 1]));
auto g_labels = h_label.Slice(linalg::Range(group_ptr[g], group_ptr[g + 1]), 0);
auto sorted_idx = common::ArgSort<std::size_t>(ctx_, linalg::cbegin(g_predt),
linalg::cend(g_predt), std::greater<>{});
double ndcg{.0};
double inv_idcg = h_inv_idcg(g);
if (inv_idcg <= 0.0) {
ndcg_gloc(g) = minus_ ? 0.0 : 1.0;
return;
}
std::size_t n{std::min(sorted_idx.size(), static_cast<std::size_t>(param_.TopK()))};
if (param_.ndcg_exp_gain) {
for (std::size_t i = 0; i < n; ++i) {
ndcg += p_discount[i] * ltr::CalcDCGGain(g_labels(sorted_idx[i])) * inv_idcg;
}
} else {
for (std::size_t i = 0; i < n; ++i) {
ndcg += p_discount[i] * g_labels(sorted_idx[i]) * inv_idcg;
}
}
ndcg_gloc(g) += ndcg * weights[g];
});
double sum_w{0};
if (weights.Empty()) {
sum_w = n_groups;
} else {
sum_w = std::accumulate(weights.weights.cbegin(), weights.weights.cend(), 0.0);
}
auto ndcg = std::accumulate(linalg::cbegin(ndcg_gloc), linalg::cend(ndcg_gloc), 0.0);
return Finalize(ndcg, sum_w);
}
};
XGBOOST_REGISTER_METRIC(EvalNDCG, "ndcg")
.describe("ndcg@k for ranking.")
.set_body([](char const* param) {
return new EvalNDCG{"ndcg", param};
});
} // namespace xgboost::metric

View File

@ -2,22 +2,29 @@
* Copyright 2020-2023 by XGBoost Contributors * Copyright 2020-2023 by XGBoost Contributors
*/ */
#include <dmlc/registry.h> #include <dmlc/registry.h>
#include <thrust/iterator/counting_iterator.h> // make_counting_iterator #include <thrust/iterator/counting_iterator.h> // for make_counting_iterator
#include <thrust/reduce.h> // reduce #include <thrust/reduce.h> // for reduce
#include <xgboost/metric.h>
#include <cstddef> // std::size_t #include <algorithm> // for transform
#include <memory> // std::shared_ptr #include <cstddef> // for size_t
#include <memory> // for shared_ptr
#include <vector> // for vector
#include "../common/cuda_context.cuh" // CUDAContext #include "../common/cuda_context.cuh" // for CUDAContext
#include "../common/device_helpers.cuh" // for MakeTransformIterator
#include "../common/optional_weight.h" // for MakeOptionalWeights
#include "../common/ranking_utils.cuh" // for CalcQueriesDCG, NDCGCache
#include "metric_common.h" #include "metric_common.h"
#include "xgboost/base.h" // XGBOOST_DEVICE #include "rank_metric.h"
#include "xgboost/context.h" // Context #include "xgboost/base.h" // for XGBOOST_DEVICE
#include "xgboost/data.h" // MetaInfo #include "xgboost/context.h" // for Context
#include "xgboost/host_device_vector.h" // HostDeviceVector #include "xgboost/data.h" // for MetaInfo
#include "xgboost/host_device_vector.h" // for HostDeviceVector
#include "xgboost/linalg.h" // for MakeTensorView
#include "xgboost/logging.h" // for CHECK
#include "xgboost/metric.h"
namespace xgboost { namespace xgboost::metric {
namespace metric {
// tag the this file, used by force static link later. // tag the this file, used by force static link later.
DMLC_REGISTRY_FILE_TAG(rank_metric_gpu); DMLC_REGISTRY_FILE_TAG(rank_metric_gpu);
@ -117,81 +124,6 @@ struct EvalPrecisionGpu {
} }
}; };
/*! \brief NDCG: Normalized Discounted Cumulative Gain at N */
struct EvalNDCGGpu {
public:
static void ComputeDCG(const dh::SegmentSorter<float> &pred_sorter,
const float *dlabels,
const EvalRankConfig &ecfg,
// The order in which labels have to be accessed. The order is determined
// by sorting the predictions or the labels for the entire dataset
const xgboost::common::Span<const uint32_t> &dlabels_sort_order,
dh::caching_device_vector<double> *dcgptr) {
dh::caching_device_vector<double> &dcgs(*dcgptr);
// Group info on device
const auto &dgroups = pred_sorter.GetGroupsSpan();
const auto &dgroup_idx = pred_sorter.GetGroupSegmentsSpan();
// First, determine non zero labels in the dataset individually
auto DetermineNonTrivialLabelLambda = [=] __device__(uint32_t idx) {
return (static_cast<unsigned>(dlabels[dlabels_sort_order[idx]]));
}; // NOLINT
// Find each group's DCG value
const auto nitems = pred_sorter.GetNumItems();
auto *ddcgs = dcgs.data().get();
int device_id = -1;
dh::safe_cuda(cudaGetDevice(&device_id));
// For each group item compute the aggregated precision
dh::LaunchN(nitems, nullptr, [=] __device__(uint32_t idx) {
const auto group_idx = dgroup_idx[idx];
const auto group_begin = dgroups[group_idx];
const auto ridx = idx - group_begin;
auto label = DetermineNonTrivialLabelLambda(idx);
if (ridx < ecfg.topn && label) {
atomicAdd(&ddcgs[group_idx], ((1 << label) - 1) / std::log2(ridx + 2.0));
}
});
}
static double EvalMetric(const dh::SegmentSorter<float> &pred_sorter,
const float *dlabels,
const EvalRankConfig &ecfg) {
// Sort the labels and compute IDCG
dh::SegmentSorter<float> segment_label_sorter;
segment_label_sorter.SortItems(dlabels, pred_sorter.GetNumItems(),
pred_sorter.GetGroupSegmentsSpan());
uint32_t ngroups = pred_sorter.GetNumGroups();
dh::caching_device_vector<double> idcg(ngroups, 0);
ComputeDCG(pred_sorter, dlabels, ecfg, segment_label_sorter.GetOriginalPositionsSpan(), &idcg);
// Compute the DCG values next
dh::caching_device_vector<double> dcg(ngroups, 0);
ComputeDCG(pred_sorter, dlabels, ecfg, pred_sorter.GetOriginalPositionsSpan(), &dcg);
double *ddcg = dcg.data().get();
double *didcg = idcg.data().get();
int device_id = -1;
dh::safe_cuda(cudaGetDevice(&device_id));
// Compute the group's DCG and reduce it across all groups
dh::LaunchN(ngroups, nullptr, [=] __device__(uint32_t gidx) {
if (didcg[gidx] == 0.0f) {
ddcg[gidx] = (ecfg.minus) ? 0.0f : 1.0f;
} else {
ddcg[gidx] /= didcg[gidx];
}
});
// Allocator to be used for managing space overhead while performing reductions
dh::XGBCachingDeviceAllocator<char> alloc;
return thrust::reduce(thrust::cuda::par(alloc), dcg.begin(), dcg.end());
}
};
/*! \brief Mean Average Precision at N, for both classification and rank */ /*! \brief Mean Average Precision at N, for both classification and rank */
struct EvalMAPGpu { struct EvalMAPGpu {
@ -272,12 +204,46 @@ XGBOOST_REGISTER_GPU_METRIC(PrecisionGpu, "pre")
.describe("precision@k for rank computed on GPU.") .describe("precision@k for rank computed on GPU.")
.set_body([](const char* param) { return new EvalRankGpu<EvalPrecisionGpu>("pre", param); }); .set_body([](const char* param) { return new EvalRankGpu<EvalPrecisionGpu>("pre", param); });
XGBOOST_REGISTER_GPU_METRIC(NDCGGpu, "ndcg")
.describe("ndcg@k for rank computed on GPU.")
.set_body([](const char* param) { return new EvalRankGpu<EvalNDCGGpu>("ndcg", param); });
XGBOOST_REGISTER_GPU_METRIC(MAPGpu, "map") XGBOOST_REGISTER_GPU_METRIC(MAPGpu, "map")
.describe("map@k for rank computed on GPU.") .describe("map@k for rank computed on GPU.")
.set_body([](const char* param) { return new EvalRankGpu<EvalMAPGpu>("map", param); }); .set_body([](const char* param) { return new EvalRankGpu<EvalMAPGpu>("map", param); });
} // namespace metric
} // namespace xgboost namespace cuda_impl {
PackedReduceResult NDCGScore(Context const *ctx, MetaInfo const &info,
HostDeviceVector<float> const &predt, bool minus,
std::shared_ptr<ltr::NDCGCache> p_cache) {
CHECK(p_cache);
auto const &p = p_cache->Param();
auto d_weight = common::MakeOptionalWeights(ctx, info.weights_);
if (!d_weight.Empty()) {
CHECK_EQ(d_weight.weights.size(), p_cache->Groups());
}
auto d_label = info.labels.View(ctx->gpu_id).Slice(linalg::All(), 0);
predt.SetDevice(ctx->gpu_id);
auto d_predt = linalg::MakeTensorView(ctx, predt.ConstDeviceSpan(), predt.Size());
auto d_group_ptr = p_cache->DataGroupPtr(ctx);
auto n_groups = info.group_ptr_.size() - 1;
auto d_inv_idcg = p_cache->InvIDCG(ctx);
auto d_sorted_idx = p_cache->SortedIdx(ctx, d_predt.Values());
auto d_out_dcg = p_cache->Dcg(ctx);
ltr::cuda_impl::CalcQueriesDCG(ctx, d_label, d_sorted_idx, p.ndcg_exp_gain, d_group_ptr, p.TopK(),
d_out_dcg);
auto it = dh::MakeTransformIterator<PackedReduceResult>(
thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(std::size_t i) {
if (d_inv_idcg(i) <= 0.0) {
return PackedReduceResult{minus ? 0.0 : 1.0, static_cast<double>(d_weight[i])};
}
return PackedReduceResult{d_out_dcg(i) * d_inv_idcg(i) * d_weight[i],
static_cast<double>(d_weight[i])};
});
auto pair = thrust::reduce(ctx->CUDACtx()->CTP(), it, it + d_out_dcg.Size(),
PackedReduceResult{0.0, 0.0});
return pair;
}
} // namespace cuda_impl
} // namespace xgboost::metric

33
src/metric/rank_metric.h Normal file
View File

@ -0,0 +1,33 @@
#ifndef XGBOOST_METRIC_RANK_METRIC_H_
#define XGBOOST_METRIC_RANK_METRIC_H_
/**
* Copyright 2023 by XGBoost Contributors
*/
#include <memory> // for shared_ptr
#include "../common/common.h" // for AssertGPUSupport
#include "../common/ranking_utils.h" // for NDCGCache
#include "metric_common.h" // for PackedReduceResult
#include "xgboost/context.h" // for Context
#include "xgboost/data.h" // for MetaInfo
#include "xgboost/host_device_vector.h" // for HostDeviceVector
namespace xgboost {
namespace metric {
namespace cuda_impl {
PackedReduceResult NDCGScore(Context const *ctx, MetaInfo const &info,
HostDeviceVector<float> const &predt, bool minus,
std::shared_ptr<ltr::NDCGCache> p_cache);
#if !defined(XGBOOST_USE_CUDA)
inline PackedReduceResult NDCGScore(Context const *, MetaInfo const &,
HostDeviceVector<float> const &, bool,
std::shared_ptr<ltr::NDCGCache>) {
common::AssertGPUSupport();
return {};
}
#endif
} // namespace cuda_impl
} // namespace metric
} // namespace xgboost
#endif // XGBOOST_METRIC_RANK_METRIC_H_

View File

@ -1,7 +1,20 @@
// Copyright by Contributors /**
#include <xgboost/metric.h> * Copyright 2016-2023 by XGBoost Contributors
*/
#include <gtest/gtest.h> // for Test, EXPECT_NEAR, ASSERT_STREQ
#include <xgboost/context.h> // for Context
#include <xgboost/data.h> // for MetaInfo, DMatrix
#include <xgboost/linalg.h> // for Matrix
#include <xgboost/metric.h> // for Metric
#include "../helpers.h" #include <algorithm> // for max
#include <memory> // for unique_ptr
#include <vector> // for vector
#include "../helpers.h" // for GetMetricEval, CreateEmptyGe...
#include "xgboost/base.h" // for bst_float, kRtEps
#include "xgboost/host_device_vector.h" // for HostDeviceVector
#include "xgboost/json.h" // for Json, String, Object
#if !defined(__CUDACC__) #if !defined(__CUDACC__)
TEST(Metric, AMS) { TEST(Metric, AMS) {
@ -51,15 +64,17 @@ TEST(Metric, DeclareUnifiedTest(Precision)) {
delete metric; delete metric;
} }
namespace xgboost {
namespace metric {
TEST(Metric, DeclareUnifiedTest(NDCG)) { TEST(Metric, DeclareUnifiedTest(NDCG)) {
auto ctx = xgboost::CreateEmptyGenericParam(GPUIDX); auto ctx = CreateEmptyGenericParam(GPUIDX);
xgboost::Metric * metric = xgboost::Metric::Create("ndcg", &ctx); Metric * metric = xgboost::Metric::Create("ndcg", &ctx);
ASSERT_STREQ(metric->Name(), "ndcg"); ASSERT_STREQ(metric->Name(), "ndcg");
EXPECT_ANY_THROW(GetMetricEval(metric, {0, 1}, {})); EXPECT_ANY_THROW(GetMetricEval(metric, {0, 1}, {}));
EXPECT_NEAR(GetMetricEval(metric, ASSERT_NEAR(GetMetricEval(metric,
xgboost::HostDeviceVector<xgboost::bst_float>{}, xgboost::HostDeviceVector<xgboost::bst_float>{},
{}), 1, 1e-10); {}), 1, 1e-10);
EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}), 1, 1e-10); ASSERT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}), 1, 1e-10);
EXPECT_NEAR(GetMetricEval(metric, EXPECT_NEAR(GetMetricEval(metric,
{0.1f, 0.9f, 0.1f, 0.9f}, {0.1f, 0.9f, 0.1f, 0.9f},
{ 0, 0, 1, 1}), { 0, 0, 1, 1}),
@ -80,7 +95,7 @@ TEST(Metric, DeclareUnifiedTest(NDCG)) {
EXPECT_NEAR(GetMetricEval(metric, EXPECT_NEAR(GetMetricEval(metric,
xgboost::HostDeviceVector<xgboost::bst_float>{}, xgboost::HostDeviceVector<xgboost::bst_float>{},
{}), 0, 1e-10); {}), 0, 1e-10);
EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}), 1, 1e-10); ASSERT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}), 1.f, 1e-10);
EXPECT_NEAR(GetMetricEval(metric, EXPECT_NEAR(GetMetricEval(metric,
{0.1f, 0.9f, 0.1f, 0.9f}, {0.1f, 0.9f, 0.1f, 0.9f},
{ 0, 0, 1, 1}), { 0, 0, 1, 1}),
@ -91,7 +106,7 @@ TEST(Metric, DeclareUnifiedTest(NDCG)) {
EXPECT_NEAR(GetMetricEval(metric, EXPECT_NEAR(GetMetricEval(metric,
xgboost::HostDeviceVector<xgboost::bst_float>{}, xgboost::HostDeviceVector<xgboost::bst_float>{},
{}), 0, 1e-10); {}), 0, 1e-10);
EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}), 1, 1e-10); EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}), 1.f, 1e-10);
EXPECT_NEAR(GetMetricEval(metric, EXPECT_NEAR(GetMetricEval(metric,
{0.1f, 0.9f, 0.1f, 0.9f}, {0.1f, 0.9f, 0.1f, 0.9f},
{ 0, 0, 1, 1}), { 0, 0, 1, 1}),
@ -100,20 +115,21 @@ TEST(Metric, DeclareUnifiedTest(NDCG)) {
delete metric; delete metric;
metric = xgboost::Metric::Create("ndcg@2-", &ctx); metric = xgboost::Metric::Create("ndcg@2-", &ctx);
ASSERT_STREQ(metric->Name(), "ndcg@2-"); ASSERT_STREQ(metric->Name(), "ndcg@2-");
EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}), 1, 1e-10); EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}), 1.f, 1e-10);
EXPECT_NEAR(GetMetricEval(metric, EXPECT_NEAR(GetMetricEval(metric,
{0.1f, 0.9f, 0.1f, 0.9f}, {0.1f, 0.9f, 0.1f, 0.9f},
{ 0, 0, 1, 1}), { 0, 0, 1, 1}),
0.3868f, 0.001f); 1.f - 0.3868f, 1.f - 0.001f);
delete metric; delete metric;
} }
TEST(Metric, DeclareUnifiedTest(MAP)) { TEST(Metric, DeclareUnifiedTest(MAP)) {
auto ctx = xgboost::CreateEmptyGenericParam(GPUIDX); auto ctx = xgboost::CreateEmptyGenericParam(GPUIDX);
xgboost::Metric * metric = xgboost::Metric::Create("map", &ctx); Metric * metric = xgboost::Metric::Create("map", &ctx);
ASSERT_STREQ(metric->Name(), "map"); ASSERT_STREQ(metric->Name(), "map");
EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}), 1, 1e-10); EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}), 1, kRtEps);
EXPECT_NEAR(GetMetricEval(metric, EXPECT_NEAR(GetMetricEval(metric,
{0.1f, 0.9f, 0.1f, 0.9f}, {0.1f, 0.9f, 0.1f, 0.9f},
{ 0, 0, 1, 1}), { 0, 0, 1, 1}),
@ -154,3 +170,39 @@ TEST(Metric, DeclareUnifiedTest(MAP)) {
0.25f, 0.001f); 0.25f, 0.001f);
delete metric; delete metric;
} }
TEST(Metric, DeclareUnifiedTest(NDCGExpGain)) {
Context ctx = xgboost::CreateEmptyGenericParam(GPUIDX);
auto p_fmat = xgboost::RandomDataGenerator{0, 0, 0}.GenerateDMatrix();
MetaInfo& info = p_fmat->Info();
info.labels = linalg::Matrix<float>{{10.0f, 0.0f, 0.0f, 1.0f, 5.0f}, {5}, ctx.gpu_id};
info.num_row_ = info.labels.Shape(0);
info.group_ptr_.resize(2);
info.group_ptr_[0] = 0;
info.group_ptr_[1] = info.num_row_;
HostDeviceVector<float> predt{{0.1f, 0.2f, 0.3f, 4.0f, 70.0f}};
std::unique_ptr<Metric> metric{Metric::Create("ndcg", &ctx)};
Json config{Object{}};
config["name"] = String{"ndcg"};
config["lambdarank_param"] = Object{};
config["lambdarank_param"]["ndcg_exp_gain"] = String{"true"};
config["lambdarank_param"]["lambdarank_num_pair_per_sample"] = String{"32"};
metric->LoadConfig(config);
auto ndcg = metric->Evaluate(predt, p_fmat);
ASSERT_NEAR(ndcg, 0.409738f, kRtEps);
config["lambdarank_param"]["ndcg_exp_gain"] = String{"false"};
metric->LoadConfig(config);
ndcg = metric->Evaluate(predt, p_fmat);
ASSERT_NEAR(ndcg, 0.695694f, kRtEps);
predt.HostVector() = info.labels.Data()->HostVector();
ndcg = metric->Evaluate(predt, p_fmat);
ASSERT_NEAR(ndcg, 1.0, kRtEps);
}
} // namespace metric
} // namespace xgboost