Rework the precision metric. (#9222)

- Rework the precision metric for both CPU and GPU.
- Mention it in the document.
- Cleanup old support code for GPU ranking metric.
- Deterministic GPU implementation.

* Drop support for classification.

* type.

* use batch shape.

* lint.

* cpu build.

* cpu build.

* lint.

* Tests.

* Fix.

* Cleanup error message.
This commit is contained in:
Jiaming Yuan
2023-06-02 20:49:43 +08:00
committed by GitHub
parent db8288121d
commit 9fbde21e9d
22 changed files with 312 additions and 502 deletions

View File

@@ -52,32 +52,13 @@ Metric::Create(const std::string& name, Context const* ctx) {
metric->ctx_ = ctx;
return metric;
}
GPUMetric* GPUMetric::CreateGPUMetric(const std::string& name, Context const* ctx) {
auto metric = CreateMetricImpl<MetricGPUReg>(name);
if (metric == nullptr) {
LOG(WARNING) << "Cannot find a GPU metric builder for metric " << name
<< ". Resorting to the CPU builder";
return nullptr;
}
// Narrowing reference only for the compiler to allow assignment to a base class member.
// As such, using this narrowed reference to refer to derived members will be an illegal op.
// This is moot, as this type is stateless.
auto casted = static_cast<GPUMetric*>(metric);
CHECK(casted);
casted->ctx_ = ctx;
return casted;
}
} // namespace xgboost
namespace dmlc {
DMLC_REGISTRY_ENABLE(::xgboost::MetricReg);
DMLC_REGISTRY_ENABLE(::xgboost::MetricGPUReg);
}
namespace xgboost {
namespace metric {
namespace xgboost::metric {
// List of files that will be force linked in static links.
DMLC_REGISTRY_LINK_TAG(auc);
DMLC_REGISTRY_LINK_TAG(elementwise_metric);
@@ -88,5 +69,4 @@ DMLC_REGISTRY_LINK_TAG(rank_metric);
DMLC_REGISTRY_LINK_TAG(auc_gpu);
DMLC_REGISTRY_LINK_TAG(rank_metric_gpu);
#endif
} // namespace metric
} // namespace xgboost
} // namespace xgboost::metric

View File

