Stricter validation for group. (#7345)
This commit is contained in:
@@ -233,12 +233,29 @@ TEST(MetaInfo, Validate) {
|
||||
info.SetInfo("label", labels.data(), xgboost::DataType::kFloat32, info.num_row_ + 1);
|
||||
EXPECT_THROW(info.Validate(0), dmlc::Error);
|
||||
|
||||
// Make overflow data, which can happen when users pass group structure as int
|
||||
// or float.
|
||||
groups = {};
|
||||
for (size_t i = 0; i < 63; ++i) {
|
||||
groups.push_back(1562500);
|
||||
}
|
||||
groups.push_back(static_cast<xgboost::bst_group_t>(-1));
|
||||
EXPECT_THROW(info.SetInfo("group", groups.data(), xgboost::DataType::kUInt32,
|
||||
groups.size()),
|
||||
dmlc::Error);
|
||||
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
info.group_ptr_.clear();
|
||||
labels.resize(info.num_row_);
|
||||
info.SetInfo("label", labels.data(), xgboost::DataType::kFloat32, info.num_row_);
|
||||
info.labels_.SetDevice(0);
|
||||
EXPECT_THROW(info.Validate(1), dmlc::Error);
|
||||
|
||||
xgboost::HostDeviceVector<xgboost::bst_group_t> d_groups{groups};
|
||||
auto arr_interface = xgboost::GetArrayInterface(&d_groups, 64, 1);
|
||||
std::string arr_interface_str;
|
||||
xgboost::Json::Dump(arr_interface, &arr_interface_str);
|
||||
EXPECT_THROW(info.SetInfo("group", arr_interface_str), dmlc::Error);
|
||||
#endif // defined(XGBOOST_USE_CUDA)
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user