Validate weights are positive values. (#6115)
This commit is contained in:
@@ -359,6 +359,9 @@ void MetaInfo::SetInfo(const char* key, const void* dptr, DataType dtype, size_t
|
||||
weights.resize(num);
|
||||
DISPATCH_CONST_PTR(dtype, dptr, cast_dptr,
|
||||
std::copy(cast_dptr, cast_dptr + num, weights.begin()));
|
||||
auto valid = std::all_of(weights.cbegin(), weights.cend(),
|
||||
[](float w) { return w >= 0; });
|
||||
CHECK(valid) << "Weights must be positive values.";
|
||||
} else if (!std::strcmp(key, "base_margin")) {
|
||||
auto& base_margin = base_margin_.HostVector();
|
||||
base_margin.resize(num);
|
||||
|
||||
@@ -86,6 +86,10 @@ void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) {
|
||||
CopyInfoImpl(array_interface, &labels_);
|
||||
} else if (key == "weight") {
|
||||
CopyInfoImpl(array_interface, &weights_);
|
||||
auto ptr = weights_.ConstDevicePointer();
|
||||
auto valid =
|
||||
thrust::all_of(thrust::device, ptr, ptr + weights_.Size(), AllOfOp{});
|
||||
CHECK(valid) << "Weights must be positive values.";
|
||||
} else if (key == "base_margin") {
|
||||
CopyInfoImpl(array_interface, &base_margin_);
|
||||
} else if (key == "group") {
|
||||
@@ -100,10 +104,9 @@ void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) {
|
||||
} else if (key == "feature_weights") {
|
||||
CopyInfoImpl(array_interface, &feature_weigths);
|
||||
auto d_feature_weights = feature_weigths.ConstDeviceSpan();
|
||||
auto valid =
|
||||
thrust::all_of(thrust::device, d_feature_weights.data(),
|
||||
d_feature_weights.data() + d_feature_weights.size(),
|
||||
AllOfOp{});
|
||||
auto valid = thrust::all_of(
|
||||
thrust::device, d_feature_weights.data(),
|
||||
d_feature_weights.data() + d_feature_weights.size(), AllOfOp{});
|
||||
CHECK(valid) << "Feature weight must be greater than 0.";
|
||||
return;
|
||||
} else {
|
||||
|
||||
Reference in New Issue
Block a user