De-duplicate GPU parameters. (#4454)

* Only define `gpu_id` and `n_gpus` in `LearnerTrainParam`
* Pass LearnerTrainParam through XGBoost vid factory method.
* Disable all GPU usage when GPU related parameters are not specified (fixes XGBoost choosing GPU over aggressively).
* Test learner train param io.
* Fix gpu pickling.
This commit is contained in:
Jiaming Yuan
2019-05-29 11:55:57 +08:00
committed by GitHub
parent a3fedbeaa8
commit c589eff941
69 changed files with 927 additions and 562 deletions

View File

@@ -62,7 +62,7 @@ class GBLinear : public GradientBooster {
model_.param.InitAllowUnknown(cfg);
}
param_.InitAllowUnknown(cfg);
updater_.reset(LinearUpdater::Create(param_.updater));
updater_.reset(LinearUpdater::Create(param_.updater, learner_param_));
updater_->Init(cfg);
monitor_.Init("GBLinear");
}

View File

@@ -4,6 +4,7 @@
#pragma once
#include <dmlc/io.h>
#include <dmlc/parameter.h>
#include <xgboost/base.h>
#include <xgboost/feature_map.h>
#include <vector>
#include <string>

View File

@@ -13,13 +13,16 @@ DMLC_REGISTRY_ENABLE(::xgboost::GradientBoosterReg);
namespace xgboost {
GradientBooster* GradientBooster::Create(
const std::string& name,
LearnerTrainParam const* learner_param,
const std::vector<std::shared_ptr<DMatrix> >& cache_mats,
bst_float base_margin) {
auto *e = ::dmlc::Registry< ::xgboost::GradientBoosterReg>::Get()->Find(name);
if (e == nullptr) {
LOG(FATAL) << "Unknown gbm type " << name;
}
return (e->body)(cache_mats, base_margin);
auto p_bst = (e->body)(cache_mats, base_margin);
p_bst->learner_param_ = learner_param;
return p_bst;
}
} // namespace xgboost

View File

@@ -147,7 +147,7 @@ class GBTree : public GradientBooster {
}
// configure predictor
predictor_ = std::unique_ptr<Predictor>(Predictor::Create(tparam_.predictor));
predictor_ = std::unique_ptr<Predictor>(Predictor::Create(tparam_.predictor, learner_param_));
predictor_->Init(cfg, cache_);
monitor_.Init("GBTree");
}
@@ -252,7 +252,7 @@ class GBTree : public GradientBooster {
std::string tval = tparam_.updater_seq;
std::vector<std::string> ups = common::Split(tval, ',');
for (const std::string& pstr : ups) {
std::unique_ptr<TreeUpdater> up(TreeUpdater::Create(pstr.c_str()));
std::unique_ptr<TreeUpdater> up(TreeUpdater::Create(pstr.c_str(), learner_param_));
up->Init(this->cfg_);
updaters_.push_back(std::move(up));
}