Re-implement ROC-AUC. (#6747)

* Re-implement ROC-AUC.

* Binary
* MultiClass
* LTR
* Add documents.

This PR resolves a few issues:
  - Define a value when the dataset is invalid, which can happen if there's an
  empty dataset, or when the dataset contains only positive or negative values.
  - Define ROC-AUC for multi-class classification.
  - Define weighted average value for distributed setting.
  - A correct implementation for learning to rank task.  Previous
  implementation is just binary classification with averaging across groups,
  which doesn't measure ordered learning to rank.
This commit is contained in:
Jiaming Yuan 2021-03-20 16:52:40 +08:00 committed by GitHub
parent 4ee8340e79
commit bcc0277338
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
27 changed files with 1622 additions and 461 deletions

View File

@ -14,6 +14,7 @@
#include "../src/metric/elementwise_metric.cc" #include "../src/metric/elementwise_metric.cc"
#include "../src/metric/multiclass_metric.cc" #include "../src/metric/multiclass_metric.cc"
#include "../src/metric/rank_metric.cc" #include "../src/metric/rank_metric.cc"
#include "../src/metric/auc.cc"
#include "../src/metric/survival_metric.cc" #include "../src/metric/survival_metric.cc"
// objectives // objectives

View File

@ -400,7 +400,15 @@ Specify the learning task and the corresponding learning objective. The objectiv
- ``error@t``: a different than 0.5 binary classification threshold value could be specified by providing a numerical value through 't'. - ``error@t``: a different than 0.5 binary classification threshold value could be specified by providing a numerical value through 't'.
- ``merror``: Multiclass classification error rate. It is calculated as ``#(wrong cases)/#(all cases)``. - ``merror``: Multiclass classification error rate. It is calculated as ``#(wrong cases)/#(all cases)``.
- ``mlogloss``: `Multiclass logloss <http://scikit-learn.org/stable/modules/generated/sklearn.metrics.log_loss.html>`_. - ``mlogloss``: `Multiclass logloss <http://scikit-learn.org/stable/modules/generated/sklearn.metrics.log_loss.html>`_.
- ``auc``: `Area under the curve <http://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_curve>`_. Available for binary classification and learning-to-rank tasks. - ``auc``: `Receiver Operating Characteristic Area under the Curve <http://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_curve>`_.
Available for classification and learning-to-rank tasks.
- When used with binary classification, the objective should be ``binary:logistic`` or similar functions that work on probability.
- When used with multi-class classification, objective should be ``multi:softprob`` instead of ``multi:softmax``, as the latter doesn't output probability. Also the AUC is calculated by 1-vs-rest with reference class weighted by class prevalence.
- When used with LTR task, the AUC is computed by comparing pairs of documents to count correctly sorted pairs. This corresponds to pairwise learning to rank. The implementation has some issues with average AUC around groups and distributed workers not being well-defined.
- On a single machine the AUC calculation is exact. In a distributed environment the AUC is a weighted average over the AUC of training rows on each node - therefore, distributed AUC is an approximation sensitive to the distribution of data across workers. Use another metric in distributed environments if precision and reproducibility are important.
- If input dataset contains only negative or positive samples the output is `NaN`.
- ``aucpr``: `Area under the PR curve <https://en.wikipedia.org/wiki/Precision_and_recall>`_. Available for binary classification and learning-to-rank tasks. - ``aucpr``: `Area under the PR curve <https://en.wikipedia.org/wiki/Precision_and_recall>`_. Available for binary classification and learning-to-rank tasks.
- ``ndcg``: `Normalized Discounted Cumulative Gain <http://en.wikipedia.org/wiki/NDCG>`_ - ``ndcg``: `Normalized Discounted Cumulative Gain <http://en.wikipedia.org/wiki/NDCG>`_
- ``map``: `Mean Average Precision <http://en.wikipedia.org/wiki/Mean_average_precision#Mean_average_precision>`_ - ``map``: `Mean Average Precision <http://en.wikipedia.org/wiki/Mean_average_precision#Mean_average_precision>`_

View File

@ -8,6 +8,7 @@
#include <xgboost/base.h> #include <xgboost/base.h>
#include <xgboost/logging.h> #include <xgboost/logging.h>
#include <xgboost/span.h>
#include <algorithm> #include <algorithm>
#include <exception> #include <exception>
@ -163,13 +164,14 @@ inline void AssertOneAPISupport() {
#endif // XGBOOST_USE_ONEAPI #endif // XGBOOST_USE_ONEAPI
} }
template <typename Idx, typename V, typename Comp = std::less<V>> template <typename Idx, typename Container,
std::vector<Idx> ArgSort(std::vector<V> const &array, Comp comp = std::less<V>{}) { typename V = typename Container::value_type,
typename Comp = std::less<V>>
std::vector<Idx> ArgSort(Container const &array, Comp comp = std::less<V>{}) {
std::vector<Idx> result(array.size()); std::vector<Idx> result(array.size());
std::iota(result.begin(), result.end(), 0); std::iota(result.begin(), result.end(), 0);
std::stable_sort( auto op = [&array, comp](Idx const &l, Idx const &r) { return comp(array[l], array[r]); };
result.begin(), result.end(), XGBOOST_PARALLEL_STABLE_SORT(result.begin(), result.end(), op);
[&array, comp](Idx const &l, Idx const &r) { return comp(array[l], array[r]); });
return result; return result;
} }
} // namespace common } // namespace common

View File

