Re-implement PR-AUC. (#7297)

* Support binary/multi-class classification, ranking.
* Add documents.
* Handle missing data.
This commit is contained in:
Jiaming Yuan 2021-10-26 13:07:50 +08:00 committed by GitHub
parent a6bcd54b47
commit d4349426d8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 1035 additions and 655 deletions

View File

@ -393,9 +393,13 @@ Specify the learning task and the corresponding learning objective. The objectiv
- When used with multi-class classification, objective should be ``multi:softprob`` instead of ``multi:softmax``, as the latter doesn't output probability. Also the AUC is calculated by 1-vs-rest with reference class weighted by class prevalence.
- When used with LTR task, the AUC is computed by comparing pairs of documents to count correctly sorted pairs. This corresponds to pairwise learning to rank. The implementation has some issues with average AUC around groups and distributed workers not being well-defined.
- On a single machine the AUC calculation is exact. In a distributed environment the AUC is a weighted average over the AUC of training rows on each node - therefore, distributed AUC is an approximation sensitive to the distribution of data across workers. Use another metric in distributed environments if precision and reproducibility are important.
- If input dataset contains only negative or positive samples the output is `NaN`.
- When input dataset contains only negative or positive samples, the output is `NaN`. The behavior is implementation defined, for instance, ``scikit-learn`` returns :math:`0.5` instead.
- ``aucpr``: `Area under the PR curve <https://en.wikipedia.org/wiki/Precision_and_recall>`_.
Available for classification and learning-to-rank tasks.
After XGBoost 1.6, both of the requirements and restrictions for using ``aucpr`` in classification problem are similar to ``auc``. For ranking task, only binary relevance label :math:`y \in [0, 1]` is supported. Different from ``map (mean average precision)``, ``aucpr`` calculates the *interpolated* area under precision recall curve using continuous interpolation.
- ``aucpr``: `Area under the PR curve <https://en.wikipedia.org/wiki/Precision_and_recall>`_. Available for binary classification and learning-to-rank tasks.
- ``ndcg``: `Normalized Discounted Cumulative Gain <http://en.wikipedia.org/wiki/NDCG>`_
- ``map``: `Mean Average Precision <http://en.wikipedia.org/wiki/Mean_average_precision#Mean_average_precision>`_
- ``ndcg@n``, ``map@n``: 'n' can be assigned as an integer to cut off the top positions in the lists for evaluation.

View File

@ -19,6 +19,7 @@
#include <string>
#include <sstream>
#include <numeric>
#include <utility>
#if defined(__CUDACC__)
#include <thrust/system/cuda/error.h>
@ -86,6 +87,19 @@ XGBOOST_DEVICE T1 DivRoundUp(const T1 a, const T2 b) {
return static_cast<T1>(std::ceil(static_cast<double>(a) / b));
}
namespace detail {
template <class T, std::size_t N, std::size_t... Idx>
constexpr auto UnpackArr(std::array<T, N> &&arr, std::index_sequence<Idx...>) {
return std::make_tuple(std::forward<std::array<T, N>>(arr)[Idx]...);
}
} // namespace detail
template <class T, std::size_t N>
constexpr auto UnpackArr(std::array<T, N> &&arr) {
return detail::UnpackArr(std::forward<std::array<T, N>>(arr),
std::make_index_sequence<N>{});
}
/*
* Range iterator
*/

View File

@ -14,62 +14,50 @@
#include "rabit/rabit.h"
#include "xgboost/host_device_vector.h"
#include "xgboost/metric.h"
#include "auc.h"
#include "../common/common.h"
#include "../common/math.h"
#include "../common/threading_utils.h"
namespace xgboost {
namespace metric {
namespace detail {
template <class T, std::size_t N, std::size_t... Idx>
constexpr auto UnpackArr(std::array<T, N> &&arr, std::index_sequence<Idx...>) {
return std::make_tuple(std::forward<std::array<T, N>>(arr)[Idx]...);
}
} // namespace detail
template <class T, std::size_t N>
constexpr auto UnpackArr(std::array<T, N> &&arr) {
return detail::UnpackArr(std::forward<std::array<T, N>>(arr),
std::make_index_sequence<N>{});
}
/**
* Calculate AUC for binary classification problem. This function does not normalize the
* AUC by 1 / (num_positive * num_negative), instead it returns a tuple for caller to
* handle the normalization.
*/
std::tuple<float, float, float> BinaryAUC(std::vector<float> const &predts,
std::vector<float> const &labels,
std::vector<float> const &weights) {
template <typename Fn>
std::tuple<float, float, float>
BinaryAUC(common::Span<float const> predts, common::Span<float const> labels,
OptionalWeights weights,
std::vector<size_t> const &sorted_idx, Fn &&area_fn) {
CHECK(!labels.empty());
CHECK_EQ(labels.size(), predts.size());
auto p_predts = predts.data();
auto p_labels = labels.data();
float auc{0};
auto const sorted_idx = common::ArgSort<size_t>(
common::Span<float const>(predts), std::greater<>{});
auto get_weight = [&](size_t i) {
return weights.empty() ? 1.0f : weights[sorted_idx[i]];
};
float label = labels[sorted_idx.front()];
float w = get_weight(0);
float label = p_labels[sorted_idx.front()];
float w = weights[sorted_idx[0]];
float fp = (1.0 - label) * w, tp = label * w;
float tp_prev = 0, fp_prev = 0;
// TODO(jiaming): We can parallize this if we have a parallel scan for CPU.
for (size_t i = 1; i < sorted_idx.size(); ++i) {
if (predts[sorted_idx[i]] != predts[sorted_idx[i-1]]) {
auc += TrapesoidArea(fp_prev, fp, tp_prev, tp);
if (p_predts[sorted_idx[i]] != p_predts[sorted_idx[i - 1]]) {
auc += area_fn(fp_prev, fp, tp_prev, tp);
tp_prev = tp;
fp_prev = fp;
}
label = labels[sorted_idx[i]];
float w = get_weight(i);
label = p_labels[sorted_idx[i]];
float w = weights[sorted_idx[i]];
fp += (1.0f - label) * w;
tp += label * w;
}
auc += TrapesoidArea(fp_prev, fp, tp_prev, tp);
auc += area_fn(fp_prev, fp, tp_prev, tp);
if (fp <= 0.0f || tp <= 0.0f) {
auc = 0;
fp = 0;
@ -87,7 +75,10 @@ std::tuple<float, float, float> BinaryAUC(std::vector<float> const &predts,
* - Kleiman, Ross and Page, David. $AUC_{\mu}$: A Performance Metric for Multi-Class
* Machine Learning Models
*/
float MultiClassOVR(std::vector<float> const& predts, MetaInfo const& info, size_t n_classes) {
template <typename BinaryAUC>
float MultiClassOVR(common::Span<float const> predts, MetaInfo const &info,
size_t n_classes, int32_t n_threads,
BinaryAUC &&binary_auc) {
CHECK_NE(n_classes, 0);
auto const &labels = info.labels_.ConstHostVector();
@ -96,12 +87,10 @@ float MultiClassOVR(std::vector<float> const& predts, MetaInfo const& info, size
auto local_area = s_results.subspan(0, n_classes);
auto tp = s_results.subspan(n_classes, n_classes);
auto auc = s_results.subspan(2 * n_classes, n_classes);
auto weights = OptionalWeights{info.weights_.ConstHostSpan()};
if (!info.labels_.Empty()) {
dmlc::OMPException omp_handler;
#pragma omp parallel for
for (omp_ulong c = 0; c < n_classes; ++c) {
omp_handler.Run([&]() {
common::ParallelFor(n_classes, n_threads, [&](auto c) {
std::vector<float> proba(info.labels_.Size());
std::vector<float> response(info.labels_.Size());
for (size_t i = 0; i < proba.size(); ++i) {
@ -109,24 +98,21 @@ float MultiClassOVR(std::vector<float> const& predts, MetaInfo const& info, size
response[i] = labels[i] == c ? 1.0f : 0.0;
}
float fp;
std::tie(fp, tp[c], auc[c]) =
BinaryAUC(proba, response, info.weights_.ConstHostVector());
std::tie(fp, tp[c], auc[c]) = binary_auc(proba, response, weights);
local_area[c] = fp * tp[c];
});
}
omp_handler.Rethrow();
}
// we have 2 averages going in here, first is among workers, second is among classes.
// allreduce sums up fp/tp auc for each class.
// we have 2 averages going in here, first is among workers, second is among
// classes. allreduce sums up fp/tp auc for each class.
rabit::Allreduce<rabit::op::Sum>(results.data(), results.size());
float auc_sum{0};
float tp_sum{0};
for (size_t c = 0; c < n_classes; ++c) {
if (local_area[c] != 0) {
// normalize and weight it by prevalence. After allreduce, `local_area` means the
// total covered area (not area under curve, rather it's the accessible area for
// each worker) for each class.
// normalize and weight it by prevalence. After allreduce, `local_area`
// means the total covered area (not area under curve, rather it's the
// accessible area for each worker) for each class.
auc_sum += auc[c] / local_area[c] * tp[c];
tp_sum += tp[c];
} else {
@ -142,10 +128,17 @@ float MultiClassOVR(std::vector<float> const& predts, MetaInfo const& info, size
return auc_sum;
}
std::tuple<float, float, float> BinaryROCAUC(common::Span<float const> predts,
common::Span<float const> labels,
OptionalWeights weights) {
auto const sorted_idx = common::ArgSort<size_t>(predts, std::greater<>{});
return BinaryAUC(predts, labels, weights, sorted_idx, TrapezoidArea);
}
/**
* Calculate AUC for 1 ranking group;
*/
float GroupRankingAUC(common::Span<float const> predts,
float GroupRankingROC(common::Span<float const> predts,
common::Span<float const> labels, float w) {
// on ranking, we just count all pairs.
float auc{0};
@ -174,11 +167,40 @@ float GroupRankingAUC(common::Span<float const> predts,
return auc;
}
/**
* \brief PR-AUC for binary classification.
*
* https://doi.org/10.1371/journal.pone.0092209
*/
std::tuple<float, float, float> BinaryPRAUC(common::Span<float const> predts,
common::Span<float const> labels,
OptionalWeights weights) {
auto const sorted_idx = common::ArgSort<size_t>(predts, std::greater<>{});
float total_pos{0}, total_neg{0};
for (size_t i = 0; i < labels.size(); ++i) {
auto w = weights[i];
total_pos += w * labels[i];
total_neg += w * (1.0f - labels[i]);
}
if (total_pos <= 0 || total_neg <= 0) {
return {1.0f, 1.0f, std::numeric_limits<float>::quiet_NaN()};
}
auto fn = [total_pos](float fp_prev, float fp, float tp_prev, float tp) {
return detail::CalcDeltaPRAUC(fp_prev, fp, tp_prev, tp, total_pos);
};
float tp{0}, fp{0}, auc{0};
std::tie(fp, tp, auc) = BinaryAUC(predts, labels, weights, sorted_idx, fn);
return std::make_tuple(1.0, 1.0, auc);
}
/**
* Cast LTR problem to binary classification problem by comparing pairs.
*/
template <bool is_roc>
std::pair<float, uint32_t> RankingAUC(std::vector<float> const &predts,
MetaInfo const &info) {
MetaInfo const &info, int32_t n_threads) {
CHECK_GE(info.group_ptr_.size(), 2);
uint32_t n_groups = info.group_ptr_.size() - 1;
float sum_auc = 0;
@ -189,7 +211,7 @@ std::pair<float, uint32_t> RankingAUC(std::vector<float> const &predts,
std::atomic<uint32_t> invalid_groups{0};
dmlc::OMPException omp_handler;
#pragma omp parallel for reduction(+:sum_auc)
#pragma omp parallel for reduction(+:sum_auc) num_threads(n_threads)
for (omp_ulong g = 1; g < info.group_ptr_.size(); ++g) {
omp_handler.Run([&]() {
size_t cnt = info.group_ptr_[g] - info.group_ptr_[g - 1];
@ -197,30 +219,32 @@ std::pair<float, uint32_t> RankingAUC(std::vector<float> const &predts,
auto g_predts = s_predts.subspan(info.group_ptr_[g - 1], cnt);
auto g_labels = s_labels.subspan(info.group_ptr_[g - 1], cnt);
float auc;
if (g_labels.size() < 3) {
if (is_roc && g_labels.size() < 3) {
// With 2 documents, there's only 1 comparison can be made. So either
// TP or FP will be zero.
invalid_groups++;
auc = 0;
} else {
auc = GroupRankingAUC(g_predts, g_labels, w);
if (is_roc) {
auc = GroupRankingROC(g_predts, g_labels, w);
} else {
auc = std::get<2>(BinaryPRAUC(g_predts, g_labels, OptionalWeights{w}));
}
if (std::isnan(auc)) {
invalid_groups++;
auc = 0;
}
}
sum_auc += auc;
});
}
omp_handler.Rethrow();
if (invalid_groups != 0) {
InvalidGroupAUC();
}
return std::make_pair(sum_auc, n_groups - invalid_groups);
}
template <typename Curve>
class EvalAUC : public Metric {
std::shared_ptr<DeviceAUCCache> d_cache_;
public:
float Eval(const HostDeviceVector<bst_float> &preds, const MetaInfo &info,
bool distributed) override {
float auc {0};
@ -232,8 +256,10 @@ class EvalAUC : public Metric {
// We use the global size to handle empty dataset.
std::array<size_t, 2> meta{info.labels_.Size(), preds.Size()};
rabit::Allreduce<rabit::op::Max>(meta.data(), meta.size());
if (!info.group_ptr_.empty()) {
if (meta[0] == 0) {
// Empty across all workers, which is not supported.
auc = std::numeric_limits<float>::quiet_NaN();
} else if (!info.group_ptr_.empty()) {
/**
* learning to rank
*/
@ -243,13 +269,11 @@ class EvalAUC : public Metric {
uint32_t valid_groups = 0;
if (!info.labels_.Empty()) {
CHECK_EQ(info.group_ptr_.back(), info.labels_.Size());
if (tparam_->gpu_id == GenericParameter::kCpuId) {
std::tie(auc, valid_groups) =
RankingAUC(preds.ConstHostVector(), info);
} else {
std::tie(auc, valid_groups) = GPURankingAUC(
preds.ConstDeviceSpan(), info, tparam_->gpu_id, &this->d_cache_);
static_cast<Curve *>(this)->EvalRanking(preds, info);
}
if (valid_groups != info.group_ptr_.size() - 1) {
InvalidGroupAUC();
}
std::array<float, 2> results{auc, static_cast<float>(valid_groups)};
@ -270,45 +294,85 @@ class EvalAUC : public Metric {
*/
size_t n_classes = meta[1] / meta[0];
CHECK_NE(n_classes, 0);
if (tparam_->gpu_id == GenericParameter::kCpuId) {
auc = MultiClassOVR(preds.ConstHostVector(), info, n_classes);
} else {
auc = GPUMultiClassAUCOVR(preds.ConstDeviceSpan(), info, tparam_->gpu_id,
&this->d_cache_, n_classes);
}
auc = static_cast<Curve *>(this)->EvalMultiClass(preds, info, n_classes);
} else {
/**
* binary classification
*/
float fp{0}, tp{0};
if (!(preds.Empty() || info.labels_.Empty())) {
if (tparam_->gpu_id == GenericParameter::kCpuId) {
std::tie(fp, tp, auc) =
BinaryAUC(preds.ConstHostVector(), info.labels_.ConstHostVector(),
info.weights_.ConstHostVector());
} else {
std::tie(fp, tp, auc) = GPUBinaryAUC(
preds.ConstDeviceSpan(), info, tparam_->gpu_id, &this->d_cache_);
}
static_cast<Curve *>(this)->EvalBinary(preds, info);
}
float local_area = fp * tp;
std::array<float, 2> result{auc, local_area};
rabit::Allreduce<rabit::op::Sum>(result.data(), result.size());
std::tie(auc, local_area) = UnpackArr(std::move(result));
std::tie(auc, local_area) = common::UnpackArr(std::move(result));
if (local_area <= 0) {
// the dataset across all workers have only positive or negative sample
auc = std::numeric_limits<float>::quiet_NaN();
} else {
CHECK_LE(auc, local_area);
// normalization
auc = auc / local_area;
}
}
if (std::isnan(auc)) {
LOG(WARNING) << "Dataset contains only positive or negative samples.";
LOG(WARNING) << "Dataset is empty, or contains only positive or negative samples.";
}
return auc;
}
};
class EvalROCAUC : public EvalAUC<EvalROCAUC> {
std::shared_ptr<DeviceAUCCache> d_cache_;
public:
std::pair<float, uint32_t> EvalRanking(HostDeviceVector<float> const &predts,
MetaInfo const &info) {
float auc{0};
uint32_t valid_groups = 0;
auto n_threads = tparam_->Threads();
if (tparam_->gpu_id == GenericParameter::kCpuId) {
std::tie(auc, valid_groups) =
RankingAUC<true>(predts.ConstHostVector(), info, n_threads);
} else {
std::tie(auc, valid_groups) = GPURankingAUC(
predts.ConstDeviceSpan(), info, tparam_->gpu_id, &this->d_cache_);
}
return std::make_pair(auc, valid_groups);
}
float EvalMultiClass(HostDeviceVector<float> const &predts,
MetaInfo const &info, size_t n_classes) {
float auc{0};
auto n_threads = tparam_->Threads();
CHECK_NE(n_classes, 0);
if (tparam_->gpu_id == GenericParameter::kCpuId) {
auc = MultiClassOVR(predts.ConstHostVector(), info, n_classes, n_threads,
BinaryROCAUC);
} else {
auc = GPUMultiClassROCAUC(predts.ConstDeviceSpan(), info, tparam_->gpu_id,
&this->d_cache_, n_classes);
}
return auc;
}
std::tuple<float, float, float>
EvalBinary(HostDeviceVector<float> const &predts, MetaInfo const &info) {
float fp, tp, auc;
if (tparam_->gpu_id == GenericParameter::kCpuId) {
std::tie(fp, tp, auc) =
BinaryROCAUC(predts.ConstHostVector(), info.labels_.ConstHostVector(),
OptionalWeights{info.weights_.ConstHostSpan()});
} else {
std::tie(fp, tp, auc) = GPUBinaryROCAUC(predts.ConstDeviceSpan(), info,
tparam_->gpu_id, &this->d_cache_);
}
return std::make_tuple(fp, tp, auc);
}
public:
char const* Name() const override {
return "auc";
}
@ -316,18 +380,19 @@ class EvalAUC : public Metric {
XGBOOST_REGISTER_METRIC(EvalAUC, "auc")
.describe("Receiver Operating Characteristic Area Under the Curve.")
.set_body([](const char*) { return new EvalAUC(); });
.set_body([](const char*) { return new EvalROCAUC(); });
#if !defined(XGBOOST_USE_CUDA)
std::tuple<float, float, float>
GPUBinaryAUC(common::Span<float const> predts, MetaInfo const &info,
GPUBinaryROCAUC(common::Span<float const> predts, MetaInfo const &info,
int32_t device, std::shared_ptr<DeviceAUCCache> *p_cache) {
common::AssertGPUSupport();
return std::make_tuple(0.0f, 0.0f, 0.0f);
}
float GPUMultiClassAUCOVR(common::Span<float const> predts, MetaInfo const &info,
int32_t device, std::shared_ptr<DeviceAUCCache>* cache,
float GPUMultiClassROCAUC(common::Span<float const> predts,
MetaInfo const &info, int32_t device,
std::shared_ptr<DeviceAUCCache> *cache,
size_t n_classes) {
common::AssertGPUSupport();
return 0;
@ -341,5 +406,85 @@ GPURankingAUC(common::Span<float const> predts, MetaInfo const &info,
}
struct DeviceAUCCache {};
#endif // !defined(XGBOOST_USE_CUDA)
class EvalAUCPR : public EvalAUC<EvalAUCPR> {
std::shared_ptr<DeviceAUCCache> d_cache_;
public:
std::tuple<float, float, float>
EvalBinary(HostDeviceVector<float> const &predts, MetaInfo const &info) {
float pr, re, auc;
if (tparam_->gpu_id == GenericParameter::kCpuId) {
std::tie(pr, re, auc) =
BinaryPRAUC(predts.ConstHostSpan(), info.labels_.ConstHostSpan(),
OptionalWeights{info.weights_.ConstHostSpan()});
} else {
std::tie(pr, re, auc) = GPUBinaryPRAUC(predts.ConstDeviceSpan(), info,
tparam_->gpu_id, &this->d_cache_);
}
return std::make_tuple(pr, re, auc);
}
float EvalMultiClass(HostDeviceVector<float> const &predts,
MetaInfo const &info, size_t n_classes) {
if (tparam_->gpu_id == GenericParameter::kCpuId) {
auto n_threads = this->tparam_->Threads();
return MultiClassOVR(predts.ConstHostSpan(), info, n_classes, n_threads,
BinaryPRAUC);
} else {
return GPUMultiClassPRAUC(predts.ConstDeviceSpan(), info, tparam_->gpu_id,
&d_cache_, n_classes);
}
}
std::pair<float, uint32_t> EvalRanking(HostDeviceVector<float> const &predts,
MetaInfo const &info) {
float auc{0};
uint32_t valid_groups = 0;
auto n_threads = tparam_->Threads();
if (tparam_->gpu_id == GenericParameter::kCpuId) {
auto labels = info.labels_.ConstHostSpan();
if (std::any_of(labels.cbegin(), labels.cend(), PRAUCLabelInvalid{})) {
InvalidLabels();
}
std::tie(auc, valid_groups) =
RankingAUC<false>(predts.ConstHostVector(), info, n_threads);
} else {
std::tie(auc, valid_groups) = GPURankingPRAUC(
predts.ConstDeviceSpan(), info, tparam_->gpu_id, &d_cache_);
}
return std::make_pair(auc, valid_groups);
}
public:
const char *Name() const override { return "aucpr"; }
};
XGBOOST_REGISTER_METRIC(AUCPR, "aucpr")
.describe("Area under PR curve for both classification and rank.")
.set_body([](char const *) { return new EvalAUCPR{}; });
#if !defined(XGBOOST_USE_CUDA)
std::tuple<float, float, float>
GPUBinaryPRAUC(common::Span<float const> predts, MetaInfo const &info,
int32_t device, std::shared_ptr<DeviceAUCCache> *p_cache) {
common::AssertGPUSupport();
return {};
}
float GPUMultiClassPRAUC(common::Span<float const> predts, MetaInfo const &info,
int32_t device, std::shared_ptr<DeviceAUCCache> *cache,
size_t n_classes) {
common::AssertGPUSupport();
return {};
}
std::pair<float, uint32_t>
GPURankingPRAUC(common::Span<float const> predts, MetaInfo const &info,
int32_t device, std::shared_ptr<DeviceAUCCache> *cache) {
common::AssertGPUSupport();
return {};
}
#endif
} // namespace metric
} // namespace xgboost

View File

@ -3,6 +3,8 @@
*/
#include <thrust/scan.h>
#include <cub/cub.cuh>
#include <algorithm>
#include <cassert>
#include <limits>
#include <memory>
@ -19,12 +21,13 @@
namespace xgboost {
namespace metric {
namespace {
struct GetWeightOp {
common::Span<float const> weights;
common::Span<size_t const> sorted_idx;
// Pair of FP/TP
using Pair = thrust::pair<float, float>;
__device__ float operator()(size_t i) const {
return weights.empty() ? 1.0f : weights[sorted_idx[i]];
template <typename T, typename U, typename P = thrust::pair<T, U>>
struct PairPlus : public thrust::binary_function<P, P, P> {
XGBOOST_DEVICE P operator()(P const& l, P const& r) const {
return thrust::make_pair(l.first + r.first, l.second + r.second);
}
};
} // namespace
@ -33,8 +36,6 @@ struct GetWeightOp {
* A cache to GPU data to avoid reallocating memory.
*/
struct DeviceAUCCache {
// Pair of FP/TP
using Pair = thrust::pair<float, float>;
// index sorted by prediction value
dh::device_vector<size_t> sorted_idx;
// track FP/TP for computation on trapesoid area
@ -64,6 +65,16 @@ struct DeviceAUCCache {
}
};
template <bool is_multi>
void InitCacheOnce(common::Span<float const> predts, int32_t device,
std::shared_ptr<DeviceAUCCache>* p_cache) {
auto& cache = *p_cache;
if (!cache) {
cache.reset(new DeviceAUCCache);
}
cache->Init(predts, is_multi, device);
}
/**
* The GPU implementation uses same calculation as CPU with a few more steps to distribute
* work across threads:
@ -73,15 +84,11 @@ struct DeviceAUCCache {
* which are left coordinates of trapesoids.
* - Reduce the scan array into 1 AUC value.
*/
template <typename Fn>
std::tuple<float, float, float>
GPUBinaryAUC(common::Span<float const> predts, MetaInfo const &info,
int32_t device, std::shared_ptr<DeviceAUCCache> *p_cache) {
auto& cache = *p_cache;
if (!cache) {
cache.reset(new DeviceAUCCache);
}
cache->Init(predts, false, device);
int32_t device, common::Span<size_t const> d_sorted_idx,
Fn area_fn, std::shared_ptr<DeviceAUCCache> cache) {
auto labels = info.labels_.ConstDeviceSpan();
auto weights = info.weights_.ConstDeviceSpan();
dh::safe_cuda(cudaSetDevice(device));
@ -89,22 +96,15 @@ GPUBinaryAUC(common::Span<float const> predts, MetaInfo const &info,
CHECK(!labels.empty());
CHECK_EQ(labels.size(), predts.size());
/**
* Create sorted index for each class
*/
auto d_sorted_idx = dh::ToSpan(cache->sorted_idx);
dh::ArgSort<false>(predts, d_sorted_idx);
/**
* Linear scan
*/
auto get_weight = GetWeightOp{weights, d_sorted_idx};
using Pair = thrust::pair<float, float>;
auto get_fp_tp = [=]__device__(size_t i) {
auto get_weight = OptionalWeights{weights};
auto get_fp_tp = [=]XGBOOST_DEVICE(size_t i) {
size_t idx = d_sorted_idx[i];
float label = labels[idx];
float w = get_weight(i);
float w = get_weight[d_sorted_idx[i]];
float fp = (1.0 - label) * w;
float tp = label * w;
@ -113,7 +113,7 @@ GPUBinaryAUC(common::Span<float const> predts, MetaInfo const &info,
}; // NOLINT
auto d_fptp = dh::ToSpan(cache->fptp);
dh::LaunchN(d_sorted_idx.size(),
[=] __device__(size_t i) { d_fptp[i] = get_fp_tp(i); });
[=] XGBOOST_DEVICE(size_t i) { d_fptp[i] = get_fp_tp(i); });
dh::XGBDeviceAllocator<char> alloc;
auto d_unique_idx = dh::ToSpan(cache->unique_idx);
@ -121,24 +121,20 @@ GPUBinaryAUC(common::Span<float const> predts, MetaInfo const &info,
auto uni_key = dh::MakeTransformIterator<float>(
thrust::make_counting_iterator(0),
[=] __device__(size_t i) { return predts[d_sorted_idx[i]]; });
[=] XGBOOST_DEVICE(size_t i) { return predts[d_sorted_idx[i]]; });
auto end_unique = thrust::unique_by_key_copy(
thrust::cuda::par(alloc), uni_key, uni_key + d_sorted_idx.size(),
dh::tbegin(d_unique_idx), thrust::make_discard_iterator(),
dh::tbegin(d_unique_idx));
d_unique_idx = d_unique_idx.subspan(0, end_unique.second - dh::tbegin(d_unique_idx));
dh::InclusiveScan(
dh::tbegin(d_fptp), dh::tbegin(d_fptp),
[=] __device__(Pair const &l, Pair const &r) {
return thrust::make_pair(l.first + r.first, l.second + r.second);
},
d_fptp.size());
dh::InclusiveScan(dh::tbegin(d_fptp), dh::tbegin(d_fptp),
PairPlus<float, float>{}, d_fptp.size());
auto d_neg_pos = dh::ToSpan(cache->neg_pos);
// scatter unique negaive/positive values
// shift to right by 1 with initial value being 0
dh::LaunchN(d_unique_idx.size(), [=] __device__(size_t i) {
dh::LaunchN(d_unique_idx.size(), [=] XGBOOST_DEVICE(size_t i) {
if (d_unique_idx[i] == 0) { // first unique index is 0
assert(i == 0);
d_neg_pos[0] = {0, 0};
@ -154,7 +150,7 @@ GPUBinaryAUC(common::Span<float const> predts, MetaInfo const &info,
});
auto in = dh::MakeTransformIterator<float>(
thrust::make_counting_iterator(0), [=] __device__(size_t i) {
thrust::make_counting_iterator(0), [=] XGBOOST_DEVICE(size_t i) {
float fp, tp;
float fp_prev, tp_prev;
if (i == 0) {
@ -165,7 +161,7 @@ GPUBinaryAUC(common::Span<float const> predts, MetaInfo const &info,
thrust::tie(fp, tp) = d_fptp[d_unique_idx[i] - 1];
thrust::tie(fp_prev, tp_prev) = d_neg_pos[d_unique_idx[i - 1]];
}
return TrapesoidArea(fp_prev, fp, tp_prev, tp);
return area_fn(fp_prev, fp, tp_prev, tp);
});
Pair last = cache->fptp.back();
@ -173,11 +169,31 @@ GPUBinaryAUC(common::Span<float const> predts, MetaInfo const &info,
return std::make_tuple(last.first, last.second, auc);
}
std::tuple<float, float, float>
GPUBinaryROCAUC(common::Span<float const> predts, MetaInfo const &info,
int32_t device, std::shared_ptr<DeviceAUCCache> *p_cache) {
auto &cache = *p_cache;
InitCacheOnce<false>(predts, device, p_cache);
/**
* Create sorted index for each class
*/
auto d_sorted_idx = dh::ToSpan(cache->sorted_idx);
dh::ArgSort<false>(predts, d_sorted_idx);
// Create lambda to avoid pass function pointer.
return GPUBinaryAUC(
predts, info, device, d_sorted_idx,
[] XGBOOST_DEVICE(float x0, float x1, float y0, float y1) {
return TrapezoidArea(x0, x1, y0, y1);
},
cache);
}
void Transpose(common::Span<float const> in, common::Span<float> out, size_t m,
size_t n, int32_t device) {
size_t n) {
CHECK_EQ(in.size(), out.size());
CHECK_EQ(in.size(), m * n);
dh::LaunchN(in.size(), [=] __device__(size_t i) {
dh::LaunchN(in.size(), [=] XGBOOST_DEVICE(size_t i) {
size_t col = i / m;
size_t row = i % m;
size_t idx = row * n + col;
@ -204,7 +220,7 @@ float ScaleClasses(common::Span<float> results, common::Span<float> local_area,
cache->reducer->AllReduceSum(results.data(), results.data(), results.size());
}
auto reduce_in = dh::MakeTransformIterator<thrust::pair<float, float>>(
thrust::make_counting_iterator(0), [=] __device__(size_t i) {
thrust::make_counting_iterator(0), [=] XGBOOST_DEVICE(size_t i) {
if (local_area[i] > 0) {
return thrust::make_pair(auc[i] / local_area[i] * tp[i], tp[i]);
}
@ -213,12 +229,9 @@ float ScaleClasses(common::Span<float> results, common::Span<float> local_area,
float tp_sum;
float auc_sum;
thrust::tie(auc_sum, tp_sum) = thrust::reduce(
thrust::cuda::par(alloc), reduce_in, reduce_in + n_classes,
thrust::make_pair(0.0f, 0.0f),
[=] __device__(auto const &l, auto const &r) {
return thrust::make_pair(l.first + r.first, l.second + r.second);
});
thrust::tie(auc_sum, tp_sum) =
thrust::reduce(thrust::cuda::par(alloc), reduce_in, reduce_in + n_classes,
Pair{0.0f, 0.0f}, PairPlus<float, float>{});
if (tp_sum != 0 && !std::isnan(auc_sum)) {
auc_sum /= tp_sum;
} else {
@ -227,19 +240,98 @@ float ScaleClasses(common::Span<float> results, common::Span<float> local_area,
return auc_sum;
}
/**
* Calculate FP/TP for multi-class and PR-AUC ranking. `segment_id` is a function for
* getting class id or group id given scan index.
*/
template <typename Fn>
void SegmentedFPTP(common::Span<Pair> d_fptp, Fn segment_id) {
using Triple = thrust::tuple<uint32_t, float, float>;
// expand to tuple to include idx
auto fptp_it_in = dh::MakeTransformIterator<Triple>(
thrust::make_counting_iterator(0), [=] XGBOOST_DEVICE(size_t i) {
return thrust::make_tuple(i, d_fptp[i].first, d_fptp[i].second);
});
// shrink down to pair
auto fptp_it_out = thrust::make_transform_output_iterator(
dh::TypedDiscard<Triple>{}, [d_fptp] XGBOOST_DEVICE(Triple const &t) {
d_fptp[thrust::get<0>(t)] =
thrust::make_pair(thrust::get<1>(t), thrust::get<2>(t));
return t;
});
dh::InclusiveScan(
fptp_it_in, fptp_it_out,
[=] XGBOOST_DEVICE(Triple const &l, Triple const &r) {
uint32_t l_gid = segment_id(thrust::get<0>(l));
uint32_t r_gid = segment_id(thrust::get<0>(r));
if (l_gid != r_gid) {
return r;
}
return Triple(thrust::get<0>(r),
thrust::get<1>(l) + thrust::get<1>(r), // fp
thrust::get<2>(l) + thrust::get<2>(r)); // tp
},
d_fptp.size());
}
/**
* Reduce the values of AUC for each group/class.
*/
template <typename Area, typename Seg>
void SegmentedReduceAUC(common::Span<size_t const> d_unique_idx,
common::Span<uint32_t const> d_class_ptr,
common::Span<uint32_t const> d_unique_class_ptr,
std::shared_ptr<DeviceAUCCache> cache,
Area area_fn,
Seg segment_id,
common::Span<float> d_auc) {
auto d_fptp = dh::ToSpan(cache->fptp);
auto d_neg_pos = dh::ToSpan(cache->neg_pos);
dh::XGBDeviceAllocator<char> alloc;
auto key_in = dh::MakeTransformIterator<uint32_t>(
thrust::make_counting_iterator(0), [=] XGBOOST_DEVICE(size_t i) {
size_t class_id = segment_id(d_unique_idx[i]);
return class_id;
});
auto val_in = dh::MakeTransformIterator<float>(
thrust::make_counting_iterator(0), [=] XGBOOST_DEVICE(size_t i) {
size_t class_id = segment_id(d_unique_idx[i]);
float fp, tp, fp_prev, tp_prev;
if (i == d_unique_class_ptr[class_id]) {
// first item is ignored, we use this thread to calculate the last item
thrust::tie(fp, tp) = d_fptp[LastOf(class_id, d_class_ptr)];
thrust::tie(fp_prev, tp_prev) =
d_neg_pos[d_unique_idx[LastOf(class_id, d_unique_class_ptr)]];
} else {
thrust::tie(fp, tp) = d_fptp[d_unique_idx[i] - 1];
thrust::tie(fp_prev, tp_prev) = d_neg_pos[d_unique_idx[i - 1]];
}
float auc = area_fn(fp_prev, fp, tp_prev, tp, class_id);
return auc;
});
thrust::reduce_by_key(thrust::cuda::par(alloc), key_in,
key_in + d_unique_idx.size(), val_in,
thrust::make_discard_iterator(), dh::tbegin(d_auc));
}
/**
* MultiClass implementation is similar to binary classification, except we need to split
* up each class in all kernels.
*/
float GPUMultiClassAUCOVR(common::Span<float const> predts, MetaInfo const &info,
int32_t device, std::shared_ptr<DeviceAUCCache>* p_cache,
size_t n_classes) {
template <bool scale, typename Fn>
float GPUMultiClassAUCOVR(common::Span<float const> predts,
MetaInfo const &info, int32_t device,
common::Span<uint32_t> d_class_ptr, size_t n_classes,
std::shared_ptr<DeviceAUCCache> cache, Fn area_fn) {
dh::safe_cuda(cudaSetDevice(device));
auto& cache = *p_cache;
if (!cache) {
cache.reset(new DeviceAUCCache);
}
cache->Init(predts, true, device);
/**
* Sorted idx
*/
auto d_predts_t = dh::ToSpan(cache->predts_t);
// Index is sorted within class.
auto d_sorted_idx = dh::ToSpan(cache->sorted_idx);
auto labels = info.labels_.ConstDeviceSpan();
auto weights = info.weights_.ConstDeviceSpan();
@ -250,7 +342,7 @@ float GPUMultiClassAUCOVR(common::Span<float const> predts, MetaInfo const &info
dh::TemporaryArray<float> resutls(n_classes * 4, 0.0f);
auto d_results = dh::ToSpan(resutls);
dh::LaunchN(n_classes * 4,
[=] __device__(size_t i) { d_results[i] = 0.0f; });
[=] XGBOOST_DEVICE(size_t i) { d_results[i] = 0.0f; });
auto local_area = d_results.subspan(0, n_classes);
auto fp = d_results.subspan(n_classes, n_classes);
auto tp = d_results.subspan(2 * n_classes, n_classes);
@ -258,43 +350,26 @@ float GPUMultiClassAUCOVR(common::Span<float const> predts, MetaInfo const &info
return ScaleClasses(d_results, local_area, fp, tp, auc, cache, n_classes);
}
/**
* Create sorted index for each class
*/
auto d_predts_t = dh::ToSpan(cache->predts_t);
Transpose(predts, d_predts_t, n_samples, n_classes, device);
dh::TemporaryArray<uint32_t> class_ptr(n_classes + 1, 0);
auto d_class_ptr = dh::ToSpan(class_ptr);
dh::LaunchN(n_classes + 1,
[=] __device__(size_t i) { d_class_ptr[i] = i * n_samples; });
// no out-of-place sort for thrust, cub sort doesn't accept general iterator. So can't
// use transform iterator in sorting.
auto d_sorted_idx = dh::ToSpan(cache->sorted_idx);
dh::SegmentedArgSort<false>(d_predts_t, d_class_ptr, d_sorted_idx);
/**
* Linear scan
*/
dh::caching_device_vector<float> d_auc(n_classes, 0);
auto s_d_auc = dh::ToSpan(d_auc);
auto get_weight = GetWeightOp{weights, d_sorted_idx};
using Pair = thrust::pair<float, float>;
auto get_weight = OptionalWeights{weights};
auto d_fptp = dh::ToSpan(cache->fptp);
auto get_fp_tp = [=]__device__(size_t i) {
auto get_fp_tp = [=]XGBOOST_DEVICE(size_t i) {
size_t idx = d_sorted_idx[i];
size_t class_id = i / n_samples;
// labels is a vector of size n_samples.
float label = labels[idx % n_samples] == class_id;
float w = weights.empty() ? 1.0f : weights[d_sorted_idx[i] % n_samples];
float w = get_weight[d_sorted_idx[i] % n_samples];
float fp = (1.0 - label) * w;
float tp = label * w;
return thrust::make_pair(fp, tp);
}; // NOLINT
dh::LaunchN(d_sorted_idx.size(),
[=] __device__(size_t i) { d_fptp[i] = get_fp_tp(i); });
[=] XGBOOST_DEVICE(size_t i) { d_fptp[i] = get_fp_tp(i); });
/**
* Handle duplicated predictions
@ -303,14 +378,14 @@ float GPUMultiClassAUCOVR(common::Span<float const> predts, MetaInfo const &info
auto d_unique_idx = dh::ToSpan(cache->unique_idx);
dh::Iota(d_unique_idx);
auto uni_key = dh::MakeTransformIterator<thrust::pair<uint32_t, float>>(
thrust::make_counting_iterator(0), [=] __device__(size_t i) {
thrust::make_counting_iterator(0), [=] XGBOOST_DEVICE(size_t i) {
uint32_t class_id = i / n_samples;
float predt = d_predts_t[d_sorted_idx[i]];
return thrust::make_pair(class_id, predt);
});
// unique values are sparse, so we need a CSR style indptr
dh::TemporaryArray<uint32_t> unique_class_ptr(class_ptr.size());
dh::TemporaryArray<uint32_t> unique_class_ptr(d_class_ptr.size());
auto d_unique_class_ptr = dh::ToSpan(unique_class_ptr);
auto n_uniques = dh::SegmentedUniqueByKey(
thrust::cuda::par(alloc),
@ -324,39 +399,14 @@ float GPUMultiClassAUCOVR(common::Span<float const> predts, MetaInfo const &info
thrust::equal_to<thrust::pair<uint32_t, float>>{});
d_unique_idx = d_unique_idx.subspan(0, n_uniques);
using Triple = thrust::tuple<uint32_t, float, float>;
// expand to tuple to include class id
auto fptp_it_in = dh::MakeTransformIterator<Triple>(
thrust::make_counting_iterator(0), [=] __device__(size_t i) {
return thrust::make_tuple(i, d_fptp[i].first, d_fptp[i].second);
});
// shrink down to pair
auto fptp_it_out = thrust::make_transform_output_iterator(
dh::TypedDiscard<Triple>{}, [d_fptp] __device__(Triple const &t) {
d_fptp[thrust::get<0>(t)] =
thrust::make_pair(thrust::get<1>(t), thrust::get<2>(t));
return t;
});
dh::InclusiveScan(
fptp_it_in, fptp_it_out,
[=] __device__(Triple const &l, Triple const &r) {
uint32_t l_cid = thrust::get<0>(l) / n_samples;
uint32_t r_cid = thrust::get<0>(r) / n_samples;
if (l_cid != r_cid) {
return r;
}
return Triple(thrust::get<0>(r),
thrust::get<1>(l) + thrust::get<1>(r), // fp
thrust::get<2>(l) + thrust::get<2>(r)); // tp
},
d_fptp.size());
auto get_class_id = [=] XGBOOST_DEVICE(size_t idx) { return idx / n_samples; };
SegmentedFPTP(d_fptp, get_class_id);
// scatter unique FP_PREV/TP_PREV values
auto d_neg_pos = dh::ToSpan(cache->neg_pos);
// When dataset is not empty, each class must have at least 1 (unique) sample
// prediction, so no need to handle special case.
dh::LaunchN(d_unique_idx.size(), [=] __device__(size_t i) {
dh::LaunchN(d_unique_idx.size(), [=] XGBOOST_DEVICE(size_t i) {
if (d_unique_idx[i] % n_samples == 0) { // first unique index is 0
assert(d_unique_idx[i] % n_samples == 0);
d_neg_pos[d_unique_idx[i]] = {0, 0}; // class_id * n_samples = i
@ -375,32 +425,9 @@ float GPUMultiClassAUCOVR(common::Span<float const> predts, MetaInfo const &info
/**
* Reduce the result for each class
*/
auto key_in = dh::MakeTransformIterator<uint32_t>(
thrust::make_counting_iterator(0), [=] __device__(size_t i) {
size_t class_id = d_unique_idx[i] / n_samples;
return class_id;
});
auto val_in = dh::MakeTransformIterator<float>(
thrust::make_counting_iterator(0), [=] __device__(size_t i) {
size_t class_id = d_unique_idx[i] / n_samples;
float fp, tp;
float fp_prev, tp_prev;
if (i == d_unique_class_ptr[class_id]) {
// first item is ignored, we use this thread to calculate the last item
thrust::tie(fp, tp) = d_fptp[class_id * n_samples + (n_samples - 1)];
thrust::tie(fp_prev, tp_prev) =
d_neg_pos[d_unique_idx[LastOf(class_id, d_unique_class_ptr)]];
} else {
thrust::tie(fp, tp) = d_fptp[d_unique_idx[i] - 1];
thrust::tie(fp_prev, tp_prev) = d_neg_pos[d_unique_idx[i - 1]];
}
float auc = TrapesoidArea(fp_prev, fp, tp_prev, tp);
return auc;
});
thrust::reduce_by_key(thrust::cuda::par(alloc), key_in,
key_in + d_unique_idx.size(), val_in,
thrust::make_discard_iterator(), d_auc.begin());
auto s_d_auc = dh::ToSpan(d_auc);
SegmentedReduceAUC(d_unique_idx, d_class_ptr, d_unique_class_ptr, cache,
area_fn, get_class_id, s_d_auc);
/**
* Scale the classes with number of samples for each class.
@ -412,16 +439,58 @@ float GPUMultiClassAUCOVR(common::Span<float const> predts, MetaInfo const &info
auto tp = d_results.subspan(2 * n_classes, n_classes);
auto auc = d_results.subspan(3 * n_classes, n_classes);
dh::LaunchN(n_classes, [=] __device__(size_t c) {
dh::LaunchN(n_classes, [=] XGBOOST_DEVICE(size_t c) {
auc[c] = s_d_auc[c];
auto last = d_fptp[n_samples * c + (n_samples - 1)];
fp[c] = last.first;
tp[c] = last.second;
if (scale) {
local_area[c] = last.first * last.second;
tp[c] = last.second;
} else {
local_area[c] = 1.0f;
tp[c] = 1.0f;
}
});
return ScaleClasses(d_results, local_area, fp, tp, auc, cache, n_classes);
}
void MultiClassSortedIdx(common::Span<float const> predts,
common::Span<uint32_t> d_class_ptr,
std::shared_ptr<DeviceAUCCache> cache) {
size_t n_classes = d_class_ptr.size() - 1;
auto d_predts_t = dh::ToSpan(cache->predts_t);
auto n_samples = d_predts_t.size() / n_classes;
if (n_samples == 0) {
return;
}
Transpose(predts, d_predts_t, n_samples, n_classes);
dh::LaunchN(n_classes + 1,
[=] XGBOOST_DEVICE(size_t i) { d_class_ptr[i] = i * n_samples; });
auto d_sorted_idx = dh::ToSpan(cache->sorted_idx);
dh::SegmentedArgSort<false>(d_predts_t, d_class_ptr, d_sorted_idx);
}
float GPUMultiClassROCAUC(common::Span<float const> predts,
MetaInfo const &info, int32_t device,
std::shared_ptr<DeviceAUCCache> *p_cache,
size_t n_classes) {
auto& cache = *p_cache;
InitCacheOnce<true>(predts, device, p_cache);
/**
* Create sorted index for each class
*/
dh::TemporaryArray<uint32_t> class_ptr(n_classes + 1, 0);
MultiClassSortedIdx(predts, dh::ToSpan(class_ptr), cache);
auto fn = [] XGBOOST_DEVICE(float fp_prev, float fp, float tp_prev, float tp,
size_t /*class_id*/) {
return TrapezoidArea(fp_prev, fp, tp_prev, tp);
};
return GPUMultiClassAUCOVR<true>(predts, info, device, dh::ToSpan(class_ptr),
n_classes, cache, fn);
}
namespace {
struct RankScanItem {
size_t idx;
@ -435,10 +504,7 @@ std::pair<float, uint32_t>
GPURankingAUC(common::Span<float const> predts, MetaInfo const &info,
int32_t device, std::shared_ptr<DeviceAUCCache> *p_cache) {
auto& cache = *p_cache;
if (!cache) {
cache.reset(new DeviceAUCCache);
}
cache->Init(predts, false, device);
InitCacheOnce<false>(predts, device, p_cache);
dh::caching_device_vector<bst_group_t> group_ptr(info.group_ptr_);
dh::XGBCachingDeviceAllocator<char> alloc;
@ -449,10 +515,10 @@ GPURankingAUC(common::Span<float const> predts, MetaInfo const &info,
*/
auto check_it = dh::MakeTransformIterator<size_t>(
thrust::make_counting_iterator(0),
[=] __device__(size_t i) { return d_group_ptr[i + 1] - d_group_ptr[i]; });
[=] XGBOOST_DEVICE(size_t i) { return d_group_ptr[i + 1] - d_group_ptr[i]; });
size_t n_valid = thrust::count_if(
thrust::cuda::par(alloc), check_it, check_it + group_ptr.size() - 1,
[=] __device__(size_t len) { return len >= 3; });
[=] XGBOOST_DEVICE(size_t len) { return len >= 3; });
if (n_valid < info.group_ptr_.size() - 1) {
InvalidGroupAUC();
}
@ -475,8 +541,9 @@ GPURankingAUC(common::Span<float const> predts, MetaInfo const &info,
// Use max to represent triangle
auto n_threads = common::SegmentedTrapezoidThreads(
d_group_ptr, d_threads_group_ptr, std::numeric_limits<size_t>::max());
CHECK_LT(n_threads, std::numeric_limits<int32_t>::max());
// get the coordinate in nested summation
auto get_i_j = [=]__device__(size_t idx, size_t query_group_idx) {
auto get_i_j = [=]XGBOOST_DEVICE(size_t idx, size_t query_group_idx) {
auto data_group_begin = d_group_ptr[query_group_idx];
size_t n_samples = d_group_ptr[query_group_idx + 1] - data_group_begin;
auto thread_group_begin = d_threads_group_ptr[query_group_idx];
@ -491,7 +558,7 @@ GPURankingAUC(common::Span<float const> predts, MetaInfo const &info,
return thrust::make_pair(i, j);
}; // NOLINT
auto in = dh::MakeTransformIterator<RankScanItem>(
thrust::make_counting_iterator(0), [=] __device__(size_t idx) {
thrust::make_counting_iterator(0), [=] XGBOOST_DEVICE(size_t idx) {
bst_group_t query_group_idx = dh::SegmentId(d_threads_group_ptr, idx);
auto data_group_begin = d_group_ptr[query_group_idx];
size_t n_samples = d_group_ptr[query_group_idx + 1] - data_group_begin;
@ -519,7 +586,8 @@ GPURankingAUC(common::Span<float const> predts, MetaInfo const &info,
dh::TemporaryArray<float> d_auc(group_ptr.size() - 1);
auto s_d_auc = dh::ToSpan(d_auc);
auto out = thrust::make_transform_output_iterator(
dh::TypedDiscard<RankScanItem>{}, [=] __device__(RankScanItem const &item) -> RankScanItem {
dh::TypedDiscard<RankScanItem>{},
[=] XGBOOST_DEVICE(RankScanItem const &item) -> RankScanItem {
auto group_id = item.group_id;
assert(group_id < d_group_ptr.size());
auto data_group_begin = d_group_ptr[group_id];
@ -536,7 +604,7 @@ GPURankingAUC(common::Span<float const> predts, MetaInfo const &info,
});
dh::InclusiveScan(
in, out,
[] __device__(RankScanItem const &l, RankScanItem const &r) {
[] XGBOOST_DEVICE(RankScanItem const &l, RankScanItem const &r) {
if (l.group_id != r.group_id) {
return r;
}
@ -551,5 +619,288 @@ GPURankingAUC(common::Span<float const> predts, MetaInfo const &info,
dh::tend(s_d_auc), 0.0f);
return std::make_pair(auc, n_valid);
}
std::tuple<float, float, float>
GPUBinaryPRAUC(common::Span<float const> predts, MetaInfo const &info,
int32_t device, std::shared_ptr<DeviceAUCCache> *p_cache) {
auto& cache = *p_cache;
InitCacheOnce<false>(predts, device, p_cache);
/**
* Create sorted index for each class
*/
auto d_sorted_idx = dh::ToSpan(cache->sorted_idx);
dh::ArgSort<false>(predts, d_sorted_idx);
auto labels = info.labels_.ConstDeviceSpan();
auto d_weights = info.weights_.ConstDeviceSpan();
auto get_weight = OptionalWeights{d_weights};
auto it = dh::MakeTransformIterator<thrust::pair<float, float>>(
thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(size_t i) {
auto w = get_weight[d_sorted_idx[i]];
return thrust::make_pair(labels[d_sorted_idx[i]] * w,
(1.0f - labels[d_sorted_idx[i]]) * w);
});
dh::XGBCachingDeviceAllocator<char> alloc;
float total_pos, total_neg;
thrust::tie(total_pos, total_neg) =
thrust::reduce(thrust::cuda::par(alloc), it, it + labels.size(),
Pair{0.0f, 0.0f}, PairPlus<float, float>{});
if (total_pos <= 0.0 || total_neg <= 0.0) {
return {0.0f, 0.0f, 0.0f};
}
auto fn = [total_pos] XGBOOST_DEVICE(float fp_prev, float fp, float tp_prev,
float tp) {
return detail::CalcDeltaPRAUC(fp_prev, fp, tp_prev, tp, total_pos);
};
float fp, tp, auc;
std::tie(fp, tp, auc) = GPUBinaryAUC(predts, info, device, d_sorted_idx, fn, cache);
return std::make_tuple(1.0, 1.0, auc);
}
float GPUMultiClassPRAUC(common::Span<float const> predts,
MetaInfo const &info, int32_t device,
std::shared_ptr<DeviceAUCCache> *p_cache,
size_t n_classes) {
auto& cache = *p_cache;
InitCacheOnce<true>(predts, device, p_cache);
/**
* Create sorted index for each class
*/
dh::TemporaryArray<uint32_t> class_ptr(n_classes + 1, 0);
auto d_class_ptr = dh::ToSpan(class_ptr);
MultiClassSortedIdx(predts, d_class_ptr, cache);
auto d_sorted_idx = dh::ToSpan(cache->sorted_idx);
auto d_weights = info.weights_.ConstDeviceSpan();
/**
* Get total positive/negative
*/
auto labels = info.labels_.ConstDeviceSpan();
auto n_samples = info.num_row_;
dh::caching_device_vector<thrust::pair<float, float>> totals(n_classes);
auto key_it =
dh::MakeTransformIterator<size_t>(thrust::make_counting_iterator(0ul),
[n_samples] XGBOOST_DEVICE(size_t i) {
return i / n_samples; // class id
});
auto get_weight = OptionalWeights{d_weights};
auto val_it = dh::MakeTransformIterator<thrust::pair<float, float>>(
thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(size_t i) {
auto idx = d_sorted_idx[i] % n_samples;
auto w = get_weight[idx];
auto class_id = i / n_samples;
auto y = labels[idx] == class_id;
return thrust::make_pair(y * w, (1.0f - y) * w);
});
dh::XGBCachingDeviceAllocator<char> alloc;
thrust::reduce_by_key(thrust::cuda::par(alloc), key_it,
key_it + predts.size(), val_it,
thrust::make_discard_iterator(), totals.begin(),
thrust::equal_to<size_t>{}, PairPlus<float, float>{});
/**
* Calculate AUC
*/
auto d_totals = dh::ToSpan(totals);
auto fn = [d_totals] XGBOOST_DEVICE(float fp_prev, float fp, float tp_prev,
float tp, size_t class_id) {
auto total_pos = d_totals[class_id].first;
return detail::CalcDeltaPRAUC(fp_prev, fp, tp_prev, tp,
d_totals[class_id].first);
};
return GPUMultiClassAUCOVR<false>(predts, info, device, d_class_ptr,
n_classes, cache, fn);
}
template <typename Fn>
std::pair<float, uint32_t>
GPURankingPRAUCImpl(common::Span<float const> predts, MetaInfo const &info,
common::Span<uint32_t> d_group_ptr, int32_t device,
std::shared_ptr<DeviceAUCCache> cache, Fn area_fn) {
/**
* Sorted idx
*/
auto d_sorted_idx = dh::ToSpan(cache->sorted_idx);
auto labels = info.labels_.ConstDeviceSpan();
auto weights = info.weights_.ConstDeviceSpan();
uint32_t n_groups = static_cast<uint32_t>(info.group_ptr_.size() - 1);
/**
* Linear scan
*/
size_t n_samples = labels.size();
dh::caching_device_vector<float> d_auc(n_groups, 0);
auto get_weight = OptionalWeights{weights};
auto d_fptp = dh::ToSpan(cache->fptp);
auto get_fp_tp = [=] XGBOOST_DEVICE(size_t i) {
size_t idx = d_sorted_idx[i];
size_t group_id = dh::SegmentId(d_group_ptr, idx);
float label = labels[idx];
float w = get_weight[group_id];
float fp = (1.0 - label) * w;
float tp = label * w;
return thrust::make_pair(fp, tp);
}; // NOLINT
dh::LaunchN(d_sorted_idx.size(),
[=] XGBOOST_DEVICE(size_t i) { d_fptp[i] = get_fp_tp(i); });
/**
* Handle duplicated predictions
*/
dh::XGBDeviceAllocator<char> alloc;
auto d_unique_idx = dh::ToSpan(cache->unique_idx);
dh::Iota(d_unique_idx);
auto uni_key = dh::MakeTransformIterator<thrust::pair<uint32_t, float>>(
thrust::make_counting_iterator(0), [=] XGBOOST_DEVICE(size_t i) {
auto idx = d_sorted_idx[i];
bst_group_t group_id = dh::SegmentId(d_group_ptr, idx);
float predt = predts[idx];
return thrust::make_pair(group_id, predt);
});
// unique values are sparse, so we need a CSR style indptr
dh::TemporaryArray<uint32_t> unique_class_ptr(d_group_ptr.size());
auto d_unique_class_ptr = dh::ToSpan(unique_class_ptr);
auto n_uniques = dh::SegmentedUniqueByKey(
thrust::cuda::par(alloc),
dh::tbegin(d_group_ptr),
dh::tend(d_group_ptr),
uni_key,
uni_key + d_sorted_idx.size(),
dh::tbegin(d_unique_idx),
d_unique_class_ptr.data(),
dh::tbegin(d_unique_idx),
thrust::equal_to<thrust::pair<uint32_t, float>>{});
d_unique_idx = d_unique_idx.subspan(0, n_uniques);
auto get_group_id = [=] XGBOOST_DEVICE(size_t idx) {
return dh::SegmentId(d_group_ptr, idx);
};
SegmentedFPTP(d_fptp, get_group_id);
// scatter unique FP_PREV/TP_PREV values
auto d_neg_pos = dh::ToSpan(cache->neg_pos);
dh::LaunchN(d_unique_idx.size(), [=] XGBOOST_DEVICE(size_t i) {
if (thrust::binary_search(thrust::seq, d_unique_class_ptr.cbegin(),
d_unique_class_ptr.cend(),
i)) { // first unique index is 0
d_neg_pos[d_unique_idx[i]] = {0, 0};
return;
}
auto group_idx = dh::SegmentId(d_group_ptr, d_unique_idx[i]);
d_neg_pos[d_unique_idx[i]] = d_fptp[d_unique_idx[i] - 1];
if (i == LastOf(group_idx, d_unique_class_ptr)) {
// last one needs to be included.
size_t last = d_unique_idx[LastOf(group_idx, d_unique_class_ptr)];
d_neg_pos[LastOf(group_idx, d_group_ptr)] = d_fptp[last - 1];
return;
}
});
/**
* Reduce the result for each group
*/
auto s_d_auc = dh::ToSpan(d_auc);
SegmentedReduceAUC(d_unique_idx, d_group_ptr, d_unique_class_ptr, cache,
area_fn, get_group_id, s_d_auc);
/**
* Scale the groups with number of samples for each group.
*/
float auc;
uint32_t invalid_groups;
{
auto it = dh::MakeTransformIterator<thrust::pair<float, uint32_t>>(
thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(size_t g) {
float fp, tp;
thrust::tie(fp, tp) = d_fptp[LastOf(g, d_group_ptr)];
float area = fp * tp;
auto n_documents = d_group_ptr[g + 1] - d_group_ptr[g];
if (area > 0 && n_documents >= 2) {
return thrust::make_pair(s_d_auc[g], static_cast<uint32_t>(0));
}
return thrust::make_pair(0.0f, static_cast<uint32_t>(1));
});
thrust::tie(auc, invalid_groups) = thrust::reduce(
thrust::cuda::par(alloc), it, it + n_groups,
thrust::pair<float, uint32_t>(0.0f, 0), PairPlus<float, uint32_t>{});
}
return std::make_pair(auc, n_groups - invalid_groups);
}
std::pair<float, uint32_t>
GPURankingPRAUC(common::Span<float const> predts, MetaInfo const &info,
int32_t device, std::shared_ptr<DeviceAUCCache> *p_cache) {
dh::safe_cuda(cudaSetDevice(device));
if (predts.empty()) {
return std::make_pair(0.0f, static_cast<uint32_t>(0));
}
auto &cache = *p_cache;
InitCacheOnce<false>(predts, device, p_cache);
dh::device_vector<bst_group_t> group_ptr(info.group_ptr_.size());
thrust::copy(info.group_ptr_.begin(), info.group_ptr_.end(), group_ptr.begin());
auto d_group_ptr = dh::ToSpan(group_ptr);
CHECK_GE(info.group_ptr_.size(), 1) << "Must have at least 1 query group for LTR.";
size_t n_groups = info.group_ptr_.size() - 1;
/**
* Create sorted index for each group
*/
auto d_sorted_idx = dh::ToSpan(cache->sorted_idx);
dh::SegmentedArgSort<false>(predts, d_group_ptr, d_sorted_idx);
dh::XGBDeviceAllocator<char> alloc;
auto labels = info.labels_.ConstDeviceSpan();
if (thrust::any_of(thrust::cuda::par(alloc), dh::tbegin(labels),
dh::tend(labels), PRAUCLabelInvalid{})) {
InvalidLabels();
}
/**
* Get total positive/negative for each group.
*/
auto d_weights = info.weights_.ConstDeviceSpan();
dh::caching_device_vector<thrust::pair<float, float>> totals(n_groups);
auto key_it = dh::MakeTransformIterator<size_t>(
thrust::make_counting_iterator(0ul),
[=] XGBOOST_DEVICE(size_t i) { return dh::SegmentId(d_group_ptr, i); });
auto val_it = dh::MakeTransformIterator<thrust::pair<float, float>>(
thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(size_t i) {
float w = 1.0f;
if (!d_weights.empty()) {
// Avoid a binary search if the groups are not weighted.
auto g = dh::SegmentId(d_group_ptr, i);
w = d_weights[g];
}
auto y = labels[i];
return thrust::make_pair(y * w, (1.0f - y) * w);
});
thrust::reduce_by_key(thrust::cuda::par(alloc), key_it,
key_it + predts.size(), val_it,
thrust::make_discard_iterator(), totals.begin(),
thrust::equal_to<size_t>{}, PairPlus<float, float>{});
/**
* Calculate AUC
*/
auto d_totals = dh::ToSpan(totals);
auto fn = [d_totals] XGBOOST_DEVICE(float fp_prev, float fp, float tp_prev,
float tp, size_t group_id) {
auto total_pos = d_totals[group_id].first;
return detail::CalcDeltaPRAUC(fp_prev, fp, tp_prev, tp,
d_totals[group_id].first);
};
return GPURankingPRAUCImpl(predts, info, d_group_ptr, n_groups, cache, fn);
}
} // namespace metric
} // namespace xgboost

View File

@ -3,7 +3,9 @@
*/
#ifndef XGBOOST_METRIC_AUC_H_
#define XGBOOST_METRIC_AUC_H_
#include <array>
#include <cmath>
#include <limits>
#include <memory>
#include <tuple>
#include <utility>
@ -12,32 +14,115 @@
#include "xgboost/base.h"
#include "xgboost/span.h"
#include "xgboost/data.h"
#include "xgboost/metric.h"
#include "../common/common.h"
#include "../common/threading_utils.h"
namespace xgboost {
namespace metric {
XGBOOST_DEVICE inline float TrapesoidArea(float x0, float x1, float y0, float y1) {
/***********
* ROC AUC *
***********/
XGBOOST_DEVICE inline float TrapezoidArea(float x0, float x1, float y0, float y1) {
return std::abs(x0 - x1) * (y0 + y1) * 0.5f;
}
struct DeviceAUCCache;
std::tuple<float, float, float>
GPUBinaryAUC(common::Span<float const> predts, MetaInfo const &info,
GPUBinaryROCAUC(common::Span<float const> predts, MetaInfo const &info,
int32_t device, std::shared_ptr<DeviceAUCCache> *p_cache);
float GPUMultiClassAUCOVR(common::Span<float const> predts, MetaInfo const &info,
int32_t device, std::shared_ptr<DeviceAUCCache>* cache,
float GPUMultiClassROCAUC(common::Span<float const> predts,
MetaInfo const &info, int32_t device,
std::shared_ptr<DeviceAUCCache> *cache,
size_t n_classes);
std::pair<float, uint32_t>
GPURankingAUC(common::Span<float const> predts, MetaInfo const &info,
int32_t device, std::shared_ptr<DeviceAUCCache> *cache);
/**********
* PR AUC *
**********/
std::tuple<float, float, float>
GPUBinaryPRAUC(common::Span<float const> predts, MetaInfo const &info,
int32_t device, std::shared_ptr<DeviceAUCCache> *p_cache);
float GPUMultiClassPRAUC(common::Span<float const> predts, MetaInfo const &info,
int32_t device, std::shared_ptr<DeviceAUCCache> *cache,
size_t n_classes);
std::pair<float, uint32_t>
GPURankingPRAUC(common::Span<float const> predts, MetaInfo const &info,
int32_t device, std::shared_ptr<DeviceAUCCache> *cache);
namespace detail {
XGBOOST_DEVICE inline float CalcH(float fp_a, float fp_b, float tp_a,
float tp_b) {
return (fp_b - fp_a) / (tp_b - tp_a);
}
XGBOOST_DEVICE inline float CalcB(float fp_a, float h, float tp_a, float total_pos) {
return (fp_a - h * tp_a) / total_pos;
}
XGBOOST_DEVICE inline float CalcA(float h) { return h + 1; }
XGBOOST_DEVICE inline float CalcDeltaPRAUC(float fp_prev, float fp,
float tp_prev, float tp,
float total_pos) {
float pr_prev = tp_prev / total_pos;
float pr = tp / total_pos;
float h{0}, a{0}, b{0};
if (tp == tp_prev) {
a = 1.0;
b = 0.0;
} else {
h = detail::CalcH(fp_prev, fp, tp_prev, tp);
a = detail::CalcA(h);
b = detail::CalcB(fp_prev, h, tp_prev, total_pos);
}
float area = 0;
if (b != 0.0) {
area = (pr - pr_prev -
b / a * (std::log(a * pr + b) - std::log(a * pr_prev + b))) /
a;
} else {
area = (pr - pr_prev) / a;
}
return area;
}
} // namespace detail
inline void InvalidGroupAUC() {
LOG(INFO) << "Invalid group with less than 3 samples is found on worker "
<< rabit::GetRank() << ". Calculating AUC value requires at "
<< "least 2 pairs of samples.";
}
struct PRAUCLabelInvalid {
XGBOOST_DEVICE bool operator()(float y) { return y < 0.0f || y > 1.0f; }
};
inline void InvalidLabels() {
LOG(FATAL) << "PR-AUC supports only binary relevance for learning to rank.";
}
struct OptionalWeights {
common::Span<float const> weights;
float dft { 1.0f };
explicit OptionalWeights(common::Span<float const> w) : weights{w} {}
explicit OptionalWeights(float w) : dft{w} {}
XGBOOST_DEVICE float operator[](size_t i) const {
return weights.empty() ? dft : weights[i];
}
};
} // namespace metric
} // namespace xgboost
#endif // XGBOOST_METRIC_AUC_H_

View File

@ -392,166 +392,10 @@ struct EvalCox : public Metric {
}
};
/*! \brief Area Under PR Curve, for both classification and rank computed on CPU */
struct EvalAucPR : public Metric {
// implementation of AUC-PR for weighted data
// translated from PRROC R Package
// see https://doi.org/10.1371/journal.pone.0092209
private:
// This is used to compute the AUCPR metrics on the GPU - for ranking tasks and
// for training jobs that run on the GPU.
std::unique_ptr<xgboost::Metric> aucpr_gpu_;
template <typename WeightPolicy>
bst_float Eval(const HostDeviceVector<bst_float> &preds,
const MetaInfo &info,
bool distributed,
const std::vector<unsigned> &gptr) {
const auto ngroups = static_cast<bst_omp_uint>(gptr.size() - 1);
// sum of all AUC's across all query groups
double sum_auc = 0.0;
int auc_error = 0;
const auto &h_labels = info.labels_.ConstHostVector();
const auto &h_preds = preds.ConstHostVector();
dmlc::OMPException exc;
#pragma omp parallel reduction(+:sum_auc, auc_error) if (ngroups > 1)
{
exc.Run([&]() {
// Each thread works on a distinct group and sorts the predictions in that group
PredIndPairContainer rec;
#pragma omp for schedule(static)
for (bst_omp_uint group_id = 0; group_id < ngroups; ++group_id) {
exc.Run([&]() {
double total_pos = 0.0;
double total_neg = 0.0;
// Same thread can work on multiple groups one after another; hence, resize
// the predictions array based on the current group
rec.resize(gptr[group_id + 1] - gptr[group_id]);
#pragma omp parallel for schedule(static) reduction(+:total_pos, total_neg) \
if (!omp_in_parallel()) // NOLINT
for (bst_omp_uint j = gptr[group_id]; j < gptr[group_id + 1]; ++j) {
exc.Run([&]() {
const bst_float wt = WeightPolicy::GetWeightOfInstance(info, j, group_id);
total_pos += wt * h_labels[j];
total_neg += wt * (1.0f - h_labels[j]);
rec[j - gptr[group_id]] = {h_preds[j], j};
});
}
// we need pos > 0 && neg > 0
if (total_pos <= 0.0 || total_neg <= 0.0) {
auc_error += 1;
return;
}
XGBOOST_PARALLEL_SORT(rec.begin(), rec.end(), common::CmpFirst);
// calculate AUC
double tp = 0.0, prevtp = 0.0, fp = 0.0, prevfp = 0.0, h = 0.0, a = 0.0, b = 0.0;
for (size_t j = 0; j < rec.size(); ++j) {
const bst_float wt = WeightPolicy::GetWeightOfSortedRecord(info, rec, j, group_id);
tp += wt * h_labels[rec[j].second];
fp += wt * (1.0f - h_labels[rec[j].second]);
if ((j < rec.size() - 1 && rec[j].first != rec[j + 1].first) ||
j == rec.size() - 1) {
if (tp == prevtp) {
a = 1.0;
b = 0.0;
} else {
h = (fp - prevfp) / (tp - prevtp);
a = 1.0 + h;
b = (prevfp - h * prevtp) / total_pos;
}
if (0.0 != b) {
sum_auc += (tp / total_pos - prevtp / total_pos -
b / a * (std::log(a * tp / total_pos + b) -
std::log(a * prevtp / total_pos + b))) / a;
} else {
sum_auc += (tp / total_pos - prevtp / total_pos) / a;
}
prevtp = tp;
prevfp = fp;
}
}
// sanity check
if (tp < 0 || prevtp < 0 || fp < 0 || prevfp < 0) {
CHECK(!auc_error) << "AUC-PR: error in calculation";
}
});
}
});
}
exc.Rethrow();
// Report average AUC-PR across all groups
// In distributed mode, workers which only contains pos or neg samples
// will be ignored when aggregate AUC-PR.
bst_float dat[2] = {0.0f, 0.0f};
if (auc_error < static_cast<int>(ngroups)) {
dat[0] = static_cast<bst_float>(sum_auc);
dat[1] = static_cast<bst_float>(static_cast<int>(ngroups) - auc_error);
}
if (distributed) {
rabit::Allreduce<rabit::op::Sum>(dat, 2);
}
CHECK_GT(dat[1], 0.0f)
<< "AUC-PR: the dataset only contains pos or neg samples";
CHECK_LE(dat[0], dat[1]) << "AUC-PR: AUC > 1.0";
return dat[0] / dat[1];
}
public:
bst_float Eval(const HostDeviceVector<bst_float> &preds,
const MetaInfo &info,
bool distributed) override {
CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty";
CHECK_EQ(preds.Size(), info.labels_.Size())
<< "label size predict size not match";
std::vector<unsigned> tgptr(2, 0);
tgptr[1] = static_cast<unsigned>(info.labels_.Size());
const auto &gptr = info.group_ptr_.empty() ? tgptr : info.group_ptr_;
CHECK_EQ(gptr.back(), info.labels_.Size())
<< "EvalAucPR: group structure must match number of prediction";
// For ranking task, weights are per-group
// For binary classification task, weights are per-instance
const bool is_ranking_task =
!info.group_ptr_.empty() && info.weights_.Size() != info.num_row_;
// Check if we have a GPU assignment; else, revert back to CPU
if (tparam_->gpu_id >= 0 && is_ranking_task) {
if (!aucpr_gpu_) {
// Check and see if we have the GPU metric registered in the internal registry
aucpr_gpu_.reset(GPUMetric::CreateGPUMetric(this->Name(), tparam_));
}
if (aucpr_gpu_) {
return aucpr_gpu_->Eval(preds, info, distributed);
}
}
if (is_ranking_task) {
return Eval<PerGroupWeightPolicy>(preds, info, distributed, gptr);
} else {
return Eval<PerInstanceWeightPolicy>(preds, info, distributed, gptr);
}
}
const char *Name() const override { return "aucpr"; }
};
XGBOOST_REGISTER_METRIC(AMS, "ams")
.describe("AMS metric for higgs.")
.set_body([](const char* param) { return new EvalAMS(param); });
XGBOOST_REGISTER_METRIC(AucPR, "aucpr")
.describe("Area under PR curve for both classification and rank.")
.set_body([](const char*) { return new EvalAucPR(); });
XGBOOST_REGISTER_METRIC(Precision, "pre")
.describe("precision@k for rank.")
.set_body([](const char* param) { return new EvalPrecision("pre", param); });

View File

@ -274,196 +274,6 @@ struct EvalMAPGpu {
}
};
/*! \brief Area Under PR Curve metric computation for ranking datasets */
struct EvalAucPRGpu : public Metric {
public:
// This function object computes the item's positive/negative precision value
class ComputeItemPrecision : public thrust::unary_function<uint32_t, float> {
public:
// The precision type to be computed
enum class PrecisionType {
kPositive,
kNegative
};
XGBOOST_DEVICE ComputeItemPrecision(PrecisionType ptype,
uint32_t ngroups,
const float *dweights,
const xgboost::common::Span<const uint32_t> &dgidxs,
const float *dlabels)
: ptype_(ptype), ngroups_(ngroups), dweights_(dweights), dgidxs_(dgidxs), dlabels_(dlabels) {}
// Compute precision value for the prediction that was originally at 'idx'
__device__ __forceinline__ float operator()(uint32_t idx) const {
// For ranking task, weights are per-group
// For binary classification task, weights are per-instance
const auto wt = dweights_ == nullptr ? 1.0f : dweights_[ngroups_ == 1 ? idx : dgidxs_[idx]];
return wt * (ptype_ == PrecisionType::kPositive ? dlabels_[idx] : (1.0f - dlabels_[idx]));
}
private:
PrecisionType ptype_; // Precision type to be computed
uint32_t ngroups_; // Number of groups in the dataset
const float *dweights_; // Instance/group weights
const xgboost::common::Span<const uint32_t> dgidxs_; // The group a given instance belongs to
const float *dlabels_; // Unsorted labels in the dataset
};
bst_float Eval(const HostDeviceVector<bst_float> &preds,
const MetaInfo &info,
bool distributed) override {
// Sanity check is done by the caller
std::vector<unsigned> tgptr(2, 0);
tgptr[1] = static_cast<unsigned>(info.labels_.Size());
const std::vector<unsigned> &gptr = info.group_ptr_.empty() ? tgptr : info.group_ptr_;
auto device = tparam_->gpu_id;
dh::safe_cuda(cudaSetDevice(device));
info.labels_.SetDevice(device);
preds.SetDevice(device);
info.weights_.SetDevice(device);
auto dpreds = preds.ConstDevicePointer();
auto dlabels = info.labels_.ConstDevicePointer();
auto dweights = info.weights_.ConstDevicePointer();
// Sort all the predictions
dh::SegmentSorter<float> segment_pred_sorter;
segment_pred_sorter.SortItems(dpreds, preds.Size(), gptr);
const auto &dsorted_preds = segment_pred_sorter.GetItemsSpan();
// Original positions of the predictions after they have been sorted
const auto &dpreds_orig_pos = segment_pred_sorter.GetOriginalPositionsSpan();
// Group info on device
const auto &dgroups = segment_pred_sorter.GetGroupsSpan();
uint32_t ngroups = segment_pred_sorter.GetNumGroups();
const auto &dgroup_idx = segment_pred_sorter.GetGroupSegmentsSpan();
// First, aggregate the positive and negative precision for each group
dh::caching_device_vector<double> total_pos(ngroups, 0);
dh::caching_device_vector<double> total_neg(ngroups, 0);
// Allocator to be used for managing space overhead while performing transformed reductions
dh::XGBCachingDeviceAllocator<char> alloc;
// Compute each elements positive precision value and reduce them across groups concurrently.
ComputeItemPrecision pos_prec_functor(ComputeItemPrecision::PrecisionType::kPositive,
ngroups, dweights, dgroup_idx, dlabels);
auto end_range =
thrust::reduce_by_key(thrust::cuda::par(alloc),
dh::tcbegin(dgroup_idx), dh::tcend(dgroup_idx),
thrust::make_transform_iterator(
// The indices need not be sequential within a group, as we care only
// about the sum of positive precision values within a group
dh::tcbegin(segment_pred_sorter.GetOriginalPositionsSpan()),
pos_prec_functor),
thrust::make_discard_iterator(), // We don't care for the group indices
total_pos.begin()); // Sum of positive precision values in the group
CHECK(end_range.second - total_pos.begin() == total_pos.size());
// Compute each elements negative precision value and reduce them across groups concurrently.
ComputeItemPrecision neg_prec_functor(ComputeItemPrecision::PrecisionType::kNegative,
ngroups, dweights, dgroup_idx, dlabels);
end_range =
thrust::reduce_by_key(thrust::cuda::par(alloc),
dh::tcbegin(dgroup_idx), dh::tcend(dgroup_idx),
thrust::make_transform_iterator(
// The indices need not be sequential within a group, as we care only
// about the sum of negative precision values within a group
dh::tcbegin(segment_pred_sorter.GetOriginalPositionsSpan()),
neg_prec_functor),
thrust::make_discard_iterator(), // We don't care for the group indices
total_neg.begin()); // Sum of negative precision values in the group
CHECK(end_range.second - total_neg.begin() == total_neg.size());
const auto *dtotal_pos = total_pos.data().get();
const auto *dtotal_neg = total_neg.data().get();
// AUC sum for each group
dh::caching_device_vector<double> sum_auc(ngroups, 0);
// AUC error across all groups
dh::caching_device_vector<int> auc_error(1, 0);
auto *dsum_auc = sum_auc.data().get();
auto *dauc_error = auc_error.data().get();
int device_id = -1;
dh::safe_cuda(cudaGetDevice(&device_id));
// For each group item compute the aggregated precision
dh::LaunchN<1, 32>(ngroups, nullptr, [=] __device__(uint32_t gidx) {
// We need pos > 0 && neg > 0
if (dtotal_pos[gidx] <= 0.0 || dtotal_neg[gidx] <= 0.0) {
atomicAdd(dauc_error, 1);
} else {
auto gbegin = dgroups[gidx];
auto gend = dgroups[gidx + 1];
// Calculate AUC
double tp = 0.0, prevtp = 0.0, fp = 0.0, prevfp = 0.0, h = 0.0, a = 0.0, b = 0.0;
for (auto i = gbegin; i < gend; ++i) {
const auto wt = dweights == nullptr ? 1.0f
: dweights[ngroups == 1 ? dpreds_orig_pos[i] : gidx];
tp += wt * dlabels[dpreds_orig_pos[i]];
fp += wt * (1.0f - dlabels[dpreds_orig_pos[i]]);
if ((i < gend - 1 && dsorted_preds[i] != dsorted_preds[i + 1]) || (i == gend - 1)) {
if (tp == prevtp) {
a = 1.0;
b = 0.0;
} else {
h = (fp - prevfp) / (tp - prevtp);
a = 1.0 + h;
b = (prevfp - h * prevtp) / dtotal_pos[gidx];
}
if (0.0 != b) {
dsum_auc[gidx] += (tp / dtotal_pos[gidx] - prevtp / dtotal_pos[gidx] -
b / a * (std::log(a * tp / dtotal_pos[gidx] + b) -
std::log(a * prevtp / dtotal_pos[gidx] + b))) / a;
} else {
dsum_auc[gidx] += (tp / dtotal_pos[gidx] - prevtp / dtotal_pos[gidx]) / a;
}
prevtp = tp;
prevfp = fp;
}
}
// Sanity check
if (tp < 0 || prevtp < 0 || fp < 0 || prevfp < 0) {
// Check if we have any metric error thus far
auto current_auc_error = atomicAdd(dauc_error, 0);
KERNEL_CHECK(!current_auc_error);
}
}
});
const auto hsum_auc = thrust::reduce(thrust::cuda::par(alloc), sum_auc.begin(), sum_auc.end());
const auto hauc_error = auc_error.back(); // Copy it back to host
// Report average AUC-PR across all groups
// In distributed mode, workers which only contains pos or neg samples
// will be ignored when aggregate AUC-PR.
bst_float dat[2] = {0.0f, 0.0f};
if (hauc_error < static_cast<int>(ngroups)) {
dat[0] = static_cast<bst_float>(hsum_auc);
dat[1] = static_cast<bst_float>(static_cast<int>(ngroups) - hauc_error);
}
if (distributed) {
rabit::Allreduce<rabit::op::Sum>(dat, 2);
}
CHECK_GT(dat[1], 0.0f)
<< "AUC-PR: the dataset only contains pos or neg samples";
CHECK_LE(dat[0], dat[1]) << "AUC-PR: AUC > 1.0";
return dat[0] / dat[1];
}
const char* Name() const override {
return "aucpr";
}
};
XGBOOST_REGISTER_GPU_METRIC(AucPRGpu, "aucpr")
.describe("Area under PR curve for rank computed on GPU.")
.set_body([](const char* param) { return new EvalAucPRGpu(); });
XGBOOST_REGISTER_GPU_METRIC(PrecisionGpu, "pre")
.describe("precision@k for rank computed on GPU.")
.set_body([](const char* param) { return new EvalRankGpu<EvalPrecisionGpu>("pre", param); });

View File

@ -48,7 +48,7 @@ TEST(Metric, DeclareUnifiedTest(BinaryAUC)) {
0.5, 1e-10);
}
TEST(Metric, DeclareUnifiedTest(MultiAUC)) {
TEST(Metric, DeclareUnifiedTest(MultiClassAUC)) {
auto tparam = CreateEmptyGenericParam(GPUIDX);
std::unique_ptr<Metric> uni_ptr{
Metric::Create("auc", &tparam)};
@ -64,6 +64,17 @@ TEST(Metric, DeclareUnifiedTest(MultiAUC)) {
},
{0, 1, 2}),
1.0f, 1e-10);
EXPECT_NEAR(GetMetricEval(metric,
{
1.0f, 0.0f, 0.0f, // p_0
0.0f, 1.0f, 0.0f, // p_1
0.0f, 0.0f, 1.0f // p_2
},
{0, 1, 2},
{1.0f, 1.0f, 1.0f}),
1.0f, 1e-10);
EXPECT_NEAR(GetMetricEval(metric,
{
1.0f, 0.0f, 0.0f, // p_0
@ -72,6 +83,7 @@ TEST(Metric, DeclareUnifiedTest(MultiAUC)) {
},
{2, 1, 0}),
0.5f, 1e-10);
EXPECT_NEAR(GetMetricEval(metric,
{
1.0f, 0.0f, 0.0f, // p_0
@ -139,5 +151,110 @@ TEST(Metric, DeclareUnifiedTest(RankingAUC)) {
/*weights=*/{}, groups),
0.769841f, 1e-6);
}
TEST(Metric, DeclareUnifiedTest(PRAUC)) {
auto tparam = xgboost::CreateEmptyGenericParam(GPUIDX);
xgboost::Metric *metric = xgboost::Metric::Create("aucpr", &tparam);
ASSERT_STREQ(metric->Name(), "aucpr");
EXPECT_NEAR(GetMetricEval(metric, {0, 0, 1, 1}, {0, 0, 1, 1}), 1, 1e-10);
EXPECT_NEAR(GetMetricEval(metric, {0.1f, 0.9f, 0.1f, 0.9f}, {0, 0, 1, 1}),
0.5f, 0.001f);
EXPECT_NEAR(GetMetricEval(
metric,
{0.4f, 0.2f, 0.9f, 0.1f, 0.2f, 0.4f, 0.1f, 0.1f, 0.2f, 0.1f},
{0, 0, 0, 0, 0, 1, 0, 0, 1, 1}),
0.2908445f, 0.001f);
EXPECT_NEAR(GetMetricEval(
metric, {0.87f, 0.31f, 0.40f, 0.42f, 0.25f, 0.66f, 0.95f,
0.09f, 0.10f, 0.97f, 0.76f, 0.69f, 0.15f, 0.20f,
0.30f, 0.14f, 0.07f, 0.58f, 0.61f, 0.08f},
{0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 1, 1}),
0.2769199f, 0.001f);
auto auc = GetMetricEval(metric, {0, 1}, {});
ASSERT_TRUE(std::isnan(auc));
// AUCPR with instance weights
EXPECT_NEAR(GetMetricEval(metric,
{0.29f, 0.52f, 0.11f, 0.21f, 0.219f, 0.93f, 0.493f,
0.17f, 0.47f, 0.13f, 0.43f, 0.59f, 0.87f, 0.007f},
{0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0},
{1, 2, 7, 4, 5, 2.2f, 3.2f, 5, 6, 1, 2, 1.1f, 3.2f,
4.5f}), // weights
0.694435f, 0.001f);
// Both groups contain only pos or neg samples.
auc = GetMetricEval(metric,
{0, 0.1f, 0.3f, 0.5f, 0.7f},
{1, 1, 0, 0, 0},
{},
{0, 2, 5});
ASSERT_TRUE(std::isnan(auc));
delete metric;
}
TEST(Metric, DeclareUnifiedTest(MultiClassPRAUC)) {
auto tparam = xgboost::CreateEmptyGenericParam(GPUIDX);
std::unique_ptr<Metric> metric{Metric::Create("aucpr", &tparam)};
float auc = 0;
std::vector<float> labels {1.0f, 0.0f, 2.0f};
HostDeviceVector<float> predts{
0.0f, 1.0f, 0.0f,
1.0f, 0.0f, 0.0f,
0.0f, 0.0f, 1.0f,
};
auc = GetMetricEval(metric.get(), predts, labels, {});
EXPECT_EQ(auc, 1.0f);
auc = GetMetricEval(metric.get(), predts, labels, {1.0f, 1.0f, 1.0f});
EXPECT_EQ(auc, 1.0f);
predts.HostVector() = {
0.0f, 1.0f, 0.0f,
1.0f, 0.0f, 0.0f,
0.0f, 0.0f, 1.0f,
0.0f, 0.0f, 1.0f,
};
labels = {1.0f, 0.0f, 2.0f, 1.0f};
auc = GetMetricEval(metric.get(), predts, labels, {1.0f, 2.0f, 3.0f, 4.0f});
ASSERT_GT(auc, 0.699);
}
TEST(Metric, DeclareUnifiedTest(RankingPRAUC)) {
auto tparam = xgboost::CreateEmptyGenericParam(GPUIDX);
std::unique_ptr<Metric> metric{Metric::Create("aucpr", &tparam)};
std::vector<float> labels {1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 1.0f};
std::vector<uint32_t> groups {0, 2, 6};
float auc = 0;
auc = GetMetricEval(metric.get(), {1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 1.0f}, labels, {}, groups);
EXPECT_EQ(auc, 1.0f);
auc = GetMetricEval(metric.get(), {1.0f, 0.5f, 0.8f, 0.3f, 0.2f, 1.0f}, labels, {}, groups);
EXPECT_EQ(auc, 1.0f);
auc = GetMetricEval(metric.get(), {1.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f},
{1.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f}, {}, groups);
ASSERT_TRUE(std::isnan(auc));
// Incorrect label
ASSERT_THROW(GetMetricEval(metric.get(), {1.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f},
{1.0f, 1.0f, 0.0f, 0.0f, 0.0f, 3.0f}, {}, groups),
dmlc::Error);
// AUCPR with groups and no weights
EXPECT_NEAR(GetMetricEval(
metric.get(), {0.87f, 0.31f, 0.40f, 0.42f, 0.25f, 0.66f, 0.95f,
0.09f, 0.10f, 0.97f, 0.76f, 0.69f, 0.15f, 0.20f,
0.30f, 0.14f, 0.07f, 0.58f, 0.61f, 0.08f},
{0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 1, 1},
{}, // weights
{0, 2, 5, 9, 14, 20}), // group info
0.556021f, 0.001f);
}
} // namespace metric
} // namespace xgboost

View File

@ -24,66 +24,6 @@ TEST(Metric, AMS) {
}
#endif
TEST(Metric, DeclareUnifiedTest(AUCPR)) {
auto tparam = xgboost::CreateEmptyGenericParam(GPUIDX);
xgboost::Metric *metric = xgboost::Metric::Create("aucpr", &tparam);
ASSERT_STREQ(metric->Name(), "aucpr");
EXPECT_NEAR(GetMetricEval(metric, {0, 0, 1, 1}, {0, 0, 1, 1}), 1, 1e-10);
EXPECT_NEAR(GetMetricEval(metric, {0.1f, 0.9f, 0.1f, 0.9f}, {0, 0, 1, 1}),
0.5f, 0.001f);
EXPECT_NEAR(
GetMetricEval(metric,
{0.4f, 0.2f, 0.9f, 0.1f, 0.2f, 0.4f, 0.1f, 0.1f, 0.2f, 0.1f},
{0, 0, 0, 0, 0, 1, 0, 0, 1, 1}),
0.2908445f, 0.001f);
EXPECT_NEAR(GetMetricEval(
metric, {0.87f, 0.31f, 0.40f, 0.42f, 0.25f, 0.66f, 0.95f,
0.09f, 0.10f, 0.97f, 0.76f, 0.69f, 0.15f, 0.20f,
0.30f, 0.14f, 0.07f, 0.58f, 0.61f, 0.08f},
{0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 1, 1}),
0.2769199f, 0.001f);
EXPECT_ANY_THROW(GetMetricEval(metric, {0, 1}, {}));
EXPECT_ANY_THROW(GetMetricEval(metric, {0, 0}, {0, 0}));
EXPECT_ANY_THROW(GetMetricEval(metric, {0, 0}, {1, 1}));
// AUCPR with instance weights
EXPECT_NEAR(GetMetricEval(
metric, {0.29f, 0.52f, 0.11f, 0.21f, 0.219f, 0.93f, 0.493f,
0.17f, 0.47f, 0.13f, 0.43f, 0.59f, 0.87f, 0.007f},
{0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0},
{1, 2, 7, 4, 5, 2.2f, 3.2f, 5, 6, 1, 2, 1.1f, 3.2f, 4.5f}), // weights
0.694435f, 0.001f);
// AUCPR with groups and no weights
EXPECT_NEAR(GetMetricEval(
metric, {0.87f, 0.31f, 0.40f, 0.42f, 0.25f, 0.66f, 0.95f,
0.09f, 0.10f, 0.97f, 0.76f, 0.69f, 0.15f, 0.20f,
0.30f, 0.14f, 0.07f, 0.58f, 0.61f, 0.08f},
{0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 1, 1},
{}, // weights
{0, 2, 5, 9, 14, 20}), // group info
0.556021f, 0.001f);
// AUCPR with groups and weights
EXPECT_NEAR(GetMetricEval(
metric, {0.29f, 0.52f, 0.11f, 0.21f, 0.219f, 0.93f, 0.493f,
0.17f, 0.47f, 0.13f, 0.43f, 0.59f, 0.87f, 0.007f}, // predictions
{0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0},
{1, 2, 7, 4, 5, 2.2f, 3.2f, 5, 6, 1, 2, 1.1f, 3.2f, 4.5f}, // weights
{0, 2, 5, 9, 14}), // group info
0.8150615f, 0.001f);
// Exception scenarios for grouped datasets
EXPECT_ANY_THROW(GetMetricEval(metric,
{0, 0.1f, 0.3f, 0.5f, 0.7f},
{1, 1, 0, 0, 0},
{},
{0, 2, 5}));
delete metric;
}
TEST(Metric, DeclareUnifiedTest(Precision)) {
// When the limit for precision is not given, it takes the limit at
// std::numeric_limits<unsigned>::max(); hence all values are very small

View File

@ -47,3 +47,12 @@ class TestGPUEvalMetrics:
gpu_auc = float(gpu.eval(Xy).split(":")[1])
np.testing.assert_allclose(cpu_auc, gpu_auc)
def test_pr_auc_binary(self):
self.cpu_test.run_pr_auc_binary("gpu_hist")
def test_pr_auc_multi(self):
self.cpu_test.run_pr_auc_multi("gpu_hist")
def test_pr_auc_ltr(self):
self.cpu_test.run_pr_auc_ltr("gpu_hist")

View File

@ -239,6 +239,7 @@ class TestEvalMetrics:
np.testing.assert_allclose(skl_auc, auc, rtol=1e-6)
X = rng.randn(*X.shape)
score = booster.predict(xgb.DMatrix(X, weight=weights))
skl_auc = roc_auc_score(
y, score, average="weighted", sample_weight=weights, multi_class="ovr"
@ -251,3 +252,63 @@ class TestEvalMetrics:
)
def test_roc_auc_multi(self, n_samples, weighted):
self.run_roc_auc_multi("hist", n_samples, weighted)
def run_pr_auc_binary(self, tree_method):
from sklearn.metrics import precision_recall_curve, auc
from sklearn.datasets import make_classification
X, y = make_classification(128, 4, n_classes=2, random_state=1994)
clf = xgb.XGBClassifier(tree_method=tree_method, n_estimators=1)
clf.fit(X, y, eval_metric="aucpr", eval_set=[(X, y)])
evals_result = clf.evals_result()["validation_0"]["aucpr"][-1]
y_score = clf.predict_proba(X)[:, 1] # get the positive column
precision, recall, _ = precision_recall_curve(y, y_score)
prauc = auc(recall, precision)
# Interpolation results are slightly different from sklearn, but overall should be
# similar.
np.testing.assert_allclose(prauc, evals_result, rtol=1e-2)
clf = xgb.XGBClassifier(tree_method=tree_method, n_estimators=10)
clf.fit(X, y, eval_metric="aucpr", eval_set=[(X, y)])
evals_result = clf.evals_result()["validation_0"]["aucpr"][-1]
np.testing.assert_allclose(0.99, evals_result, rtol=1e-2)
def test_pr_auc_binary(self):
self.run_pr_auc_binary("hist")
def run_pr_auc_multi(self, tree_method):
from sklearn.datasets import make_classification
X, y = make_classification(
64, 16, n_informative=8, n_classes=3, random_state=1994
)
clf = xgb.XGBClassifier(tree_method=tree_method, n_estimators=1)
clf.fit(X, y, eval_metric="aucpr", eval_set=[(X, y)])
evals_result = clf.evals_result()["validation_0"]["aucpr"][-1]
# No available implementation for comparison, just check that XGBoost converges to
# 1.0
clf = xgb.XGBClassifier(tree_method=tree_method, n_estimators=10)
clf.fit(X, y, eval_metric="aucpr", eval_set=[(X, y)])
evals_result = clf.evals_result()["validation_0"]["aucpr"][-1]
np.testing.assert_allclose(1.0, evals_result, rtol=1e-2)
def test_pr_auc_multi(self):
self.run_pr_auc_multi("hist")
def run_pr_auc_ltr(self, tree_method):
from sklearn.datasets import make_classification
X, y = make_classification(128, 4, n_classes=2, random_state=1994)
ltr = xgb.XGBRanker(tree_method=tree_method, n_estimators=16)
groups = np.array([32, 32, 64])
ltr.fit(
X,
y,
group=groups,
eval_set=[(X, y)],
eval_group=[groups],
eval_metric="aucpr"
)
results = ltr.evals_result()["validation_0"]["aucpr"]
assert results[-1] >= 0.99
def test_pr_auc_ltr(self):
self.run_pr_auc_ltr("hist")

View File

@ -587,7 +587,7 @@ def run_empty_dmatrix_auc(client: "Client", tree_method: str, n_workers: int) ->
cls = xgb.dask.DaskXGBClassifier(
tree_method=tree_method, n_estimators=2, use_label_encoder=False
)
cls.fit(X, y, eval_metric="auc", eval_set=[(valid_X, valid_y)])
cls.fit(X, y, eval_metric=["auc", "aucpr"], eval_set=[(valid_X, valid_y)])
# multiclass
X_, y_ = make_classification(
@ -618,7 +618,7 @@ def run_empty_dmatrix_auc(client: "Client", tree_method: str, n_workers: int) ->
cls = xgb.dask.DaskXGBClassifier(
tree_method=tree_method, n_estimators=2, use_label_encoder=False
)
cls.fit(X, y, eval_metric="auc", eval_set=[(valid_X, valid_y)])
cls.fit(X, y, eval_metric=["auc", "aucpr"], eval_set=[(valid_X, valid_y)])
def test_empty_dmatrix_auc() -> None: