From d1f00fb0b77d02f57c10fbde2b90af4a7b636bf7 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Thu, 21 Oct 2021 12:13:33 +0800 Subject: [PATCH] Stricter validation for group. (#7345) --- src/data/data.cc | 12 ++++++++++++ src/data/data.cu | 3 +++ tests/cpp/data/test_metainfo.cc | 17 +++++++++++++++++ 3 files changed, 32 insertions(+) diff --git a/src/data/data.cc b/src/data/data.cc index 2ef5e2a1d..7179fe9b1 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -337,6 +337,17 @@ inline bool MetaTryLoadFloatInfo(const std::string& fname, return true; } +void ValidateQueryGroup(std::vector const &group_ptr_) { + bool valid_query_group = true; + for (size_t i = 1; i < group_ptr_.size(); ++i) { + valid_query_group = valid_query_group && group_ptr_[i] >= group_ptr_[i - 1]; + if (!valid_query_group) { + break; + } + } + CHECK(valid_query_group) << "Invalid group structure."; +} + // macro to dispatch according to specified pointer types #define DISPATCH_CONST_PTR(dtype, old_ptr, cast_ptr, proc) \ switch (dtype) { \ @@ -387,6 +398,7 @@ void MetaInfo::SetInfo(const char* key, const void* dptr, DataType dtype, size_t for (size_t i = 1; i < group_ptr_.size(); ++i) { group_ptr_[i] = group_ptr_[i - 1] + group_ptr_[i]; } + ValidateQueryGroup(group_ptr_); } else if (!std::strcmp(key, "qid")) { std::vector query_ids(num, 0); DISPATCH_CONST_PTR(dtype, dptr, cast_dptr, diff --git a/src/data/data.cu b/src/data/data.cu index 2c421938c..2f298f330 100644 --- a/src/data/data.cu +++ b/src/data/data.cu @@ -122,6 +122,8 @@ struct WeightsCheck { }; } // anonymous namespace +void ValidateQueryGroup(std::vector const &group_ptr_); + void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) { Json j_interface = Json::Load({interface_str.c_str(), interface_str.size()}); auto const& j_arr = get(j_interface); @@ -157,6 +159,7 @@ void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) { CopyInfoImpl(array_interface, &base_margin_); } else if (key == "group") { CopyGroupInfoImpl(array_interface, &group_ptr_); + ValidateQueryGroup(group_ptr_); return; } else if (key == "qid") { CopyQidImpl(array_interface, &group_ptr_); diff --git a/tests/cpp/data/test_metainfo.cc b/tests/cpp/data/test_metainfo.cc index ad59ba4f5..a58bda90d 100644 --- a/tests/cpp/data/test_metainfo.cc +++ b/tests/cpp/data/test_metainfo.cc @@ -233,12 +233,29 @@ TEST(MetaInfo, Validate) { info.SetInfo("label", labels.data(), xgboost::DataType::kFloat32, info.num_row_ + 1); EXPECT_THROW(info.Validate(0), dmlc::Error); + // Make overflow data, which can happen when users pass group structure as int + // or float. + groups = {}; + for (size_t i = 0; i < 63; ++i) { + groups.push_back(1562500); + } + groups.push_back(static_cast(-1)); + EXPECT_THROW(info.SetInfo("group", groups.data(), xgboost::DataType::kUInt32, + groups.size()), + dmlc::Error); + #if defined(XGBOOST_USE_CUDA) info.group_ptr_.clear(); labels.resize(info.num_row_); info.SetInfo("label", labels.data(), xgboost::DataType::kFloat32, info.num_row_); info.labels_.SetDevice(0); EXPECT_THROW(info.Validate(1), dmlc::Error); + + xgboost::HostDeviceVector d_groups{groups}; + auto arr_interface = xgboost::GetArrayInterface(&d_groups, 64, 1); + std::string arr_interface_str; + xgboost::Json::Dump(arr_interface, &arr_interface_str); + EXPECT_THROW(info.SetInfo("group", arr_interface_str), dmlc::Error); #endif // defined(XGBOOST_USE_CUDA) }