diff --git a/src/learner.cc b/src/learner.cc index 8e075281a..468beee30 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -185,7 +185,8 @@ class LearnerImpl : public Learner { /*! \brief Map `tree_method` parameter to `updater` parameter */ void ConfigureUpdaters() { // This method is not applicable to non-tree learners - if (cfg_.count("booster") > 0 && cfg_.at("booster") != "gbtree") { + if (cfg_.find("booster") != cfg_.cend() && + (cfg_.at("booster") != "gbtree" && cfg_.at("booster") != "dart")) { return; } // `updater` parameter was manually specified diff --git a/tests/cpp/test_learner.cc b/tests/cpp/test_learner.cc index 6ea92b7c1..46e68b9ec 100644 --- a/tests/cpp/test_learner.cc +++ b/tests/cpp/test_learner.cc @@ -33,6 +33,9 @@ TEST(learner, SelectTreeMethod) { learner->Configure({arg("tree_method", "hist")}); ASSERT_EQ(learner->GetConfigurationArguments().at("updater"), "grow_quantile_histmaker"); + learner->Configure({arg{"booster", "dart"}, arg{"tree_method", "hist"}}); + ASSERT_EQ(learner->GetConfigurationArguments().at("updater"), + "grow_quantile_histmaker"); #ifdef XGBOOST_USE_CUDA learner->Configure({arg("tree_method", "gpu_exact")}); ASSERT_EQ(learner->GetConfigurationArguments().at("updater"), @@ -40,6 +43,9 @@ TEST(learner, SelectTreeMethod) { learner->Configure({arg("tree_method", "gpu_hist")}); ASSERT_EQ(learner->GetConfigurationArguments().at("updater"), "grow_gpu_hist"); + learner->Configure({arg{"booster", "dart"}, arg{"tree_method", "gpu_hist"}}); + ASSERT_EQ(learner->GetConfigurationArguments().at("updater"), + "grow_gpu_hist"); #endif delete mat_ptr;