Fix linalg iterator. (#8603)

This commit is contained in:
Jiaming Yuan 2022-12-16 23:05:03 +08:00 committed by GitHub
parent 38887a1876
commit a10e4cba4e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -63,7 +63,7 @@ void ElementWiseKernel(Context const* ctx, linalg::TensorView<T, D> t, Fn&& fn)
#endif // !defined(XGBOOST_USE_CUDA)
template <typename T, std::int32_t kDim>
auto cbegin(TensorView<T, kDim> v) { // NOLINT
auto cbegin(TensorView<T, kDim> const& v) { // NOLINT
auto it = common::MakeIndexTransformIter([&](size_t i) -> std::remove_cv_t<T> const& {
return linalg::detail::Apply(v, linalg::UnravelIndex(i, v.Shape()));
});
@ -71,19 +71,19 @@ auto cbegin(TensorView<T, kDim> v) { // NOLINT
}
template <typename T, std::int32_t kDim>
auto cend(TensorView<T, kDim> v) { // NOLINT
auto cend(TensorView<T, kDim> const& v) { // NOLINT
return cbegin(v) + v.Size();
}
template <typename T, std::int32_t kDim>
auto begin(TensorView<T, kDim> v) { // NOLINT
auto begin(TensorView<T, kDim>& v) { // NOLINT
auto it = common::MakeIndexTransformIter(
[&](size_t i) -> T& { return linalg::detail::Apply(v, linalg::UnravelIndex(i, v.Shape())); });
return it;
}
template <typename T, std::int32_t kDim>
auto end(TensorView<T, kDim> v) { // NOLINT
auto end(TensorView<T, kDim>& v) { // NOLINT
return begin(v) + v.Size();
}
} // namespace linalg