diff --git a/src/common/device_helpers.cuh b/src/common/device_helpers.cuh index 2d7c602d3..c10c6a4eb 100644 --- a/src/common/device_helpers.cuh +++ b/src/common/device_helpers.cuh @@ -369,6 +369,16 @@ class dvec { } thrust::copy(begin, end, this->tbegin()); } + + void copy(thrust::device_ptr begin, thrust::device_ptr end) { + safe_cuda(cudaSetDevice(this->device_idx())); + if (end - begin != size()) { + throw std::runtime_error( + "Cannot copy assign vector to dvec, sizes are different"); + } + safe_cuda(cudaMemcpy(this->data(), begin.get(), + size() * sizeof(T), cudaMemcpyDefault)); + } }; /**