Implement column sampler in CUDA. (#9785)

- CUDA implementation.
- Extract the broadcasting logic, we will need the context parameter after revamping the collective implementation.
- Some changes to the event loop for fixing a deadlock in CI.
- Move argsort into algorithms.cuh, add support for cuda stream.
This commit is contained in:
Jiaming Yuan
2023-11-17 04:29:08 +08:00
committed by GitHub
parent 178cfe70a8
commit fedd9674c8
20 changed files with 447 additions and 232 deletions

View File

@@ -83,13 +83,14 @@ void InitCacheOnce(common::Span<float const> predts, std::shared_ptr<DeviceAUCCa
* - Reduce the scan array into 1 AUC value.
*/
template <typename Fn>
std::tuple<double, double, double>
GPUBinaryAUC(common::Span<float const> predts, MetaInfo const &info,
DeviceOrd device, common::Span<size_t const> d_sorted_idx,
Fn area_fn, std::shared_ptr<DeviceAUCCache> cache) {
auto labels = info.labels.View(device);
std::tuple<double, double, double> GPUBinaryAUC(Context const *ctx,
common::Span<float const> predts,
MetaInfo const &info,
common::Span<size_t const> d_sorted_idx, Fn area_fn,
std::shared_ptr<DeviceAUCCache> cache) {
auto labels = info.labels.View(ctx->Device());
auto weights = info.weights_.ConstDeviceSpan();
dh::safe_cuda(cudaSetDevice(device.ordinal));
dh::safe_cuda(cudaSetDevice(ctx->Ordinal()));
CHECK_NE(labels.Size(), 0);
CHECK_EQ(labels.Size(), predts.size());
@@ -115,7 +116,7 @@ GPUBinaryAUC(common::Span<float const> predts, MetaInfo const &info,
dh::XGBDeviceAllocator<char> alloc;
auto d_unique_idx = dh::ToSpan(cache->unique_idx);
dh::Iota(d_unique_idx);
dh::Iota(d_unique_idx, ctx->CUDACtx()->Stream());
auto uni_key = dh::MakeTransformIterator<float>(
thrust::make_counting_iterator(0),
@@ -167,8 +168,9 @@ GPUBinaryAUC(common::Span<float const> predts, MetaInfo const &info,
return std::make_tuple(last.first, last.second, auc);
}
std::tuple<double, double, double> GPUBinaryROCAUC(common::Span<float const> predts,
MetaInfo const &info, DeviceOrd device,
std::tuple<double, double, double> GPUBinaryROCAUC(Context const *ctx,
common::Span<float const> predts,
MetaInfo const &info,
std::shared_ptr<DeviceAUCCache> *p_cache) {
auto &cache = *p_cache;
InitCacheOnce<false>(predts, p_cache);
@@ -177,10 +179,10 @@ std::tuple<double, double, double> GPUBinaryROCAUC(common::Span<float const> pre
* Create sorted index for each class
*/
auto d_sorted_idx = dh::ToSpan(cache->sorted_idx);
dh::ArgSort<false>(predts, d_sorted_idx);
common::ArgSort<false>(ctx, predts, d_sorted_idx);
// Create lambda to avoid pass function pointer.
return GPUBinaryAUC(
predts, info, device, d_sorted_idx,
ctx, predts, info, d_sorted_idx,
[] XGBOOST_DEVICE(double x0, double x1, double y0, double y1) -> double {
return TrapezoidArea(x0, x1, y0, y1);
},
@@ -361,7 +363,7 @@ double GPUMultiClassAUCOVR(Context const *ctx, MetaInfo const &info,
*/
dh::XGBDeviceAllocator<char> alloc;
auto d_unique_idx = dh::ToSpan(cache->unique_idx);
dh::Iota(d_unique_idx);
dh::Iota(d_unique_idx, ctx->CUDACtx()->Stream());
auto uni_key = dh::MakeTransformIterator<thrust::pair<uint32_t, float>>(
thrust::make_counting_iterator(0), [=] XGBOOST_DEVICE(size_t i) {
uint32_t class_id = i / n_samples;
@@ -603,8 +605,9 @@ std::pair<double, std::uint32_t> GPURankingAUC(Context const *ctx, common::Span<
return std::make_pair(auc, n_valid);
}
std::tuple<double, double, double> GPUBinaryPRAUC(common::Span<float const> predts,
MetaInfo const &info, DeviceOrd device,
std::tuple<double, double, double> GPUBinaryPRAUC(Context const *ctx,
common::Span<float const> predts,
MetaInfo const &info,
std::shared_ptr<DeviceAUCCache> *p_cache) {
auto& cache = *p_cache;
InitCacheOnce<false>(predts, p_cache);
@@ -613,9 +616,9 @@ std::tuple<double, double, double> GPUBinaryPRAUC(common::Span<float const> pred
* Create sorted index for each class
*/
auto d_sorted_idx = dh::ToSpan(cache->sorted_idx);
dh::ArgSort<false>(predts, d_sorted_idx);
common::ArgSort<false>(ctx, predts, d_sorted_idx);
auto labels = info.labels.View(device);
auto labels = info.labels.View(ctx->Device());
auto d_weights = info.weights_.ConstDeviceSpan();
auto get_weight = common::OptionalWeights{d_weights};
auto it = dh::MakeTransformIterator<Pair>(
@@ -639,7 +642,7 @@ std::tuple<double, double, double> GPUBinaryPRAUC(common::Span<float const> pred
return detail::CalcDeltaPRAUC(fp_prev, fp, tp_prev, tp, total_pos);
};
double fp, tp, auc;
std::tie(fp, tp, auc) = GPUBinaryAUC(predts, info, device, d_sorted_idx, fn, cache);
std::tie(fp, tp, auc) = GPUBinaryAUC(ctx, predts, info, d_sorted_idx, fn, cache);
return std::make_tuple(1.0, 1.0, auc);
}
@@ -699,16 +702,17 @@ double GPUMultiClassPRAUC(Context const *ctx, common::Span<float const> predts,
}
template <typename Fn>
std::pair<double, uint32_t>
GPURankingPRAUCImpl(common::Span<float const> predts, MetaInfo const &info,
common::Span<uint32_t> d_group_ptr, DeviceOrd device,
std::shared_ptr<DeviceAUCCache> cache, Fn area_fn) {
std::pair<double, uint32_t> GPURankingPRAUCImpl(Context const *ctx,
common::Span<float const> predts,
MetaInfo const &info,
common::Span<uint32_t> d_group_ptr,
std::shared_ptr<DeviceAUCCache> cache, Fn area_fn) {
/**
* Sorted idx
*/
auto d_sorted_idx = dh::ToSpan(cache->sorted_idx);
auto labels = info.labels.View(device);
auto labels = info.labels.View(ctx->Device());
auto weights = info.weights_.ConstDeviceSpan();
uint32_t n_groups = static_cast<uint32_t>(info.group_ptr_.size() - 1);
@@ -739,7 +743,7 @@ GPURankingPRAUCImpl(common::Span<float const> predts, MetaInfo const &info,
*/
dh::XGBDeviceAllocator<char> alloc;
auto d_unique_idx = dh::ToSpan(cache->unique_idx);
dh::Iota(d_unique_idx);
dh::Iota(d_unique_idx, ctx->CUDACtx()->Stream());
auto uni_key = dh::MakeTransformIterator<thrust::pair<uint32_t, float>>(
thrust::make_counting_iterator(0), [=] XGBOOST_DEVICE(size_t i) {
auto idx = d_sorted_idx[i];
@@ -882,7 +886,7 @@ std::pair<double, std::uint32_t> GPURankingPRAUC(Context const *ctx,
return detail::CalcDeltaPRAUC(fp_prev, fp, tp_prev, tp,
d_totals[group_id].first);
};
return GPURankingPRAUCImpl(predts, info, d_group_ptr, ctx->Device(), cache, fn);
return GPURankingPRAUCImpl(ctx, predts, info, d_group_ptr, cache, fn);
}
} // namespace metric
} // namespace xgboost