Fix span reverse iterator. (#7387)

* Fix span reverse iterator.

* Disable `rbegin` on device code to avoid calling host function.
* Add `trbegin` and friends.
This commit is contained in:
Jiaming Yuan 2021-11-02 13:35:59 +08:00 committed by GitHub
parent 8211e5f341
commit 6295dc3b67
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 37 additions and 11 deletions

View File

@ -423,10 +423,10 @@ class Span {
using pointer = T*; // NOLINT
using reference = T&; // NOLINT
using iterator = detail::SpanIterator<Span<T, Extent>, false>; // NOLINT
using const_iterator = const detail::SpanIterator<Span<T, Extent>, true>; // NOLINT
using reverse_iterator = detail::SpanIterator<Span<T, Extent>, false>; // NOLINT
using const_reverse_iterator = const detail::SpanIterator<Span<T, Extent>, true>; // NOLINT
using iterator = detail::SpanIterator<Span<T, Extent>, false>; // NOLINT
using const_iterator = const detail::SpanIterator<Span<T, Extent>, true>; // NOLINT
using reverse_iterator = std::reverse_iterator<iterator>; // NOLINT
using const_reverse_iterator = const std::reverse_iterator<const_iterator>; // NOLINT
// constructors
constexpr Span() __span_noexcept = default;
@ -504,11 +504,11 @@ class Span {
return {this, size()};
}
XGBOOST_DEVICE constexpr reverse_iterator rbegin() const __span_noexcept { // NOLINT
constexpr reverse_iterator rbegin() const __span_noexcept { // NOLINT
return reverse_iterator{end()};
}
XGBOOST_DEVICE constexpr reverse_iterator rend() const __span_noexcept { // NOLINT
constexpr reverse_iterator rend() const __span_noexcept { // NOLINT
return reverse_iterator{begin()};
}

View File

@ -957,6 +957,16 @@ thrust::device_ptr<T> tend(xgboost::common::Span<T>& span) { // NOLINT
return tbegin(span) + span.size();
}
template <typename T>
XGBOOST_DEVICE auto trbegin(xgboost::common::Span<T> &span) { // NOLINT
return thrust::make_reverse_iterator(span.data() + span.size());
}
template <typename T>
XGBOOST_DEVICE auto trend(xgboost::common::Span<T> &span) { // NOLINT
return trbegin(span) + span.size();
}
template <typename T>
thrust::device_ptr<T const> tcbegin(xgboost::common::Span<T> const& span) { // NOLINT
return thrust::device_ptr<T const>(span.data());
@ -967,6 +977,16 @@ thrust::device_ptr<T const> tcend(xgboost::common::Span<T> const& span) { // NO
return tcbegin(span) + span.size();
}
template <typename T>
XGBOOST_DEVICE auto tcrbegin(xgboost::common::Span<T> const &span) { // NOLINT
return thrust::make_reverse_iterator(span.data() + span.size());
}
template <typename T>
XGBOOST_DEVICE auto tcrend(xgboost::common::Span<T> const &span) { // NOLINT
return tcrbegin(span) + span.size();
}
// This type sorts an array which is divided into multiple groups. The sorting is influenced
// by the function object 'Comparator'
template <typename T>

View File

@ -98,12 +98,18 @@ struct TestRBeginREnd {
InitializeRange(arr, arr + 16);
Span<float> s (arr);
Span<float>::iterator rbeg { s.rbegin() };
Span<float>::iterator rend { s.rend() };
SPAN_ASSERT_TRUE(rbeg == rend + 16, status_);
SPAN_ASSERT_TRUE(*(rbeg - 1) == arr[15], status_);
SPAN_ASSERT_TRUE(*rend == arr[0], status_);
#if defined(__CUDA_ARCH__)
auto rbeg = dh::trbegin(s);
auto rend = dh::trend(s);
#else
Span<float>::reverse_iterator rbeg{s.rbegin()};
Span<float>::reverse_iterator rend{s.rend()};
#endif
SPAN_ASSERT_TRUE(rbeg + 16 == rend, status_);
SPAN_ASSERT_TRUE(*(rbeg) == arr[15], status_);
SPAN_ASSERT_TRUE(*(rend - 1) == arr[0], status_);
}
};