De-duplicate GPU parameters. (#4454)
* Only define `gpu_id` and `n_gpus` in `LearnerTrainParam` * Pass LearnerTrainParam through XGBoost vid factory method. * Disable all GPU usage when GPU related parameters are not specified (fixes XGBoost choosing GPU over aggressively). * Test learner train param io. * Fix gpu pickling.
This commit is contained in:
@@ -62,7 +62,7 @@ class GBLinear : public GradientBooster {
|
||||
model_.param.InitAllowUnknown(cfg);
|
||||
}
|
||||
param_.InitAllowUnknown(cfg);
|
||||
updater_.reset(LinearUpdater::Create(param_.updater));
|
||||
updater_.reset(LinearUpdater::Create(param_.updater, learner_param_));
|
||||
updater_->Init(cfg);
|
||||
monitor_.Init("GBLinear");
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#pragma once
|
||||
#include <dmlc/io.h>
|
||||
#include <dmlc/parameter.h>
|
||||
#include <xgboost/base.h>
|
||||
#include <xgboost/feature_map.h>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
@@ -13,13 +13,16 @@ DMLC_REGISTRY_ENABLE(::xgboost::GradientBoosterReg);
|
||||
namespace xgboost {
|
||||
GradientBooster* GradientBooster::Create(
|
||||
const std::string& name,
|
||||
LearnerTrainParam const* learner_param,
|
||||
const std::vector<std::shared_ptr<DMatrix> >& cache_mats,
|
||||
bst_float base_margin) {
|
||||
auto *e = ::dmlc::Registry< ::xgboost::GradientBoosterReg>::Get()->Find(name);
|
||||
if (e == nullptr) {
|
||||
LOG(FATAL) << "Unknown gbm type " << name;
|
||||
}
|
||||
return (e->body)(cache_mats, base_margin);
|
||||
auto p_bst = (e->body)(cache_mats, base_margin);
|
||||
p_bst->learner_param_ = learner_param;
|
||||
return p_bst;
|
||||
}
|
||||
|
||||
} // namespace xgboost
|
||||
|
||||
@@ -147,7 +147,7 @@ class GBTree : public GradientBooster {
|
||||
}
|
||||
|
||||
// configure predictor
|
||||
predictor_ = std::unique_ptr<Predictor>(Predictor::Create(tparam_.predictor));
|
||||
predictor_ = std::unique_ptr<Predictor>(Predictor::Create(tparam_.predictor, learner_param_));
|
||||
predictor_->Init(cfg, cache_);
|
||||
monitor_.Init("GBTree");
|
||||
}
|
||||
@@ -252,7 +252,7 @@ class GBTree : public GradientBooster {
|
||||
std::string tval = tparam_.updater_seq;
|
||||
std::vector<std::string> ups = common::Split(tval, ',');
|
||||
for (const std::string& pstr : ups) {
|
||||
std::unique_ptr<TreeUpdater> up(TreeUpdater::Create(pstr.c_str()));
|
||||
std::unique_ptr<TreeUpdater> up(TreeUpdater::Create(pstr.c_str(), learner_param_));
|
||||
up->Init(this->cfg_);
|
||||
updaters_.push_back(std::move(up));
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user