Fix ignoring dart in updater configuration. (#4024)
* Fix ignoring dart in updater configuration.
This commit is contained in:
parent
9897b5042f
commit
be948df23f
@ -185,7 +185,8 @@ class LearnerImpl : public Learner {
|
|||||||
/*! \brief Map `tree_method` parameter to `updater` parameter */
|
/*! \brief Map `tree_method` parameter to `updater` parameter */
|
||||||
void ConfigureUpdaters() {
|
void ConfigureUpdaters() {
|
||||||
// This method is not applicable to non-tree learners
|
// This method is not applicable to non-tree learners
|
||||||
if (cfg_.count("booster") > 0 && cfg_.at("booster") != "gbtree") {
|
if (cfg_.find("booster") != cfg_.cend() &&
|
||||||
|
(cfg_.at("booster") != "gbtree" && cfg_.at("booster") != "dart")) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
// `updater` parameter was manually specified
|
// `updater` parameter was manually specified
|
||||||
|
|||||||
@ -33,6 +33,9 @@ TEST(learner, SelectTreeMethod) {
|
|||||||
learner->Configure({arg("tree_method", "hist")});
|
learner->Configure({arg("tree_method", "hist")});
|
||||||
ASSERT_EQ(learner->GetConfigurationArguments().at("updater"),
|
ASSERT_EQ(learner->GetConfigurationArguments().at("updater"),
|
||||||
"grow_quantile_histmaker");
|
"grow_quantile_histmaker");
|
||||||
|
learner->Configure({arg{"booster", "dart"}, arg{"tree_method", "hist"}});
|
||||||
|
ASSERT_EQ(learner->GetConfigurationArguments().at("updater"),
|
||||||
|
"grow_quantile_histmaker");
|
||||||
#ifdef XGBOOST_USE_CUDA
|
#ifdef XGBOOST_USE_CUDA
|
||||||
learner->Configure({arg("tree_method", "gpu_exact")});
|
learner->Configure({arg("tree_method", "gpu_exact")});
|
||||||
ASSERT_EQ(learner->GetConfigurationArguments().at("updater"),
|
ASSERT_EQ(learner->GetConfigurationArguments().at("updater"),
|
||||||
@ -40,6 +43,9 @@ TEST(learner, SelectTreeMethod) {
|
|||||||
learner->Configure({arg("tree_method", "gpu_hist")});
|
learner->Configure({arg("tree_method", "gpu_hist")});
|
||||||
ASSERT_EQ(learner->GetConfigurationArguments().at("updater"),
|
ASSERT_EQ(learner->GetConfigurationArguments().at("updater"),
|
||||||
"grow_gpu_hist");
|
"grow_gpu_hist");
|
||||||
|
learner->Configure({arg{"booster", "dart"}, arg{"tree_method", "gpu_hist"}});
|
||||||
|
ASSERT_EQ(learner->GetConfigurationArguments().at("updater"),
|
||||||
|
"grow_gpu_hist");
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
delete mat_ptr;
|
delete mat_ptr;
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user