Implement extend method for meta info. (#5800)
* Implement extend for host device vector.
This commit is contained in:
@@ -343,6 +343,44 @@ void MetaInfo::SetInfo(const char* key, const void* dptr, DataType dtype, size_t
|
||||
}
|
||||
}
|
||||
|
||||
void MetaInfo::Extend(MetaInfo const& that, bool accumulate_rows) {
|
||||
if (accumulate_rows) {
|
||||
this->num_row_ += that.num_row_;
|
||||
}
|
||||
if (this->num_col_ != 0) {
|
||||
CHECK_EQ(this->num_col_, that.num_col_)
|
||||
<< "Number of columns must be consistent across batches.";
|
||||
}
|
||||
this->num_col_ = that.num_col_;
|
||||
|
||||
this->labels_.SetDevice(that.labels_.DeviceIdx());
|
||||
this->labels_.Extend(that.labels_);
|
||||
|
||||
this->weights_.SetDevice(that.weights_.DeviceIdx());
|
||||
this->weights_.Extend(that.weights_);
|
||||
|
||||
this->labels_lower_bound_.SetDevice(that.labels_lower_bound_.DeviceIdx());
|
||||
this->labels_lower_bound_.Extend(that.labels_lower_bound_);
|
||||
|
||||
this->labels_upper_bound_.SetDevice(that.labels_upper_bound_.DeviceIdx());
|
||||
this->labels_upper_bound_.Extend(that.labels_upper_bound_);
|
||||
|
||||
this->base_margin_.SetDevice(that.base_margin_.DeviceIdx());
|
||||
this->base_margin_.Extend(that.base_margin_);
|
||||
|
||||
if (this->group_ptr_.size() == 0) {
|
||||
this->group_ptr_ = that.group_ptr_;
|
||||
} else {
|
||||
CHECK_NE(that.group_ptr_.size(), 0);
|
||||
auto group_ptr = that.group_ptr_;
|
||||
for (size_t i = 1; i < group_ptr.size(); ++i) {
|
||||
group_ptr[i] += this->group_ptr_.back();
|
||||
}
|
||||
this->group_ptr_.insert(this->group_ptr_.end(), group_ptr.begin() + 1,
|
||||
group_ptr.end());
|
||||
}
|
||||
}
|
||||
|
||||
void MetaInfo::Validate(int32_t device) const {
|
||||
if (group_ptr_.size() != 0 && weights_.Size() != 0) {
|
||||
CHECK_EQ(group_ptr_.size(), weights_.Size() + 1)
|
||||
|
||||
Reference in New Issue
Block a user