/** * Copyright 2022-2023 by XGBoost Contributors */ #ifndef XGBOOST_COMMON_ALGORITHM_CUH_ #define XGBOOST_COMMON_ALGORITHM_CUH_ #include // copy #include // stable_sort_by_key #include // tuple,get #include // size_t #include // int32_t #if defined(XGBOOST_USE_HIP) #include #elif defined(XGBOOST_USE_CUDA) #include // DispatchSegmentedRadixSort,NullType,DoubleBuffer #endif #include // distance #include // numeric_limits #include // conditional_t,remove_const_t #include "common.h" // safe_cuda #include "cuda_context.cuh" // CUDAContext #include "device_helpers.cuh" // TemporaryArray,SegmentId,LaunchN,Iota,device_vector #include "xgboost/base.h" // XGBOOST_DEVICE #include "xgboost/context.h" // Context #include "xgboost/logging.h" // CHECK #include "xgboost/span.h" // Span,byte namespace xgboost { namespace common { namespace detail { // Wrapper around cub sort to define is_decending template static void DeviceSegmentedRadixSortKeys(CUDAContext const *ctx, void *d_temp_storage, std::size_t &temp_storage_bytes, // NOLINT const KeyT *d_keys_in, KeyT *d_keys_out, int num_items, int num_segments, BeginOffsetIteratorT d_begin_offsets, EndOffsetIteratorT d_end_offsets, int begin_bit = 0, int end_bit = sizeof(KeyT) * 8, bool debug_synchronous = false) { using OffsetT = int; // Null value type #if defined(XGBOOST_USE_CUDA) cub::DoubleBuffer d_keys(const_cast(d_keys_in), d_keys_out); cub::DoubleBuffer d_values; dh::safe_cuda((cub::DispatchSegmentedRadixSort< IS_DESCENDING, KeyT, cub::NullType, BeginOffsetIteratorT, EndOffsetIteratorT, OffsetT>::Dispatch(d_temp_storage, temp_storage_bytes, d_keys, d_values, num_items, num_segments, d_begin_offsets, d_end_offsets, begin_bit, end_bit, false, ctx->Stream(), debug_synchronous))); #elif defined(XGBOOST_USE_HIP) if (IS_DESCENDING) { rocprim::segmented_radix_sort_pairs_desc(d_temp_storage, temp_storage_bytes, d_keys_in, d_keys_out, nullptr, nullptr, num_items, num_segments, d_begin_offsets, d_end_offsets, begin_bit, end_bit, ctx->Stream(), debug_synchronous); } else { rocprim::segmented_radix_sort_pairs(d_temp_storage, temp_storage_bytes, d_keys_in, d_keys_out, nullptr, nullptr, num_items, num_segments, d_begin_offsets, d_end_offsets, begin_bit, end_bit, ctx->Stream(), debug_synchronous); } #endif } // Wrapper around cub sort for easier `descending` sort. template void DeviceSegmentedRadixSortPair(void *d_temp_storage, std::size_t &temp_storage_bytes, // NOLINT const KeyT *d_keys_in, KeyT *d_keys_out, const ValueT *d_values_in, ValueT *d_values_out, std::size_t num_items, std::size_t num_segments, BeginOffsetIteratorT d_begin_offsets, EndOffsetIteratorT d_end_offsets, dh::CUDAStreamView stream, int begin_bit = 0, int end_bit = sizeof(KeyT) * 8) { #if defined(XGBOOST_USE_CUDA) cub::DoubleBuffer d_keys(const_cast(d_keys_in), d_keys_out); cub::DoubleBuffer d_values(const_cast(d_values_in), d_values_out); #endif // In old version of cub, num_items in dispatch is also int32_t, no way to change. using OffsetT = std::conditional_t(), std::size_t, std::int32_t>; CHECK_LE(num_items, std::numeric_limits::max()); // For Thrust >= 1.12 or CUDA >= 11.4, we require system cub installation #if defined(XGBOOST_USE_CUDA) #if THRUST_MAJOR_VERSION >= 2 dh::safe_cuda((cub::DispatchSegmentedRadixSort< descending, KeyT, ValueT, BeginOffsetIteratorT, EndOffsetIteratorT, OffsetT>::Dispatch(d_temp_storage, temp_storage_bytes, d_keys, d_values, num_items, num_segments, d_begin_offsets, d_end_offsets, begin_bit, end_bit, false, stream))); #elif (THRUST_MAJOR_VERSION == 1 && THRUST_MINOR_VERSION >= 13) dh::safe_cuda((cub::DispatchSegmentedRadixSort< descending, KeyT, ValueT, BeginOffsetIteratorT, EndOffsetIteratorT, OffsetT>::Dispatch(d_temp_storage, temp_storage_bytes, d_keys, d_values, num_items, num_segments, d_begin_offsets, d_end_offsets, begin_bit, end_bit, false, stream, false))); #else dh::safe_cuda( (cub::DispatchSegmentedRadixSort::Dispatch(d_temp_storage, temp_storage_bytes, d_keys, d_values, num_items, num_segments, d_begin_offsets, d_end_offsets, begin_bit, end_bit, false, stream, false))); #endif #elif defined(XGBOOST_USE_HIP) if (descending) { rocprim::segmented_radix_sort_pairs_desc(d_temp_storage, temp_storage_bytes, d_keys_in, d_keys_out, d_values_in, d_values_out, num_items, num_segments, d_begin_offsets, d_end_offsets, begin_bit, end_bit, stream, false); } else { rocprim::segmented_radix_sort_pairs(d_temp_storage, temp_storage_bytes, d_keys_in, d_keys_out, d_values_in, d_values_out, num_items, num_segments, d_begin_offsets, d_end_offsets, begin_bit, end_bit, stream, false); } #endif } } // namespace detail template void SegmentedSequence(Context const *ctx, Span d_offset_ptr, Span out_sequence) { dh::LaunchN(out_sequence.size(), ctx->CUDACtx()->Stream(), [out_sequence, d_offset_ptr] __device__(std::size_t idx) { auto group = dh::SegmentId(d_offset_ptr, idx); out_sequence[idx] = idx - d_offset_ptr[group]; }); } template inline void SegmentedSortKeys(Context const *ctx, Span group_ptr, Span out_sorted_values) { CHECK_GE(group_ptr.size(), 1ul); std::size_t n_groups = group_ptr.size() - 1; std::size_t bytes = 0; auto const *cuctx = ctx->CUDACtx(); CHECK(cuctx); detail::DeviceSegmentedRadixSortKeys( cuctx, nullptr, bytes, out_sorted_values.data(), out_sorted_values.data(), out_sorted_values.size(), n_groups, group_ptr.data(), group_ptr.data() + 1); dh::TemporaryArray temp_storage(bytes); detail::DeviceSegmentedRadixSortKeys( cuctx, temp_storage.data().get(), bytes, out_sorted_values.data(), out_sorted_values.data(), out_sorted_values.size(), n_groups, group_ptr.data(), group_ptr.data() + 1); } /** * \brief Create sorted index for data with multiple segments. * * \tparam accending sorted in non-decreasing order. * \tparam per_seg_index Index starts from 0 for each segment if true, otherwise the * the index span the whole data. */ template void SegmentedArgSort(Context const *ctx, Span values, Span group_ptr, Span sorted_idx) { CHECK_GE(group_ptr.size(), 1ul); std::size_t n_groups = group_ptr.size() - 1; std::size_t bytes = 0; if (per_seg_index) { SegmentedSequence(ctx, group_ptr, sorted_idx); } else { dh::Iota(sorted_idx); } dh::TemporaryArray> values_out(values.size()); dh::TemporaryArray> sorted_idx_out(sorted_idx.size()); detail::DeviceSegmentedRadixSortPair( nullptr, bytes, values.data(), values_out.data().get(), sorted_idx.data(), sorted_idx_out.data().get(), sorted_idx.size(), n_groups, group_ptr.data(), group_ptr.data() + 1, ctx->CUDACtx()->Stream()); dh::TemporaryArray temp_storage(bytes); detail::DeviceSegmentedRadixSortPair( temp_storage.data().get(), bytes, values.data(), values_out.data().get(), sorted_idx.data(), sorted_idx_out.data().get(), sorted_idx.size(), n_groups, group_ptr.data(), group_ptr.data() + 1, ctx->CUDACtx()->Stream()); dh::safe_cuda(cudaMemcpyAsync(sorted_idx.data(), sorted_idx_out.data().get(), sorted_idx.size_bytes(), cudaMemcpyDeviceToDevice)); } /** * \brief Different from the radix-sort-based argsort, this one can handle cases where * segment doesn't start from 0, but as a result it uses comparison sort. */ template void SegmentedArgMergeSort(Context const *ctx, SegIt seg_begin, SegIt seg_end, ValIt val_begin, ValIt val_end, dh::device_vector *p_sorted_idx) { using Tup = thrust::tuple; auto &sorted_idx = *p_sorted_idx; std::size_t n = std::distance(val_begin, val_end); sorted_idx.resize(n); dh::Iota(dh::ToSpan(sorted_idx)); dh::device_vector keys(sorted_idx.size()); auto key_it = dh::MakeTransformIterator(thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(std::size_t i) -> Tup { std::int32_t seg_idx; if (i < *seg_begin) { seg_idx = -1; } else { seg_idx = dh::SegmentId(seg_begin, seg_end, i); } auto residue = val_begin[i]; return thrust::make_tuple(seg_idx, residue); }); thrust::copy(ctx->CUDACtx()->CTP(), key_it, key_it + keys.size(), keys.begin()); thrust::stable_sort_by_key(ctx->CUDACtx()->TP(), keys.begin(), keys.end(), sorted_idx.begin(), [=] XGBOOST_DEVICE(Tup const &l, Tup const &r) { if (thrust::get<0>(l) != thrust::get<0>(r)) { return thrust::get<0>(l) < thrust::get<0>(r); // segment index } return thrust::get<1>(l) < thrust::get<1>(r); // residue }); } } // namespace common } // namespace xgboost #endif // XGBOOST_COMMON_ALGORITHM_CUH_