@ -1198,6 +1198,62 @@ size_t SegmentedUnique(Inputs &&...inputs) {
return SegmentedUnique(thrust::cuda::par(alloc), std::forward<Inputs&&>(inputs)...); return SegmentedUnique(thrust::cuda::par(alloc), std::forward<Inputs&&>(inputs)...);
} }
/**
* \brief Unique by key for many groups of data. Has same constraint as `SegmentedUnique`.
*
* \tparam exec thrust execution policy
* \tparam key_segments_first start iter to segment pointer
* \tparam key_segments_last end iter to segment pointer
* \tparam key_first start iter to key for comparison
* \tparam key_last end iter to key for comparison
* \tparam val_first start iter to values
* \tparam key_segments_out output iterator for new segment pointer
* \tparam val_out output iterator for values
* \tparam comp binary comparison operator
*/
template <typename DerivedPolicy, typename SegInIt, typename SegOutIt,
typename KeyInIt, typename ValInIt, typename ValOutIt, typename Comp>
size_t SegmentedUniqueByKey(
const thrust::detail::execution_policy_base<DerivedPolicy> &exec,
SegInIt key_segments_first, SegInIt key_segments_last, KeyInIt key_first,
KeyInIt key_last, ValInIt val_first, SegOutIt key_segments_out,
ValOutIt val_out, Comp comp) {
using Key =
thrust::pair<size_t,
typename thrust::iterator_traits<KeyInIt>::value_type>;
auto unique_key_it = dh::MakeTransformIterator<Key>(
thrust::make_counting_iterator(static_cast<size_t>(0)),
[=] __device__(size_t i) {
size_t seg = dh::SegmentId(key_segments_first, key_segments_last, i);
return thrust::make_pair(seg, *(key_first + i));
});
size_t segments_len = key_segments_last - key_segments_first;
thrust::fill(thrust::device, key_segments_out,
key_segments_out + segments_len, 0);
size_t n_inputs = std::distance(key_first, key_last);
// Reduce the number of uniques elements per segment, avoid creating an
// intermediate array for `reduce_by_key`. It's limited by the types that
// atomicAdd supports. For example, size_t is not supported as of CUDA 10.2.
auto reduce_it = thrust::make_transform_output_iterator(
thrust::make_discard_iterator(),
detail::SegmentedUniqueReduceOp<Key, SegOutIt>{key_segments_out});
auto uniques_ret = thrust::unique_by_key_copy(
exec, unique_key_it, unique_key_it + n_inputs, val_first, reduce_it,
val_out, [=] __device__(Key const &l, Key const &r) {
if (l.first == r.first) {
// In the same segment.
return comp(thrust::get<1>(l), thrust::get<1>(r));
}
return false;
});
auto n_uniques = uniques_ret.second - val_out;
CHECK_LE(n_uniques, n_inputs);
thrust::exclusive_scan(exec, key_segments_out,
key_segments_out + segments_len, key_segments_out, 0);
return n_uniques;
}
template <typename Policy, typename InputIt, typename Init, typename Func> template <typename Policy, typename InputIt, typename Init, typename Func>
auto Reduce(Policy policy, InputIt first, InputIt second, Init init, Func reduce_op) { auto Reduce(Policy policy, InputIt first, InputIt second, Init init, Func reduce_op) {
size_t constexpr kLimit = std::numeric_limits<int32_t>::max() / 2; size_t constexpr kLimit = std::numeric_limits<int32_t>::max() / 2;
@ -1215,36 +1271,73 @@ auto Reduce(Policy policy, InputIt first, InputIt second, Init init, Func reduce
return aggregate; return aggregate;
} }
// wrapper to avoid integer `num_items`.
template <typename InputIteratorT, typename OutputIteratorT, typename ScanOpT,
typename OffsetT>
void InclusiveScan(InputIteratorT d_in, OutputIteratorT d_out, ScanOpT scan_op,
OffsetT num_items) {
size_t bytes = 0;
safe_cuda((
cub::DispatchScan<InputIteratorT, OutputIteratorT, ScanOpT, cub::NullType,
OffsetT>::Dispatch(nullptr, bytes, d_in, d_out, scan_op,
cub::NullType(), num_items, nullptr,
false)));
dh::TemporaryArray<char> storage(bytes);
safe_cuda((
cub::DispatchScan<InputIteratorT, OutputIteratorT, ScanOpT, cub::NullType,
OffsetT>::Dispatch(storage.data().get(), bytes, d_in,
d_out, scan_op, cub::NullType(),
num_items, nullptr, false)));
}
template <typename InputIteratorT, typename OutputIteratorT, typename OffsetT>
void InclusiveSum(InputIteratorT d_in, OutputIteratorT d_out, OffsetT num_items) {
InclusiveScan(d_in, d_out, cub::Sum(), num_items);
}
template <bool accending, typename IdxT, typename U> template <bool accending, typename IdxT, typename U>
void ArgSort(xgboost::common::Span<U> values, xgboost::common::Span<IdxT> sorted_idx) { void ArgSort(xgboost::common::Span<U> keys, xgboost::common::Span<IdxT> sorted_idx) {
size_t bytes = 0; size_t bytes = 0;
Iota(sorted_idx); Iota(sorted_idx);
CHECK_LT(sorted_idx.size(), 1 << 31);
TemporaryArray<U> out(values.size()); using KeyT = typename decltype(keys)::value_type;
using ValueT = std::remove_const_t<IdxT>;
TemporaryArray<KeyT> out(keys.size());
cub::DoubleBuffer<KeyT> d_keys(const_cast<KeyT *>(keys.data()),
out.data().get());
cub::DoubleBuffer<ValueT> d_values(const_cast<ValueT *>(sorted_idx.data()),
sorted_idx.data());
if (accending) { if (accending) {
cub::DeviceRadixSort::SortPairs(nullptr, bytes, values.data(), void *d_temp_storage = nullptr;
out.data().get(), sorted_idx.data(), cub::DispatchRadixSort<false, KeyT, ValueT, size_t>::Dispatch(
sorted_idx.data(), sorted_idx.size()); d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
sizeof(KeyT) * 8, false, nullptr, false);
dh::TemporaryArray<char> storage(bytes); dh::TemporaryArray<char> storage(bytes);
cub::DeviceRadixSort::SortPairs(storage.data().get(), bytes, values.data(), d_temp_storage = storage.data().get();
out.data().get(), sorted_idx.data(), cub::DispatchRadixSort<false, KeyT, ValueT, size_t>::Dispatch(
sorted_idx.data(), sorted_idx.size()); d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
sizeof(KeyT) * 8, false, nullptr, false);
} else { } else {
cub::DeviceRadixSort::SortPairsDescending( void *d_temp_storage = nullptr;
nullptr, bytes, values.data(), out.data().get(), sorted_idx.data(), safe_cuda((cub::DispatchRadixSort<true, KeyT, ValueT, size_t>::Dispatch(
sorted_idx.data(), sorted_idx.size()); d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
sizeof(KeyT) * 8, false, nullptr, false)));
dh::TemporaryArray<char> storage(bytes); dh::TemporaryArray<char> storage(bytes);
cub::DeviceRadixSort::SortPairsDescending( d_temp_storage = storage.data().get();
storage.data().get(), bytes, values.data(), out.data().get(), safe_cuda((cub::DispatchRadixSort<true, KeyT, ValueT, size_t>::Dispatch(
sorted_idx.data(), sorted_idx.data(), sorted_idx.size()); d_temp_storage, bytes, d_keys, d_values, sorted_idx.size(), 0,
sizeof(KeyT) * 8, false, nullptr, false)));
} }
} }
namespace detail { namespace detail {
// Wrapper around cub sort for easier `descending` sort // Wrapper around cub sort for easier `descending` sort and `size_t num_items`.
template <bool descending, typename KeyT, typename ValueT, typename OffsetIteratorT> template <bool descending, typename KeyT, typename ValueT,
typename OffsetIteratorT>
void DeviceSegmentedRadixSortPair( void DeviceSegmentedRadixSortPair(
void *d_temp_storage, size_t &temp_storage_bytes, const KeyT *d_keys_in, // NOLINT void *d_temp_storage, size_t &temp_storage_bytes, const KeyT *d_keys_in, // NOLINT
KeyT *d_keys_out, const ValueT *d_values_in, ValueT *d_values_out, KeyT *d_keys_out, const ValueT *d_values_in, ValueT *d_values_out,
size_t num_items, size_t num_segments, OffsetIteratorT d_begin_offsets, size_t num_items, size_t num_segments, OffsetIteratorT d_begin_offsets,
OffsetIteratorT d_end_offsets, int begin_bit = 0, OffsetIteratorT d_end_offsets, int begin_bit = 0,
@ -1253,12 +1346,12 @@ void DeviceSegmentedRadixSortPair(
cub::DoubleBuffer<ValueT> d_values(const_cast<ValueT *>(d_values_in), cub::DoubleBuffer<ValueT> d_values(const_cast<ValueT *>(d_values_in),
d_values_out); d_values_out);
using OffsetT = size_t; using OffsetT = size_t;
dh::safe_cuda((cub::DispatchSegmentedRadixSort< safe_cuda((cub::DispatchSegmentedRadixSort<
descending, KeyT, ValueT, OffsetIteratorT, descending, KeyT, ValueT, OffsetIteratorT,
OffsetT>::Dispatch(d_temp_storage, temp_storage_bytes, d_keys, OffsetT>::Dispatch(d_temp_storage, temp_storage_bytes, d_keys,
d_values, num_items, num_segments, d_values, num_items, num_segments,
d_begin_offsets, d_end_offsets, begin_bit, d_begin_offsets, d_end_offsets, begin_bit,
end_bit, false, nullptr, false))); end_bit, false, nullptr, false)));
} }
} // namespace detail } // namespace detail
@ -1270,12 +1363,11 @@ void SegmentedArgSort(xgboost::common::Span<U> values,
size_t n_groups = group_ptr.size() - 1; size_t n_groups = group_ptr.size() - 1;
size_t bytes = 0; size_t bytes = 0;
Iota(sorted_idx); Iota(sorted_idx);
CHECK_LT(sorted_idx.size(), 1 << 31); TemporaryArray<std::remove_const_t<U>> values_out(values.size());
TemporaryArray<U> values_out(values.size());
detail::DeviceSegmentedRadixSortPair<!accending>( detail::DeviceSegmentedRadixSortPair<!accending>(
nullptr, bytes, values.data(), values_out.data().get(), nullptr, bytes, values.data(), values_out.data().get(), sorted_idx.data(),
sorted_idx.data(), sorted_idx.data(), sorted_idx.size(), n_groups, sorted_idx.data(), sorted_idx.size(), n_groups, group_ptr.data(),
group_ptr.data(), group_ptr.data() + 1); group_ptr.data() + 1);
dh::TemporaryArray<xgboost::common::byte> temp_storage(bytes); dh::TemporaryArray<xgboost::common::byte> temp_storage(bytes);
detail::DeviceSegmentedRadixSortPair<!accending>( detail::DeviceSegmentedRadixSortPair<!accending>(
temp_storage.data().get(), bytes, values.data(), values_out.data().get(), temp_storage.data().get(), bytes, values.data(), values_out.data().get(),

View File

@ -26,6 +26,9 @@ XGBOOST_DEVICE inline float Sigmoid(float x) {
return 1.0f / (1.0f + expf(-x)); return 1.0f / (1.0f + expf(-x));
} }
template <typename T>
XGBOOST_DEVICE inline static T Sqr(T a) { return a * a; }
/*! /*!
* \brief Equality test for both integer and floating point. * \brief Equality test for both integer and floating point.
*/ */

View File

@ -99,7 +99,7 @@ std::vector<T> WeightedSamplingWithoutReplacement(
auto k = std::log(u) / w; auto k = std::log(u) / w;
keys[i] = k; keys[i] = k;
} }
auto ind = ArgSort<size_t>(keys, std::greater<>{}); auto ind = ArgSort<size_t>(Span<float>{keys}, std::greater<>{});
ind.resize(n); ind.resize(n);
std::vector<T> results(ind.size()); std::vector<T> results(ind.size());

View File

@ -0,0 +1,84 @@
/*!
* Copyright 2021 by XGBoost Contributors
*/
#ifndef XGBOOST_COMMON_RANKING_UTILS_H_
#define XGBOOST_COMMON_RANKING_UTILS_H_
#include <cub/cub.cuh>
#include "xgboost/base.h"
#include "device_helpers.cuh"
#include "./math.h"
namespace xgboost {
namespace common {
/**
* \param n Number of items (length of the base)
* \param h hight
*/
XGBOOST_DEVICE inline size_t DiscreteTrapezoidArea(size_t n, size_t h) {
n -= 1; // without diagonal entries
h = std::min(n, h); // Specific for ranking.
size_t total = ((n - (h - 1)) + n) * h / 2;
return total;
}
/**
* Used for mapping many groups of trapezoid shaped computation onto CUDA blocks. The
* trapezoid must be on upper right corner.
*
* Equivalent to loops like:
*
* \code
* for (size i = 0; i < h; ++i) {
* for (size_t j = i + 1; j < n; ++j) {
* do_something();
* }
* }
* \endcode
*/
template <typename U>
inline size_t
SegmentedTrapezoidThreads(xgboost::common::Span<U> group_ptr,
xgboost::common::Span<size_t> out_group_threads_ptr,
size_t h) {
CHECK_GE(group_ptr.size(), 1);
CHECK_EQ(group_ptr.size(), out_group_threads_ptr.size());
dh::LaunchN(
dh::CurrentDevice(), group_ptr.size(), [=] XGBOOST_DEVICE(size_t idx) {
if (idx == 0) {
out_group_threads_ptr[0] = 0;
return;
}
size_t cnt = static_cast<size_t>(group_ptr[idx] - group_ptr[idx - 1]);
out_group_threads_ptr[idx] = DiscreteTrapezoidArea(cnt, h);
});
dh::InclusiveSum(out_group_threads_ptr.data(), out_group_threads_ptr.data(),
out_group_threads_ptr.size());
size_t total = 0;
dh::safe_cuda(cudaMemcpy(
&total, out_group_threads_ptr.data() + out_group_threads_ptr.size() - 1,
sizeof(total), cudaMemcpyDeviceToHost));
return total;
}
/**
* Called inside kernel to obtain coordinate from trapezoid grid.
*/
XGBOOST_DEVICE inline void UnravelTrapeziodIdx(size_t i_idx, size_t n,
size_t *out_i, size_t *out_j) {
auto &i = *out_i;
auto &j = *out_j;
double idx = static_cast<double>(i_idx);
double N = static_cast<double>(n);
i = std::ceil(-(0.5 - N + std::sqrt(common::Sqr(N - 0.5) + 2.0 * (-idx - 1.0)))) - 1.0;
auto I = static_cast<double>(i);
size_t n_elems = -0.5 * common::Sqr(I) + (N - 0.5) * I;
j = idx - n_elems + i + 1;
}
} // namespace common
} // namespace xgboost
#endif // XGBOOST_COMMON_RANKING_UTILS_H_

View File

@ -400,7 +400,9 @@ void MetaInfo::SetInfo(const char* key, const void* dptr, DataType dtype, size_t
group_ptr_.push_back(i); group_ptr_.push_back(i);
} }
} }
group_ptr_.push_back(query_ids.size()); if (group_ptr_.back() != query_ids.size()) {
group_ptr_.push_back(query_ids.size());
}
} else if (!std::strcmp(key, "label_lower_bound")) { } else if (!std::strcmp(key, "label_lower_bound")) {
auto& labels = labels_lower_bound_.HostVector(); auto& labels = labels_lower_bound_.HostVector();
labels.resize(num); labels.resize(num);

340
src/metric/auc.cc Normal file
View File

@ -0,0 +1,340 @@
/*!
* Copyright 2021 by XGBoost Contributors
*/
#include <array>
#include <atomic>
#include <algorithm>
#include <functional>
#include <limits>
#include <memory>
#include <utility>
#include <tuple>
#include <vector>
#include "rabit/rabit.h"
#include "xgboost/host_device_vector.h"
#include "xgboost/metric.h"
#include "auc.h"
#include "../common/common.h"
#include "../common/math.h"
namespace xgboost {
namespace metric {
namespace detail {
template <class T, std::size_t N, std::size_t... Idx>
constexpr auto UnpackArr(std::array<T, N> &&arr, std::index_sequence<Idx...>) {
return std::make_tuple(std::forward<std::array<T, N>>(arr)[Idx]...);
}
} // namespace detail
template <class T, std::size_t N>
constexpr auto UnpackArr(std::array<T, N> &&arr) {
return detail::UnpackArr(std::forward<std::array<T, N>>(arr),
std::make_index_sequence<N>{});
}
/**
* Calculate AUC for binary classification problem. This function does not normalize the
* AUC by 1 / (num_positive * num_negative), instead it returns a tuple for caller to
* handle the normalization.
*/
std::tuple<float, float, float> BinaryAUC(std::vector<float> const &predts,
std::vector<float> const &labels,
std::vector<float> const &weights) {
CHECK(!labels.empty());
CHECK_EQ(labels.size(), predts.size());
float auc {0};
auto const sorted_idx = common::ArgSort<size_t>(
common::Span<float const>(predts), std::greater<>{});
auto get_weight = [&](size_t i) {
return weights.empty() ? 1.0f : weights[sorted_idx[i]];
};
float label = labels[sorted_idx.front()];
float w = get_weight(0);
float fp = (1.0 - label) * w, tp = label * w;
float tp_prev = 0, fp_prev = 0;
// TODO(jiaming): We can parallize this if we have a parallel scan for CPU.
for (size_t i = 1; i < sorted_idx.size(); ++i) {
if (predts[sorted_idx[i]] != predts[sorted_idx[i-1]]) {
auc += TrapesoidArea(fp_prev, fp, tp_prev, tp);
tp_prev = tp;
fp_prev = fp;
}
label = labels[sorted_idx[i]];
float w = get_weight(i);
fp += (1.0f - label) * w;
tp += label * w;
}
auc += TrapesoidArea(fp_prev, fp, tp_prev, tp);
if (fp <= 0.0f || tp <= 0.0f) {
auc = 0;
fp = 0;
tp = 0;
}
return std::make_tuple(fp, tp, auc);
}
/**
* Calculate AUC for multi-class classification problem using 1-vs-rest approach.
*
* TODO(jiaming): Use better algorithms like:
*
* - Kleiman, Ross and Page, David. $AUC_{\mu}$: A Performance Metric for Multi-Class
* Machine Learning Models
*/
float MultiClassOVR(std::vector<float> const& predts, MetaInfo const& info) {
auto n_classes = predts.size() / info.labels_.Size();
CHECK_NE(n_classes, 0);
auto const& labels = info.labels_.ConstHostVector();
std::vector<float> results(n_classes * 3, 0);
auto s_results = common::Span<float>(results);
auto local_area = s_results.subspan(0, n_classes);
auto tp = s_results.subspan(n_classes, n_classes);
auto auc = s_results.subspan(2 * n_classes, n_classes);
if (!info.labels_.Empty()) {
dmlc::OMPException omp_handler;
#pragma omp parallel for
for (omp_ulong c = 0; c < n_classes; ++c) {
omp_handler.Run([&]() {
std::vector<float> proba(info.labels_.Size());
std::vector<float> response(info.labels_.Size());
for (size_t i = 0; i < proba.size(); ++i) {
proba[i] = predts[i * n_classes + c];
response[i] = labels[i] == c ? 1.0f : 0.0;
}
float fp;
std::tie(fp, tp[c], auc[c]) =
BinaryAUC(proba, response, info.weights_.ConstHostVector());
local_area[c] = fp * tp[c];
});
}
omp_handler.Rethrow();
}
// we have 2 averages going in here, first is among workers, second is among classes.
// allreduce sums up fp/tp auc for each class.
rabit::Allreduce<rabit::op::Sum>(results.data(), results.size());
float auc_sum{0};
float tp_sum{0};
for (size_t c = 0; c < n_classes; ++c) {
if (local_area[c] != 0) {
// normalize and weight it by prevalence. After allreduce, `local_area` means the
// total covered area (not area under curve, rather it's the accessible are for each
// worker) for each class.
auc_sum += auc[c] / local_area[c] * tp[c];
tp_sum += tp[c];
} else {
auc_sum = std::numeric_limits<float>::quiet_NaN();
break;
}
}
if (tp_sum == 0 || std::isnan(auc_sum)) {
auc_sum = std::numeric_limits<float>::quiet_NaN();
} else {
auc_sum /= tp_sum;
}
return auc_sum;
}
/**
* Calculate AUC for 1 ranking group;
*/
float GroupRankingAUC(common::Span<float const> predts,
common::Span<float const> labels, float w) {
// on ranking, we just count all pairs.
float auc{0};
auto const sorted_idx = common::ArgSort<size_t>(labels, std::greater<>{});
w = common::Sqr(w);
float sum_w = 0.0f;
for (size_t i = 0; i < labels.size(); ++i) {
for (size_t j = i + 1; j < labels.size(); ++j) {
auto predt = predts[sorted_idx[i]] - predts[sorted_idx[j]];
if (predt > 0) {
predt = 1.0;
} else if (predt == 0) {
predt = 0.5;
} else {
predt = 0;
}
auc += predt * w;
sum_w += w;
}
}
if (sum_w != 0) {
auc /= sum_w;
}
CHECK_LE(auc, 1.0f);
return auc;
}
/**
* Cast LTR problem to binary classification problem by comparing pairs.
*/
std::pair<float, uint32_t> RankingAUC(std::vector<float> const &predts,
MetaInfo const &info) {
CHECK_GE(info.group_ptr_.size(), 2);
uint32_t n_groups = info.group_ptr_.size() - 1;
float sum_auc = 0;
auto s_predts = common::Span<float const>{predts};
auto s_labels = info.labels_.ConstHostSpan();
auto s_weights = info.weights_.ConstHostSpan();
std::atomic<uint32_t> invalid_groups{0};
dmlc::OMPException omp_handler;
#pragma omp parallel for reduction(+:sum_auc)
for (omp_ulong g = 1; g < info.group_ptr_.size(); ++g) {
omp_handler.Run([&]() {
size_t cnt = info.group_ptr_[g] - info.group_ptr_[g - 1];
float w = s_weights.empty() ? 1.0f : s_weights[g - 1];
auto g_predts = s_predts.subspan(info.group_ptr_[g - 1], cnt);
auto g_labels = s_labels.subspan(info.group_ptr_[g - 1], cnt);
float auc;
if (g_labels.size() < 3) {
// With 2 documents, there's only 1 comparison can be made. So either
// TP or FP will be zero.
invalid_groups++;
auc = 0;
} else {
auc = GroupRankingAUC(g_predts, g_labels, w);
}
sum_auc += auc;
});
}
omp_handler.Rethrow();
if (invalid_groups != 0) {
InvalidGroupAUC();
}
return std::make_pair(sum_auc, n_groups - invalid_groups);
}
class EvalAUC : public Metric {
std::shared_ptr<DeviceAUCCache> d_cache_;
public:
float Eval(const HostDeviceVector<bst_float> &preds, const MetaInfo &info,
bool distributed) override {
float auc {0};
if (tparam_->gpu_id != GenericParameter::kCpuId) {
preds.SetDevice(tparam_->gpu_id);
info.labels_.SetDevice(tparam_->gpu_id);
info.weights_.SetDevice(tparam_->gpu_id);
}
if (!info.group_ptr_.empty()) {
/**
* learning to rank
*/
if (!info.weights_.Empty()) {
CHECK_EQ(info.weights_.Size(), info.group_ptr_.size() - 1);
}
uint32_t valid_groups = 0;
if (!info.labels_.Empty()) {
CHECK_EQ(info.group_ptr_.back(), info.labels_.Size());
if (tparam_->gpu_id == GenericParameter::kCpuId) {
std::tie(auc, valid_groups) =
RankingAUC(preds.ConstHostVector(), info);
} else {
std::tie(auc, valid_groups) = GPURankingAUC(
preds.ConstDeviceSpan(), info, tparam_->gpu_id, &this->d_cache_);
}
}
std::array<float, 2> results{auc, static_cast<float>(valid_groups)};
rabit::Allreduce<rabit::op::Sum>(results.data(), results.size());
auc = results[0];
valid_groups = static_cast<uint32_t>(results[1]);
if (valid_groups <= 0) {
auc = std::numeric_limits<float>::quiet_NaN();
} else {
auc /= valid_groups;
CHECK_LE(auc, 1) << "Total AUC across groups: " << auc * valid_groups
<< ", valid groups: " << valid_groups;
}
} else if (info.labels_.Size() != preds.Size() &&
preds.Size() % info.labels_.Size() == 0) {
/**
* multi class
*/
if (tparam_->gpu_id == GenericParameter::kCpuId) {
auc = MultiClassOVR(preds.ConstHostVector(), info);
} else {
auc = GPUMultiClassAUCOVR(preds.ConstDeviceSpan(), info, tparam_->gpu_id,
&this->d_cache_);
}
} else {
/**
* binary classification
*/
float fp{0}, tp{0};
if (!(preds.Empty() || info.labels_.Empty())) {
if (tparam_->gpu_id == GenericParameter::kCpuId) {
std::tie(fp, tp, auc) =
BinaryAUC(preds.ConstHostVector(), info.labels_.ConstHostVector(),
info.weights_.ConstHostVector());
} else {
std::tie(fp, tp, auc) = GPUBinaryAUC(
preds.ConstDeviceSpan(), info, tparam_->gpu_id, &this->d_cache_);
}
}
float local_area = fp * tp;
std::array<float, 2> result{auc, local_area};
rabit::Allreduce<rabit::op::Sum>(result.data(), result.size());
std::tie(auc, local_area) = UnpackArr(std::move(result));
if (local_area <= 0) {
// the dataset across all workers have only positive or negative sample
auc = std::numeric_limits<float>::quiet_NaN();
} else {
// normalization
auc = auc / local_area;
}
}
if (std::isnan(auc)) {
LOG(WARNING) << "Dataset contains only positive or negative samples.";
}
return auc;
}
char const* Name() const override {
return "auc";
}
};
XGBOOST_REGISTER_METRIC(EvalBinaryAUC, "auc")
.describe("Receiver Operating Characteristic Area Under the Curve.")
.set_body([](const char*) { return new EvalAUC(); });
#if !defined(XGBOOST_USE_CUDA)
std::tuple<float, float, float>
GPUBinaryAUC(common::Span<float const> predts, MetaInfo const &info,
int32_t device, std::shared_ptr<DeviceAUCCache> *p_cache) {
common::AssertGPUSupport();
return std::make_tuple(0.0f, 0.0f, 0.0f);
}
float GPUMultiClassAUCOVR(common::Span<float const> predts, MetaInfo const &info,
int32_t device, std::shared_ptr<DeviceAUCCache>* cache) {
common::AssertGPUSupport();
return 0;
}
std::pair<float, uint32_t>
GPURankingAUC(common::Span<float const> predts, MetaInfo const &info,
int32_t device, std::shared_ptr<DeviceAUCCache> *p_cache) {
common::AssertGPUSupport();
return std::make_pair(0.0f, 0u);
}
struct DeviceAUCCache {};
#endif // !defined(XGBOOST_USE_CUDA)
} // namespace metric
} // namespace xgboost

540
src/metric/auc.cu Normal file
View File

@ -0,0 +1,540 @@
/*!
* Copyright 2021 by XGBoost Contributors
*/
#include <thrust/scan.h>
#include <cub/cub.cuh>
#include <cassert>
#include <limits>
#include <memory>
#include <utility>
#include <tuple>
#include "rabit/rabit.h"
#include "xgboost/span.h"
#include "xgboost/data.h"
#include "auc.h"
#include "../common/device_helpers.cuh"
#include "../common/ranking_utils.cuh"
namespace xgboost {
namespace metric {
namespace {
template <typename T>
class Discard : public thrust::discard_iterator<T> {
public:
using value_type = T; // NOLINT
};
struct GetWeightOp {
common::Span<float const> weights;
common::Span<size_t const> sorted_idx;
__device__ float operator()(size_t i) const {
return weights.empty() ? 1.0f : weights[sorted_idx[i]];
}
};
} // namespace
/**
* A cache to GPU data to avoid reallocating memory.
*/
struct DeviceAUCCache {
// Pair of FP/TP
using Pair = thrust::pair<float, float>;
// index sorted by prediction value
dh::device_vector<size_t> sorted_idx;
// track FP/TP for computation on trapesoid area
dh::device_vector<Pair> fptp;
// track FP_PREV/TP_PREV for computation on trapesoid area
dh::device_vector<Pair> neg_pos;
// index of unique prediction values.
dh::device_vector<size_t> unique_idx;
// p^T: transposed prediction matrix, used by MultiClassAUC
dh::device_vector<float> predts_t;
std::unique_ptr<dh::AllReducer> reducer;
void Init(common::Span<float const> predts, bool is_multi, int32_t device) {
if (sorted_idx.size() != predts.size()) {
sorted_idx.resize(predts.size());
fptp.resize(sorted_idx.size());
unique_idx.resize(sorted_idx.size());
neg_pos.resize(sorted_idx.size());
if (is_multi) {
predts_t.resize(sorted_idx.size());
reducer.reset(new dh::AllReducer);
reducer->Init(rabit::GetRank());
}
}
}
};
/**
* The GPU implementation uses same calculation as CPU with a few more steps to distribute
* work across threads:
*
* - Run scan to obtain TP/FP values, which are right coordinates of trapesoid.
* - Find distinct prediction values and get the corresponding FP_PREV/TP_PREV value,
* which are left coordinates of trapesoid.
* - Reduce the scan array into 1 AUC value.
*/
std::tuple<float, float, float>
GPUBinaryAUC(common::Span<float const> predts, MetaInfo const &info,
int32_t device, std::shared_ptr<DeviceAUCCache> *p_cache) {
auto& cache = *p_cache;
if (!cache) {
cache.reset(new DeviceAUCCache);
}
cache->Init(predts, false, device);
auto labels = info.labels_.ConstDeviceSpan();
auto weights = info.weights_.ConstDeviceSpan();
dh::safe_cuda(cudaSetDevice(device));
CHECK(!labels.empty());
CHECK_EQ(labels.size(), predts.size());
/**
* Create sorted index for each class
*/
auto d_sorted_idx = dh::ToSpan(cache->sorted_idx);
dh::ArgSort<false>(predts, d_sorted_idx);
/**
* Linear scan
*/
auto get_weight = GetWeightOp{weights, d_sorted_idx};
using Pair = thrust::pair<float, float>;
auto get_fp_tp = [=]__device__(size_t i) {
size_t idx = d_sorted_idx[i];
float label = labels[idx];
float w = get_weight(i);
float fp = (1.0 - label) * w;
float tp = label * w;
return thrust::make_pair(fp, tp);
}; // NOLINT
auto d_fptp = dh::ToSpan(cache->fptp);
dh::LaunchN(device, d_sorted_idx.size(),
[=] __device__(size_t i) { d_fptp[i] = get_fp_tp(i); });
dh::XGBDeviceAllocator<char> alloc;
auto d_unique_idx = dh::ToSpan(cache->unique_idx);
dh::Iota(d_unique_idx, device);
auto uni_key = dh::MakeTransformIterator<float>(
thrust::make_counting_iterator(0),
[=] __device__(size_t i) { return predts[d_sorted_idx[i]]; });
auto end_unique = thrust::unique_by_key_copy(
thrust::cuda::par(alloc), uni_key, uni_key + d_sorted_idx.size(),
dh::tbegin(d_unique_idx), thrust::make_discard_iterator(),
dh::tbegin(d_unique_idx));
d_unique_idx = d_unique_idx.subspan(0, end_unique.second - dh::tbegin(d_unique_idx));
dh::InclusiveScan(
dh::tbegin(d_fptp), dh::tbegin(d_fptp),
[=] __device__(Pair const &l, Pair const &r) {
return thrust::make_pair(l.first + r.first, l.second + r.second);
},
d_fptp.size());
auto d_neg_pos = dh::ToSpan(cache->neg_pos);
// scatter unique negaive/positive values
// shift to right by 1 with initial value being 0
dh::LaunchN(device, d_unique_idx.size(), [=] __device__(size_t i) {
if (d_unique_idx[i] == 0) { // first unique index is 0
assert(i == 0);
d_neg_pos[0] = {0, 0};
return;
}
d_neg_pos[d_unique_idx[i]] = d_fptp[d_unique_idx[i] - 1];
if (i == d_unique_idx.size() - 1) {
// last one needs to be included, may override above assignment if the last
// prediction value is district from previous one.
d_neg_pos.back() = d_fptp[d_unique_idx[i] - 1];
return;
}
});
auto in = dh::MakeTransformIterator<float>(
thrust::make_counting_iterator(0), [=] __device__(size_t i) {
float fp, tp;
float fp_prev, tp_prev;
if (i == 0) {
// handle the last element
thrust::tie(fp, tp) = d_fptp.back();
thrust::tie(fp_prev, tp_prev) = d_neg_pos[d_unique_idx.back()];
} else {
thrust::tie(fp, tp) = d_fptp[d_unique_idx[i] - 1];
thrust::tie(fp_prev, tp_prev) = d_neg_pos[d_unique_idx[i - 1]];
}
return TrapesoidArea(fp_prev, fp, tp_prev, tp);
});
Pair last = cache->fptp.back();
float auc = thrust::reduce(thrust::cuda::par(alloc), in, in + d_unique_idx.size());
return std::make_tuple(last.first, last.second, auc);
}
void Transpose(common::Span<float const> in, common::Span<float> out, size_t m,
size_t n, int32_t device) {
CHECK_EQ(in.size(), out.size());
CHECK_EQ(in.size(), m * n);
dh::LaunchN(device, in.size(), [=] __device__(size_t i) {
size_t col = i / m;
size_t row = i % m;
size_t idx = row * n + col;
out[i] = in[idx];
});
}
/**
* Last index of a group in a CSR style of index pointer.
*/
template <typename Idx>
XGBOOST_DEVICE size_t LastOf(size_t group, common::Span<Idx> indptr) {
return indptr[group + 1] - 1;
}
/**
* MultiClass implementation is similar to binary classification, except we need to split
* up each class in all kernels.
*/
float GPUMultiClassAUCOVR(common::Span<float const> predts, MetaInfo const &info,
int32_t device, std::shared_ptr<DeviceAUCCache>* p_cache) {
auto& cache = *p_cache;
if (!cache) {
cache.reset(new DeviceAUCCache);
}
cache->Init(predts, true, device);
auto labels = info.labels_.ConstDeviceSpan();
auto weights = info.weights_.ConstDeviceSpan();
size_t n_samples = labels.size();
size_t n_classes = predts.size() / labels.size();
CHECK_NE(n_classes, 0);
/**
* Create sorted index for each class
*/
auto d_sorted_idx = dh::ToSpan(cache->sorted_idx);
dh::Iota(d_sorted_idx, device);
auto d_predts_t = dh::ToSpan(cache->predts_t);
Transpose(predts, d_predts_t, n_samples, n_classes, device);
dh::TemporaryArray<uint32_t> class_ptr(n_classes + 1, 0);
auto d_class_ptr = dh::ToSpan(class_ptr);
dh::LaunchN(device, n_classes + 1, [=]__device__(size_t i) {
d_class_ptr[i] = i * n_samples;
});
// no out-of-place sort for thrust, cub sort doesn't accept general iterator. So can't
// use transform iterator in sorting.
dh::SegmentedArgSort<false>(d_predts_t, d_class_ptr, d_sorted_idx);
/**
* Linear scan
*/
dh::caching_device_vector<float> d_auc(n_classes, 0);
auto s_d_auc = dh::ToSpan(d_auc);
auto get_weight = GetWeightOp{weights, d_sorted_idx};
using Pair = thrust::pair<float, float>;
auto d_fptp = dh::ToSpan(cache->fptp);
auto get_fp_tp = [=]__device__(size_t i) {
size_t idx = d_sorted_idx[i];
size_t class_id = i / n_samples;
// labels is a vector of size n_samples.
float label = labels[idx % n_samples] == class_id;
float w = get_weight(i % n_samples);
float fp = (1.0 - label) * w;
float tp = label * w;
return thrust::make_pair(fp, tp);
}; // NOLINT
dh::LaunchN(device, d_sorted_idx.size(),
[=] __device__(size_t i) { d_fptp[i] = get_fp_tp(i); });
/**
* Handle duplicated predictions
*/
dh::XGBDeviceAllocator<char> alloc;
auto d_unique_idx = dh::ToSpan(cache->unique_idx);
dh::Iota(d_unique_idx, device);
auto uni_key = dh::MakeTransformIterator<thrust::pair<uint32_t, float>>(
thrust::make_counting_iterator(0), [=] __device__(size_t i) {
uint32_t class_id = i / n_samples;
float predt = d_predts_t[d_sorted_idx[i]];
return thrust::make_pair(class_id, predt);
});
// unique values are sparse, so we need a CSR style indptr
dh::TemporaryArray<uint32_t> unique_class_ptr(class_ptr.size() + 1);
auto d_unique_class_ptr = dh::ToSpan(unique_class_ptr);
auto n_uniques = dh::SegmentedUniqueByKey(
thrust::cuda::par(alloc),
dh::tbegin(d_class_ptr),
dh::tend(d_class_ptr),
uni_key,
uni_key + d_sorted_idx.size(),
dh::tbegin(d_unique_idx),
d_unique_class_ptr.data(),
dh::tbegin(d_unique_idx),
thrust::equal_to<thrust::pair<uint32_t, float>>{});
d_unique_idx = d_unique_idx.subspan(0, n_uniques);
using Triple = thrust::tuple<uint32_t, float, float>;
// expand to tuple to include class id
auto fptp_it_in = dh::MakeTransformIterator<Triple>(
thrust::make_counting_iterator(0), [=] __device__(size_t i) {
uint32_t class_id = i / n_samples;
return thrust::make_tuple(class_id, d_fptp[i].first, d_fptp[i].second);
});
// shrink down to pair
auto fptp_it_out = thrust::make_transform_output_iterator(
dh::tbegin(d_fptp), [=] __device__(Triple const &t) {
return thrust::make_pair(thrust::get<1>(t), thrust::get<2>(t));
});
dh::InclusiveScan(
fptp_it_in, fptp_it_out,
[=] __device__(Triple const &l, Triple const &r) {
uint32_t l_cid = thrust::get<0>(l);
uint32_t r_cid = thrust::get<0>(r);
if (l_cid != r_cid) {
return r;
}
return Triple(r_cid, // class_id
thrust::get<1>(l) + thrust::get<1>(r), // fp
thrust::get<2>(l) + thrust::get<2>(r)); // tp
},
d_fptp.size());
// scatter unique FP_PREV/TP_PREV values
auto d_neg_pos = dh::ToSpan(cache->neg_pos);
// When dataset is not empty, each class must have at least 1 (unique) sample
// prediction, so no need to handle special case.
dh::LaunchN(device, d_unique_idx.size(), [=]__device__(size_t i) {
if (d_unique_idx[i] % n_samples == 0) { // first unique index is 0
assert(d_unique_idx[i] % n_samples == 0);
d_neg_pos[d_unique_idx[i]] = {0, 0}; // class_id * n_samples = i
return;
}
uint32_t class_id = d_unique_idx[i] / n_samples;
d_neg_pos[d_unique_idx[i]] = d_fptp[d_unique_idx[i] - 1];
if (i == LastOf(class_id, d_unique_class_ptr)) {
// last one needs to be included.
size_t last = d_unique_idx[LastOf(class_id, d_unique_class_ptr)];
d_neg_pos[LastOf(class_id, d_class_ptr)] = d_fptp[last - 1];
return;
}
});
/**
* Reduce the result for each class
*/
auto key_in = dh::MakeTransformIterator<uint32_t>(
thrust::make_counting_iterator(0), [=] __device__(size_t i) {
size_t class_id = d_unique_idx[i] / n_samples;
return class_id;
});
auto val_in = dh::MakeTransformIterator<float>(
thrust::make_counting_iterator(0), [=] __device__(size_t i) {
size_t class_id = d_unique_idx[i] / n_samples;
float fp, tp;
float fp_prev, tp_prev;
if (i == d_unique_class_ptr[class_id]) {
// first item is ignored, we use this thread to calculate the last item
thrust::tie(fp, tp) = d_fptp[class_id * n_samples + (n_samples - 1)];
thrust::tie(fp_prev, tp_prev) =
d_neg_pos[d_unique_idx[LastOf(class_id, d_unique_class_ptr)]];
} else {
thrust::tie(fp, tp) = d_fptp[d_unique_idx[i] - 1];
thrust::tie(fp_prev, tp_prev) = d_neg_pos[d_unique_idx[i - 1]];
}
float auc = TrapesoidArea(fp_prev, fp, tp_prev, tp);
return auc;
});
thrust::reduce_by_key(thrust::cuda::par(alloc), key_in,
key_in + d_unique_idx.size(), val_in,
thrust::make_discard_iterator(), d_auc.begin());
/**
* Scale the classes with number of samples for each class.
*/
dh::TemporaryArray<float> resutls(n_classes * 4);
auto d_results = dh::ToSpan(resutls);
auto local_area = d_results.subspan(0, n_classes);
auto fp = d_results.subspan(n_classes, n_classes);
auto tp = d_results.subspan(2 * n_classes, n_classes);
auto auc = d_results.subspan(3 * n_classes, n_classes);
dh::LaunchN(device, n_classes, [=] __device__(size_t c) {
auc[c] = s_d_auc[c];
auto last = d_fptp[n_samples * c + (n_samples - 1)];
fp[c] = last.first;
tp[c] = last.second;
local_area[c] = last.first * last.second;
});
if (rabit::IsDistributed()) {
cache->reducer->AllReduceSum(resutls.data().get(), resutls.data().get(),
resutls.size());
}
auto reduce_in = dh::MakeTransformIterator<thrust::pair<float, float>>(
thrust::make_counting_iterator(0), [=] __device__(size_t i) {
if (local_area[i] > 0) {
return thrust::make_pair(auc[i] / local_area[i] * tp[i], tp[i]);
}
return thrust::make_pair(std::numeric_limits<float>::quiet_NaN(), 0.0f);
});
float tp_sum;
float auc_sum;
thrust::tie(auc_sum, tp_sum) = thrust::reduce(
thrust::cuda::par(alloc), reduce_in, reduce_in + n_classes,
thrust::make_pair(0.0f, 0.0f),
[=] __device__(auto const &l, auto const &r) {
return thrust::make_pair(l.first + r.first, l.second + r.second);
});
if (tp_sum != 0 && !std::isnan(auc_sum)) {
auc_sum /= tp_sum;
} else {
return std::numeric_limits<float>::quiet_NaN();
}
return auc_sum;
}
namespace {
struct RankScanItem {
size_t idx;
float predt;
float w;
bst_group_t group_id;
};
} // anonymous namespace
std::pair<float, uint32_t>
GPURankingAUC(common::Span<float const> predts, MetaInfo const &info,
int32_t device, std::shared_ptr<DeviceAUCCache> *p_cache) {
auto& cache = *p_cache;
if (!cache) {
cache.reset(new DeviceAUCCache);
}
cache->Init(predts, false, device);
dh::caching_device_vector<bst_group_t> group_ptr(info.group_ptr_);
dh::XGBCachingDeviceAllocator<char> alloc;
auto d_group_ptr = dh::ToSpan(group_ptr);
/**
* Validate the dataset
*/
auto check_it = dh::MakeTransformIterator<size_t>(
thrust::make_counting_iterator(0),
[=] __device__(size_t i) { return d_group_ptr[i + 1] - d_group_ptr[i]; });
size_t n_valid = thrust::count_if(
thrust::cuda::par(alloc), check_it, check_it + group_ptr.size() - 1,
[=] __device__(size_t len) { return len >= 3; });
if (n_valid < info.group_ptr_.size() - 1) {
InvalidGroupAUC();
}
if (n_valid == 0) {
return std::make_pair(0.0f, 0);
}
/**
* Sort the labels
*/
auto d_sorted_idx = dh::ToSpan(cache->sorted_idx);
auto d_labels = info.labels_.ConstDeviceSpan();
dh::Iota(d_sorted_idx, device);
dh::SegmentedArgSort<false>(d_labels, d_group_ptr, d_sorted_idx);
auto d_weights = info.weights_.ConstDeviceSpan();
dh::caching_device_vector<size_t> threads_group_ptr(group_ptr.size(), 0);
auto d_threads_group_ptr = dh::ToSpan(threads_group_ptr);
// Use max to represent triangle
auto n_threads = common::SegmentedTrapezoidThreads(
d_group_ptr, d_threads_group_ptr, std::numeric_limits<size_t>::max());
// get the coordinate in nested summation
auto get_i_j = [=]__device__(size_t idx, size_t query_group_idx) {
auto data_group_begin = d_group_ptr[query_group_idx];
size_t n_samples = d_group_ptr[query_group_idx + 1] - data_group_begin;
auto thread_group_begin = d_threads_group_ptr[query_group_idx];
auto idx_in_thread_group = idx - thread_group_begin;
size_t i, j;
common::UnravelTrapeziodIdx(idx_in_thread_group, n_samples, &i, &j);
// we use global index among all groups for sorted idx, so i, j should also be global
// index.
i += data_group_begin;
j += data_group_begin;
return thrust::make_pair(i, j);
}; // NOLINT
auto in = dh::MakeTransformIterator<RankScanItem>(
thrust::make_counting_iterator(0), [=] __device__(size_t idx) {
bst_group_t query_group_idx = dh::SegmentId(d_threads_group_ptr, idx);
auto data_group_begin = d_group_ptr[query_group_idx];
size_t n_samples = d_group_ptr[query_group_idx + 1] - data_group_begin;
if (n_samples < 3) {
// at least 3 documents are required.
return RankScanItem{idx, 0, 0, query_group_idx};
}
size_t i, j;
thrust::tie(i, j) = get_i_j(idx, query_group_idx);
float predt = predts[d_sorted_idx[i]] - predts[d_sorted_idx[j]];
float w = common::Sqr(d_weights.empty() ? 1.0f : d_weights[query_group_idx]);
if (predt > 0) {
predt = 1.0;
} else if (predt == 0) {
predt = 0.5;
} else {
predt = 0;
}
predt *= w;
return RankScanItem{idx, predt, w, query_group_idx};
});
dh::TemporaryArray<float> d_auc(group_ptr.size() - 1);
auto s_d_auc = dh::ToSpan(d_auc);
auto out = thrust::make_transform_output_iterator(
Discard<RankScanItem>(), [=] __device__(RankScanItem const &item) -> RankScanItem {
auto group_id = item.group_id;
assert(group_id < d_group_ptr.size());
auto data_group_begin = d_group_ptr[group_id];
size_t n_samples = d_group_ptr[group_id + 1] - data_group_begin;
// last item of current group
if (item.idx == LastOf(group_id, d_threads_group_ptr)) {
if (item.w > 0) {
s_d_auc[group_id] = item.predt / item.w;
} else {
s_d_auc[group_id] = 0;
}
}
return {}; // discard
});
dh::InclusiveScan(
in, out,
[] __device__(RankScanItem const &l, RankScanItem const &r) {
if (l.group_id != r.group_id) {
return r;
}
return RankScanItem{r.idx, l.predt + r.predt, l.w + r.w, l.group_id};
},
n_threads);
/**
* Scale the AUC with number of items in each group.
*/
float auc = thrust::reduce(thrust::cuda::par(alloc), dh::tbegin(s_d_auc),
dh::tend(s_d_auc), 0.0f);
return std::make_pair(auc, n_valid);
}
} // namespace metric
} // namespace xgboost

42
src/metric/auc.h Normal file
View File

@ -0,0 +1,42 @@
/*!
* Copyright 2021 by XGBoost Contributors
*/
#ifndef XGBOOST_METRIC_AUC_H_
#define XGBOOST_METRIC_AUC_H_
#include <cmath>
#include <memory>
#include <tuple>
#include <utility>
#include "rabit/rabit.h"
#include "xgboost/base.h"
#include "xgboost/span.h"
#include "xgboost/data.h"
namespace xgboost {
namespace metric {
XGBOOST_DEVICE inline float TrapesoidArea(float x0, float x1, float y0, float y1) {
return std::abs(x0 - x1) * (y0 + y1) * 0.5f;
}
struct DeviceAUCCache;
std::tuple<float, float, float>
GPUBinaryAUC(common::Span<float const> predts, MetaInfo const &info,
int32_t device, std::shared_ptr<DeviceAUCCache> *p_cache);
float GPUMultiClassAUCOVR(common::Span<float const> predts, MetaInfo const &info,
int32_t device, std::shared_ptr<DeviceAUCCache>* cache);
std::pair<float, uint32_t>
GPURankingAUC(common::Span<float const> predts, MetaInfo const &info,
int32_t device, std::shared_ptr<DeviceAUCCache> *cache);
inline void InvalidGroupAUC() {
LOG(INFO) << "Invalid group with less than 3 samples is found on worker "
<< rabit::GetRank() << ". Calculating AUC value requires at "
<< "least 2 pairs of samples.";
}
} // namespace metric
} // namespace xgboost
#endif // XGBOOST_METRIC_AUC_H_

View File

@ -156,134 +156,6 @@ struct EvalAMS : public Metric {
float ratio_; float ratio_;
}; };
/*! \brief Area Under Curve, for both classification and rank computed on CPU */
struct EvalAuc : public Metric {
private:
// This is used to compute the AUC metrics on the GPU - for ranking tasks and
// for training jobs that run on the GPU.
std::unique_ptr<xgboost::Metric> auc_gpu_;
template <typename WeightPolicy>
bst_float Eval(const HostDeviceVector<bst_float> &preds,
const MetaInfo &info,
bool distributed,
const std::vector<unsigned> &gptr) {
const auto ngroups = static_cast<bst_omp_uint>(gptr.size() - 1);
// sum of all AUC's across all query groups
double sum_auc = 0.0;
int auc_error = 0;
const auto& labels = info.labels_.ConstHostVector();
const auto &h_preds = preds.ConstHostVector();
dmlc::OMPException exc;
#pragma omp parallel reduction(+:sum_auc, auc_error) if (ngroups > 1)
{
exc.Run([&]() {
// Each thread works on a distinct group and sorts the predictions in that group
PredIndPairContainer rec;
#pragma omp for schedule(static)
for (bst_omp_uint group_id = 0; group_id < ngroups; ++group_id) {
exc.Run([&]() {
// Same thread can work on multiple groups one after another; hence, resize
// the predictions array based on the current group
rec.resize(gptr[group_id + 1] - gptr[group_id]);
#pragma omp parallel for schedule(static) if (!omp_in_parallel())
for (bst_omp_uint j = gptr[group_id]; j < gptr[group_id + 1]; ++j) {
exc.Run([&]() {
rec[j - gptr[group_id]] = {h_preds[j], j};
});
}
XGBOOST_PARALLEL_SORT(rec.begin(), rec.end(), common::CmpFirst);
// calculate AUC
double sum_pospair = 0.0;
double sum_npos = 0.0, sum_nneg = 0.0, buf_pos = 0.0, buf_neg = 0.0;
for (size_t j = 0; j < rec.size(); ++j) {
const bst_float wt = WeightPolicy::GetWeightOfSortedRecord(info, rec, j, group_id);
const bst_float ctr = labels[rec[j].second];
// keep bucketing predictions in same bucket
if (j != 0 && rec[j].first != rec[j - 1].first) {
sum_pospair += buf_neg * (sum_npos + buf_pos * 0.5);
sum_npos += buf_pos;
sum_nneg += buf_neg;
buf_neg = buf_pos = 0.0f;
}
buf_pos += ctr * wt;
buf_neg += (1.0f - ctr) * wt;
}
sum_pospair += buf_neg * (sum_npos + buf_pos * 0.5);
sum_npos += buf_pos;
sum_nneg += buf_neg;
// check weird conditions
if (sum_npos <= 0.0 || sum_nneg <= 0.0) {
auc_error += 1;
} else {
// this is the AUC
sum_auc += sum_pospair / (sum_npos * sum_nneg);
}
});
}
});
}
exc.Rethrow();
// Report average AUC across all groups
// In distributed mode, workers which only contains pos or neg samples
// will be ignored when aggregate AUC.
bst_float dat[2] = {0.0f, 0.0f};
if (auc_error < static_cast<int>(ngroups)) {
dat[0] = static_cast<bst_float>(sum_auc);
dat[1] = static_cast<bst_float>(static_cast<int>(ngroups) - auc_error);
}
if (distributed) {
rabit::Allreduce<rabit::op::Sum>(dat, 2);
}
CHECK_GT(dat[1], 0.0f)
<< "AUC: the dataset only contains pos or neg samples";
return dat[0] / dat[1];
}
public:
bst_float Eval(const HostDeviceVector<bst_float> &preds,
const MetaInfo &info,
bool distributed) override {
CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty";
CHECK_EQ(preds.Size(), info.labels_.Size())
<< "label size predict size not match";
std::vector<unsigned> tgptr(2, 0);
tgptr[1] = static_cast<unsigned>(info.labels_.Size());
const auto &gptr = info.group_ptr_.empty() ? tgptr : info.group_ptr_;
CHECK_EQ(gptr.back(), info.labels_.Size())
<< "EvalAuc: group structure must match number of prediction";
// For ranking task, weights are per-group
// For binary classification task, weights are per-instance
const bool is_ranking_task =
!info.group_ptr_.empty() && info.weights_.Size() != info.num_row_;
// Check if we have a GPU assignment; else, revert back to CPU
if (tparam_->gpu_id >= 0) {
if (!auc_gpu_) {
// Check and see if we have the GPU metric registered in the internal registry
auc_gpu_.reset(GPUMetric::CreateGPUMetric(this->Name(), tparam_));
}
if (auc_gpu_) {
return auc_gpu_->Eval(preds, info, distributed);
}
}
if (is_ranking_task) {
return Eval<PerGroupWeightPolicy>(preds, info, distributed, gptr);
} else {
return Eval<PerInstanceWeightPolicy>(preds, info, distributed, gptr);
}
}
const char *Name() const override { return "auc"; }
};
/*! \brief Evaluate rank list */ /*! \brief Evaluate rank list */
struct EvalRank : public Metric, public EvalRankConfig { struct EvalRank : public Metric, public EvalRankConfig {
private: private:
@ -672,10 +544,6 @@ XGBOOST_REGISTER_METRIC(AMS, "ams")
.describe("AMS metric for higgs.") .describe("AMS metric for higgs.")
.set_body([](const char* param) { return new EvalAMS(param); }); .set_body([](const char* param) { return new EvalAMS(param); });
XGBOOST_REGISTER_METRIC(Auc, "auc")
.describe("Area under curve for both classification and rank.")
.set_body([](const char*) { return new EvalAuc(); });
XGBOOST_REGISTER_METRIC(AucPR, "aucpr") XGBOOST_REGISTER_METRIC(AucPR, "aucpr")
.describe("Area under PR curve for both classification and rank.") .describe("Area under PR curve for both classification and rank.")
.set_body([](const char*) { return new EvalAucPR(); }); .set_body([](const char*) { return new EvalAucPR(); });

View File

@ -274,237 +274,6 @@ struct EvalMAPGpu {
} }
}; };
/*! \brief Area Under Curve metric computation for ranking datasets */
struct EvalAucGpu : public Metric {
public:
// This function object computes the positive precision pair for each prediction group
class ComputePosPair : public thrust::unary_function<uint32_t, double> {
public:
XGBOOST_DEVICE ComputePosPair(const double *pred_group_pos_precision,
const double *pred_group_neg_precision,
const double *pred_group_incr_precision)
: pred_group_pos_precision_(pred_group_pos_precision),
pred_group_neg_precision_(pred_group_neg_precision),
pred_group_incr_precision_(pred_group_incr_precision) {}
// Compute positive precision pair for the prediction group at 'idx'
__device__ __forceinline__ double operator()(uint32_t idx) const {
return pred_group_neg_precision_[idx] *
(pred_group_incr_precision_[idx] + pred_group_pos_precision_[idx] * 0.5);
}
private:
// Accumulated positive precision for the prediction group
const double *pred_group_pos_precision_{nullptr};
// Accumulated negative precision for the prediction group
const double *pred_group_neg_precision_{nullptr};
// Incremental positive precision for the prediction group
const double *pred_group_incr_precision_{nullptr};
};
template <typename T>
void ReleaseMemory(dh::caching_device_vector<T> &vec) { // NOLINT
dh::caching_device_vector<T>().swap(vec);
}
bst_float Eval(const HostDeviceVector<bst_float> &preds,
const MetaInfo &info,
bool distributed) override {
// Sanity check is done by the caller
std::vector<unsigned> tgptr(2, 0);
tgptr[1] = static_cast<unsigned>(info.labels_.Size());
const std::vector<unsigned> &gptr = info.group_ptr_.empty() ? tgptr : info.group_ptr_;
auto device = tparam_->gpu_id;
dh::safe_cuda(cudaSetDevice(device));
info.labels_.SetDevice(device);
preds.SetDevice(device);
info.weights_.SetDevice(device);
auto dpreds = preds.ConstDevicePointer();
auto dlabels = info.labels_.ConstDevicePointer();
auto dweights = info.weights_.ConstDevicePointer();
// Sort all the predictions (from one or more groups)
dh::SegmentSorter<float> segment_pred_sorter;
segment_pred_sorter.SortItems(dpreds, preds.Size(), gptr);
const auto &dsorted_preds = segment_pred_sorter.GetItemsSpan();
const auto &dpreds_orig_pos = segment_pred_sorter.GetOriginalPositionsSpan();
// Group info on device
const auto &dgroups = segment_pred_sorter.GetGroupsSpan();
uint32_t ngroups = segment_pred_sorter.GetNumGroups();
// Final values
double hsum_auc = 0.0;
unsigned hauc_error = 0;
int device_id = -1;
dh::safe_cuda(cudaGetDevice(&device_id));
// Allocator to be used for managing space overhead while performing reductions
dh::XGBCachingDeviceAllocator<char> alloc;
if (ngroups == 1) {
const auto nitems = segment_pred_sorter.GetNumItems();
// First, segment all the predictions in the group. This is required so that we can
// aggregate the positive and negative precisions within that prediction group
dh::caching_device_vector<unsigned> dpred_segs(nitems, 0);
auto *pred_seg_arr = dpred_segs.data().get();
// This is for getting the next segment number
dh::caching_device_vector<unsigned> seg_idx(1, 0);
auto *seg_idx_ptr = seg_idx.data().get();
dh::caching_device_vector<double> dbuf_pos(nitems, 0);
dh::caching_device_vector<double> dbuf_neg(nitems, 0);
auto *buf_pos_arr = dbuf_pos.data().get();
auto *buf_neg_arr = dbuf_neg.data().get();
dh::LaunchN(device_id, nitems, nullptr, [=] __device__(int idx) {
auto ctr = dlabels[dpreds_orig_pos[idx]];
// For ranking task, weights are per-group
// For binary classification task, weights are per-instance
const auto wt = dweights == nullptr ? 1.0f : dweights[dpreds_orig_pos[idx]];
buf_pos_arr[idx] = ctr * wt;
buf_neg_arr[idx] = (1.0f - ctr) * wt;
if (idx == nitems - 1 || dsorted_preds[idx] != dsorted_preds[idx + 1]) {
auto new_seg_idx = atomicAdd(seg_idx_ptr, 1);
auto pred_val = dsorted_preds[idx];
do {
pred_seg_arr[idx] = new_seg_idx;
idx--;
} while (idx >= 0 && dsorted_preds[idx] == pred_val);
}
});
std::array<uint32_t, 1> h_nunique_preds;
dh::safe_cuda(cudaMemcpyAsync(h_nunique_preds.data(),
seg_idx.data().get() + seg_idx.size() - 1,
sizeof(uint32_t), cudaMemcpyDeviceToHost));
auto nunique_preds = h_nunique_preds.back();
ReleaseMemory(seg_idx);
// Next, accumulate the positive and negative precisions for every prediction group
dh::caching_device_vector<double> sum_dbuf_pos(nunique_preds, 0);
auto itr = thrust::reduce_by_key(thrust::cuda::par(alloc),
dpred_segs.begin(), dpred_segs.end(), // Segmented by this
dbuf_pos.begin(), // Individual precisions
thrust::make_discard_iterator(), // Ignore unique segments
sum_dbuf_pos.begin()); // Write accumulated results here
ReleaseMemory(dbuf_pos);
CHECK(itr.second - sum_dbuf_pos.begin() == nunique_preds);
dh::caching_device_vector<double> sum_dbuf_neg(nunique_preds, 0);
itr = thrust::reduce_by_key(thrust::cuda::par(alloc),
dpred_segs.begin(), dpred_segs.end(),
dbuf_neg.begin(),
thrust::make_discard_iterator(),
sum_dbuf_neg.begin());
ReleaseMemory(dbuf_neg);
ReleaseMemory(dpred_segs);
CHECK(itr.second - sum_dbuf_neg.begin() == nunique_preds);
dh::caching_device_vector<double> sum_nneg(nunique_preds, 0);
thrust::inclusive_scan(thrust::cuda::par(alloc),
sum_dbuf_neg.begin(), sum_dbuf_neg.end(),
sum_nneg.begin());
double sum_neg_prec_val = sum_nneg.back();
ReleaseMemory(sum_nneg);
// Find incremental sum for the positive precisions that is then used to
// compute incremental positive precision pair
dh::caching_device_vector<double> sum_npos(nunique_preds + 1, 0);
thrust::inclusive_scan(thrust::cuda::par(alloc),
sum_dbuf_pos.begin(), sum_dbuf_pos.end(),
sum_npos.begin() + 1);
double sum_pos_prec_val = sum_npos.back();
if (sum_pos_prec_val <= 0.0 || sum_neg_prec_val <= 0.0) {
hauc_error = 1;
} else {
dh::caching_device_vector<double> sum_pospair(nunique_preds, 0);
// Finally, compute the positive precision pair
thrust::transform(thrust::make_counting_iterator(static_cast<uint32_t>(0)),
thrust::make_counting_iterator(static_cast<uint32_t>(nunique_preds)),
sum_pospair.begin(),
ComputePosPair(sum_dbuf_pos.data().get(),
sum_dbuf_neg.data().get(),
sum_npos.data().get()));
ReleaseMemory(sum_dbuf_pos);
ReleaseMemory(sum_dbuf_neg);
ReleaseMemory(sum_npos);
hsum_auc = thrust::reduce(thrust::cuda::par(alloc),
sum_pospair.begin(), sum_pospair.end())
/ (sum_pos_prec_val * sum_neg_prec_val);
}
} else {
// AUC sum for each group
dh::caching_device_vector<double> sum_auc(ngroups, 0);
// AUC error across all groups
dh::caching_device_vector<int> auc_error(1, 0);
auto *dsum_auc = sum_auc.data().get();
auto *dauc_error = auc_error.data().get();
// For each group item compute the aggregated precision
dh::LaunchN<1, 32>(device_id, ngroups, nullptr, [=] __device__(uint32_t gidx) {
double sum_pospair = 0.0, sum_npos = 0.0, sum_nneg = 0.0, buf_pos = 0.0, buf_neg = 0.0;
for (auto i = dgroups[gidx]; i < dgroups[gidx + 1]; ++i) {
const auto ctr = dlabels[dpreds_orig_pos[i]];
// Keep bucketing predictions in same bucket
if (i != dgroups[gidx] && dsorted_preds[i] != dsorted_preds[i - 1]) {
sum_pospair += buf_neg * (sum_npos + buf_pos * 0.5);
sum_npos += buf_pos;
sum_nneg += buf_neg;
buf_neg = buf_pos = 0.0f;
}
// For ranking task, weights are per-group
// For binary classification task, weights are per-instance
const auto wt = dweights == nullptr ? 1.0f : dweights[gidx];
buf_pos += ctr * wt;
buf_neg += (1.0f - ctr) * wt;
}
sum_pospair += buf_neg * (sum_npos + buf_pos * 0.5);
sum_npos += buf_pos;
sum_nneg += buf_neg;
// Check weird conditions
if (sum_npos <= 0.0 || sum_nneg <= 0.0) {
atomicAdd(dauc_error, 1);
} else {
// This is the AUC
dsum_auc[gidx] = sum_pospair / (sum_npos * sum_nneg);
}
});
hsum_auc = thrust::reduce(thrust::cuda::par(alloc), sum_auc.begin(), sum_auc.end());
hauc_error = auc_error.back(); // Copy it back to host
}
// Report average AUC across all groups
// In distributed mode, workers which only contains pos or neg samples
// will be ignored when aggregate AUC.
bst_float dat[2] = {0.0f, 0.0f};
if (hauc_error < ngroups) {
dat[0] = static_cast<bst_float>(hsum_auc);
dat[1] = static_cast<bst_float>(ngroups - hauc_error);
}
if (distributed) {
rabit::Allreduce<rabit::op::Sum>(dat, 2);
}
CHECK_GT(dat[1], 0.0f)
<< "AUC: the dataset only contains pos or neg samples";
return dat[0] / dat[1];
}
const char* Name() const override {
return "auc";
}
};
/*! \brief Area Under PR Curve metric computation for ranking datasets */ /*! \brief Area Under PR Curve metric computation for ranking datasets */
struct EvalAucPRGpu : public Metric { struct EvalAucPRGpu : public Metric {
public: public:
@ -691,10 +460,6 @@ struct EvalAucPRGpu : public Metric {
} }
}; };
XGBOOST_REGISTER_GPU_METRIC(AucGpu, "auc")
.describe("Area under curve for rank computed on GPU.")
.set_body([](const char* param) { return new EvalAucGpu(); });
XGBOOST_REGISTER_GPU_METRIC(AucPRGpu, "aucpr") XGBOOST_REGISTER_GPU_METRIC(AucPRGpu, "aucpr")
.describe("Area under PR curve for rank computed on GPU.") .describe("Area under PR curve for rank computed on GPU.")
.set_body([](const char* param) { return new EvalAucPRGpu(); }); .set_body([](const char* param) { return new EvalAucPRGpu(); });

