Serialise booster after training to reset state (#5484)

* Serialise booster after training to reset state

* Prevent process_type being set on load

* Check for correct updater sequence
This commit is contained in:
Rory Mitchell
2020-04-11 16:27:12 +12:00
committed by GitHub
parent 4a0c8ef237
commit 093e2227e3
5 changed files with 27 additions and 7 deletions

View File

@@ -63,6 +63,10 @@ TEST(GBTree, WrongUpdater) {
// Hist can not be used for updating tree.
learner->SetParams(Args{{"tree_method", "hist"}, {"process_type", "update"}});
ASSERT_THROW(learner->UpdateOneIter(0, p_dmat), dmlc::Error);
// Prune can not be used for learning new tree.
learner->SetParams(
Args{{"tree_method", "prune"}, {"process_type", "default"}});
ASSERT_THROW(learner->UpdateOneIter(0, p_dmat), dmlc::Error);
}
#ifdef XGBOOST_USE_CUDA

View File

@@ -20,9 +20,8 @@ num_round = 20
bst = xgb.train(param, dtrain, num_round, watchlist, early_stopping_rounds=2)
# Save the model, only ask process 0 to save the model.
if xgb.rabit.get_rank() == 0:
bst.save_model("test.model")
xgb.rabit.tracker_print("Finished training\n")
bst.save_model("test.model{}".format(xgb.rabit.get_rank()))
xgb.rabit.tracker_print("Finished training\n")
# Notify the tracker all training has been successful
# This is only needed in distributed training.

View File

@@ -70,9 +70,8 @@ watchlist = [(dtrain,'train')]
num_round = 2
bst = xgb.train(param, dtrain, num_round, watchlist)
if xgb.rabit.get_rank() == 0:
bst.save_model("test_issue3402.model")
xgb.rabit.tracker_print("Finished training\n")
bst.save_model("test_issue3402.model{}".format(xgb.rabit.get_rank()))
xgb.rabit.tracker_print("Finished training\n")
# Notify the tracker all training has been successful
# This is only needed in distributed training.