diff --git a/include/xgboost/span.h b/include/xgboost/span.h index b8b35e644..e2dff409d 100644 --- a/include/xgboost/span.h +++ b/include/xgboost/span.h @@ -423,10 +423,10 @@ class Span { using pointer = T*; // NOLINT using reference = T&; // NOLINT - using iterator = detail::SpanIterator, false>; // NOLINT - using const_iterator = const detail::SpanIterator, true>; // NOLINT - using reverse_iterator = detail::SpanIterator, false>; // NOLINT - using const_reverse_iterator = const detail::SpanIterator, true>; // NOLINT + using iterator = detail::SpanIterator, false>; // NOLINT + using const_iterator = const detail::SpanIterator, true>; // NOLINT + using reverse_iterator = std::reverse_iterator; // NOLINT + using const_reverse_iterator = const std::reverse_iterator; // NOLINT // constructors constexpr Span() __span_noexcept = default; @@ -504,11 +504,11 @@ class Span { return {this, size()}; } - XGBOOST_DEVICE constexpr reverse_iterator rbegin() const __span_noexcept { // NOLINT + constexpr reverse_iterator rbegin() const __span_noexcept { // NOLINT return reverse_iterator{end()}; } - XGBOOST_DEVICE constexpr reverse_iterator rend() const __span_noexcept { // NOLINT + constexpr reverse_iterator rend() const __span_noexcept { // NOLINT return reverse_iterator{begin()}; } diff --git a/src/common/device_helpers.cuh b/src/common/device_helpers.cuh index 61e0fd553..08e3f1f3e 100644 --- a/src/common/device_helpers.cuh +++ b/src/common/device_helpers.cuh @@ -957,6 +957,16 @@ thrust::device_ptr tend(xgboost::common::Span& span) { // NOLINT return tbegin(span) + span.size(); } +template +XGBOOST_DEVICE auto trbegin(xgboost::common::Span &span) { // NOLINT + return thrust::make_reverse_iterator(span.data() + span.size()); +} + +template +XGBOOST_DEVICE auto trend(xgboost::common::Span &span) { // NOLINT + return trbegin(span) + span.size(); +} + template thrust::device_ptr tcbegin(xgboost::common::Span const& span) { // NOLINT return thrust::device_ptr(span.data()); @@ -967,6 +977,16 @@ thrust::device_ptr tcend(xgboost::common::Span const& span) { // NO return tcbegin(span) + span.size(); } +template +XGBOOST_DEVICE auto tcrbegin(xgboost::common::Span const &span) { // NOLINT + return thrust::make_reverse_iterator(span.data() + span.size()); +} + +template +XGBOOST_DEVICE auto tcrend(xgboost::common::Span const &span) { // NOLINT + return tcrbegin(span) + span.size(); +} + // This type sorts an array which is divided into multiple groups. The sorting is influenced // by the function object 'Comparator' template diff --git a/tests/cpp/common/test_span.h b/tests/cpp/common/test_span.h index 96f4efa5b..773a09e28 100644 --- a/tests/cpp/common/test_span.h +++ b/tests/cpp/common/test_span.h @@ -98,12 +98,18 @@ struct TestRBeginREnd { InitializeRange(arr, arr + 16); Span s (arr); - Span::iterator rbeg { s.rbegin() }; - Span::iterator rend { s.rend() }; - SPAN_ASSERT_TRUE(rbeg == rend + 16, status_); - SPAN_ASSERT_TRUE(*(rbeg - 1) == arr[15], status_); - SPAN_ASSERT_TRUE(*rend == arr[0], status_); +#if defined(__CUDA_ARCH__) + auto rbeg = dh::trbegin(s); + auto rend = dh::trend(s); +#else + Span::reverse_iterator rbeg{s.rbegin()}; + Span::reverse_iterator rend{s.rend()}; +#endif + + SPAN_ASSERT_TRUE(rbeg + 16 == rend, status_); + SPAN_ASSERT_TRUE(*(rbeg) == arr[15], status_); + SPAN_ASSERT_TRUE(*(rend - 1) == arr[0], status_); } };