Improve parameter validation (#6769)
* Add quotes to unused parameters. * Check for whitespace.
This commit is contained in:
parent
23b4165a6b
commit
f6fe15d11f
@ -66,7 +66,7 @@ test_that("parameter validation works", {
|
|||||||
xgb.train(params = params, data = dtrain, nrounds = nrounds))
|
xgb.train(params = params, data = dtrain, nrounds = nrounds))
|
||||||
print(output)
|
print(output)
|
||||||
}
|
}
|
||||||
expect_output(incorrect(), "bar, foo")
|
expect_output(incorrect(), '\\\\"bar\\\\", \\\\"foo\\\\"')
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -537,15 +537,18 @@ class LearnerConfiguration : public Learner {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// FIXME(trivialfis): Make eval_metric a training parameter.
|
||||||
keys.emplace_back(kEvalMetric);
|
keys.emplace_back(kEvalMetric);
|
||||||
keys.emplace_back("verbosity");
|
|
||||||
keys.emplace_back("num_output_group");
|
keys.emplace_back("num_output_group");
|
||||||
|
|
||||||
std::sort(keys.begin(), keys.end());
|
std::sort(keys.begin(), keys.end());
|
||||||
|
|
||||||
std::vector<std::string> provided;
|
std::vector<std::string> provided;
|
||||||
for (auto const &kv : cfg_) {
|
for (auto const &kv : cfg_) {
|
||||||
// FIXME(trivialfis): Make eval_metric a training parameter.
|
if (std::any_of(kv.first.cbegin(), kv.first.cend(),
|
||||||
|
[](char ch) { return std::isspace(ch); })) {
|
||||||
|
LOG(FATAL) << "Invalid parameter \"" << kv.first << "\" contains whitespace.";
|
||||||
|
}
|
||||||
provided.push_back(kv.first);
|
provided.push_back(kv.first);
|
||||||
}
|
}
|
||||||
std::sort(provided.begin(), provided.end());
|
std::sort(provided.begin(), provided.end());
|
||||||
@ -557,9 +560,9 @@ class LearnerConfiguration : public Learner {
|
|||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
ss << "\nParameters: { ";
|
ss << "\nParameters: { ";
|
||||||
for (size_t i = 0; i < diff.size() - 1; ++i) {
|
for (size_t i = 0; i < diff.size() - 1; ++i) {
|
||||||
ss << diff[i] << ", ";
|
ss << "\"" << diff[i] << "\", ";
|
||||||
}
|
}
|
||||||
ss << diff.back();
|
ss << "\"" << diff.back() << "\"";
|
||||||
ss << R"W( } might not be used.
|
ss << R"W( } might not be used.
|
||||||
|
|
||||||
This may not be accurate due to some parameters are only used in language bindings but
|
This may not be accurate due to some parameters are only used in language bindings but
|
||||||
|
|||||||
@ -40,7 +40,7 @@ TEST(Learner, ParameterValidation) {
|
|||||||
|
|
||||||
auto learner = std::unique_ptr<Learner>(Learner::Create({p_mat}));
|
auto learner = std::unique_ptr<Learner>(Learner::Create({p_mat}));
|
||||||
learner->SetParam("validate_parameters", "1");
|
learner->SetParam("validate_parameters", "1");
|
||||||
learner->SetParam("Knock Knock", "Who's there?");
|
learner->SetParam("Knock-Knock", "Who's-there?");
|
||||||
learner->SetParam("Silence", "....");
|
learner->SetParam("Silence", "....");
|
||||||
learner->SetParam("tree_method", "exact");
|
learner->SetParam("tree_method", "exact");
|
||||||
|
|
||||||
@ -48,7 +48,11 @@ TEST(Learner, ParameterValidation) {
|
|||||||
learner->Configure();
|
learner->Configure();
|
||||||
std::string output = testing::internal::GetCapturedStderr();
|
std::string output = testing::internal::GetCapturedStderr();
|
||||||
|
|
||||||
ASSERT_TRUE(output.find("Parameters: { Knock Knock, Silence }") != std::string::npos);
|
ASSERT_TRUE(output.find(R"(Parameters: { "Knock-Knock", "Silence" })") != std::string::npos);
|
||||||
|
|
||||||
|
// whitespace
|
||||||
|
learner->SetParam("tree method", "exact");
|
||||||
|
EXPECT_THROW(learner->Configure(), dmlc::Error);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(Learner, CheckGroup) {
|
TEST(Learner, CheckGroup) {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user