Serialise booster after training to reset state (#5484)
* Serialise booster after training to reset state * Prevent process_type being set on load * Check for correct updater sequence
This commit is contained in:
@@ -20,9 +20,8 @@ num_round = 20
|
||||
bst = xgb.train(param, dtrain, num_round, watchlist, early_stopping_rounds=2)
|
||||
|
||||
# Save the model, only ask process 0 to save the model.
|
||||
if xgb.rabit.get_rank() == 0:
|
||||
bst.save_model("test.model")
|
||||
xgb.rabit.tracker_print("Finished training\n")
|
||||
bst.save_model("test.model{}".format(xgb.rabit.get_rank()))
|
||||
xgb.rabit.tracker_print("Finished training\n")
|
||||
|
||||
# Notify the tracker all training has been successful
|
||||
# This is only needed in distributed training.
|
||||
|
||||
@@ -70,9 +70,8 @@ watchlist = [(dtrain,'train')]
|
||||
num_round = 2
|
||||
bst = xgb.train(param, dtrain, num_round, watchlist)
|
||||
|
||||
if xgb.rabit.get_rank() == 0:
|
||||
bst.save_model("test_issue3402.model")
|
||||
xgb.rabit.tracker_print("Finished training\n")
|
||||
bst.save_model("test_issue3402.model{}".format(xgb.rabit.get_rank()))
|
||||
xgb.rabit.tracker_print("Finished training\n")
|
||||
|
||||
# Notify the tracker all training has been successful
|
||||
# This is only needed in distributed training.
|
||||
|
||||
Reference in New Issue
Block a user