diff --git a/python-package/xgboost/training.py b/python-package/xgboost/training.py index 759f30221..74223a166 100644 --- a/python-package/xgboost/training.py +++ b/python-package/xgboost/training.py @@ -109,7 +109,9 @@ def _train_internal(params, dtrain, else: bst.best_iteration = nboost - 1 bst.best_ntree_limit = (bst.best_iteration + 1) * num_parallel_tree - return bst + + # Copy to serialise and unserialise booster to reset state and free training memory + return bst.copy() def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None, diff --git a/src/gbm/gbtree.cc b/src/gbm/gbtree.cc index 18c25dc8f..7dcf3cd9d 100644 --- a/src/gbm/gbtree.cc +++ b/src/gbm/gbtree.cc @@ -267,6 +267,11 @@ void GBTree::BoostNewTrees(HostDeviceVector* gpair, // create the trees for (int i = 0; i < tparam_.num_parallel_tree; ++i) { if (tparam_.process_type == TreeProcessType::kDefault) { + CHECK(!updaters_.front()->CanModifyTree()) + << "Updater: `" << updaters_.front()->Name() << "` " + << "can not be used to create new trees. " + << "Set `process_type` to `update` if you want to update existing " + "trees."; // create new tree std::unique_ptr ptr(new RegTree()); ptr->param.UpdateAllowUnknown(this->cfg_); @@ -319,6 +324,10 @@ void GBTree::CommitModel(std::vector>>&& ne void GBTree::LoadConfig(Json const& in) { CHECK_EQ(get(in["name"]), "gbtree"); FromJson(in["gbtree_train_param"], &tparam_); + // Process type cannot be kUpdate from loaded model + // This would cause all trees to be pushed to trees_to_update + // e.g. updating a model, then saving and loading it would result in an empty model + tparam_.process_type = TreeProcessType::kDefault; int32_t const n_gpus = xgboost::common::AllVisibleGPUs(); if (n_gpus == 0 && tparam_.predictor == PredictorType::kGPUPredictor) { LOG(WARNING) @@ -348,6 +357,13 @@ void GBTree::SaveConfig(Json* p_out) const { auto& out = *p_out; out["name"] = String("gbtree"); out["gbtree_train_param"] = ToJson(tparam_); + + // Process type cannot be kUpdate from loaded model + // This would cause all trees to be pushed to trees_to_update + // e.g. updating a model, then saving and loading it would result in an empty + // model + out["gbtree_train_param"]["process_type"] = String("default"); + out["updater"] = Object(); auto& j_updaters = out["updater"]; diff --git a/tests/cpp/gbm/test_gbtree.cc b/tests/cpp/gbm/test_gbtree.cc index 44debdd9d..463253aea 100644 --- a/tests/cpp/gbm/test_gbtree.cc +++ b/tests/cpp/gbm/test_gbtree.cc @@ -63,6 +63,10 @@ TEST(GBTree, WrongUpdater) { // Hist can not be used for updating tree. learner->SetParams(Args{{"tree_method", "hist"}, {"process_type", "update"}}); ASSERT_THROW(learner->UpdateOneIter(0, p_dmat), dmlc::Error); + // Prune can not be used for learning new tree. + learner->SetParams( + Args{{"tree_method", "prune"}, {"process_type", "default"}}); + ASSERT_THROW(learner->UpdateOneIter(0, p_dmat), dmlc::Error); } #ifdef XGBOOST_USE_CUDA diff --git a/tests/distributed/test_basic.py b/tests/distributed/test_basic.py index 0fc900834..64a06975e 100644 --- a/tests/distributed/test_basic.py +++ b/tests/distributed/test_basic.py @@ -20,9 +20,8 @@ num_round = 20 bst = xgb.train(param, dtrain, num_round, watchlist, early_stopping_rounds=2) # Save the model, only ask process 0 to save the model. -if xgb.rabit.get_rank() == 0: - bst.save_model("test.model") - xgb.rabit.tracker_print("Finished training\n") +bst.save_model("test.model{}".format(xgb.rabit.get_rank())) +xgb.rabit.tracker_print("Finished training\n") # Notify the tracker all training has been successful # This is only needed in distributed training. diff --git a/tests/distributed/test_issue3402.py b/tests/distributed/test_issue3402.py index e6c498331..67aea7a85 100644 --- a/tests/distributed/test_issue3402.py +++ b/tests/distributed/test_issue3402.py @@ -70,9 +70,8 @@ watchlist = [(dtrain,'train')] num_round = 2 bst = xgb.train(param, dtrain, num_round, watchlist) -if xgb.rabit.get_rank() == 0: - bst.save_model("test_issue3402.model") - xgb.rabit.tracker_print("Finished training\n") +bst.save_model("test_issue3402.model{}".format(xgb.rabit.get_rank())) +xgb.rabit.tracker_print("Finished training\n") # Notify the tracker all training has been successful # This is only needed in distributed training.