JSON configuration IO. (#5111)

* Add saving/loading JSON configuration.
* Implement Python pickle interface with new IO routines.
* Basic tests for training continuation.
This commit is contained in:
Jiaming Yuan
2019-12-15 17:31:53 +08:00
committed by GitHub
parent 5aa007d7b2
commit 3136185bc5
24 changed files with 761 additions and 390 deletions

View File

@@ -99,6 +99,16 @@ class GBLinear : public GradientBooster {
model_.LoadModel(model);
}
void LoadConfig(Json const& in) override {
CHECK_EQ(get<String>(in["name"]), "gblinear");
fromJson(in["gblinear_train_param"], &param_);
}
void SaveConfig(Json* p_out) const override {
auto& out = *p_out;
out["name"] = String{"gblinear"};
out["gblinear_train_param"] = toJson(param_);
}
void DoBoost(DMatrix *p_fmat,
HostDeviceVector<GradientPair> *in_gpair,
ObjFunction* obj) override {

View File

@@ -112,7 +112,8 @@ class GBLinearModel : public Model {
<< " \"weight\": [" << std::endl;
for (unsigned i = 0; i < nfeature; ++i) {
for (int gid = 0; gid < ngroup; ++gid) {
if (i != 0 || gid != 0) fo << "," << std::endl;
if (i != 0 || gid != 0)
fo << "," << std::endl;
fo << " " << (*this)[i][gid];
}
}
@@ -134,5 +135,6 @@ class GBLinearModel : public Model {
return v;
}
};
} // namespace gbm
} // namespace xgboost

View File

