Extract device algorithms. (#8789)

This commit is contained in:
Jiaming Yuan
2023-02-13 20:53:53 +08:00
committed by GitHub
parent 457f704e3d
commit 31d3ec07af
13 changed files with 361 additions and 218 deletions

View File

@@ -12,6 +12,7 @@
#include <utility>
#include "../collective/device_communicator.cuh"
#include "../common/algorithm.cuh" // SegmentedArgSort
#include "../common/optional_weight.h" // OptionalWeights
#include "../common/threading_utils.cuh" // UnravelTrapeziodIdx,SegmentedTrapezoidThreads
#include "auc.h"
@@ -20,6 +21,9 @@
namespace xgboost {
namespace metric {
// tag the this file, used by force static link later.
DMLC_REGISTRY_FILE_TAG(auc_gpu);
namespace {
// Pair of FP/TP
using Pair = thrust::pair<double, double>;
@@ -436,7 +440,7 @@ double GPUMultiClassAUCOVR(MetaInfo const &info, int32_t device, common::Span<ui
return ScaleClasses(d_results, local_area, tp, auc, n_classes);
}
void MultiClassSortedIdx(common::Span<float const> predts,
void MultiClassSortedIdx(Context const *ctx, common::Span<float const> predts,
common::Span<uint32_t> d_class_ptr,
std::shared_ptr<DeviceAUCCache> cache) {
size_t n_classes = d_class_ptr.size() - 1;
@@ -449,11 +453,11 @@ void MultiClassSortedIdx(common::Span<float const> predts,
dh::LaunchN(n_classes + 1,
[=] XGBOOST_DEVICE(size_t i) { d_class_ptr[i] = i * n_samples; });
auto d_sorted_idx = dh::ToSpan(cache->sorted_idx);
dh::SegmentedArgSort<false>(d_predts_t, d_class_ptr, d_sorted_idx);
common::SegmentedArgSort<false, false>(ctx, d_predts_t, d_class_ptr, d_sorted_idx);
}
double GPUMultiClassROCAUC(common::Span<float const> predts, MetaInfo const &info,
std::int32_t device, std::shared_ptr<DeviceAUCCache> *p_cache,
double GPUMultiClassROCAUC(Context const *ctx, common::Span<float const> predts,
MetaInfo const &info, std::shared_ptr<DeviceAUCCache> *p_cache,
std::size_t n_classes) {
auto& cache = *p_cache;
InitCacheOnce<true>(predts, p_cache);
@@ -462,13 +466,13 @@ double GPUMultiClassROCAUC(common::Span<float const> predts, MetaInfo const &inf
* Create sorted index for each class
*/
dh::TemporaryArray<uint32_t> class_ptr(n_classes + 1, 0);
MultiClassSortedIdx(predts, dh::ToSpan(class_ptr), cache);
MultiClassSortedIdx(ctx, predts, dh::ToSpan(class_ptr), cache);
auto fn = [] XGBOOST_DEVICE(double fp_prev, double fp, double tp_prev,
double tp, size_t /*class_id*/) {
return TrapezoidArea(fp_prev, fp, tp_prev, tp);
};
return GPUMultiClassAUCOVR<true>(info, device, dh::ToSpan(class_ptr), n_classes, cache, fn);
return GPUMultiClassAUCOVR<true>(info, ctx->gpu_id, dh::ToSpan(class_ptr), n_classes, cache, fn);
}
namespace {
@@ -480,8 +484,8 @@ struct RankScanItem {
};
} // anonymous namespace
std::pair<double, std::uint32_t> GPURankingAUC(common::Span<float const> predts,
MetaInfo const &info, std::int32_t device,
std::pair<double, std::uint32_t> GPURankingAUC(Context const *ctx, common::Span<float const> predts,
MetaInfo const &info,
std::shared_ptr<DeviceAUCCache> *p_cache) {
auto& cache = *p_cache;
InitCacheOnce<false>(predts, p_cache);
@@ -509,10 +513,10 @@ std::pair<double, std::uint32_t> GPURankingAUC(common::Span<float const> predts,
/**
* Sort the labels
*/
auto d_labels = info.labels.View(device);
auto d_labels = info.labels.View(ctx->gpu_id);
auto d_sorted_idx = dh::ToSpan(cache->sorted_idx);
dh::SegmentedArgSort<false>(d_labels.Values(), d_group_ptr, d_sorted_idx);
common::SegmentedArgSort<false, false>(ctx, d_labels.Values(), d_group_ptr, d_sorted_idx);
auto d_weights = info.weights_.ConstDeviceSpan();
@@ -640,8 +644,8 @@ std::tuple<double, double, double> GPUBinaryPRAUC(common::Span<float const> pred
return std::make_tuple(1.0, 1.0, auc);
}
double GPUMultiClassPRAUC(common::Span<float const> predts, MetaInfo const &info,
std::int32_t device, std::shared_ptr<DeviceAUCCache> *p_cache,
double GPUMultiClassPRAUC(Context const *ctx, common::Span<float const> predts,
MetaInfo const &info, std::shared_ptr<DeviceAUCCache> *p_cache,
std::size_t n_classes) {
auto& cache = *p_cache;
InitCacheOnce<true>(predts, p_cache);
@@ -651,7 +655,7 @@ double GPUMultiClassPRAUC(common::Span<float const> predts, MetaInfo const &info
*/
dh::TemporaryArray<uint32_t> class_ptr(n_classes + 1, 0);
auto d_class_ptr = dh::ToSpan(class_ptr);
MultiClassSortedIdx(predts, d_class_ptr, cache);
MultiClassSortedIdx(ctx, predts, d_class_ptr, cache);
auto d_sorted_idx = dh::ToSpan(cache->sorted_idx);
auto d_weights = info.weights_.ConstDeviceSpan();
@@ -659,7 +663,7 @@ double GPUMultiClassPRAUC(common::Span<float const> predts, MetaInfo const &info
/**
* Get total positive/negative
*/
auto labels = info.labels.View(device);
auto labels = info.labels.View(ctx->gpu_id);
auto n_samples = info.num_row_;
dh::caching_device_vector<Pair> totals(n_classes);
auto key_it =
@@ -692,7 +696,7 @@ double GPUMultiClassPRAUC(common::Span<float const> predts, MetaInfo const &info
return detail::CalcDeltaPRAUC(fp_prev, fp, tp_prev, tp,
d_totals[class_id].first);
};
return GPUMultiClassAUCOVR<false>(info, device, d_class_ptr, n_classes, cache, fn);
return GPUMultiClassAUCOVR<false>(info, ctx->gpu_id, d_class_ptr, n_classes, cache, fn);
}
template <typename Fn>
@@ -815,10 +819,11 @@ GPURankingPRAUCImpl(common::Span<float const> predts, MetaInfo const &info,
return std::make_pair(auc, n_groups - invalid_groups);
}
std::pair<double, std::uint32_t> GPURankingPRAUC(common::Span<float const> predts,
MetaInfo const &info, std::int32_t device,
std::pair<double, std::uint32_t> GPURankingPRAUC(Context const *ctx,
common::Span<float const> predts,
MetaInfo const &info,
std::shared_ptr<DeviceAUCCache> *p_cache) {
dh::safe_cuda(cudaSetDevice(device));
dh::safe_cuda(cudaSetDevice(ctx->gpu_id));
if (predts.empty()) {
return std::make_pair(0.0, static_cast<uint32_t>(0));
}
@@ -836,10 +841,10 @@ std::pair<double, std::uint32_t> GPURankingPRAUC(common::Span<float const> predt
* Create sorted index for each group
*/
auto d_sorted_idx = dh::ToSpan(cache->sorted_idx);
dh::SegmentedArgSort<false>(predts, d_group_ptr, d_sorted_idx);
common::SegmentedArgSort<false, false>(ctx, predts, d_group_ptr, d_sorted_idx);
dh::XGBDeviceAllocator<char> alloc;
auto labels = info.labels.View(device);
auto labels = info.labels.View(ctx->gpu_id);
if (thrust::any_of(thrust::cuda::par(alloc), dh::tbegin(labels.Values()),
dh::tend(labels.Values()), PRAUCLabelInvalid{})) {
InvalidLabels();
@@ -878,7 +883,7 @@ std::pair<double, std::uint32_t> GPURankingPRAUC(common::Span<float const> predt
return detail::CalcDeltaPRAUC(fp_prev, fp, tp_prev, tp,
d_totals[group_id].first);
};
return GPURankingPRAUCImpl(predts, info, d_group_ptr, device, cache, fn);
return GPURankingPRAUCImpl(predts, info, d_group_ptr, ctx->gpu_id, cache, fn);
}
} // namespace metric
} // namespace xgboost