Batch UpdatePosition using cudaMemcpy (#7964)
This commit is contained in:
@@ -1939,4 +1939,25 @@ class CUDAStream {
|
||||
CUDAStreamView View() const { return CUDAStreamView{stream_}; }
|
||||
void Sync() { this->View().Sync(); }
|
||||
};
|
||||
|
||||
// Force nvcc to load data as constant
|
||||
template <typename T>
|
||||
class LDGIterator {
|
||||
using DeviceWordT = typename cub::UnitWord<T>::DeviceWord;
|
||||
static constexpr std::size_t kNumWords = sizeof(T) / sizeof(DeviceWordT);
|
||||
|
||||
const T *ptr_;
|
||||
|
||||
public:
|
||||
explicit LDGIterator(const T *ptr) : ptr_(ptr) {}
|
||||
__device__ T operator[](std::size_t idx) const {
|
||||
DeviceWordT tmp[kNumWords];
|
||||
static_assert(sizeof(tmp) == sizeof(T), "Expect sizes to be equal.");
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNumWords; i++) {
|
||||
tmp[i] = __ldg(reinterpret_cast<const DeviceWordT *>(ptr_ + idx) + i);
|
||||
}
|
||||
return *reinterpret_cast<const T *>(tmp);
|
||||
}
|
||||
};
|
||||
} // namespace dh
|
||||
|
||||
Reference in New Issue
Block a user