Serialise booster after training to reset state (#5484)
* Serialise booster after training to reset state * Prevent process_type being set on load * Check for correct updater sequence
This commit is contained in:
@@ -267,6 +267,11 @@ void GBTree::BoostNewTrees(HostDeviceVector<GradientPair>* gpair,
|
||||
// create the trees
|
||||
for (int i = 0; i < tparam_.num_parallel_tree; ++i) {
|
||||
if (tparam_.process_type == TreeProcessType::kDefault) {
|
||||
CHECK(!updaters_.front()->CanModifyTree())
|
||||
<< "Updater: `" << updaters_.front()->Name() << "` "
|
||||
<< "can not be used to create new trees. "
|
||||
<< "Set `process_type` to `update` if you want to update existing "
|
||||
"trees.";
|
||||
// create new tree
|
||||
std::unique_ptr<RegTree> ptr(new RegTree());
|
||||
ptr->param.UpdateAllowUnknown(this->cfg_);
|
||||
@@ -319,6 +324,10 @@ void GBTree::CommitModel(std::vector<std::vector<std::unique_ptr<RegTree>>>&& ne
|
||||
void GBTree::LoadConfig(Json const& in) {
|
||||
CHECK_EQ(get<String>(in["name"]), "gbtree");
|
||||
FromJson(in["gbtree_train_param"], &tparam_);
|
||||
// Process type cannot be kUpdate from loaded model
|
||||
// This would cause all trees to be pushed to trees_to_update
|
||||
// e.g. updating a model, then saving and loading it would result in an empty model
|
||||
tparam_.process_type = TreeProcessType::kDefault;
|
||||
int32_t const n_gpus = xgboost::common::AllVisibleGPUs();
|
||||
if (n_gpus == 0 && tparam_.predictor == PredictorType::kGPUPredictor) {
|
||||
LOG(WARNING)
|
||||
@@ -348,6 +357,13 @@ void GBTree::SaveConfig(Json* p_out) const {
|
||||
auto& out = *p_out;
|
||||
out["name"] = String("gbtree");
|
||||
out["gbtree_train_param"] = ToJson(tparam_);
|
||||
|
||||
// Process type cannot be kUpdate from loaded model
|
||||
// This would cause all trees to be pushed to trees_to_update
|
||||
// e.g. updating a model, then saving and loading it would result in an empty
|
||||
// model
|
||||
out["gbtree_train_param"]["process_type"] = String("default");
|
||||
|
||||
out["updater"] = Object();
|
||||
|
||||
auto& j_updaters = out["updater"];
|
||||
|
||||
Reference in New Issue
Block a user