@@ -23,53 +23,14 @@ class MetricNoCache : public Metric {
double Evaluate(HostDeviceVector<float> const &predts, std::shared_ptr<DMatrix> p_fmat) final {
double result{0.0};
auto const& info = p_fmat->Info();
collective::ApplyWithLabels(info, &result, sizeof(double), [&] {
result = this->Eval(predts, info);
});
auto const &info = p_fmat->Info();
collective::ApplyWithLabels(info, &result, sizeof(double),
[&] { result = this->Eval(predts, info); });
return result;
}
};
// This creates a GPU metric instance dynamically and adds it to the GPU metric registry, if not
// present already. This is created when there is a device ordinal present and if xgboost
// is compiled with CUDA support
struct GPUMetric : public MetricNoCache {
static GPUMetric *CreateGPUMetric(const std::string &name, Context const *tparam);
};
/*!
* \brief Internal registry entries for GPU Metric factory functions.
* The additional parameter const char* param gives the value after @, can be null.
* For example, metric map@3, then: param == "3".
*/
struct MetricGPUReg
: public dmlc::FunctionRegEntryBase<MetricGPUReg,
std::function<Metric * (const char*)> > {
};
/*!
* \brief Macro to register metric computed on GPU.
*
* \code
* // example of registering a objective ndcg@k
* XGBOOST_REGISTER_GPU_METRIC(NDCG_GPU, "ndcg")
* .describe("NDCG metric computer on GPU.")
* .set_body([](const char* param) {
* int at_k = atoi(param);
* return new NDCG(at_k);
* });
* \endcode
*/
// Note: Metric names registered in the GPU registry should follow this convention:
// - GPU metric types should be registered with the same name as the non GPU metric types
#define XGBOOST_REGISTER_GPU_METRIC(UniqueId, Name) \
::xgboost::MetricGPUReg& __make_ ## MetricGPUReg ## _ ## UniqueId ## __ = \
::dmlc::Registry< ::xgboost::MetricGPUReg>::Get()->__REGISTER__(Name)
namespace metric {
// Ranking config to be used on device and host
struct EvalRankConfig {
public:
@@ -81,8 +42,8 @@ struct EvalRankConfig {
};
class PackedReduceResult {
double residue_sum_ { 0 };
double weights_sum_ { 0 };
double residue_sum_{0};
double weights_sum_{0};
public:
XGBOOST_DEVICE PackedReduceResult() {} // NOLINT
@@ -91,16 +52,15 @@ class PackedReduceResult {
XGBOOST_DEVICE
PackedReduceResult operator+(PackedReduceResult const &other) const {
return PackedReduceResult{residue_sum_ + other.residue_sum_,
weights_sum_ + other.weights_sum_};
return PackedReduceResult{residue_sum_ + other.residue_sum_, weights_sum_ + other.weights_sum_};
}
PackedReduceResult &operator+=(PackedReduceResult const &other) {
this->residue_sum_ += other.residue_sum_;
this->weights_sum_ += other.weights_sum_;
return *this;
}
double Residue() const { return residue_sum_; }
double Weights() const { return weights_sum_; }
[[nodiscard]] double Residue() const { return residue_sum_; }
[[nodiscard]] double Weights() const { return weights_sum_; }
};
} // namespace metric

View File

@@ -1,25 +1,6 @@
/**
* Copyright 2020-2023 by XGBoost contributors
*/
// When device ordinal is present, we would want to build the metrics on the GPU. It is *not*
// possible for a valid device ordinal to be present for non GPU builds. However, it is possible
// for an invalid device ordinal to be specified in GPU builds - to train/predict and/or compute
// the metrics on CPU. To accommodate these scenarios, the following is done for the metrics
// accelerated on the GPU.
// - An internal GPU registry holds all the GPU metric types (defined in the .cu file)
// - An instance of the appropriate GPU metric type is created when a device ordinal is present
// - If the creation is successful, the metric computation is done on the device
// - else, it falls back on the CPU
// - The GPU metric types are *only* registered when xgboost is built for GPUs
//
// This is done for 2 reasons:
// - Clear separation of CPU and GPU logic
// - Sorting datasets containing large number of rows is (much) faster when parallel sort
// semantics is used on the CPU. The __gnu_parallel/concurrency primitives needed to perform
// this cannot be used when the translation unit is compiled using the 'nvcc' compiler (as the
// corresponding headers that brings in those function declaration can't be included with CUDA).
// This precludes the CPU and GPU logic to coexist inside a .cu file
#include "rank_metric.h"
#include <dmlc/omp.h>
@@ -57,55 +38,8 @@
#include "xgboost/string_view.h" // for StringView
namespace {
using PredIndPair = std::pair<xgboost::bst_float, xgboost::ltr::rel_degree_t>;
using PredIndPairContainer = std::vector<PredIndPair>;
/*
* Adapter to access instance weights.
*
* - For ranking task, weights are per-group
* - For binary classification task, weights are per-instance
*
* WeightPolicy::GetWeightOfInstance() :
* get weight associated with an individual instance, using index into
* `info.weights`
* WeightPolicy::GetWeightOfSortedRecord() :
* get weight associated with an individual instance, using index into
* sorted records `rec` (in ascending order of predicted labels). `rec` is
* of type PredIndPairContainer
*/
class PerInstanceWeightPolicy {
public:
inline static xgboost::bst_float
GetWeightOfInstance(const xgboost::MetaInfo& info,
unsigned instance_id, unsigned) {
return info.GetWeight(instance_id);
}
inline static xgboost::bst_float
GetWeightOfSortedRecord(const xgboost::MetaInfo& info,
const PredIndPairContainer& rec,
unsigned record_id, unsigned) {
return info.GetWeight(rec[record_id].second);
}
};
class PerGroupWeightPolicy {
public:
inline static xgboost::bst_float
GetWeightOfInstance(const xgboost::MetaInfo& info,
unsigned, unsigned group_id) {
return info.GetWeight(group_id);
}
inline static xgboost::bst_float
GetWeightOfSortedRecord(const xgboost::MetaInfo& info,
const PredIndPairContainer&,
unsigned, unsigned group_id) {
return info.GetWeight(group_id);
}
};
} // anonymous namespace
namespace xgboost::metric {
@@ -177,10 +111,6 @@ struct EvalAMS : public MetricNoCache {
/*! \brief Evaluate rank list */
struct EvalRank : public MetricNoCache, public EvalRankConfig {
private:
// This is used to compute the ranking metrics on the GPU - for training jobs that run on the GPU.
std::unique_ptr<MetricNoCache> rank_gpu_;
public:
double Eval(const HostDeviceVector<bst_float>& preds, const MetaInfo& info) override {
CHECK_EQ(preds.Size(), info.labels.Size())
@@ -199,20 +129,10 @@ struct EvalRank : public MetricNoCache, public EvalRankConfig {
// sum statistics
double sum_metric = 0.0f;
// Check and see if we have the GPU metric registered in the internal registry
if (ctx_->gpu_id >= 0) {
if (!rank_gpu_) {
rank_gpu_.reset(GPUMetric::CreateGPUMetric(this->Name(), ctx_));
}
if (rank_gpu_) {
sum_metric = rank_gpu_->Eval(preds, info);
}
}
CHECK(ctx_);
std::vector<double> sum_tloc(ctx_->Threads(), 0.0);
if (!rank_gpu_ || ctx_->gpu_id < 0) {
{
const auto& labels = info.labels.View(Context::kCpuId);
const auto &h_preds = preds.ConstHostVector();
@@ -253,23 +173,6 @@ struct EvalRank : public MetricNoCache, public EvalRankConfig {
virtual double EvalGroup(PredIndPairContainer *recptr) const = 0;
};
/*! \brief Precision at N, for both classification and rank */
struct EvalPrecision : public EvalRank {
public:
explicit EvalPrecision(const char* name, const char* param) : EvalRank(name, param) {}
double EvalGroup(PredIndPairContainer *recptr) const override {
PredIndPairContainer &rec(*recptr);
// calculate Precision
std::stable_sort(rec.begin(), rec.end(), common::CmpFirst);
unsigned nhit = 0;
for (size_t j = 0; j < rec.size() && j < this->topn; ++j) {
nhit += (rec[j].second != 0);
}
return static_cast<double>(nhit) / this->topn;
}
};
/*! \brief Cox: Partial likelihood of the Cox proportional hazards model */
struct EvalCox : public MetricNoCache {
public:
@@ -312,7 +215,7 @@ struct EvalCox : public MetricNoCache {
return out/num_events; // normalize by the number of events
}
const char* Name() const override {
[[nodiscard]] const char* Name() const override {
return "cox-nloglik";
}
};
@@ -321,10 +224,6 @@ XGBOOST_REGISTER_METRIC(AMS, "ams")
.describe("AMS metric for higgs.")
.set_body([](const char* param) { return new EvalAMS(param); });
XGBOOST_REGISTER_METRIC(Precision, "pre")
.describe("precision@k for rank.")
.set_body([](const char* param) { return new EvalPrecision("pre", param); });
XGBOOST_REGISTER_METRIC(Cox, "cox-nloglik")
.describe("Negative log partial likelihood of Cox proportional hazards model.")
.set_body([](const char*) { return new EvalCox(); });
@@ -387,6 +286,8 @@ class EvalRankWithCache : public Metric {
return result;
}
[[nodiscard]] const char* Name() const override { return name_.c_str(); }
virtual double Eval(HostDeviceVector<float> const& preds, MetaInfo const& info,
std::shared_ptr<Cache> p_cache) = 0;
};
@@ -408,6 +309,52 @@ double Finalize(MetaInfo const& info, double score, double sw) {
}
} // namespace
class EvalPrecision : public EvalRankWithCache<ltr::PreCache> {
public:
using EvalRankWithCache::EvalRankWithCache;
double Eval(HostDeviceVector<float> const& predt, MetaInfo const& info,
std::shared_ptr<ltr::PreCache> p_cache) final {
auto n_groups = p_cache->Groups();
if (!info.weights_.Empty()) {
CHECK_EQ(info.weights_.Size(), n_groups) << error::GroupWeight();
}
if (ctx_->IsCUDA()) {
auto pre = cuda_impl::PreScore(ctx_, info, predt, p_cache);
return Finalize(info, pre.Residue(), pre.Weights());
}
auto gptr = p_cache->DataGroupPtr(ctx_);
auto h_label = info.labels.HostView().Slice(linalg::All(), 0);
auto h_predt = linalg::MakeTensorView(ctx_, &predt, predt.Size());
auto rank_idx = p_cache->SortedIdx(ctx_, predt.ConstHostSpan());
auto weight = common::MakeOptionalWeights(ctx_, info.weights_);
auto pre = p_cache->Pre(ctx_);
common::ParallelFor(p_cache->Groups(), ctx_->Threads(), [&](auto g) {
auto g_label = h_label.Slice(linalg::Range(gptr[g], gptr[g + 1]));
auto g_rank = rank_idx.subspan(gptr[g], gptr[g + 1] - gptr[g]);
auto n = std::min(static_cast<std::size_t>(param_.TopK()), g_label.Size());
double n_hits{0.0};
for (std::size_t i = 0; i < n; ++i) {
n_hits += g_label(g_rank[i]) * weight[g];
}
pre[g] = n_hits / static_cast<double>(n);
});
auto sw = 0.0;
for (std::size_t i = 0; i < pre.size(); ++i) {
sw += weight[i];
}
auto sum = std::accumulate(pre.cbegin(), pre.cend(), 0.0);
return Finalize(info, sum, sw);
}
};
/**
* \brief Implement the NDCG score function for learning to rank.
*
@@ -416,7 +363,6 @@ double Finalize(MetaInfo const& info, double score, double sw) {
class EvalNDCG : public EvalRankWithCache<ltr::NDCGCache> {
public:
using EvalRankWithCache::EvalRankWithCache;
const char* Name() const override { return name_.c_str(); }
double Eval(HostDeviceVector<float> const& preds, MetaInfo const& info,
std::shared_ptr<ltr::NDCGCache> p_cache) override {
@@ -475,7 +421,6 @@ class EvalNDCG : public EvalRankWithCache<ltr::NDCGCache> {
class EvalMAPScore : public EvalRankWithCache<ltr::MAPCache> {
public:
using EvalRankWithCache::EvalRankWithCache;
const char* Name() const override { return name_.c_str(); }
double Eval(HostDeviceVector<float> const& predt, MetaInfo const& info,
std::shared_ptr<ltr::MAPCache> p_cache) override {
@@ -494,7 +439,7 @@ class EvalMAPScore : public EvalRankWithCache<ltr::MAPCache> {
common::ParallelFor(p_cache->Groups(), ctx_->Threads(), [&](auto g) {
auto g_label = h_label.Slice(linalg::Range(gptr[g], gptr[g + 1]));
auto g_rank = rank_idx.subspan(gptr[g]);
auto g_rank = rank_idx.subspan(gptr[g], gptr[g + 1] - gptr[g]);
auto n = std::min(static_cast<std::size_t>(param_.TopK()), g_label.Size());
double n_hits{0.0};
@@ -527,6 +472,10 @@ class EvalMAPScore : public EvalRankWithCache<ltr::MAPCache> {
}
};
XGBOOST_REGISTER_METRIC(Precision, "pre")
.describe("precision@k for rank.")
.set_body([](const char* param) { return new EvalPrecision("pre", param); });
XGBOOST_REGISTER_METRIC(EvalMAP, "map")
.describe("map@k for ranking.")
.set_body([](char const* param) {

View File

@@ -28,108 +28,57 @@ namespace xgboost::metric {
// tag the this file, used by force static link later.
DMLC_REGISTRY_FILE_TAG(rank_metric_gpu);
/*! \brief Evaluate rank list on GPU */
template <typename EvalMetricT>
struct EvalRankGpu : public GPUMetric, public EvalRankConfig {
public:
double Eval(const HostDeviceVector<bst_float> &preds, const MetaInfo &info) override {
// Sanity check is done by the caller
std::vector<unsigned> tgptr(2, 0);
tgptr[1] = static_cast<unsigned>(preds.Size());
const std::vector<unsigned> &gptr = info.group_ptr_.size() == 0 ? tgptr : info.group_ptr_;
const auto ngroups = static_cast<bst_omp_uint>(gptr.size() - 1);
auto device = ctx_->gpu_id;
dh::safe_cuda(cudaSetDevice(device));
info.labels.SetDevice(device);
preds.SetDevice(device);
auto dpreds = preds.ConstDevicePointer();
auto dlabels = info.labels.View(device);
// Sort all the predictions
dh::SegmentSorter<float> segment_pred_sorter;
segment_pred_sorter.SortItems(dpreds, preds.Size(), gptr);
// Compute individual group metric and sum them up
return EvalMetricT::EvalMetric(segment_pred_sorter, dlabels.Values().data(), *this);
}
const char* Name() const override {
return name.c_str();
}
explicit EvalRankGpu(const char* name, const char* param) {
using namespace std; // NOLINT(*)
if (param != nullptr) {
std::ostringstream os;
if (sscanf(param, "%u[-]?", &this->topn) == 1) {
os << name << '@' << param;
this->name = os.str();
} else {
os << name << param;
this->name = os.str();
}
if (param[strlen(param) - 1] == '-') {
this->minus = true;
}
} else {
this->name = name;
}
}
};
/*! \brief Precision at N, for both classification and rank */
struct EvalPrecisionGpu {
public:
static double EvalMetric(const dh::SegmentSorter<float> &pred_sorter,
const float *dlabels,
const EvalRankConfig &ecfg) {
// Group info on device
const auto &dgroups = pred_sorter.GetGroupsSpan();
const auto ngroups = pred_sorter.GetNumGroups();
const auto &dgroup_idx = pred_sorter.GetGroupSegmentsSpan();
// Original positions of the predictions after they have been sorted
const auto &dpreds_orig_pos = pred_sorter.GetOriginalPositionsSpan();
// First, determine non zero labels in the dataset individually
auto DetermineNonTrivialLabelLambda = [=] __device__(uint32_t idx) {
return (static_cast<unsigned>(dlabels[dpreds_orig_pos[idx]]) != 0) ? 1 : 0;
}; // NOLINT
// Find each group's metric sum
dh::caching_device_vector<uint32_t> hits(ngroups, 0);
const auto nitems = pred_sorter.GetNumItems();
auto *dhits = hits.data().get();
int device_id = -1;
dh::safe_cuda(cudaGetDevice(&device_id));
// For each group item compute the aggregated precision
dh::LaunchN(nitems, nullptr, [=] __device__(uint32_t idx) {
const auto group_idx = dgroup_idx[idx];
const auto group_begin = dgroups[group_idx];
const auto ridx = idx - group_begin;
if (ridx < ecfg.topn && DetermineNonTrivialLabelLambda(idx)) {
atomicAdd(&dhits[group_idx], 1);
}
});
// Allocator to be used for managing space overhead while performing reductions
dh::XGBCachingDeviceAllocator<char> alloc;
return static_cast<double>(thrust::reduce(thrust::cuda::par(alloc),
hits.begin(), hits.end())) / ecfg.topn;
}
};
XGBOOST_REGISTER_GPU_METRIC(PrecisionGpu, "pre")
.describe("precision@k for rank computed on GPU.")
.set_body([](const char* param) { return new EvalRankGpu<EvalPrecisionGpu>("pre", param); });
namespace cuda_impl {
PackedReduceResult PreScore(Context const *ctx, MetaInfo const &info,
HostDeviceVector<float> const &predt,
std::shared_ptr<ltr::PreCache> p_cache) {
auto d_gptr = p_cache->DataGroupPtr(ctx);
auto d_label = info.labels.View(ctx->gpu_id).Slice(linalg::All(), 0);
predt.SetDevice(ctx->gpu_id);
auto d_rank_idx = p_cache->SortedIdx(ctx, predt.ConstDeviceSpan());
auto topk = p_cache->Param().TopK();
auto d_weight = common::MakeOptionalWeights(ctx, info.weights_);
auto it = dh::MakeTransformIterator<double>(
thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(std::size_t i) {
auto g = dh::SegmentId(d_gptr, i);
auto g_begin = d_gptr[g];
auto g_end = d_gptr[g + 1];
i -= g_begin;
auto g_label = d_label.Slice(linalg::Range(g_begin, g_end));
auto g_rank = d_rank_idx.subspan(g_begin, g_end - g_begin);
double y = g_label(g_rank[i]);
auto n = std::min(static_cast<std::size_t>(topk), g_label.Size());
double w{d_weight[g]};
if (i >= n) {
return 0.0;
}
return y / static_cast<double>(n) * w;
});
auto cuctx = ctx->CUDACtx();
auto pre = p_cache->Pre(ctx);
thrust::fill_n(cuctx->CTP(), pre.data(), pre.size(), 0.0);
std::size_t bytes;
cub::DeviceSegmentedReduce::Sum(nullptr, bytes, it, pre.data(), p_cache->Groups(), d_gptr.data(),
d_gptr.data() + 1, cuctx->Stream());
dh::TemporaryArray<char> temp(bytes);
cub::DeviceSegmentedReduce::Sum(temp.data().get(), bytes, it, pre.data(), p_cache->Groups(),
d_gptr.data(), d_gptr.data() + 1, cuctx->Stream());
auto w_it =
dh::MakeTransformIterator<double>(thrust::make_counting_iterator(0ul),
[=] XGBOOST_DEVICE(std::size_t g) { return d_weight[g]; });
auto n_weights = p_cache->Groups();
auto sw = dh::Reduce(cuctx->CTP(), w_it, w_it + n_weights, 0.0, thrust::plus<double>{});
auto sum =
dh::Reduce(cuctx->CTP(), dh::tcbegin(pre), dh::tcend(pre), 0.0, thrust::plus<double>{});
auto result = PackedReduceResult{sum, sw};
return result;
}
PackedReduceResult NDCGScore(Context const *ctx, MetaInfo const &info,
HostDeviceVector<float> const &predt, bool minus,
std::shared_ptr<ltr::NDCGCache> p_cache) {

View File

@@ -3,7 +3,7 @@
/**
* Copyright 2023 by XGBoost Contributors
*/
#include <memory> // for shared_ptr
#include <memory> // for shared_ptr
#include "../common/common.h" // for AssertGPUSupport
#include "../common/ranking_utils.h" // for NDCGCache, MAPCache
@@ -12,9 +12,7 @@
#include "xgboost/data.h" // for MetaInfo
#include "xgboost/host_device_vector.h" // for HostDeviceVector
namespace xgboost {
namespace metric {
namespace cuda_impl {
namespace xgboost::metric::cuda_impl {
PackedReduceResult NDCGScore(Context const *ctx, MetaInfo const &info,
HostDeviceVector<float> const &predt, bool minus,
std::shared_ptr<ltr::NDCGCache> p_cache);
@@ -23,6 +21,10 @@ PackedReduceResult MAPScore(Context const *ctx, MetaInfo const &info,
HostDeviceVector<float> const &predt, bool minus,
std::shared_ptr<ltr::MAPCache> p_cache);
PackedReduceResult PreScore(Context const *ctx, MetaInfo const &info,
HostDeviceVector<float> const &predt,
std::shared_ptr<ltr::PreCache> p_cache);
#if !defined(XGBOOST_USE_CUDA)
inline PackedReduceResult NDCGScore(Context const *, MetaInfo const &,
HostDeviceVector<float> const &, bool,
@@ -37,8 +39,13 @@ inline PackedReduceResult MAPScore(Context const *, MetaInfo const &,
common::AssertGPUSupport();
return {};
}
inline PackedReduceResult PreScore(Context const *, MetaInfo const &,
HostDeviceVector<float> const &,
std::shared_ptr<ltr::PreCache>) {
common::AssertGPUSupport();
return {};
}
#endif
} // namespace cuda_impl
} // namespace metric
} // namespace xgboost
} // namespace xgboost::metric::cuda_impl
#endif // XGBOOST_METRIC_RANK_METRIC_H_