Improve parameter validation (#6769)

* Add quotes to unused parameters.
* Check for whitespace.
This commit is contained in:
Jiaming Yuan 2021-03-20 01:56:55 +08:00 committed by GitHub
parent 23b4165a6b
commit f6fe15d11f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 14 additions and 7 deletions

View File

@ -66,7 +66,7 @@ test_that("parameter validation works", {
xgb.train(params = params, data = dtrain, nrounds = nrounds))
print(output)
}
expect_output(incorrect(), "bar, foo")
expect_output(incorrect(), '\\\\"bar\\\\", \\\\"foo\\\\"')
})

View File

@ -537,15 +537,18 @@ class LearnerConfiguration : public Learner {
}
}
// FIXME(trivialfis): Make eval_metric a training parameter.
keys.emplace_back(kEvalMetric);
keys.emplace_back("verbosity");
keys.emplace_back("num_output_group");
std::sort(keys.begin(), keys.end());
std::vector<std::string> provided;
for (auto const &kv : cfg_) {
// FIXME(trivialfis): Make eval_metric a training parameter.
if (std::any_of(kv.first.cbegin(), kv.first.cend(),
[](char ch) { return std::isspace(ch); })) {
LOG(FATAL) << "Invalid parameter \"" << kv.first << "\" contains whitespace.";
}
provided.push_back(kv.first);
}
std::sort(provided.begin(), provided.end());
@ -557,9 +560,9 @@ class LearnerConfiguration : public Learner {
std::stringstream ss;
ss << "\nParameters: { ";
for (size_t i = 0; i < diff.size() - 1; ++i) {
ss << diff[i] << ", ";
ss << "\"" << diff[i] << "\", ";
}
ss << diff.back();
ss << "\"" << diff.back() << "\"";
ss << R"W( } might not be used.
This may not be accurate due to some parameters are only used in language bindings but

View File

@ -40,7 +40,7 @@ TEST(Learner, ParameterValidation) {
auto learner = std::unique_ptr<Learner>(Learner::Create({p_mat}));
learner->SetParam("validate_parameters", "1");
learner->SetParam("Knock Knock", "Who's there?");
learner->SetParam("Knock-Knock", "Who's-there?");
learner->SetParam("Silence", "....");
learner->SetParam("tree_method", "exact");
@ -48,7 +48,11 @@ TEST(Learner, ParameterValidation) {
learner->Configure();
std::string output = testing::internal::GetCapturedStderr();
ASSERT_TRUE(output.find("Parameters: { Knock Knock, Silence }") != std::string::npos);
ASSERT_TRUE(output.find(R"(Parameters: { "Knock-Knock", "Silence" })") != std::string::npos);
// whitespace
learner->SetParam("tree method", "exact");
EXPECT_THROW(learner->Configure(), dmlc::Error);
}
TEST(Learner, CheckGroup) {