Stricter validation for group. (#7345)
This commit is contained in:
@@ -337,6 +337,17 @@ inline bool MetaTryLoadFloatInfo(const std::string& fname,
|
||||
return true;
|
||||
}
|
||||
|
||||
void ValidateQueryGroup(std::vector<bst_group_t> const &group_ptr_) {
|
||||
bool valid_query_group = true;
|
||||
for (size_t i = 1; i < group_ptr_.size(); ++i) {
|
||||
valid_query_group = valid_query_group && group_ptr_[i] >= group_ptr_[i - 1];
|
||||
if (!valid_query_group) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
CHECK(valid_query_group) << "Invalid group structure.";
|
||||
}
|
||||
|
||||
// macro to dispatch according to specified pointer types
|
||||
#define DISPATCH_CONST_PTR(dtype, old_ptr, cast_ptr, proc) \
|
||||
switch (dtype) { \
|
||||
@@ -387,6 +398,7 @@ void MetaInfo::SetInfo(const char* key, const void* dptr, DataType dtype, size_t
|
||||
for (size_t i = 1; i < group_ptr_.size(); ++i) {
|
||||
group_ptr_[i] = group_ptr_[i - 1] + group_ptr_[i];
|
||||
}
|
||||
ValidateQueryGroup(group_ptr_);
|
||||
} else if (!std::strcmp(key, "qid")) {
|
||||
std::vector<uint32_t> query_ids(num, 0);
|
||||
DISPATCH_CONST_PTR(dtype, dptr, cast_dptr,
|
||||
|
||||
@@ -122,6 +122,8 @@ struct WeightsCheck {
|
||||
};
|
||||
} // anonymous namespace
|
||||
|
||||
void ValidateQueryGroup(std::vector<bst_group_t> const &group_ptr_);
|
||||
|
||||
void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) {
|
||||
Json j_interface = Json::Load({interface_str.c_str(), interface_str.size()});
|
||||
auto const& j_arr = get<Array>(j_interface);
|
||||
@@ -157,6 +159,7 @@ void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) {
|
||||
CopyInfoImpl(array_interface, &base_margin_);
|
||||
} else if (key == "group") {
|
||||
CopyGroupInfoImpl(array_interface, &group_ptr_);
|
||||
ValidateQueryGroup(group_ptr_);
|
||||
return;
|
||||
} else if (key == "qid") {
|
||||
CopyQidImpl(array_interface, &group_ptr_);
|
||||
|
||||
Reference in New Issue
Block a user