View File

@ -293,7 +293,7 @@ class NDCGLambdaWeightComputer
group_segments)), group_segments)),
thrust::make_discard_iterator(), // We don't care for the group indices thrust::make_discard_iterator(), // We don't care for the group indices
dgroup_dcg_.begin()); // Sum of the item's DCG values in the group dgroup_dcg_.begin()); // Sum of the item's DCG values in the group
CHECK(static_cast<unsigned>(end_range.second - dgroup_dcg_.begin()) == dgroup_dcg_.size()); CHECK_EQ(static_cast<unsigned>(end_range.second - dgroup_dcg_.begin()), dgroup_dcg_.size());
} }
inline const common::Span<const float> GetGroupDcgsSpan() const { inline const common::Span<const float> GetGroupDcgsSpan() const {

View File

@ -15,6 +15,7 @@
#include "xgboost/parameter.h" #include "xgboost/parameter.h"
#include "xgboost/data.h" #include "xgboost/data.h"
#include "../common/math.h"
namespace xgboost { namespace xgboost {
namespace tree { namespace tree {
@ -264,14 +265,11 @@ XGBOOST_DEVICE inline static T1 ThresholdL1(T1 w, T2 alpha) {
return 0.0; return 0.0;
} }
template <typename T>
XGBOOST_DEVICE inline static T Sqr(T a) { return a * a; }
// calculate the cost of loss function // calculate the cost of loss function
template <typename TrainingParams, typename T> template <typename TrainingParams, typename T>
XGBOOST_DEVICE inline T CalcGainGivenWeight(const TrainingParams &p, XGBOOST_DEVICE inline T CalcGainGivenWeight(const TrainingParams &p,
T sum_grad, T sum_hess, T w) { T sum_grad, T sum_hess, T w) {
return -(T(2.0) * sum_grad * w + (sum_hess + p.reg_lambda) * Sqr(w)); return -(T(2.0) * sum_grad * w + (sum_hess + p.reg_lambda) * common::Sqr(w));
} }
// calculate weight given the statistics // calculate weight given the statistics
@ -296,9 +294,9 @@ XGBOOST_DEVICE inline T CalcGain(const TrainingParams &p, T sum_grad, T sum_hess
} }
if (p.max_delta_step == 0.0f) { if (p.max_delta_step == 0.0f) {
if (p.reg_alpha == 0.0f) { if (p.reg_alpha == 0.0f) {
return Sqr(sum_grad) / (sum_hess + p.reg_lambda); return common::Sqr(sum_grad) / (sum_hess + p.reg_lambda);
} else { } else {
return Sqr(ThresholdL1(sum_grad, p.reg_alpha)) / return common::Sqr(ThresholdL1(sum_grad, p.reg_alpha)) /
(sum_hess + p.reg_lambda); (sum_hess + p.reg_lambda);
} }
} else { } else {

View File

@ -114,7 +114,7 @@ class TreeEvaluator {
} }
// Avoiding tree::CalcGainGivenWeight can significantly reduce avg floating point error. // Avoiding tree::CalcGainGivenWeight can significantly reduce avg floating point error.
if (p.max_delta_step == 0.0f && has_constraint == false) { if (p.max_delta_step == 0.0f && has_constraint == false) {
return Sqr(ThresholdL1(stats.sum_grad, p.reg_alpha)) / return common::Sqr(ThresholdL1(stats.sum_grad, p.reg_alpha)) /
(stats.sum_hess + p.reg_lambda); (stats.sum_hess + p.reg_lambda);
} }
return tree::CalcGainGivenWeight<ParamT, float>(p, stats.sum_grad, return tree::CalcGainGivenWeight<ParamT, float>(p, stats.sum_grad,

View File

@ -1,11 +1,12 @@
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <xgboost/span.h>
#include "../../../src/common/common.h" #include "../../../src/common/common.h"
namespace xgboost { namespace xgboost {
namespace common { namespace common {
TEST(ArgSort, Basic) { TEST(ArgSort, Basic) {
std::vector<float> inputs {3.0, 2.0, 1.0}; std::vector<float> inputs {3.0, 2.0, 1.0};
auto ret = ArgSort<bst_feature_t>(inputs); auto ret = ArgSort<bst_feature_t>(Span<float>{inputs});
std::vector<bst_feature_t> sol{2, 1, 0}; std::vector<bst_feature_t> sol{2, 1, 0};
ASSERT_EQ(ret, sol); ASSERT_EQ(ret, sol);
} }

View File

@ -0,0 +1,66 @@
#include <gtest/gtest.h>
#include "../../../src/common/ranking_utils.cuh"
#include "../../../src/common/device_helpers.cuh"
namespace xgboost {
namespace common {
TEST(SegmentedTrapezoidThreads, Basic) {
size_t constexpr kElements = 24, kGroups = 3;
dh::device_vector<size_t> offset_ptr(kGroups + 1, 0);
offset_ptr[0] = 0;
offset_ptr[1] = 8;
offset_ptr[2] = 16;
offset_ptr[kGroups] = kElements;
size_t h = 1;
dh::device_vector<size_t> thread_ptr(kGroups + 1, 0);
size_t total = SegmentedTrapezoidThreads(dh::ToSpan(offset_ptr), dh::ToSpan(thread_ptr), h);
ASSERT_EQ(total, kElements - kGroups);
h = 2;
SegmentedTrapezoidThreads(dh::ToSpan(offset_ptr), dh::ToSpan(thread_ptr), h);
std::vector<size_t> h_thread_ptr(thread_ptr.size());
thrust::copy(thread_ptr.cbegin(), thread_ptr.cend(), h_thread_ptr.begin());
for (size_t i = 1; i < h_thread_ptr.size(); ++i) {
ASSERT_EQ(h_thread_ptr[i] - h_thread_ptr[i - 1], 13);
}
h = 7;
SegmentedTrapezoidThreads(dh::ToSpan(offset_ptr), dh::ToSpan(thread_ptr), h);
thrust::copy(thread_ptr.cbegin(), thread_ptr.cend(), h_thread_ptr.begin());
for (size_t i = 1; i < h_thread_ptr.size(); ++i) {
ASSERT_EQ(h_thread_ptr[i] - h_thread_ptr[i - 1], 28);
}
}
TEST(SegmentedTrapezoidThreads, Unravel) {
size_t i = 0, j = 0;
size_t constexpr kN = 8;
UnravelTrapeziodIdx(6, kN, &i, &j);
ASSERT_EQ(i, 0);
ASSERT_EQ(j, 7);
UnravelTrapeziodIdx(12, kN, &i, &j);
ASSERT_EQ(i, 1);
ASSERT_EQ(j, 7);
UnravelTrapeziodIdx(15, kN, &i, &j);
ASSERT_EQ(i, 2);
ASSERT_EQ(j, 5);
UnravelTrapeziodIdx(21, kN, &i, &j);
ASSERT_EQ(i, 3);
ASSERT_EQ(j, 7);
UnravelTrapeziodIdx(25, kN, &i, &j);
ASSERT_EQ(i, 5);
ASSERT_EQ(j, 6);
UnravelTrapeziodIdx(27, kN, &i, &j);
ASSERT_EQ(i, 6);
ASSERT_EQ(j, 7);
}
} // namespace common
} // namespace xgboost

View File

@ -0,0 +1,133 @@
#include <xgboost/metric.h>
#include "../helpers.h"
namespace xgboost {
namespace metric {
TEST(Metric, DeclareUnifiedTest(BinaryAUC)) {
auto tparam = xgboost::CreateEmptyGenericParam(GPUIDX);
std::unique_ptr<Metric> uni_ptr {Metric::Create("auc", &tparam)};
Metric * metric = uni_ptr.get();
ASSERT_STREQ(metric->Name(), "auc");
// Binary
EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}), 1.0f, 1e-10);
EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {1, 0}), 0.0f, 1e-10);
EXPECT_NEAR(GetMetricEval(metric, {0, 0}, {0, 1}), 0.5f, 1e-10);
EXPECT_NEAR(GetMetricEval(metric, {1, 1}, {0, 1}), 0.5f, 1e-10);
EXPECT_NEAR(GetMetricEval(metric, {0, 0}, {1, 0}), 0.5f, 1e-10);
EXPECT_NEAR(GetMetricEval(metric, {1, 1}, {1, 0}), 0.5f, 1e-10);
EXPECT_NEAR(GetMetricEval(metric, {1, 0, 0}, {0, 0, 1}), 0.25f, 1e-10);
// Invalid dataset
MetaInfo info;
info.labels_ = {0, 0};
float auc = metric->Eval({1, 1}, info, false);
ASSERT_TRUE(std::isnan(auc));
info.labels_ = HostDeviceVector<float>{};
auc = metric->Eval(HostDeviceVector<float>{}, info, false);
ASSERT_TRUE(std::isnan(auc));
EXPECT_NEAR(GetMetricEval(metric, {0, 1, 0, 1}, {0, 1, 0, 1}), 1.0f, 1e-10);
// AUC with instance weights
EXPECT_NEAR(GetMetricEval(metric,
{0.9f, 0.1f, 0.4f, 0.3f},
{0, 0, 1, 1},
{1.0f, 3.0f, 2.0f, 4.0f}),
0.75f, 0.001f);
// regression test case
ASSERT_NEAR(GetMetricEval(
metric,
{0.79523796, 0.5201713, 0.79523796, 0.24273258, 0.53452194,
0.53452194, 0.24273258, 0.5201713, 0.79523796, 0.53452194,
0.24273258, 0.53452194, 0.79523796, 0.5201713, 0.24273258,
0.5201713, 0.5201713, 0.53452194, 0.5201713, 0.53452194},
{0, 1, 0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0}),
0.5, 1e-10);
}
TEST(Metric, DeclareUnifiedTest(MultiAUC)) {
auto tparam = CreateEmptyGenericParam(GPUIDX);
std::unique_ptr<Metric> uni_ptr{
Metric::Create("auc", &tparam)};
auto metric = uni_ptr.get();
// MultiClass
// 3x3
EXPECT_NEAR(GetMetricEval(metric,
{
1.0f, 0.0f, 0.0f, // p_0
0.0f, 1.0f, 0.0f, // p_1
0.0f, 0.0f, 1.0f // p_2
},
{0, 1, 2}),
1.0f, 1e-10);
EXPECT_NEAR(GetMetricEval(metric,
{
1.0f, 0.0f, 0.0f, // p_0
0.0f, 1.0f, 0.0f, // p_1
0.0f, 0.0f, 1.0f // p_2
},
{2, 1, 0}),
0.5f, 1e-10);
EXPECT_NEAR(GetMetricEval(metric,
{
1.0f, 0.0f, 0.0f, // p_0
0.0f, 1.0f, 0.0f, // p_1
0.0f, 0.0f, 1.0f // p_2
},
{2, 0, 1}),
0.25f, 1e-10);
// invalid dataset
float auc = GetMetricEval(metric,
{
1.0f, 0.0f, 0.0f, // p_0
0.0f, 1.0f, 0.0f, // p_1
0.0f, 0.0f, 1.0f // p_2
},
{0, 1, 1}); // no class 2.
EXPECT_TRUE(std::isnan(auc)) << auc;
}
TEST(Metric, DeclareUnifiedTest(RankingAUC)) {
auto tparam = CreateEmptyGenericParam(GPUIDX);
std::unique_ptr<Metric> metric{Metric::Create("auc", &tparam)};
// single group
EXPECT_NEAR(GetMetricEval(metric.get(), {0.7f, 0.2f, 0.3f, 0.6f},
{1.0f, 0.8f, 0.4f, 0.2f}, /*weights=*/{},
{0, 4}),
0.5f, 1e-10);
// multi group
EXPECT_NEAR(GetMetricEval(metric.get(), {0, 1, 2, 0, 1, 2},
{0, 1, 2, 0, 1, 2}, /*weights=*/{}, {0, 3, 6}),
1.0f, 1e-10);
EXPECT_NEAR(GetMetricEval(metric.get(), {0, 1, 2, 0, 1, 2},
{0, 1, 2, 0, 1, 2}, /*weights=*/{1.0f, 2.0f},
{0, 3, 6}),
1.0f, 1e-10);
// AUC metric for grouped datasets - exception scenarios
ASSERT_TRUE(std::isnan(
GetMetricEval(metric.get(), {0, 1, 2}, {0, 0, 0}, {}, {0, 2, 3})));
// regression case
HostDeviceVector<float> predt{0.33935383, 0.5149714, 0.32138085, 1.4547751,
1.2010975, 0.42651367, 0.23104341, 0.83610827,
0.8494239, 0.07136688, 0.5623144, 0.8086237,
1.5066161, -4.094787, 0.76887935, -2.4082742};
std::vector<bst_group_t> groups{0, 7, 16};
std::vector<float> labels{1., 0., 0., 1., 2., 1., 0., 0.,
0., 0., 0., 0., 1., 0., 1., 0.};
EXPECT_NEAR(GetMetricEval(metric.get(), std::move(predt), labels,
/*weights=*/{}, groups),
0.769841f, 1e-6);
}
} // namespace metric
} // namespace xgboost

View File

@ -0,0 +1,5 @@
/*!
* Copyright 2021 XGBoost contributors
*/
// Dummy file to keep the CUDA conditional compile trick.
#include "test_auc.cc"

View File

@ -24,49 +24,6 @@ TEST(Metric, AMS) {
} }
#endif #endif
TEST(Metric, DeclareUnifiedTest(AUC)) {
auto tparam = xgboost::CreateEmptyGenericParam(GPUIDX);
xgboost::Metric * metric = xgboost::Metric::Create("auc", &tparam);
ASSERT_STREQ(metric->Name(), "auc");
EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}), 1, 1e-10);
EXPECT_NEAR(GetMetricEval(metric,
{0.1f, 0.9f, 0.1f, 0.9f},
{ 0, 0, 1, 1}),
0.5f, 0.001f);
EXPECT_ANY_THROW(GetMetricEval(metric, {0, 1}, {}));
EXPECT_ANY_THROW(GetMetricEval(metric, {0, 0}, {0, 0}));
EXPECT_ANY_THROW(GetMetricEval(metric, {0, 1}, {1, 1}));
// AUC with instance weights
EXPECT_NEAR(GetMetricEval(metric,
{0.9f, 0.1f, 0.4f, 0.3f},
{0, 0, 1, 1},
{1.0f, 3.0f, 2.0f, 4.0f}),
0.75f, 0.001f);
// AUC for a ranking task without weights
EXPECT_NEAR(GetMetricEval(metric,
{0.9f, 0.1f, 0.4f, 0.3f, 0.7f},
{0, 1, 0, 1, 1},
{},
{0, 2, 5}),
0.25f, 0.001f);
// AUC for a ranking task with weights/group
EXPECT_NEAR(GetMetricEval(metric,
{0.9f, 0.1f, 0.4f, 0.3f, 0.7f},
{1, 0, 1, 0, 0},
{1, 2},
{0, 2, 5}),
0.75f, 0.001f);
// AUC metric for grouped datasets - exception scenarios
EXPECT_ANY_THROW(GetMetricEval(metric, {0, 1, 2}, {0, 0, 0}, {}, {0, 2, 3}));
EXPECT_ANY_THROW(GetMetricEval(metric, {0, 1, 2}, {1, 1, 1}, {}, {0, 2, 3}));
delete metric;
}
TEST(Metric, DeclareUnifiedTest(AUCPR)) { TEST(Metric, DeclareUnifiedTest(AUCPR)) {
auto tparam = xgboost::CreateEmptyGenericParam(GPUIDX); auto tparam = xgboost::CreateEmptyGenericParam(GPUIDX);
xgboost::Metric *metric = xgboost::Metric::Create("aucpr", &tparam); xgboost::Metric *metric = xgboost::Metric::Create("aucpr", &tparam);

View File

@ -42,6 +42,7 @@ def local_cuda_cluster(request, pytestconfig):
def pytest_addoption(parser): def pytest_addoption(parser):
parser.addoption('--use-rmm-pool', action='store_true', default=False, help='Use RMM pool') parser.addoption('--use-rmm-pool', action='store_true', default=False, help='Use RMM pool')
def pytest_collection_modifyitems(config, items): def pytest_collection_modifyitems(config, items):
if config.getoption('--use-rmm-pool'): if config.getoption('--use-rmm-pool'):
blocklist = [ blocklist = [
@ -53,3 +54,9 @@ def pytest_collection_modifyitems(config, items):
for item in items: for item in items:
if any(item.nodeid.startswith(x) for x in blocklist): if any(item.nodeid.startswith(x) for x in blocklist):
item.add_marker(skip_mark) item.add_marker(skip_mark)
# mark dask tests as `mgpu`.
mgpu_mark = pytest.mark.mgpu
for item in items:
if item.nodeid.startswith("python-gpu/test_gpu_with_dask.py"):
item.add_marker(mgpu_mark)

View File

@ -0,0 +1,47 @@
import sys
import xgboost
import pytest
sys.path.append("tests/python")
import test_eval_metrics as test_em # noqa
class TestGPUEvalMetrics:
cpu_test = test_em.TestEvalMetrics()
@pytest.mark.parametrize("n_samples", [4, 100, 1000])
def test_roc_auc_binary(self, n_samples):
self.cpu_test.run_roc_auc_binary("gpu_hist", n_samples)
@pytest.mark.parametrize("n_samples", [4, 100, 1000])
def test_roc_auc_multi(self, n_samples):
self.cpu_test.run_roc_auc_multi("gpu_hist", n_samples)
@pytest.mark.parametrize("n_samples", [4, 100, 1000])
def test_roc_auc_ltr(self, n_samples):
import numpy as np
rng = np.random.RandomState(1994)
n_samples = n_samples
n_features = 10
X = rng.randn(n_samples, n_features)
y = rng.randint(0, 16, size=n_samples)
group = np.array([n_samples // 2, n_samples // 2])
Xy = xgboost.DMatrix(X, y, group=group)
cpu = xgboost.train(
{"tree_method": "hist", "eval_metric": "auc", "objective": "rank:ndcg"},
Xy,
num_boost_round=10,
)
cpu_auc = float(cpu.eval(Xy).split(":")[1])
gpu = xgboost.train(
{"tree_method": "gpu_hist", "eval_metric": "auc", "objective": "rank:ndcg"},
Xy,
num_boost_round=10,
)
gpu_auc = float(gpu.eval(Xy).split(":")[1])
np.testing.assert_allclose(cpu_auc, gpu_auc)

View File

@ -5,6 +5,10 @@ import itertools
import shutil import shutil
import urllib.request import urllib.request
import zipfile import zipfile
import sys
sys.path.append("tests/python")
import testing as tm # noqa
class TestRanking: class TestRanking:
@ -15,9 +19,9 @@ class TestRanking:
""" """
from sklearn.datasets import load_svmlight_files from sklearn.datasets import load_svmlight_files
# download the test data # download the test data
cls.dpath = 'demo/rank/' cls.dpath = os.path.join(tm.PROJECT_ROOT, "demo/rank/")
src = 'https://s3-us-west-2.amazonaws.com/xgboost-examples/MQ2008.zip' src = 'https://s3-us-west-2.amazonaws.com/xgboost-examples/MQ2008.zip'
target = cls.dpath + '/MQ2008.zip' target = os.path.join(cls.dpath, "MQ2008.zip")
if os.path.exists(cls.dpath) and os.path.exists(target): if os.path.exists(cls.dpath) and os.path.exists(target):
print("Skipping dataset download...") print("Skipping dataset download...")
@ -79,8 +83,8 @@ class TestRanking:
Cleanup test artifacts from download and unpacking Cleanup test artifacts from download and unpacking
:return: :return:
""" """
os.remove(cls.dpath + "MQ2008.zip") os.remove(os.path.join(cls.dpath, "MQ2008.zip"))
shutil.rmtree(cls.dpath + "MQ2008") shutil.rmtree(os.path.join(cls.dpath, "MQ2008"))
@classmethod @classmethod
def __test_training_with_rank_objective(cls, rank_objective, metric_name, tolerance=1e-02): def __test_training_with_rank_objective(cls, rank_objective, metric_name, tolerance=1e-02):

View File

@ -17,6 +17,8 @@ if sys.platform.startswith("win"):
sys.path.append("tests/python") sys.path.append("tests/python")
from test_with_dask import run_empty_dmatrix_reg # noqa from test_with_dask import run_empty_dmatrix_reg # noqa
from test_with_dask import run_empty_dmatrix_auc # noqa
from test_with_dask import run_auc # noqa
from test_with_dask import run_boost_from_prediction # noqa from test_with_dask import run_boost_from_prediction # noqa
from test_with_dask import run_dask_classifier # noqa from test_with_dask import run_dask_classifier # noqa
from test_with_dask import run_empty_dmatrix_cls # noqa from test_with_dask import run_empty_dmatrix_cls # noqa
@ -286,6 +288,15 @@ class TestDistributedGPU:
run_empty_dmatrix_reg(client, parameters) run_empty_dmatrix_reg(client, parameters)
run_empty_dmatrix_cls(client, parameters) run_empty_dmatrix_cls(client, parameters)
def test_empty_dmatrix_auc(self, local_cuda_cluster: LocalCUDACluster) -> None:
with Client(local_cuda_cluster) as client:
n_workers = len(_get_client_workers(client))
run_empty_dmatrix_auc(client, "gpu_hist", n_workers)
def test_auc(self, local_cuda_cluster: LocalCUDACluster) -> None:
with Client(local_cuda_cluster) as client:
run_auc(client, "gpu_hist")
def test_data_initialization(self, local_cuda_cluster: LocalCUDACluster) -> None: def test_data_initialization(self, local_cuda_cluster: LocalCUDACluster) -> None:
with Client(local_cuda_cluster) as client: with Client(local_cuda_cluster) as client:
X, y, _ = generate_array() X, y, _ = generate_array()

View File

@ -123,3 +123,90 @@ class TestEvalMetrics:
gamma_dev = float(booster.eval(xgb.DMatrix(X, y)).split(":")[1].split(":")[0]) gamma_dev = float(booster.eval(xgb.DMatrix(X, y)).split(":")[1].split(":")[0])
skl_gamma_dev = mean_gamma_deviance(y, score) skl_gamma_dev = mean_gamma_deviance(y, score)
np.testing.assert_allclose(gamma_dev, skl_gamma_dev, rtol=1e-6) np.testing.assert_allclose(gamma_dev, skl_gamma_dev, rtol=1e-6)
def run_roc_auc_binary(self, tree_method, n_samples):
import numpy as np
from sklearn.datasets import make_classification
from sklearn.metrics import roc_auc_score
rng = np.random.RandomState(1994)
n_samples = n_samples
n_features = 10
X, y = make_classification(
n_samples,
n_features,
n_informative=n_features,
n_redundant=0,
random_state=rng
)
Xy = xgb.DMatrix(X, y)
booster = xgb.train(
{
"tree_method": tree_method,
"eval_metric": "auc",
"objective": "binary:logistic",
},
Xy,
num_boost_round=8,
)
score = booster.predict(Xy)
skl_auc = roc_auc_score(y, score)
auc = float(booster.eval(Xy).split(":")[1])
np.testing.assert_allclose(skl_auc, auc, rtol=1e-6)
X = rng.randn(*X.shape)
score = booster.predict(xgb.DMatrix(X))
skl_auc = roc_auc_score(y, score)
auc = float(booster.eval(xgb.DMatrix(X, y)).split(":")[1])
np.testing.assert_allclose(skl_auc, auc, rtol=1e-6)
@pytest.mark.skipif(**tm.no_sklearn())
@pytest.mark.parametrize("n_samples", [4, 100, 1000])
def test_roc_auc(self, n_samples):
self.run_roc_auc_binary("hist", n_samples)
def run_roc_auc_multi(self, tree_method, n_samples):
import numpy as np
from sklearn.datasets import make_classification
from sklearn.metrics import roc_auc_score
rng = np.random.RandomState(1994)
n_samples = n_samples
n_features = 10
n_classes = 4
X, y = make_classification(
n_samples,
n_features,
n_informative=n_features,
n_redundant=0,
n_classes=n_classes,
random_state=rng
)
Xy = xgb.DMatrix(X, y)
booster = xgb.train(
{
"tree_method": tree_method,
"eval_metric": "auc",
"objective": "multi:softprob",
"num_class": n_classes,
},
Xy,
num_boost_round=8,
)
score = booster.predict(Xy)
skl_auc = roc_auc_score(y, score, average="weighted", multi_class="ovr")
auc = float(booster.eval(Xy).split(":")[1])
np.testing.assert_allclose(skl_auc, auc, rtol=1e-6)
X = rng.randn(*X.shape)
score = booster.predict(xgb.DMatrix(X))
skl_auc = roc_auc_score(y, score, average="weighted", multi_class="ovr")
auc = float(booster.eval(xgb.DMatrix(X, y)).split(":")[1])
np.testing.assert_allclose(skl_auc, auc, rtol=1e-6)
@pytest.mark.parametrize("n_samples", [4, 100, 1000])
def test_roc_auc_multi(self, n_samples):
self.run_roc_auc_multi("hist", n_samples)

View File

@ -9,6 +9,7 @@ import scipy
import json import json
from typing import List, Tuple, Dict, Optional, Type, Any from typing import List, Tuple, Dict, Optional, Type, Any
import asyncio import asyncio
from functools import partial
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
import tempfile import tempfile
from sklearn.datasets import make_classification from sklearn.datasets import make_classification
@ -528,9 +529,106 @@ def run_empty_dmatrix_cls(client: "Client", parameters: dict) -> None:
_check_outputs(out, predictions) _check_outputs(out, predictions)
def run_empty_dmatrix_auc(client: "Client", tree_method: str, n_workers: int) -> None:
from sklearn import datasets
n_samples = 100
n_features = 97
rng = np.random.RandomState(1994)
make_classification = partial(
datasets.make_classification,
n_features=n_features,
random_state=rng
)
# binary
X_, y_ = make_classification(n_samples=n_samples, random_state=rng)
X = dd.from_array(X_, chunksize=10)
y = dd.from_array(y_, chunksize=10)
n_samples = n_workers - 1
valid_X_, valid_y_ = make_classification(n_samples=n_samples, random_state=rng)
valid_X = dd.from_array(valid_X_, chunksize=n_samples)
valid_y = dd.from_array(valid_y_, chunksize=n_samples)
cls = xgb.dask.DaskXGBClassifier(
tree_method=tree_method, n_estimators=2, use_label_encoder=False
)
cls.fit(X, y, eval_metric="auc", eval_set=[(valid_X, valid_y)])
# multiclass
X_, y_ = make_classification(
n_samples=n_samples,
n_classes=10,
n_informative=n_features,
n_redundant=0,
n_repeated=0
)
X = dd.from_array(X_, chunksize=10)
y = dd.from_array(y_, chunksize=10)
n_samples = n_workers - 1
valid_X_, valid_y_ = make_classification(
n_samples=n_samples,
n_classes=10,
n_informative=n_features,
n_redundant=0,
n_repeated=0
)
valid_X = dd.from_array(valid_X_, chunksize=n_samples)
valid_y = dd.from_array(valid_y_, chunksize=n_samples)
cls = xgb.dask.DaskXGBClassifier(
tree_method=tree_method, n_estimators=2, use_label_encoder=False
)
cls.fit(X, y, eval_metric="auc", eval_set=[(valid_X, valid_y)])
def test_empty_dmatrix_auc() -> None:
with LocalCluster(n_workers=2) as cluster:
with Client(cluster) as client:
run_empty_dmatrix_auc(client, "hist", 2)
def run_auc(client: "Client", tree_method: str) -> None:
from sklearn import datasets
n_samples = 100
n_features = 97
rng = np.random.RandomState(1994)
X_, y_ = datasets.make_classification(
n_samples=n_samples, n_features=n_features, random_state=rng
)
X = dd.from_array(X_, chunksize=10)
y = dd.from_array(y_, chunksize=10)
valid_X_, valid_y_ = datasets.make_classification(
n_samples=n_samples, n_features=n_features, random_state=rng
)
valid_X = dd.from_array(valid_X_, chunksize=10)
valid_y = dd.from_array(valid_y_, chunksize=10)
cls = xgb.XGBClassifier(
tree_method=tree_method, n_estimators=2, use_label_encoder=False
)
cls.fit(X_, y_, eval_metric="auc", eval_set=[(valid_X_, valid_y_)])
dcls = xgb.dask.DaskXGBClassifier(
tree_method=tree_method, n_estimators=2, use_label_encoder=False
)
dcls.fit(X, y, eval_metric="auc", eval_set=[(valid_X, valid_y)])
approx = dcls.evals_result()["validation_0"]["auc"]
exact = cls.evals_result()["validation_0"]["auc"]
for i in range(2):
# approximated test.
assert np.abs(approx[i] - exact[i]) <= 0.06
def test_auc(client: "Client") -> None:
run_auc(client, "hist")
# No test for Exact, as empty DMatrix handling are mostly for distributed # No test for Exact, as empty DMatrix handling are mostly for distributed
# environment and Exact doesn't support it. # environment and Exact doesn't support it.
def test_empty_dmatrix_hist() -> None: def test_empty_dmatrix_hist() -> None:
with LocalCluster(n_workers=kWorkers) as cluster: with LocalCluster(n_workers=kWorkers) as cluster:
with Client(cluster) as client: with Client(cluster) as client: