fix auc.cu

This commit is contained in:
amdsc21 2023-03-09 20:29:38 +01:00
parent 6eba0a56ec
commit a56055225a
7 changed files with 205 additions and 20 deletions

View File

@ -4,7 +4,11 @@
#pragma once
#include <vector>
#if defined(XGBOOST_USE_HIP)
#include "../common/device_helpers.hip.h"
#elif defined(XGBOOST_USE_CUDA)
#include "../common/device_helpers.cuh"
#endif
namespace xgboost {
namespace collective {

View File

@ -10,14 +10,26 @@
#include <cstddef> // size_t
#include <cstdint> // int32_t
#if defined(XGBOOST_USE_HIP)
#include <hipcub/hipcub.hpp>
#elif defined(XGBOOST_USE_CUDA)
#include <cub/cub.cuh> // DispatchSegmentedRadixSort,NullType,DoubleBuffer
#endif
#include <iterator> // distance
#include <limits> // numeric_limits
#include <type_traits> // conditional_t,remove_const_t
#include "common.h" // safe_cuda
#include "cuda_context.cuh" // CUDAContext
#if defined(XGBOOST_USE_HIP)
#include "device_helpers.hip.h"
#elif defined(XGBOOST_USE_CUDA)
#include "device_helpers.cuh" // TemporaryArray,SegmentId,LaunchN,Iota,device_vector
#endif
#include "xgboost/base.h" // XGBOOST_DEVICE
#include "xgboost/context.h" // Context
#include "xgboost/logging.h" // CHECK
@ -39,6 +51,7 @@ static void DeviceSegmentedRadixSortKeys(CUDAContext const *ctx, void *d_temp_st
using OffsetT = int;
// Null value type
#if defined(XGBOOST_USE_CUDA)
cub::DoubleBuffer<KeyT> d_keys(const_cast<KeyT *>(d_keys_in), d_keys_out);
cub::DoubleBuffer<cub::NullType> d_values;
@ -47,6 +60,20 @@ static void DeviceSegmentedRadixSortKeys(CUDAContext const *ctx, void *d_temp_st
OffsetT>::Dispatch(d_temp_storage, temp_storage_bytes, d_keys, d_values, num_items,
num_segments, d_begin_offsets, d_end_offsets, begin_bit,
end_bit, false, ctx->Stream(), debug_synchronous)));
#elif defined(XGBOOST_USE_HIP)
if (IS_DESCENDING) {
rocprim::segmented_radix_sort_pairs_desc<KeyT, hipcub::NullType, BeginOffsetIteratorT>(d_temp_storage,
temp_storage_bytes, d_keys_in, d_keys_out, nullptr, nullptr, num_items,
num_segments, d_begin_offsets, d_end_offsets,
begin_bit, end_bit, ctx->Stream(), debug_synchronous);
}
else {
rocprim::segmented_radix_sort_pairs<KeyT, hipcub::NullType, BeginOffsetIteratorT>(d_temp_storage,
temp_storage_bytes, d_keys_in, d_keys_out, nullptr, nullptr, num_items,
num_segments, d_begin_offsets, d_end_offsets,
begin_bit, end_bit, ctx->Stream(), debug_synchronous);
}
#endif
}
// Wrapper around cub sort for easier `descending` sort.
@ -60,14 +87,18 @@ void DeviceSegmentedRadixSortPair(void *d_temp_storage,
BeginOffsetIteratorT d_begin_offsets,
EndOffsetIteratorT d_end_offsets, dh::CUDAStreamView stream,
int begin_bit = 0, int end_bit = sizeof(KeyT) * 8) {
#if defined(XGBOOST_USE_CUDA)
cub::DoubleBuffer<KeyT> d_keys(const_cast<KeyT *>(d_keys_in), d_keys_out);
cub::DoubleBuffer<ValueT> d_values(const_cast<ValueT *>(d_values_in), d_values_out);
#endif
// In old version of cub, num_items in dispatch is also int32_t, no way to change.
using OffsetT = std::conditional_t<dh::BuildWithCUDACub() && dh::HasThrustMinorVer<13>(),
std::size_t, std::int32_t>;
CHECK_LE(num_items, std::numeric_limits<OffsetT>::max());
// For Thrust >= 1.12 or CUDA >= 11.4, we require system cub installation
#if defined(XGBOOST_USE_CUDA)
#if THRUST_MAJOR_VERSION >= 2
dh::safe_cuda((cub::DispatchSegmentedRadixSort<
descending, KeyT, ValueT, BeginOffsetIteratorT, EndOffsetIteratorT,
@ -88,6 +119,18 @@ void DeviceSegmentedRadixSortPair(void *d_temp_storage,
d_begin_offsets, d_end_offsets, begin_bit,
end_bit, false, stream, false)));
#endif
#elif defined(XGBOOST_USE_HIP)
if (descending) {
rocprim::segmented_radix_sort_pairs_desc(d_temp_storage, temp_storage_bytes, d_keys_in, d_keys_out,
d_values_in, d_values_out, num_items, num_segments,
d_begin_offsets, d_end_offsets, begin_bit, end_bit, stream, false);
}
else {
rocprim::segmented_radix_sort_pairs(d_temp_storage, temp_storage_bytes, d_keys_in, d_keys_out,
d_values_in, d_values_out, num_items, num_segments, d_begin_offsets, d_end_offsets,
begin_bit, end_bit, stream, false);
}
#endif
}
} // namespace detail