@@ -34,6 +34,7 @@ DMLC_REGISTRY_FILE_TAG(gbtree);
void GBTree::Configure(const Args& cfg) {
this->cfg_ = cfg;
std::string updater_seq = tparam_.updater_seq;
tparam_.UpdateAllowUnknown(cfg);
model_.Configure(cfg);
@@ -75,24 +76,31 @@ void GBTree::Configure(const Args& cfg) {
"`tree_method` parameter instead.";
// Don't drive users to silent XGBOost.
showed_updater_warning_ = true;
} else {
this->ConfigureUpdaters();
LOG(DEBUG) << "Using updaters: " << tparam_.updater_seq;
}
for (auto& up : updaters_) {
up->Configure(cfg);
this->ConfigureUpdaters();
if (updater_seq != tparam_.updater_seq) {
updaters_.clear();
this->InitUpdater(cfg);
} else {
for (auto &up : updaters_) {
up->Configure(cfg);
}
}
configured_ = true;
}
// FIXME(trivialfis): This handles updaters and predictor. Because the choice of updaters
// depends on whether external memory is used and how large is dataset. We can remove the
// dependency on DMatrix once `hist` tree method can handle external memory so that we can
// make it default.
// FIXME(trivialfis): This handles updaters. Because the choice of updaters depends on
// whether external memory is used and how large is dataset. We can remove the dependency
// on DMatrix once `hist` tree method can handle external memory so that we can make it
// default.
void GBTree::ConfigureWithKnownData(Args const& cfg, DMatrix* fmat) {
CHECK(this->configured_);
std::string updater_seq = tparam_.updater_seq;
CHECK(tparam_.GetInitialised());
tparam_.UpdateAllowUnknown(cfg);
this->PerformTreeMethodHeuristic(fmat);
this->ConfigureUpdaters();
@@ -101,9 +109,8 @@ void GBTree::ConfigureWithKnownData(Args const& cfg, DMatrix* fmat) {
if (updater_seq != tparam_.updater_seq) {
LOG(DEBUG) << "Using updaters: " << tparam_.updater_seq;
this->updaters_.clear();
this->InitUpdater(cfg);
}
this->InitUpdater(cfg);
}
void GBTree::PerformTreeMethodHeuristic(DMatrix* fmat) {
@@ -141,6 +148,9 @@ void GBTree::PerformTreeMethodHeuristic(DMatrix* fmat) {
}
void GBTree::ConfigureUpdaters() {
if (specified_updater_) {
return;
}
// `updater` parameter was manually specified
/* Choose updaters according to tree_method parameters */
switch (tparam_.tree_method) {
@@ -289,6 +299,46 @@ void GBTree::CommitModel(std::vector<std::vector<std::unique_ptr<RegTree>>>&& ne
monitor_.Stop("CommitModel");
}
void GBTree::LoadConfig(Json const& in) {
CHECK_EQ(get<String>(in["name"]), "gbtree");
fromJson(in["gbtree_train_param"], &tparam_);
int32_t const n_gpus = xgboost::common::AllVisibleGPUs();
if (n_gpus == 0 && tparam_.predictor == PredictorType::kGPUPredictor) {
tparam_.UpdateAllowUnknown(Args{{"predictor", "auto"}});
}
if (n_gpus == 0 && tparam_.tree_method == TreeMethod::kGPUHist) {
tparam_.UpdateAllowUnknown(Args{{"tree_method", "hist"}});
LOG(WARNING)
<< "Loading from a raw memory buffer on CPU only machine. "
"Change tree_method to hist.";
}
auto const& j_updaters = get<Object const>(in["updater"]);
updaters_.clear();
for (auto const& kv : j_updaters) {
std::unique_ptr<TreeUpdater> up(TreeUpdater::Create(kv.first, generic_param_));
up->LoadConfig(kv.second);
updaters_.push_back(std::move(up));
}
specified_updater_ = get<Boolean>(in["specified_updater"]);
}
void GBTree::SaveConfig(Json* p_out) const {
auto& out = *p_out;
out["name"] = String("gbtree");
out["gbtree_train_param"] = toJson(tparam_);
out["updater"] = Object();
auto& j_updaters = out["updater"];
for (auto const& up : updaters_) {
j_updaters[up->Name()] = Object();
auto& j_up = j_updaters[up->Name()];
up->SaveConfig(&j_up);
}
out["specified_updater"] = Boolean{specified_updater_};
}
void GBTree::LoadModel(Json const& in) {
CHECK_EQ(get<String>(in["name"]), "gbtree");
model_.LoadModel(in["model"]);
@@ -324,7 +374,7 @@ class Dart : public GBTree {
for (size_t i = 0; i < weight_drop_.size(); ++i) {
j_weight_drop[i] = Number(weight_drop_[i]);
}
out["weight_drop"] = Array(j_weight_drop);
out["weight_drop"] = Array(std::move(j_weight_drop));
}
void LoadModel(Json const& in) override {
CHECK_EQ(get<String>(in["name"]), "dart");
@@ -352,6 +402,21 @@ class Dart : public GBTree {
}
}
void LoadConfig(Json const& in) override {
CHECK_EQ(get<String>(in["name"]), "dart");
auto const& gbtree = in["gbtree"];
GBTree::LoadConfig(gbtree);
fromJson(in["dart_train_param"], &dparam_);
}
void SaveConfig(Json* p_out) const override {
auto& out = *p_out;
out["name"] = String("dart");
out["gbtree"] = Object();
auto& gbtree = out["gbtree"];
GBTree::SaveConfig(&gbtree);
out["dart_train_param"] = toJson(dparam_);
}
// predict the leaf scores with dropout if ntree_limit = 0
void PredictBatch(DMatrix* p_fmat,
HostDeviceVector<bst_float>* out_preds,

View File

@@ -192,6 +192,9 @@ class GBTree : public GradientBooster {
model_.Save(fo);
}
void LoadConfig(Json const& in) override;
void SaveConfig(Json* p_out) const override;
void SaveModel(Json* p_out) const override;
void LoadModel(Json const& in) override;

View File

@@ -46,7 +46,8 @@ void GBTreeModel::SaveModel(Json* p_out) const {
for (auto const& tree : trees) {
Json tree_json{Object()};
tree->SaveModel(&tree_json);
tree_json["id"] = std::to_string(t);
// The field is not used in XGBoost, but might be useful for external project.
tree_json["id"] = Integer(t);
trees_json.emplace_back(tree_json);
t++;
}