Implement extend method for meta info. (#5800)
* Implement extend for host device vector.
This commit is contained in:
@@ -136,6 +136,14 @@ void HostDeviceVector<T>::Copy(std::initializer_list<T> other) {
|
||||
std::copy(other.begin(), other.end(), HostVector().begin());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void HostDeviceVector<T>::Extend(HostDeviceVector const& other) {
|
||||
auto ori_size = this->Size();
|
||||
this->HostVector().resize(ori_size + other.Size());
|
||||
std::copy(other.ConstHostVector().cbegin(), other.ConstHostVector().cend(),
|
||||
this->HostVector().begin() + ori_size);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool HostDeviceVector<T>::HostCanRead() const {
|
||||
return true;
|
||||
|
||||
@@ -125,6 +125,25 @@ class HostDeviceVectorImpl {
|
||||
}
|
||||
}
|
||||
|
||||
void Extend(HostDeviceVectorImpl* other) {
|
||||
auto ori_size = this->Size();
|
||||
this->Resize(ori_size + other->Size(), T());
|
||||
if (HostCanWrite() && other->HostCanRead()) {
|
||||
auto& h_vec = this->HostVector();
|
||||
auto& other_vec = other->HostVector();
|
||||
CHECK_EQ(h_vec.size(), ori_size + other->Size());
|
||||
std::copy(other_vec.cbegin(), other_vec.cend(), h_vec.begin() + ori_size);
|
||||
} else {
|
||||
auto ptr = other->ConstDevicePointer();
|
||||
SetDevice();
|
||||
CHECK_EQ(this->DeviceIdx(), other->DeviceIdx());
|
||||
dh::safe_cuda(cudaMemcpyAsync(this->DevicePointer() + ori_size,
|
||||
ptr,
|
||||
other->Size() * sizeof(T),
|
||||
cudaMemcpyDeviceToDevice));
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<T>& HostVector() {
|
||||
LazySyncHost(GPUAccess::kNone);
|
||||
return data_h_;
|
||||
@@ -326,6 +345,11 @@ void HostDeviceVector<T>::Copy(std::initializer_list<T> other) {
|
||||
impl_->Copy(other);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void HostDeviceVector<T>::Extend(HostDeviceVector const& other) {
|
||||
impl_->Extend(other.impl_);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::vector<T>& HostDeviceVector<T>::HostVector() { return impl_->HostVector(); }
|
||||
|
||||
|
||||
Reference in New Issue
Block a user