Add CUDA iterator to tensor view. (#10074)
This commit is contained in:
@@ -295,6 +295,9 @@ class TensorView {
|
||||
using ShapeT = std::size_t[kDim];
|
||||
using StrideT = ShapeT;
|
||||
|
||||
using element_type = T; // NOLINT
|
||||
using value_type = std::remove_cv_t<T>; // NOLINT
|
||||
|
||||
private:
|
||||
StrideT stride_{1};
|
||||
ShapeT shape_{0};
|
||||
@@ -314,7 +317,7 @@ class TensorView {
|
||||
}
|
||||
|
||||
template <size_t old_dim, size_t new_dim, int32_t D, typename I>
|
||||
LINALG_HD size_t MakeSliceDim(size_t new_shape[D], size_t new_stride[D],
|
||||
LINALG_HD size_t MakeSliceDim(std::size_t new_shape[D], std::size_t new_stride[D],
|
||||
detail::RangeTag<I> &&range) const {
|
||||
static_assert(new_dim < D);
|
||||
static_assert(old_dim < kDim);
|
||||
@@ -528,9 +531,10 @@ class TensorView {
|
||||
LINALG_HD auto Stride(size_t i) const { return stride_[i]; }
|
||||
|
||||
/**
|
||||
* \brief Number of items in the tensor.
|
||||
* @brief Number of items in the tensor.
|
||||
*/
|
||||
[[nodiscard]] LINALG_HD std::size_t Size() const { return size_; }
|
||||
[[nodiscard]] bool Empty() const { return Size() == 0; }
|
||||
/**
|
||||
* \brief Whether this is a contiguous array, both C and F contiguous returns true.
|
||||
*/
|
||||
@@ -865,7 +869,9 @@ class Tensor {
|
||||
auto HostView() { return this->View(DeviceOrd::CPU()); }
|
||||
auto HostView() const { return this->View(DeviceOrd::CPU()); }
|
||||
|
||||
[[nodiscard]] size_t Size() const { return data_.Size(); }
|
||||
[[nodiscard]] std::size_t Size() const { return data_.Size(); }
|
||||
[[nodiscard]] bool Empty() const { return Size() == 0; }
|
||||
|
||||
auto Shape() const { return common::Span<size_t const, kDim>{shape_}; }
|
||||
auto Shape(size_t i) const { return shape_[i]; }
|
||||
|
||||
|
||||
@@ -701,10 +701,10 @@ class IterSpan {
|
||||
return {data() + _offset, _count == dynamic_extent ? size() - _offset : _count};
|
||||
}
|
||||
[[nodiscard]] XGBOOST_DEVICE constexpr iterator begin() const noexcept { // NOLINT
|
||||
return {this, 0};
|
||||
return it_;
|
||||
}
|
||||
[[nodiscard]] XGBOOST_DEVICE constexpr iterator end() const noexcept { // NOLINT
|
||||
return {this, size()};
|
||||
return it_ + size();
|
||||
}
|
||||
};
|
||||
} // namespace common
|
||||
|
||||
Reference in New Issue
Block a user