Pass pointer to model parameters. (#5101)
* Pass pointer to model parameters. This PR de-duplicates most of the model parameters except the one in `tree_model.h`. One difficulty is `base_score` is a model property but can be changed at runtime by objective function. Hence when performing model IO, we need to save the one provided by users, instead of the one transformed by objective. Here we created an immutable version of `LearnerModelParam` that represents the value of model parameter after configuration.
This commit is contained in:
163
src/learner.cc
163
src/learner.cc
@@ -16,17 +16,23 @@
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "xgboost/feature_map.h"
|
||||
#include "xgboost/learner.h"
|
||||
#include "xgboost/base.h"
|
||||
#include "xgboost/parameter.h"
|
||||
#include "xgboost/logging.h"
|
||||
#include "xgboost/feature_map.h"
|
||||
#include "xgboost/gbm.h"
|
||||
#include "xgboost/generic_parameters.h"
|
||||
#include "xgboost/host_device_vector.h"
|
||||
#include "xgboost/json.h"
|
||||
#include "xgboost/learner.h"
|
||||
#include "xgboost/logging.h"
|
||||
#include "xgboost/metric.h"
|
||||
#include "xgboost/objective.h"
|
||||
#include "xgboost/parameter.h"
|
||||
|
||||
#include "common/common.h"
|
||||
#include "common/io.h"
|
||||
#include "common/random.h"
|
||||
#include "common/timer.h"
|
||||
#include "common/version.h"
|
||||
|
||||
namespace {
|
||||
|
||||
@@ -69,8 +75,15 @@ bool Learner::AllowLazyCheckPoint() const {
|
||||
return gbm_->AllowLazyCheckPoint();
|
||||
}
|
||||
|
||||
/*! \brief training parameter for regression */
|
||||
struct LearnerModelParam : public dmlc::Parameter<LearnerModelParam> {
|
||||
Learner::~Learner() = default;
|
||||
|
||||
/*! \brief training parameter for regression
|
||||
*
|
||||
* Should be deprecated, but still used for being compatible with binary IO.
|
||||
* Once it's gone, `LearnerModelParam` should handle transforming `base_margin`
|
||||
* with objective by itself.
|
||||
*/
|
||||
struct LearnerModelParamLegacy : public dmlc::Parameter<LearnerModelParamLegacy> {
|
||||
/* \brief global bias */
|
||||
bst_float base_score;
|
||||
/* \brief number of features */
|
||||
@@ -84,12 +97,28 @@ struct LearnerModelParam : public dmlc::Parameter<LearnerModelParam> {
|
||||
/*! \brief reserved field */
|
||||
int reserved[29];
|
||||
/*! \brief constructor */
|
||||
LearnerModelParam() {
|
||||
std::memset(this, 0, sizeof(LearnerModelParam));
|
||||
LearnerModelParamLegacy() {
|
||||
std::memset(this, 0, sizeof(LearnerModelParamLegacy));
|
||||
base_score = 0.5f;
|
||||
}
|
||||
// Skip other legacy fields.
|
||||
Json ToJson() const {
|
||||
Object obj;
|
||||
obj["base_score"] = std::to_string(base_score);
|
||||
obj["num_feature"] = std::to_string(num_feature);
|
||||
obj["num_class"] = std::to_string(num_class);
|
||||
return Json(std::move(obj));
|
||||
}
|
||||
void FromJson(Json const& obj) {
|
||||
auto const& j_param = get<Object const>(obj);
|
||||
std::map<std::string, std::string> m;
|
||||
m["base_score"] = get<String const>(j_param.at("base_score"));
|
||||
m["num_feature"] = get<String const>(j_param.at("num_feature"));
|
||||
m["num_class"] = get<String const>(j_param.at("num_class"));
|
||||
this->Init(m);
|
||||
}
|
||||
// declare parameters
|
||||
DMLC_DECLARE_PARAMETER(LearnerModelParam) {
|
||||
DMLC_DECLARE_PARAMETER(LearnerModelParamLegacy) {
|
||||
DMLC_DECLARE_FIELD(base_score)
|
||||
.set_default(0.5f)
|
||||
.describe("Global bias of the model.");
|
||||
@@ -104,12 +133,20 @@ struct LearnerModelParam : public dmlc::Parameter<LearnerModelParam> {
|
||||
}
|
||||
};
|
||||
|
||||
LearnerModelParam::LearnerModelParam(
|
||||
LearnerModelParamLegacy const &user_param, float base_margin)
|
||||
: base_score{base_margin}, num_feature{user_param.num_feature},
|
||||
num_output_group{user_param.num_class == 0
|
||||
? 1
|
||||
: static_cast<uint32_t>(user_param.num_class)} {}
|
||||
|
||||
struct LearnerTrainParam : public XGBoostParameter<LearnerTrainParam> {
|
||||
// data split mode, can be row, col, or none.
|
||||
DataSplitMode dsplit;
|
||||
// flag to disable default metric
|
||||
int disable_default_eval_metric;
|
||||
|
||||
// FIXME(trivialfis): The following parameters belong to model itself, but can be
|
||||
// specified by users. Move them to model parameter once we can get rid of binary IO.
|
||||
std::string booster;
|
||||
std::string objective;
|
||||
|
||||
@@ -134,7 +171,7 @@ struct LearnerTrainParam : public XGBoostParameter<LearnerTrainParam> {
|
||||
};
|
||||
|
||||
|
||||
DMLC_REGISTER_PARAMETER(LearnerModelParam);
|
||||
DMLC_REGISTER_PARAMETER(LearnerModelParamLegacy);
|
||||
DMLC_REGISTER_PARAMETER(LearnerTrainParam);
|
||||
DMLC_REGISTER_PARAMETER(GenericParameter);
|
||||
|
||||
@@ -142,14 +179,7 @@ int constexpr GenericParameter::kCpuId;
|
||||
|
||||
void GenericParameter::ConfigureGpuId(bool require_gpu) {
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
int32_t n_visible = common::AllVisibleGPUs();
|
||||
if (n_visible == 0) {
|
||||
// Running XGBoost compiled with CUDA on CPU only machine.
|
||||
this->UpdateAllowUnknown(Args{{"gpu_id", std::to_string(kCpuId)}});
|
||||
return;
|
||||
}
|
||||
|
||||
if (this->gpu_id == kCpuId) { // 0. User didn't specify the `gpu_id'
|
||||
if (gpu_id == kCpuId) { // 0. User didn't specify the `gpu_id'
|
||||
if (require_gpu) { // 1. `tree_method' or `predictor' or both are using
|
||||
// GPU.
|
||||
// 2. Use device 0 as default.
|
||||
@@ -159,7 +189,10 @@ void GenericParameter::ConfigureGpuId(bool require_gpu) {
|
||||
|
||||
// 3. When booster is loaded from a memory image (Python pickle or R
|
||||
// raw model), number of available GPUs could be different. Wrap around it.
|
||||
if (this->gpu_id != kCpuId && this->gpu_id >= n_visible) {
|
||||
int32_t n_gpus = common::AllVisibleGPUs();
|
||||
if (n_gpus == 0) {
|
||||
this->UpdateAllowUnknown(Args{{"gpu_id", std::to_string(kCpuId)}});
|
||||
} else if (gpu_id != kCpuId && gpu_id >= n_gpus) {
|
||||
this->UpdateAllowUnknown(Args{{"gpu_id", std::to_string(gpu_id % n_gpus)}});
|
||||
}
|
||||
#else
|
||||
@@ -175,25 +208,25 @@ void GenericParameter::ConfigureGpuId(bool require_gpu) {
|
||||
class LearnerImpl : public Learner {
|
||||
public:
|
||||
explicit LearnerImpl(std::vector<std::shared_ptr<DMatrix> > cache)
|
||||
: configured_{false}, cache_(std::move(cache)) {
|
||||
: need_configuration_{true}, cache_(std::move(cache)) {
|
||||
monitor_.Init("Learner");
|
||||
}
|
||||
// Configuration before data is known.
|
||||
void Configure() override {
|
||||
if (configured_) { return; }
|
||||
if (!this->need_configuration_) { return; }
|
||||
|
||||
monitor_.Start("Configure");
|
||||
auto old_tparam = tparam_;
|
||||
Args args = {cfg_.cbegin(), cfg_.cend()};
|
||||
|
||||
tparam_.UpdateAllowUnknown(args);
|
||||
|
||||
generic_param_.UpdateAllowUnknown(args);
|
||||
generic_param_.CheckDeprecated();
|
||||
mparam_.UpdateAllowUnknown(args);
|
||||
generic_parameters_.UpdateAllowUnknown(args);
|
||||
generic_parameters_.CheckDeprecated();
|
||||
|
||||
ConsoleLogger::Configure(args);
|
||||
if (generic_param_.nthread != 0) {
|
||||
omp_set_num_threads(generic_param_.nthread);
|
||||
if (generic_parameters_.nthread != 0) {
|
||||
omp_set_num_threads(generic_parameters_.nthread);
|
||||
}
|
||||
|
||||
// add additional parameters
|
||||
@@ -202,9 +235,9 @@ class LearnerImpl : public Learner {
|
||||
tparam_.dsplit = DataSplitMode::kRow;
|
||||
}
|
||||
|
||||
mparam_.InitAllowUnknown(args);
|
||||
|
||||
// set seed only before the model is initialized
|
||||
common::GlobalRandom().seed(generic_param_.seed);
|
||||
common::GlobalRandom().seed(generic_parameters_.seed);
|
||||
// must precede configure gbm since num_features is required for gbm
|
||||
this->ConfigureNumFeatures();
|
||||
args = {cfg_.cbegin(), cfg_.cend()}; // renew
|
||||
@@ -212,9 +245,12 @@ class LearnerImpl : public Learner {
|
||||
this->ConfigureGBM(old_tparam, args);
|
||||
this->ConfigureMetrics(args);
|
||||
|
||||
generic_param_.ConfigureGpuId(this->gbm_->UseGPU());
|
||||
generic_parameters_.ConfigureGpuId(this->gbm_->UseGPU());
|
||||
|
||||
this->configured_ = true;
|
||||
learner_model_param_ = LearnerModelParam(mparam_,
|
||||
obj_->ProbToMargin(mparam_.base_score));
|
||||
|
||||
this->need_configuration_ = false;
|
||||
monitor_.Stop("Configure");
|
||||
}
|
||||
|
||||
@@ -241,7 +277,7 @@ class LearnerImpl : public Learner {
|
||||
}
|
||||
|
||||
void Load(dmlc::Stream* fi) override {
|
||||
generic_param_.UpdateAllowUnknown(Args{});
|
||||
generic_parameters_.UpdateAllowUnknown(Args{});
|
||||
tparam_.Init(std::vector<std::pair<std::string, std::string>>{});
|
||||
// TODO(tqchen) mark deprecation of old format.
|
||||
common::PeekableInStream fp(fi);
|
||||
@@ -279,9 +315,9 @@ class LearnerImpl : public Learner {
|
||||
}
|
||||
CHECK(fi->Read(&tparam_.booster)) << "BoostLearner: wrong model format";
|
||||
// duplicated code with LazyInitModel
|
||||
obj_.reset(ObjFunction::Create(tparam_.objective, &generic_param_));
|
||||
gbm_.reset(GradientBooster::Create(tparam_.booster, &generic_param_,
|
||||
cache_, mparam_.base_score));
|
||||
obj_.reset(ObjFunction::Create(tparam_.objective, &generic_parameters_));
|
||||
gbm_.reset(GradientBooster::Create(tparam_.booster, &generic_parameters_,
|
||||
&learner_model_param_, cache_));
|
||||
gbm_->Load(fi);
|
||||
if (mparam_.contain_extra_attrs != 0) {
|
||||
std::vector<std::pair<std::string, std::string> > attr;
|
||||
@@ -340,7 +376,7 @@ class LearnerImpl : public Learner {
|
||||
std::vector<std::string> metr;
|
||||
fi->Read(&metr);
|
||||
for (auto name : metr) {
|
||||
metrics_.emplace_back(Metric::Create(name, &generic_param_));
|
||||
metrics_.emplace_back(Metric::Create(name, &generic_parameters_));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -351,7 +387,7 @@ class LearnerImpl : public Learner {
|
||||
cfg_.insert(n.cbegin(), n.cend());
|
||||
|
||||
Args args = {cfg_.cbegin(), cfg_.cend()};
|
||||
generic_param_.UpdateAllowUnknown(args);
|
||||
generic_parameters_.UpdateAllowUnknown(args);
|
||||
gbm_->Configure(args);
|
||||
obj_->Configure({cfg_.begin(), cfg_.end()});
|
||||
|
||||
@@ -364,13 +400,14 @@ class LearnerImpl : public Learner {
|
||||
tparam_.dsplit = DataSplitMode::kRow;
|
||||
}
|
||||
|
||||
this->generic_param_.ConfigureGpuId(gbm_->UseGPU());
|
||||
this->configured_ = true;
|
||||
// There's no logic for state machine for binary IO, as it has a mix of everything and
|
||||
// half loaded model.
|
||||
this->Configure();
|
||||
}
|
||||
|
||||
// rabit save model to rabit checkpoint
|
||||
void Save(dmlc::Stream* fo) const override {
|
||||
if (!this->configured_) {
|
||||
if (this->need_configuration_) {
|
||||
// Save empty model. Calling Configure in a dummy LearnerImpl avoids violating
|
||||
// constness.
|
||||
LearnerImpl empty(std::move(this->cache_));
|
||||
@@ -383,7 +420,7 @@ class LearnerImpl : public Learner {
|
||||
return;
|
||||
}
|
||||
|
||||
LearnerModelParam mparam = mparam_; // make a copy to potentially modify
|
||||
LearnerModelParamLegacy mparam = mparam_; // make a copy to potentially modify
|
||||
std::vector<std::pair<std::string, std::string> > extra_attr;
|
||||
// extra attributed to be added just before saving
|
||||
if (tparam_.objective == "count:poisson") {
|
||||
@@ -419,11 +456,12 @@ class LearnerImpl : public Learner {
|
||||
return it.first == "SAVED_PARAM_gpu_id";
|
||||
})) {
|
||||
mparam.contain_extra_attrs = 1;
|
||||
extra_attr.emplace_back("SAVED_PARAM_gpu_id", std::to_string(generic_param_.gpu_id));
|
||||
extra_attr.emplace_back("SAVED_PARAM_gpu_id",
|
||||
std::to_string(generic_parameters_.gpu_id));
|
||||
}
|
||||
}
|
||||
#endif // defined(XGBOOST_USE_CUDA)
|
||||
fo->Write(&mparam, sizeof(LearnerModelParam));
|
||||
fo->Write(&mparam, sizeof(LearnerModelParamLegacy));
|
||||
fo->Write(tparam_.objective);
|
||||
fo->Write(tparam_.booster);
|
||||
gbm_->Save(fo);
|
||||
@@ -459,14 +497,16 @@ class LearnerImpl : public Learner {
|
||||
std::vector<std::string> DumpModel(const FeatureMap& fmap,
|
||||
bool with_stats,
|
||||
std::string format) const override {
|
||||
CHECK(!this->need_configuration_)
|
||||
<< "The model hasn't been built yet. Are you using raw Booster interface?";
|
||||
return gbm_->DumpModel(fmap, with_stats, format);
|
||||
}
|
||||
|
||||
void UpdateOneIter(int iter, DMatrix* train) override {
|
||||
monitor_.Start("UpdateOneIter");
|
||||
this->Configure();
|
||||
if (generic_param_.seed_per_iteration || rabit::IsDistributed()) {
|
||||
common::GlobalRandom().seed(generic_param_.seed * kRandSeedMagic + iter);
|
||||
if (generic_parameters_.seed_per_iteration || rabit::IsDistributed()) {
|
||||
common::GlobalRandom().seed(generic_parameters_.seed * kRandSeedMagic + iter);
|
||||
}
|
||||
this->CheckDataSplitMode();
|
||||
this->ValidateDMatrix(train);
|
||||
@@ -485,8 +525,8 @@ class LearnerImpl : public Learner {
|
||||
HostDeviceVector<GradientPair>* in_gpair) override {
|
||||
monitor_.Start("BoostOneIter");
|
||||
this->Configure();
|
||||
if (generic_param_.seed_per_iteration || rabit::IsDistributed()) {
|
||||
common::GlobalRandom().seed(generic_param_.seed * kRandSeedMagic + iter);
|
||||
if (generic_parameters_.seed_per_iteration || rabit::IsDistributed()) {
|
||||
common::GlobalRandom().seed(generic_parameters_.seed * kRandSeedMagic + iter);
|
||||
}
|
||||
this->CheckDataSplitMode();
|
||||
this->ValidateDMatrix(train);
|
||||
@@ -503,7 +543,7 @@ class LearnerImpl : public Learner {
|
||||
std::ostringstream os;
|
||||
os << '[' << iter << ']' << std::setiosflags(std::ios::fixed);
|
||||
if (metrics_.size() == 0 && tparam_.disable_default_eval_metric <= 0) {
|
||||
metrics_.emplace_back(Metric::Create(obj_->DefaultEvalMetric(), &generic_param_));
|
||||
metrics_.emplace_back(Metric::Create(obj_->DefaultEvalMetric(), &generic_parameters_));
|
||||
metrics_.back()->Configure({cfg_.begin(), cfg_.end()});
|
||||
}
|
||||
for (size_t i = 0; i < data_sets.size(); ++i) {
|
||||
@@ -523,7 +563,7 @@ class LearnerImpl : public Learner {
|
||||
}
|
||||
|
||||
void SetParam(const std::string& key, const std::string& value) override {
|
||||
configured_ = false;
|
||||
this->need_configuration_ = true;
|
||||
if (key == kEvalMetric) {
|
||||
if (std::find(metric_names_.cbegin(), metric_names_.cend(),
|
||||
value) == metric_names_.cend()) {
|
||||
@@ -535,7 +575,6 @@ class LearnerImpl : public Learner {
|
||||
}
|
||||
// Short hand for setting multiple parameters
|
||||
void SetParams(std::vector<std::pair<std::string, std::string>> const& args) override {
|
||||
configured_ = false;
|
||||
for (auto const& kv : args) {
|
||||
this->SetParam(kv.first, kv.second);
|
||||
}
|
||||
@@ -569,7 +608,7 @@ class LearnerImpl : public Learner {
|
||||
}
|
||||
|
||||
GenericParameter const& GetGenericParameter() const override {
|
||||
return generic_param_;
|
||||
return generic_parameters_;
|
||||
}
|
||||
|
||||
void Predict(DMatrix* data, bool output_margin,
|
||||
@@ -617,6 +656,7 @@ class LearnerImpl : public Learner {
|
||||
}
|
||||
|
||||
void ConfigureObjective(LearnerTrainParam const& old, Args* p_args) {
|
||||
// Once binary IO is gone, NONE of these config is useful.
|
||||
if (cfg_.find("num_class") != cfg_.cend() && cfg_.at("num_class") != "0") {
|
||||
cfg_["num_output_group"] = cfg_["num_class"];
|
||||
if (atoi(cfg_["num_class"].c_str()) > 1 && cfg_.count("objective") == 0) {
|
||||
@@ -627,13 +667,13 @@ class LearnerImpl : public Learner {
|
||||
if (cfg_.find("max_delta_step") == cfg_.cend() &&
|
||||
cfg_.find("objective") != cfg_.cend() &&
|
||||
tparam_.objective == "count:poisson") {
|
||||
// max_delta_step is a duplicated parameter in Poisson regression and tree param.
|
||||
// Rename one of them once binary IO is gone.
|
||||
cfg_["max_delta_step"] = kMaxDeltaStepDefaultValue;
|
||||
}
|
||||
if (obj_ == nullptr || tparam_.objective != old.objective) {
|
||||
obj_.reset(ObjFunction::Create(tparam_.objective, &generic_param_));
|
||||
obj_.reset(ObjFunction::Create(tparam_.objective, &generic_parameters_));
|
||||
}
|
||||
// reset the base score
|
||||
mparam_.base_score = obj_->ProbToMargin(mparam_.base_score);
|
||||
auto& args = *p_args;
|
||||
args = {cfg_.cbegin(), cfg_.cend()}; // renew
|
||||
obj_->Configure(args);
|
||||
@@ -645,7 +685,7 @@ class LearnerImpl : public Learner {
|
||||
return m->Name() != name;
|
||||
};
|
||||
if (std::all_of(metrics_.begin(), metrics_.end(), DupCheck)) {
|
||||
metrics_.emplace_back(std::unique_ptr<Metric>(Metric::Create(name, &generic_param_)));
|
||||
metrics_.emplace_back(std::unique_ptr<Metric>(Metric::Create(name, &generic_parameters_)));
|
||||
mparam_.contain_eval_metrics = 1;
|
||||
}
|
||||
}
|
||||
@@ -656,8 +696,8 @@ class LearnerImpl : public Learner {
|
||||
|
||||
void ConfigureGBM(LearnerTrainParam const& old, Args const& args) {
|
||||
if (gbm_ == nullptr || old.booster != tparam_.booster) {
|
||||
gbm_.reset(GradientBooster::Create(tparam_.booster, &generic_param_,
|
||||
cache_, mparam_.base_score));
|
||||
gbm_.reset(GradientBooster::Create(tparam_.booster, &generic_parameters_,
|
||||
&learner_model_param_, cache_));
|
||||
}
|
||||
gbm_->Configure(args);
|
||||
}
|
||||
@@ -682,7 +722,8 @@ class LearnerImpl : public Learner {
|
||||
}
|
||||
CHECK_NE(mparam_.num_feature, 0)
|
||||
<< "0 feature is supplied. Are you using raw Booster interface?";
|
||||
// setup
|
||||
learner_model_param_.num_feature = mparam_.num_feature;
|
||||
// Remove these once binary IO is gone.
|
||||
cfg_["num_feature"] = common::ToString(mparam_.num_feature);
|
||||
cfg_["num_class"] = common::ToString(mparam_.num_class);
|
||||
}
|
||||
@@ -701,7 +742,8 @@ class LearnerImpl : public Learner {
|
||||
}
|
||||
|
||||
// model parameter
|
||||
LearnerModelParam mparam_;
|
||||
LearnerModelParamLegacy mparam_;
|
||||
LearnerModelParam learner_model_param_;
|
||||
LearnerTrainParam tparam_;
|
||||
// configurations
|
||||
std::map<std::string, std::string> cfg_;
|
||||
@@ -713,8 +755,7 @@ class LearnerImpl : public Learner {
|
||||
std::map<DMatrix*, HostDeviceVector<bst_float>> preds_;
|
||||
// gradient pairs
|
||||
HostDeviceVector<GradientPair> gpair_;
|
||||
|
||||
bool configured_;
|
||||
bool need_configuration_;
|
||||
|
||||
private:
|
||||
/*! \brief random number transformation seed. */
|
||||
|
||||
Reference in New Issue
Block a user