View File

@ -1208,8 +1208,7 @@ void InclusiveScan(InputIteratorT d_in, OutputIteratorT d_out, ScanOpT scan_op,
#endif
#endif
safe_cuda((rocprim::inclusive_scan<InputIteratorT, OutputIteratorT, ScanOpT>(nullptr,
bytes, d_in, d_out, num_items, scan_op)));
safe_cuda((rocprim::inclusive_scan(nullptr, bytes, d_in, d_out, (size_t) num_items, scan_op)));
TemporaryArray<char> storage(bytes);
@ -1229,8 +1228,7 @@ void InclusiveScan(InputIteratorT d_in, OutputIteratorT d_out, ScanOpT scan_op,
#endif
#endif
safe_cuda((rocprim::inclusive_scan<InputIteratorT, OutputIteratorT, ScanOpT>(
storage.data().get(), bytes, d_in, d_out, num_items, scan_op)));
safe_cuda((rocprim::inclusive_scan(storage.data().get(), bytes, d_in, d_out, (size_t) num_items, scan_op)));
}
template <typename InIt, typename OutIt, typename Predicate>
@ -1262,11 +1260,7 @@ void ArgSort(xgboost::common::Span<U> keys, xgboost::common::Span<IdxT> sorted_i
using ValueT = std::remove_const_t<IdxT>;
TemporaryArray<KeyT> out(keys.size());
hipcub::DoubleBuffer<KeyT> d_keys(const_cast<KeyT *>(keys.data()),
out.data().get());
TemporaryArray<IdxT> sorted_idx_out(sorted_idx.size());
hipcub::DoubleBuffer<ValueT> d_values(const_cast<ValueT *>(sorted_idx.data()),
sorted_idx_out.data().get());
// track https://github.com/NVIDIA/cub/pull/340 for 64bit length support
using OffsetT = std::conditional_t<!BuildWithCUDACub(), std::ptrdiff_t, int32_t>;
@ -1286,8 +1280,8 @@ void ArgSort(xgboost::common::Span<U> keys, xgboost::common::Span<IdxT> sorted_i
#endif
#endif
safe_cuda((rocprim::radix_sort_pairs<KeyT, ValueT, OffsetT>(d_temp_storage,
bytes, d_keys, d_values, sorted_idx.size(), 0,
safe_cuda((rocprim::radix_sort_pairs(d_temp_storage,
bytes, keys.data(), out.data().get(), sorted_idx.data(), sorted_idx_out.data().get(), sorted_idx.size(), 0,
sizeof(KeyT) * 8)));
TemporaryArray<char> storage(bytes);
@ -1305,8 +1299,8 @@ void ArgSort(xgboost::common::Span<U> keys, xgboost::common::Span<IdxT> sorted_i
#endif
#endif
safe_cuda((rocprim::radix_sort_pairs<KeyT, ValueT, OffsetT>(d_temp_storage,
bytes, d_keys, d_values, sorted_idx.size(), 0,
safe_cuda((rocprim::radix_sort_pairs(d_temp_storage,
bytes, keys.data(), out.data().get(), sorted_idx.data(), sorted_idx_out.data().get(), sorted_idx.size(), 0,
sizeof(KeyT) * 8)));
} else {
void *d_temp_storage = nullptr;
@ -1323,8 +1317,8 @@ void ArgSort(xgboost::common::Span<U> keys, xgboost::common::Span<IdxT> sorted_i
#endif
#endif
safe_cuda((rocprim::radix_sort_pairs<KeyT, ValueT, OffsetT>(d_temp_storage,
bytes, d_keys, d_values, sorted_idx.size(), 0,
safe_cuda((rocprim::radix_sort_pairs(d_temp_storage,
bytes, keys.data(), out.data().get(), sorted_idx.data(), sorted_idx_out.data().get(), sorted_idx.size(), 0,
sizeof(KeyT) * 8)));
TemporaryArray<char> storage(bytes);
@ -1341,8 +1335,8 @@ void ArgSort(xgboost::common::Span<U> keys, xgboost::common::Span<IdxT> sorted_i
sizeof(KeyT) * 8, false, nullptr, false)));
#endif
#endif
safe_cuda((rocprim::radix_sort_pairs<KeyT, ValueT, OffsetT>(d_temp_storage,
bytes, d_keys, d_values, sorted_idx.size(), 0,
safe_cuda((rocprim::radix_sort_pairs(d_temp_storage,
bytes, keys.data(), out.data().get(), sorted_idx.data(), sorted_idx_out.data().get(), sorted_idx.size(), 0,
sizeof(KeyT) * 8)));
}

View File

@ -9,7 +9,13 @@
#include "./math.h" // Sqr
#include "common.h"
#if defined(XGBOOST_USE_HIP)
#include "device_helpers.hip.h"
#elif defined(XGBOOST_USE_CUDA)
#include "device_helpers.cuh" // LaunchN
#endif
#include "xgboost/base.h" // XGBOOST_DEVICE
#include "xgboost/span.h" // Span
@ -67,7 +73,7 @@ SegmentedTrapezoidThreads(xgboost::common::Span<U> group_ptr,
dh::safe_cuda(hipMemcpy(
&total, out_group_threads_ptr.data() + out_group_threads_ptr.size() - 1,
sizeof(total), hipMemcpyDeviceToHost));
#else
#elif defined(XGBOOST_USE_CUDA)
dh::safe_cuda(cudaMemcpy(
&total, out_group_threads_ptr.data() + out_group_threads_ptr.size() - 1,
sizeof(total), cudaMemcpyDeviceToHost));

View File

@ -393,7 +393,7 @@ XGBOOST_REGISTER_METRIC(EvalAUC, "auc")
.describe("Receiver Operating Characteristic Area Under the Curve.")
.set_body([](const char*) { return new EvalROCAUC(); });
#if !defined(XGBOOST_USE_CUDA)
#if !defined(XGBOOST_USE_CUDA) && !defined(XGBOOST_USE_HIP)
std::tuple<double, double, double> GPUBinaryROCAUC(common::Span<float const>, MetaInfo const &,
std::int32_t,
std::shared_ptr<DeviceAUCCache> *) {
@ -414,7 +414,7 @@ std::pair<double, std::uint32_t> GPURankingAUC(Context const *, common::Span<flo
return {};
}
struct DeviceAUCCache {};
#endif // !defined(XGBOOST_USE_CUDA)
#endif // !defined(XGBOOST_USE_CUDA) && !defined(XGBOOST_USE_HIP)
class EvalPRAUC : public EvalAUC<EvalPRAUC> {
std::shared_ptr<DeviceAUCCache> d_cache_;
@ -471,7 +471,7 @@ XGBOOST_REGISTER_METRIC(AUCPR, "aucpr")
.describe("Area under PR curve for both classification and rank.")
.set_body([](char const *) { return new EvalPRAUC{}; });
#if !defined(XGBOOST_USE_CUDA)
#if !defined(XGBOOST_USE_CUDA) && !defined(XGBOOST_USE_HIP)
std::tuple<double, double, double> GPUBinaryPRAUC(common::Span<float const>, MetaInfo const &,
std::int32_t, std::shared_ptr<DeviceAUCCache> *) {
common::AssertGPUSupport();

View File

@ -5,7 +5,13 @@
#include <algorithm>
#include <cassert>
#if defined(XGBOOST_USE_HIP)
#include <hipcub/hipcub.hpp> // NOLINT
#elif defined(XGBOOST_USE_CUDA)
#include <cub/cub.cuh> // NOLINT
#endif
#include <limits>
#include <memory>
#include <tuple>
@ -89,7 +95,12 @@ GPUBinaryAUC(common::Span<float const> predts, MetaInfo const &info,
Fn area_fn, std::shared_ptr<DeviceAUCCache> cache) {
auto labels = info.labels.View(device);
auto weights = info.weights_.ConstDeviceSpan();
#if defined(XGBOOST_USE_HIP)
dh::safe_cuda(hipSetDevice(device));
#elif defined(XGBOOST_USE_CUDA)
dh::safe_cuda(cudaSetDevice(device));
#endif
CHECK_NE(labels.Size(), 0);
CHECK_EQ(labels.Size(), predts.size());
@ -120,10 +131,19 @@ GPUBinaryAUC(common::Span<float const> predts, MetaInfo const &info,
auto uni_key = dh::MakeTransformIterator<float>(
thrust::make_counting_iterator(0),
[=] XGBOOST_DEVICE(size_t i) { return predts[d_sorted_idx[i]]; });
#if defined(XGBOOST_USE_HIP)
auto end_unique = thrust::unique_by_key_copy(
thrust::hip::par(alloc), uni_key, uni_key + d_sorted_idx.size(),
dh::tbegin(d_unique_idx), thrust::make_discard_iterator(),
dh::tbegin(d_unique_idx));
#elif defined(XGBOOST_USE_CUDA)
auto end_unique = thrust::unique_by_key_copy(
thrust::cuda::par(alloc), uni_key, uni_key + d_sorted_idx.size(),
dh::tbegin(d_unique_idx), thrust::make_discard_iterator(),
dh::tbegin(d_unique_idx));
#endif
d_unique_idx = d_unique_idx.subspan(0, end_unique.second - dh::tbegin(d_unique_idx));
dh::InclusiveScan(dh::tbegin(d_fptp), dh::tbegin(d_fptp),
@ -163,7 +183,13 @@ GPUBinaryAUC(common::Span<float const> predts, MetaInfo const &info,
});
Pair last = cache->fptp.back();
#if defined(XGBOOST_USE_HIP)
double auc = thrust::reduce(thrust::hip::par(alloc), in, in + d_unique_idx.size());
#elif defined(XGBOOST_USE_CUDA)
double auc = thrust::reduce(thrust::cuda::par(alloc), in, in + d_unique_idx.size());
#endif
return std::make_tuple(last.first, last.second, auc);
}
@ -218,9 +244,17 @@ double ScaleClasses(common::Span<double> results, common::Span<double> local_are
double tp_sum;
double auc_sum;
#if defined(XGBOOST_USE_HIP)
thrust::tie(auc_sum, tp_sum) =
thrust::reduce(thrust::hip::par(alloc), reduce_in, reduce_in + n_classes,
Pair{0.0, 0.0}, PairPlus<double, double>{});
#elif defined(XGBOOST_USE_CUDA)
thrust::tie(auc_sum, tp_sum) =
thrust::reduce(thrust::cuda::par(alloc), reduce_in, reduce_in + n_classes,
Pair{0.0, 0.0}, PairPlus<double, double>{});
#endif
if (tp_sum != 0 && !std::isnan(auc_sum)) {
auc_sum /= tp_sum;
} else {
@ -300,9 +334,16 @@ void SegmentedReduceAUC(common::Span<size_t const> d_unique_idx,
double auc = area_fn(fp_prev, fp, tp_prev, tp, class_id);
return auc;
});
#if defined(XGBOOST_USE_HIP)
thrust::reduce_by_key(thrust::hip::par(alloc), key_in,
key_in + d_unique_idx.size(), val_in,
thrust::make_discard_iterator(), dh::tbegin(d_auc));
#elif defined(XGBOOST_USE_CUDA)
thrust::reduce_by_key(thrust::cuda::par(alloc), key_in,
key_in + d_unique_idx.size(), val_in,
thrust::make_discard_iterator(), dh::tbegin(d_auc));
#endif
}
/**
@ -312,7 +353,12 @@ void SegmentedReduceAUC(common::Span<size_t const> d_unique_idx,
template <bool scale, typename Fn>
double GPUMultiClassAUCOVR(MetaInfo const &info, int32_t device, common::Span<uint32_t> d_class_ptr,
size_t n_classes, std::shared_ptr<DeviceAUCCache> cache, Fn area_fn) {
#if defined(XGBOOST_USE_HIP)
dh::safe_cuda(hipSetDevice(device));
#elif defined(XGBOOST_USE_CUDA)
dh::safe_cuda(cudaSetDevice(device));
#endif
/**
* Sorted idx
*/
@ -373,6 +419,19 @@ double GPUMultiClassAUCOVR(MetaInfo const &info, int32_t device, common::Span<ui
// unique values are sparse, so we need a CSR style indptr
dh::TemporaryArray<uint32_t> unique_class_ptr(d_class_ptr.size());
auto d_unique_class_ptr = dh::ToSpan(unique_class_ptr);
#if defined(XGBOOST_USE_HIP)
auto n_uniques = dh::SegmentedUniqueByKey(
thrust::hip::par(alloc),
dh::tbegin(d_class_ptr),
dh::tend(d_class_ptr),
uni_key,
uni_key + d_sorted_idx.size(),
dh::tbegin(d_unique_idx),
d_unique_class_ptr.data(),
dh::tbegin(d_unique_idx),
thrust::equal_to<thrust::pair<uint32_t, float>>{});
#elif defined(XGBOOST_USE_CUDA)
auto n_uniques = dh::SegmentedUniqueByKey(
thrust::cuda::par(alloc),
dh::tbegin(d_class_ptr),
@ -383,6 +442,8 @@ double GPUMultiClassAUCOVR(MetaInfo const &info, int32_t device, common::Span<ui
d_unique_class_ptr.data(),
dh::tbegin(d_unique_idx),
thrust::equal_to<thrust::pair<uint32_t, float>>{});
#endif
d_unique_idx = d_unique_idx.subspan(0, n_uniques);
auto get_class_id = [=] XGBOOST_DEVICE(size_t idx) { return idx / n_samples; };
@ -500,9 +561,17 @@ std::pair<double, std::uint32_t> GPURankingAUC(Context const *ctx, common::Span<
auto check_it = dh::MakeTransformIterator<size_t>(
thrust::make_counting_iterator(0),
[=] XGBOOST_DEVICE(size_t i) { return d_group_ptr[i + 1] - d_group_ptr[i]; });
#if defined(XGBOOST_USE_HIP)
size_t n_valid = thrust::count_if(
thrust::hip::par(alloc), check_it, check_it + group_ptr.size() - 1,
[=] XGBOOST_DEVICE(size_t len) { return len >= 3; });
#elif defined(XGBOOST_USE_CUDA)
size_t n_valid = thrust::count_if(
thrust::cuda::par(alloc), check_it, check_it + group_ptr.size() - 1,
[=] XGBOOST_DEVICE(size_t len) { return len >= 3; });
#endif
if (n_valid < info.group_ptr_.size() - 1) {
InvalidGroupAUC();
}
@ -599,8 +668,14 @@ std::pair<double, std::uint32_t> GPURankingAUC(Context const *ctx, common::Span<
/**
* Scale the AUC with number of items in each group.
*/
#if defined(XGBOOST_USE_HIP)
double auc = thrust::reduce(thrust::hip::par(alloc), dh::tbegin(s_d_auc),
dh::tend(s_d_auc), 0.0);
#elif defined(XGBOOST_USE_CUDA)
double auc = thrust::reduce(thrust::cuda::par(alloc), dh::tbegin(s_d_auc),
dh::tend(s_d_auc), 0.0);
#endif
return std::make_pair(auc, n_valid);
}
@ -627,9 +702,16 @@ std::tuple<double, double, double> GPUBinaryPRAUC(common::Span<float const> pred
});
dh::XGBCachingDeviceAllocator<char> alloc;
double total_pos, total_neg;
#if defined(XGBOOST_USE_HIP)
thrust::tie(total_pos, total_neg) =
thrust::reduce(thrust::hip::par(alloc), it, it + labels.Size(),
Pair{0.0, 0.0}, PairPlus<double, double>{});
#elif defined(XGBOOST_USE_CUDA)
thrust::tie(total_pos, total_neg) =
thrust::reduce(thrust::cuda::par(alloc), it, it + labels.Size(),
Pair{0.0, 0.0}, PairPlus<double, double>{});
#endif
if (total_pos <= 0.0 || total_neg <= 0.0) {
return {0.0f, 0.0f, 0.0f};
@ -681,10 +763,18 @@ double GPUMultiClassPRAUC(Context const *ctx, common::Span<float const> predts,
return thrust::make_pair(y * w, (1.0f - y) * w);
});
dh::XGBCachingDeviceAllocator<char> alloc;
#if defined(XGBOOST_USE_HIP)
thrust::reduce_by_key(thrust::hip::par(alloc), key_it,
key_it + predts.size(), val_it,
thrust::make_discard_iterator(), totals.begin(),
thrust::equal_to<size_t>{}, PairPlus<double, double>{});
#elif defined(XGBOOST_USE_CUDA)
thrust::reduce_by_key(thrust::cuda::par(alloc), key_it,
key_it + predts.size(), val_it,
thrust::make_discard_iterator(), totals.begin(),
thrust::equal_to<size_t>{}, PairPlus<double, double>{});
#endif
/**
* Calculate AUC
@ -752,6 +842,19 @@ GPURankingPRAUCImpl(common::Span<float const> predts, MetaInfo const &info,
// unique values are sparse, so we need a CSR style indptr
dh::TemporaryArray<uint32_t> unique_class_ptr(d_group_ptr.size());
auto d_unique_class_ptr = dh::ToSpan(unique_class_ptr);
#if defined(XGBOOST_USE_HIP)
auto n_uniques = dh::SegmentedUniqueByKey(
thrust::hip::par(alloc),
dh::tbegin(d_group_ptr),
dh::tend(d_group_ptr),
uni_key,
uni_key + d_sorted_idx.size(),
dh::tbegin(d_unique_idx),
d_unique_class_ptr.data(),
dh::tbegin(d_unique_idx),
thrust::equal_to<thrust::pair<uint32_t, float>>{});
#elif defined(XGBOOST_USE_CUDA)
auto n_uniques = dh::SegmentedUniqueByKey(
thrust::cuda::par(alloc),
dh::tbegin(d_group_ptr),
@ -762,6 +865,8 @@ GPURankingPRAUCImpl(common::Span<float const> predts, MetaInfo const &info,
d_unique_class_ptr.data(),
dh::tbegin(d_unique_idx),
thrust::equal_to<thrust::pair<uint32_t, float>>{});
#endif
d_unique_idx = d_unique_idx.subspan(0, n_uniques);
auto get_group_id = [=] XGBOOST_DEVICE(size_t idx) {
@ -812,9 +917,16 @@ GPURankingPRAUCImpl(common::Span<float const> predts, MetaInfo const &info,
}
return thrust::make_pair(0.0, static_cast<uint32_t>(1));
});
#if defined(XGBOOST_USE_HIP)
thrust::tie(auc, invalid_groups) = thrust::reduce(
thrust::hip::par(alloc), it, it + n_groups,
thrust::pair<double, uint32_t>(0.0, 0), PairPlus<double, uint32_t>{});
#elif defined(XGBOOST_USE_CUDA)
thrust::tie(auc, invalid_groups) = thrust::reduce(
thrust::cuda::par(alloc), it, it + n_groups,
thrust::pair<double, uint32_t>(0.0, 0), PairPlus<double, uint32_t>{});
#endif
}
return std::make_pair(auc, n_groups - invalid_groups);
}
@ -823,7 +935,12 @@ std::pair<double, std::uint32_t> GPURankingPRAUC(Context const *ctx,
common::Span<float const> predts,
MetaInfo const &info,
std::shared_ptr<DeviceAUCCache> *p_cache) {
#if defined(XGBOOST_USE_HIP)
dh::safe_cuda(hipSetDevice(ctx->gpu_id));
#elif defined(XGBOOST_USE_CUDA)
dh::safe_cuda(cudaSetDevice(ctx->gpu_id));
#endif
if (predts.empty()) {
return std::make_pair(0.0, static_cast<uint32_t>(0));
}
@ -845,10 +962,19 @@ std::pair<double, std::uint32_t> GPURankingPRAUC(Context const *ctx,
dh::XGBDeviceAllocator<char> alloc;
auto labels = info.labels.View(ctx->gpu_id);
#if defined(XGBOOST_USE_HIP)
if (thrust::any_of(thrust::hip::par(alloc), dh::tbegin(labels.Values()),
dh::tend(labels.Values()), PRAUCLabelInvalid{})) {
InvalidLabels();
}
#elif defined(XGBOOST_USE_CUDA)
if (thrust::any_of(thrust::cuda::par(alloc), dh::tbegin(labels.Values()),
dh::tend(labels.Values()), PRAUCLabelInvalid{})) {
InvalidLabels();
}
#endif
/**
* Get total positive/negative for each group.
*/
@ -868,10 +994,18 @@ std::pair<double, std::uint32_t> GPURankingPRAUC(Context const *ctx,
auto y = labels(i);
return thrust::make_pair(y * w, (1.0 - y) * w);
});
#if defined(XGBOOST_USE_HIP)
thrust::reduce_by_key(thrust::hip::par(alloc), key_it,
key_it + predts.size(), val_it,
thrust::make_discard_iterator(), totals.begin(),
thrust::equal_to<size_t>{}, PairPlus<double, double>{});
#elif defined(XGBOOST_USE_CUDA)
thrust::reduce_by_key(thrust::cuda::par(alloc), key_it,
key_it + predts.size(), val_it,
thrust::make_discard_iterator(), totals.begin(),
thrust::equal_to<size_t>{}, PairPlus<double, double>{});
#endif
/**
* Calculate AUC

View File

@ -0,0 +1,4 @@
#if defined(XGBOOST_USE_HIP)
#include "auc.cu"
#endif