This commit is contained in:
Jiaming Yuan
2021-11-15 01:28:11 +08:00
committed by GitHub
parent a7057fa64c
commit d4274bc556
6 changed files with 17 additions and 18 deletions

View File

@@ -297,8 +297,8 @@ MetaInfo MetaInfo::Slice(common::Span<int32_t const> ridxs) const {
out.base_margin_.HostVector() = Gather(this->base_margin_.HostVector(), ridxs);
}
out.feature_weigths.Resize(this->feature_weigths.Size());
out.feature_weigths.Copy(this->feature_weigths);
out.feature_weights.Resize(this->feature_weights.Size());
out.feature_weights.Copy(this->feature_weights);
out.feature_names = this->feature_names;
out.feature_types.Resize(this->feature_types.Size());
@@ -431,7 +431,7 @@ void MetaInfo::SetInfo(const char* key, const void* dptr, DataType dtype, size_t
DISPATCH_CONST_PTR(dtype, dptr, cast_dptr,
std::copy(cast_dptr, cast_dptr + num, labels.begin()));
} else if (!std::strcmp(key, "feature_weights")) {
auto &h_feature_weights = feature_weigths.HostVector();
auto &h_feature_weights = feature_weights.HostVector();
h_feature_weights.resize(num);
DISPATCH_CONST_PTR(
dtype, dptr, cast_dptr,
@@ -460,7 +460,7 @@ void MetaInfo::GetInfo(char const *key, bst_ulong *out_len, DataType dtype,
} else if (!std::strcmp(key, "label_upper_bound")) {
vec = &this->labels_upper_bound_.HostVector();
} else if (!std::strcmp(key, "feature_weights")) {
vec = &this->feature_weigths.HostVector();
vec = &this->feature_weights.HostVector();
} else {
LOG(FATAL) << "Unknown float field name: " << key;
}
@@ -566,10 +566,10 @@ void MetaInfo::Extend(MetaInfo const& that, bool accumulate_rows, bool check_col
auto &h_feature_types = feature_types.HostVector();
LoadFeatureType(this->feature_type_names, &h_feature_types);
}
if (!that.feature_weigths.Empty()) {
this->feature_weigths.Resize(that.feature_weigths.Size());
this->feature_weigths.SetDevice(that.feature_weigths.DeviceIdx());
this->feature_weigths.Copy(that.feature_weigths);
if (!that.feature_weights.Empty()) {
this->feature_weights.Resize(that.feature_weights.Size());
this->feature_weights.SetDevice(that.feature_weights.DeviceIdx());
this->feature_weights.Copy(that.feature_weights);
}
}
@@ -612,10 +612,10 @@ void MetaInfo::Validate(int32_t device) const {
check_device(labels_lower_bound_);
return;
}
if (feature_weigths.Size() != 0) {
CHECK_EQ(feature_weigths.Size(), num_col_)
if (feature_weights.Size() != 0) {
CHECK_EQ(feature_weights.Size(), num_col_)
<< "Size of feature_weights must equal to number of columns.";
check_device(feature_weigths);
check_device(feature_weights);
}
if (labels_upper_bound_.Size() != 0) {
CHECK_EQ(labels_upper_bound_.Size(), num_row_)