Feature weights (#5962)

This commit is contained in:
Jiaming Yuan
2020-08-18 19:55:41 +08:00
committed by GitHub
parent a418278064
commit 4d99c58a5f
25 changed files with 509 additions and 104 deletions

View File

@@ -293,6 +293,9 @@ MetaInfo MetaInfo::Slice(common::Span<int32_t const> ridxs) const {
} else {
out.base_margin_.HostVector() = Gather(this->base_margin_.HostVector(), ridxs);
}
out.feature_weigths.Resize(this->feature_weigths.Size());
out.feature_weigths.Copy(this->feature_weigths);
return out;
}
@@ -377,6 +380,16 @@ void MetaInfo::SetInfo(const char* key, const void* dptr, DataType dtype, size_t
labels.resize(num);
DISPATCH_CONST_PTR(dtype, dptr, cast_dptr,
std::copy(cast_dptr, cast_dptr + num, labels.begin()));
} else if (!std::strcmp(key, "feature_weights")) {
auto &h_feature_weights = feature_weigths.HostVector();
h_feature_weights.resize(num);
DISPATCH_CONST_PTR(
dtype, dptr, cast_dptr,
std::copy(cast_dptr, cast_dptr + num, h_feature_weights.begin()));
bool valid =
std::all_of(h_feature_weights.cbegin(), h_feature_weights.cend(),
[](float w) { return w >= 0; });
CHECK(valid) << "Feature weight must be greater than 0.";
} else {
LOG(FATAL) << "Unknown key for MetaInfo: " << key;
}
@@ -396,6 +409,8 @@ void MetaInfo::GetInfo(char const *key, bst_ulong *out_len, DataType dtype,
vec = &this->labels_lower_bound_.HostVector();
} else if (!std::strcmp(key, "label_upper_bound")) {
vec = &this->labels_upper_bound_.HostVector();
} else if (!std::strcmp(key, "feature_weights")) {
vec = &this->feature_weigths.HostVector();
} else {
LOG(FATAL) << "Unknown float field name: " << key;
}
@@ -497,6 +512,11 @@ void MetaInfo::Extend(MetaInfo const& that, bool accumulate_rows) {
auto &h_feature_types = feature_types.HostVector();
LoadFeatureType(this->feature_type_names, &h_feature_types);
}
if (!that.feature_weigths.Empty()) {
this->feature_weigths.Resize(that.feature_weigths.Size());
this->feature_weigths.SetDevice(that.feature_weigths.DeviceIdx());
this->feature_weigths.Copy(that.feature_weigths);
}
}
void MetaInfo::Validate(int32_t device) const {
@@ -538,6 +558,11 @@ void MetaInfo::Validate(int32_t device) const {
check_device(labels_lower_bound_);
return;
}
if (feature_weigths.Size() != 0) {
CHECK_EQ(feature_weigths.Size(), num_col_)
<< "Size of feature_weights must equal to number of columns.";
check_device(feature_weigths);
}
if (labels_upper_bound_.Size() != 0) {
CHECK_EQ(labels_upper_bound_.Size(), num_row_)
<< "Size of label_upper_bound must equal to number of rows.";

View File

@@ -58,6 +58,15 @@ void CopyGroupInfoImpl(ArrayInterface column, std::vector<bst_group_t>* out) {
std::partial_sum(out->begin(), out->end(), out->begin());
}
namespace {
// thrust::all_of tries to copy lambda function.
struct AllOfOp {
__device__ bool operator()(float w) {
return w >= 0;
}
};
} // anonymous namespace
void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) {
Json j_interface = Json::Load({interface_str.c_str(), interface_str.size()});
auto const& j_arr = get<Array>(j_interface);
@@ -82,6 +91,21 @@ void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) {
} else if (key == "group") {
CopyGroupInfoImpl(array_interface, &group_ptr_);
return;
} else if (key == "label_lower_bound") {
CopyInfoImpl(array_interface, &labels_lower_bound_);
return;
} else if (key == "label_upper_bound") {
CopyInfoImpl(array_interface, &labels_upper_bound_);
return;
} else if (key == "feature_weights") {
CopyInfoImpl(array_interface, &feature_weigths);
auto d_feature_weights = feature_weigths.ConstDeviceSpan();
auto valid =
thrust::all_of(thrust::device, d_feature_weights.data(),
d_feature_weights.data() + d_feature_weights.size(),
AllOfOp{});
CHECK(valid) << "Feature weight must be greater than 0.";
return;
} else {
LOG(FATAL) << "Unknown metainfo: " << key;
}