Implement extend method for meta info. (#5800)

* Implement extend for host device vector.
This commit is contained in:
Jiaming Yuan
2020-06-20 03:32:03 +08:00
committed by GitHub
parent a6d9a06b7b
commit c4d721200a
7 changed files with 138 additions and 1 deletions

View File

@@ -136,6 +136,14 @@ void HostDeviceVector<T>::Copy(std::initializer_list<T> other) {
std::copy(other.begin(), other.end(), HostVector().begin());
}
template <typename T>
void HostDeviceVector<T>::Extend(HostDeviceVector const& other) {
auto ori_size = this->Size();
this->HostVector().resize(ori_size + other.Size());
std::copy(other.ConstHostVector().cbegin(), other.ConstHostVector().cend(),
this->HostVector().begin() + ori_size);
}
template <typename T>
bool HostDeviceVector<T>::HostCanRead() const {
return true;

View File

@@ -125,6 +125,25 @@ class HostDeviceVectorImpl {
}
}
void Extend(HostDeviceVectorImpl* other) {
auto ori_size = this->Size();
this->Resize(ori_size + other->Size(), T());
if (HostCanWrite() && other->HostCanRead()) {
auto& h_vec = this->HostVector();
auto& other_vec = other->HostVector();
CHECK_EQ(h_vec.size(), ori_size + other->Size());
std::copy(other_vec.cbegin(), other_vec.cend(), h_vec.begin() + ori_size);
} else {
auto ptr = other->ConstDevicePointer();
SetDevice();
CHECK_EQ(this->DeviceIdx(), other->DeviceIdx());
dh::safe_cuda(cudaMemcpyAsync(this->DevicePointer() + ori_size,
ptr,
other->Size() * sizeof(T),
cudaMemcpyDeviceToDevice));
}
}
std::vector<T>& HostVector() {
LazySyncHost(GPUAccess::kNone);
return data_h_;
@@ -326,6 +345,11 @@ void HostDeviceVector<T>::Copy(std::initializer_list<T> other) {
impl_->Copy(other);
}
template <typename T>
void HostDeviceVector<T>::Extend(HostDeviceVector const& other) {
impl_->Extend(other.impl_);
}
template <typename T>
std::vector<T>& HostDeviceVector<T>::HostVector() { return impl_->HostVector(); }

View File

@@ -343,6 +343,44 @@ void MetaInfo::SetInfo(const char* key, const void* dptr, DataType dtype, size_t
}
}
void MetaInfo::Extend(MetaInfo const& that, bool accumulate_rows) {
if (accumulate_rows) {
this->num_row_ += that.num_row_;
}
if (this->num_col_ != 0) {
CHECK_EQ(this->num_col_, that.num_col_)
<< "Number of columns must be consistent across batches.";
}
this->num_col_ = that.num_col_;
this->labels_.SetDevice(that.labels_.DeviceIdx());
this->labels_.Extend(that.labels_);
this->weights_.SetDevice(that.weights_.DeviceIdx());
this->weights_.Extend(that.weights_);
this->labels_lower_bound_.SetDevice(that.labels_lower_bound_.DeviceIdx());
this->labels_lower_bound_.Extend(that.labels_lower_bound_);
this->labels_upper_bound_.SetDevice(that.labels_upper_bound_.DeviceIdx());
this->labels_upper_bound_.Extend(that.labels_upper_bound_);
this->base_margin_.SetDevice(that.base_margin_.DeviceIdx());
this->base_margin_.Extend(that.base_margin_);
if (this->group_ptr_.size() == 0) {
this->group_ptr_ = that.group_ptr_;
} else {
CHECK_NE(that.group_ptr_.size(), 0);
auto group_ptr = that.group_ptr_;
for (size_t i = 1; i < group_ptr.size(); ++i) {
group_ptr[i] += this->group_ptr_.back();
}
this->group_ptr_.insert(this->group_ptr_.end(), group_ptr.begin() + 1,
group_ptr.end());
}
}
void MetaInfo::Validate(int32_t device) const {
if (group_ptr_.size() != 0 && weights_.Size() != 0) {
CHECK_EQ(group_ptr_.size(), weights_.Size() + 1)