Fix span reverse iterator. (#7387)
* Fix span reverse iterator. * Disable `rbegin` on device code to avoid calling host function. * Add `trbegin` and friends.
This commit is contained in:
parent
8211e5f341
commit
6295dc3b67
@ -425,8 +425,8 @@ class Span {
|
|||||||
|
|
||||||
using iterator = detail::SpanIterator<Span<T, Extent>, false>; // NOLINT
|
using iterator = detail::SpanIterator<Span<T, Extent>, false>; // NOLINT
|
||||||
using const_iterator = const detail::SpanIterator<Span<T, Extent>, true>; // NOLINT
|
using const_iterator = const detail::SpanIterator<Span<T, Extent>, true>; // NOLINT
|
||||||
using reverse_iterator = detail::SpanIterator<Span<T, Extent>, false>; // NOLINT
|
using reverse_iterator = std::reverse_iterator<iterator>; // NOLINT
|
||||||
using const_reverse_iterator = const detail::SpanIterator<Span<T, Extent>, true>; // NOLINT
|
using const_reverse_iterator = const std::reverse_iterator<const_iterator>; // NOLINT
|
||||||
|
|
||||||
// constructors
|
// constructors
|
||||||
constexpr Span() __span_noexcept = default;
|
constexpr Span() __span_noexcept = default;
|
||||||
@ -504,11 +504,11 @@ class Span {
|
|||||||
return {this, size()};
|
return {this, size()};
|
||||||
}
|
}
|
||||||
|
|
||||||
XGBOOST_DEVICE constexpr reverse_iterator rbegin() const __span_noexcept { // NOLINT
|
constexpr reverse_iterator rbegin() const __span_noexcept { // NOLINT
|
||||||
return reverse_iterator{end()};
|
return reverse_iterator{end()};
|
||||||
}
|
}
|
||||||
|
|
||||||
XGBOOST_DEVICE constexpr reverse_iterator rend() const __span_noexcept { // NOLINT
|
constexpr reverse_iterator rend() const __span_noexcept { // NOLINT
|
||||||
return reverse_iterator{begin()};
|
return reverse_iterator{begin()};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -957,6 +957,16 @@ thrust::device_ptr<T> tend(xgboost::common::Span<T>& span) { // NOLINT
|
|||||||
return tbegin(span) + span.size();
|
return tbegin(span) + span.size();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
XGBOOST_DEVICE auto trbegin(xgboost::common::Span<T> &span) { // NOLINT
|
||||||
|
return thrust::make_reverse_iterator(span.data() + span.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
XGBOOST_DEVICE auto trend(xgboost::common::Span<T> &span) { // NOLINT
|
||||||
|
return trbegin(span) + span.size();
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
thrust::device_ptr<T const> tcbegin(xgboost::common::Span<T> const& span) { // NOLINT
|
thrust::device_ptr<T const> tcbegin(xgboost::common::Span<T> const& span) { // NOLINT
|
||||||
return thrust::device_ptr<T const>(span.data());
|
return thrust::device_ptr<T const>(span.data());
|
||||||
@ -967,6 +977,16 @@ thrust::device_ptr<T const> tcend(xgboost::common::Span<T> const& span) { // NO
|
|||||||
return tcbegin(span) + span.size();
|
return tcbegin(span) + span.size();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
XGBOOST_DEVICE auto tcrbegin(xgboost::common::Span<T> const &span) { // NOLINT
|
||||||
|
return thrust::make_reverse_iterator(span.data() + span.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
XGBOOST_DEVICE auto tcrend(xgboost::common::Span<T> const &span) { // NOLINT
|
||||||
|
return tcrbegin(span) + span.size();
|
||||||
|
}
|
||||||
|
|
||||||
// This type sorts an array which is divided into multiple groups. The sorting is influenced
|
// This type sorts an array which is divided into multiple groups. The sorting is influenced
|
||||||
// by the function object 'Comparator'
|
// by the function object 'Comparator'
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
|||||||
@ -98,12 +98,18 @@ struct TestRBeginREnd {
|
|||||||
InitializeRange(arr, arr + 16);
|
InitializeRange(arr, arr + 16);
|
||||||
|
|
||||||
Span<float> s (arr);
|
Span<float> s (arr);
|
||||||
Span<float>::iterator rbeg { s.rbegin() };
|
|
||||||
Span<float>::iterator rend { s.rend() };
|
|
||||||
|
|
||||||
SPAN_ASSERT_TRUE(rbeg == rend + 16, status_);
|
#if defined(__CUDA_ARCH__)
|
||||||
SPAN_ASSERT_TRUE(*(rbeg - 1) == arr[15], status_);
|
auto rbeg = dh::trbegin(s);
|
||||||
SPAN_ASSERT_TRUE(*rend == arr[0], status_);
|
auto rend = dh::trend(s);
|
||||||
|
#else
|
||||||
|
Span<float>::reverse_iterator rbeg{s.rbegin()};
|
||||||
|
Span<float>::reverse_iterator rend{s.rend()};
|
||||||
|
#endif
|
||||||
|
|
||||||
|
SPAN_ASSERT_TRUE(rbeg + 16 == rend, status_);
|
||||||
|
SPAN_ASSERT_TRUE(*(rbeg) == arr[15], status_);
|
||||||
|
SPAN_ASSERT_TRUE(*(rend - 1) == arr[0], status_);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user