Stricter validation for group. (#7345)
This commit is contained in:
parent
74bab6e504
commit
d1f00fb0b7
@ -337,6 +337,17 @@ inline bool MetaTryLoadFloatInfo(const std::string& fname,
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ValidateQueryGroup(std::vector<bst_group_t> const &group_ptr_) {
|
||||||
|
bool valid_query_group = true;
|
||||||
|
for (size_t i = 1; i < group_ptr_.size(); ++i) {
|
||||||
|
valid_query_group = valid_query_group && group_ptr_[i] >= group_ptr_[i - 1];
|
||||||
|
if (!valid_query_group) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
CHECK(valid_query_group) << "Invalid group structure.";
|
||||||
|
}
|
||||||
|
|
||||||
// macro to dispatch according to specified pointer types
|
// macro to dispatch according to specified pointer types
|
||||||
#define DISPATCH_CONST_PTR(dtype, old_ptr, cast_ptr, proc) \
|
#define DISPATCH_CONST_PTR(dtype, old_ptr, cast_ptr, proc) \
|
||||||
switch (dtype) { \
|
switch (dtype) { \
|
||||||
@ -387,6 +398,7 @@ void MetaInfo::SetInfo(const char* key, const void* dptr, DataType dtype, size_t
|
|||||||
for (size_t i = 1; i < group_ptr_.size(); ++i) {
|
for (size_t i = 1; i < group_ptr_.size(); ++i) {
|
||||||
group_ptr_[i] = group_ptr_[i - 1] + group_ptr_[i];
|
group_ptr_[i] = group_ptr_[i - 1] + group_ptr_[i];
|
||||||
}
|
}
|
||||||
|
ValidateQueryGroup(group_ptr_);
|
||||||
} else if (!std::strcmp(key, "qid")) {
|
} else if (!std::strcmp(key, "qid")) {
|
||||||
std::vector<uint32_t> query_ids(num, 0);
|
std::vector<uint32_t> query_ids(num, 0);
|
||||||
DISPATCH_CONST_PTR(dtype, dptr, cast_dptr,
|
DISPATCH_CONST_PTR(dtype, dptr, cast_dptr,
|
||||||
|
|||||||
@ -122,6 +122,8 @@ struct WeightsCheck {
|
|||||||
};
|
};
|
||||||
} // anonymous namespace
|
} // anonymous namespace
|
||||||
|
|
||||||
|
void ValidateQueryGroup(std::vector<bst_group_t> const &group_ptr_);
|
||||||
|
|
||||||
void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) {
|
void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) {
|
||||||
Json j_interface = Json::Load({interface_str.c_str(), interface_str.size()});
|
Json j_interface = Json::Load({interface_str.c_str(), interface_str.size()});
|
||||||
auto const& j_arr = get<Array>(j_interface);
|
auto const& j_arr = get<Array>(j_interface);
|
||||||
@ -157,6 +159,7 @@ void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) {
|
|||||||
CopyInfoImpl(array_interface, &base_margin_);
|
CopyInfoImpl(array_interface, &base_margin_);
|
||||||
} else if (key == "group") {
|
} else if (key == "group") {
|
||||||
CopyGroupInfoImpl(array_interface, &group_ptr_);
|
CopyGroupInfoImpl(array_interface, &group_ptr_);
|
||||||
|
ValidateQueryGroup(group_ptr_);
|
||||||
return;
|
return;
|
||||||
} else if (key == "qid") {
|
} else if (key == "qid") {
|
||||||
CopyQidImpl(array_interface, &group_ptr_);
|
CopyQidImpl(array_interface, &group_ptr_);
|
||||||
|
|||||||
@ -233,12 +233,29 @@ TEST(MetaInfo, Validate) {
|
|||||||
info.SetInfo("label", labels.data(), xgboost::DataType::kFloat32, info.num_row_ + 1);
|
info.SetInfo("label", labels.data(), xgboost::DataType::kFloat32, info.num_row_ + 1);
|
||||||
EXPECT_THROW(info.Validate(0), dmlc::Error);
|
EXPECT_THROW(info.Validate(0), dmlc::Error);
|
||||||
|
|
||||||
|
// Make overflow data, which can happen when users pass group structure as int
|
||||||
|
// or float.
|
||||||
|
groups = {};
|
||||||
|
for (size_t i = 0; i < 63; ++i) {
|
||||||
|
groups.push_back(1562500);
|
||||||
|
}
|
||||||
|
groups.push_back(static_cast<xgboost::bst_group_t>(-1));
|
||||||
|
EXPECT_THROW(info.SetInfo("group", groups.data(), xgboost::DataType::kUInt32,
|
||||||
|
groups.size()),
|
||||||
|
dmlc::Error);
|
||||||
|
|
||||||
#if defined(XGBOOST_USE_CUDA)
|
#if defined(XGBOOST_USE_CUDA)
|
||||||
info.group_ptr_.clear();
|
info.group_ptr_.clear();
|
||||||
labels.resize(info.num_row_);
|
labels.resize(info.num_row_);
|
||||||
info.SetInfo("label", labels.data(), xgboost::DataType::kFloat32, info.num_row_);
|
info.SetInfo("label", labels.data(), xgboost::DataType::kFloat32, info.num_row_);
|
||||||
info.labels_.SetDevice(0);
|
info.labels_.SetDevice(0);
|
||||||
EXPECT_THROW(info.Validate(1), dmlc::Error);
|
EXPECT_THROW(info.Validate(1), dmlc::Error);
|
||||||
|
|
||||||
|
xgboost::HostDeviceVector<xgboost::bst_group_t> d_groups{groups};
|
||||||
|
auto arr_interface = xgboost::GetArrayInterface(&d_groups, 64, 1);
|
||||||
|
std::string arr_interface_str;
|
||||||
|
xgboost::Json::Dump(arr_interface, &arr_interface_str);
|
||||||
|
EXPECT_THROW(info.SetInfo("group", arr_interface_str), dmlc::Error);
|
||||||
#endif // defined(XGBOOST_USE_CUDA)
|
#endif // defined(XGBOOST_USE_CUDA)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user