Fix ignoring dart in updater configuration. (#4024)

* Fix ignoring dart in updater configuration.
This commit is contained in:
Jiaming Yuan 2018-12-26 18:24:45 +08:00 committed by GitHub
parent 9897b5042f
commit be948df23f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 8 additions and 1 deletions

View File

@ -185,7 +185,8 @@ class LearnerImpl : public Learner {
/*! \brief Map `tree_method` parameter to `updater` parameter */ /*! \brief Map `tree_method` parameter to `updater` parameter */
void ConfigureUpdaters() { void ConfigureUpdaters() {
// This method is not applicable to non-tree learners // This method is not applicable to non-tree learners
if (cfg_.count("booster") > 0 && cfg_.at("booster") != "gbtree") { if (cfg_.find("booster") != cfg_.cend() &&
(cfg_.at("booster") != "gbtree" && cfg_.at("booster") != "dart")) {
return; return;
} }
// `updater` parameter was manually specified // `updater` parameter was manually specified

View File

@ -33,6 +33,9 @@ TEST(learner, SelectTreeMethod) {
learner->Configure({arg("tree_method", "hist")}); learner->Configure({arg("tree_method", "hist")});
ASSERT_EQ(learner->GetConfigurationArguments().at("updater"), ASSERT_EQ(learner->GetConfigurationArguments().at("updater"),
"grow_quantile_histmaker"); "grow_quantile_histmaker");
learner->Configure({arg{"booster", "dart"}, arg{"tree_method", "hist"}});
ASSERT_EQ(learner->GetConfigurationArguments().at("updater"),
"grow_quantile_histmaker");
#ifdef XGBOOST_USE_CUDA #ifdef XGBOOST_USE_CUDA
learner->Configure({arg("tree_method", "gpu_exact")}); learner->Configure({arg("tree_method", "gpu_exact")});
ASSERT_EQ(learner->GetConfigurationArguments().at("updater"), ASSERT_EQ(learner->GetConfigurationArguments().at("updater"),
@ -40,6 +43,9 @@ TEST(learner, SelectTreeMethod) {
learner->Configure({arg("tree_method", "gpu_hist")}); learner->Configure({arg("tree_method", "gpu_hist")});
ASSERT_EQ(learner->GetConfigurationArguments().at("updater"), ASSERT_EQ(learner->GetConfigurationArguments().at("updater"),
"grow_gpu_hist"); "grow_gpu_hist");
learner->Configure({arg{"booster", "dart"}, arg{"tree_method", "gpu_hist"}});
ASSERT_EQ(learner->GetConfigurationArguments().at("updater"),
"grow_gpu_hist");
#endif #endif
delete mat_ptr; delete mat_ptr;