Serialise booster after training to reset state (#5484)
* Serialise booster after training to reset state * Prevent process_type being set on load * Check for correct updater sequence
This commit is contained in:
parent
4a0c8ef237
commit
093e2227e3
@ -109,7 +109,9 @@ def _train_internal(params, dtrain,
|
|||||||
else:
|
else:
|
||||||
bst.best_iteration = nboost - 1
|
bst.best_iteration = nboost - 1
|
||||||
bst.best_ntree_limit = (bst.best_iteration + 1) * num_parallel_tree
|
bst.best_ntree_limit = (bst.best_iteration + 1) * num_parallel_tree
|
||||||
return bst
|
|
||||||
|
# Copy to serialise and unserialise booster to reset state and free training memory
|
||||||
|
return bst.copy()
|
||||||
|
|
||||||
|
|
||||||
def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
|
def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
|
||||||
|
|||||||
@ -267,6 +267,11 @@ void GBTree::BoostNewTrees(HostDeviceVector<GradientPair>* gpair,
|
|||||||
// create the trees
|
// create the trees
|
||||||
for (int i = 0; i < tparam_.num_parallel_tree; ++i) {
|
for (int i = 0; i < tparam_.num_parallel_tree; ++i) {
|
||||||
if (tparam_.process_type == TreeProcessType::kDefault) {
|
if (tparam_.process_type == TreeProcessType::kDefault) {
|
||||||
|
CHECK(!updaters_.front()->CanModifyTree())
|
||||||
|
<< "Updater: `" << updaters_.front()->Name() << "` "
|
||||||
|
<< "can not be used to create new trees. "
|
||||||
|
<< "Set `process_type` to `update` if you want to update existing "
|
||||||
|
"trees.";
|
||||||
// create new tree
|
// create new tree
|
||||||
std::unique_ptr<RegTree> ptr(new RegTree());
|
std::unique_ptr<RegTree> ptr(new RegTree());
|
||||||
ptr->param.UpdateAllowUnknown(this->cfg_);
|
ptr->param.UpdateAllowUnknown(this->cfg_);
|
||||||
@ -319,6 +324,10 @@ void GBTree::CommitModel(std::vector<std::vector<std::unique_ptr<RegTree>>>&& ne
|
|||||||
void GBTree::LoadConfig(Json const& in) {
|
void GBTree::LoadConfig(Json const& in) {
|
||||||
CHECK_EQ(get<String>(in["name"]), "gbtree");
|
CHECK_EQ(get<String>(in["name"]), "gbtree");
|
||||||
FromJson(in["gbtree_train_param"], &tparam_);
|
FromJson(in["gbtree_train_param"], &tparam_);
|
||||||
|
// Process type cannot be kUpdate from loaded model
|
||||||
|
// This would cause all trees to be pushed to trees_to_update
|
||||||
|
// e.g. updating a model, then saving and loading it would result in an empty model
|
||||||
|
tparam_.process_type = TreeProcessType::kDefault;
|
||||||
int32_t const n_gpus = xgboost::common::AllVisibleGPUs();
|
int32_t const n_gpus = xgboost::common::AllVisibleGPUs();
|
||||||
if (n_gpus == 0 && tparam_.predictor == PredictorType::kGPUPredictor) {
|
if (n_gpus == 0 && tparam_.predictor == PredictorType::kGPUPredictor) {
|
||||||
LOG(WARNING)
|
LOG(WARNING)
|
||||||
@ -348,6 +357,13 @@ void GBTree::SaveConfig(Json* p_out) const {
|
|||||||
auto& out = *p_out;
|
auto& out = *p_out;
|
||||||
out["name"] = String("gbtree");
|
out["name"] = String("gbtree");
|
||||||
out["gbtree_train_param"] = ToJson(tparam_);
|
out["gbtree_train_param"] = ToJson(tparam_);
|
||||||
|
|
||||||
|
// Process type cannot be kUpdate from loaded model
|
||||||
|
// This would cause all trees to be pushed to trees_to_update
|
||||||
|
// e.g. updating a model, then saving and loading it would result in an empty
|
||||||
|
// model
|
||||||
|
out["gbtree_train_param"]["process_type"] = String("default");
|
||||||
|
|
||||||
out["updater"] = Object();
|
out["updater"] = Object();
|
||||||
|
|
||||||
auto& j_updaters = out["updater"];
|
auto& j_updaters = out["updater"];
|
||||||
|
|||||||
@ -63,6 +63,10 @@ TEST(GBTree, WrongUpdater) {
|
|||||||
// Hist can not be used for updating tree.
|
// Hist can not be used for updating tree.
|
||||||
learner->SetParams(Args{{"tree_method", "hist"}, {"process_type", "update"}});
|
learner->SetParams(Args{{"tree_method", "hist"}, {"process_type", "update"}});
|
||||||
ASSERT_THROW(learner->UpdateOneIter(0, p_dmat), dmlc::Error);
|
ASSERT_THROW(learner->UpdateOneIter(0, p_dmat), dmlc::Error);
|
||||||
|
// Prune can not be used for learning new tree.
|
||||||
|
learner->SetParams(
|
||||||
|
Args{{"tree_method", "prune"}, {"process_type", "default"}});
|
||||||
|
ASSERT_THROW(learner->UpdateOneIter(0, p_dmat), dmlc::Error);
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef XGBOOST_USE_CUDA
|
#ifdef XGBOOST_USE_CUDA
|
||||||
|
|||||||
@ -20,8 +20,7 @@ num_round = 20
|
|||||||
bst = xgb.train(param, dtrain, num_round, watchlist, early_stopping_rounds=2)
|
bst = xgb.train(param, dtrain, num_round, watchlist, early_stopping_rounds=2)
|
||||||
|
|
||||||
# Save the model, only ask process 0 to save the model.
|
# Save the model, only ask process 0 to save the model.
|
||||||
if xgb.rabit.get_rank() == 0:
|
bst.save_model("test.model{}".format(xgb.rabit.get_rank()))
|
||||||
bst.save_model("test.model")
|
|
||||||
xgb.rabit.tracker_print("Finished training\n")
|
xgb.rabit.tracker_print("Finished training\n")
|
||||||
|
|
||||||
# Notify the tracker all training has been successful
|
# Notify the tracker all training has been successful
|
||||||
|
|||||||
@ -70,8 +70,7 @@ watchlist = [(dtrain,'train')]
|
|||||||
num_round = 2
|
num_round = 2
|
||||||
bst = xgb.train(param, dtrain, num_round, watchlist)
|
bst = xgb.train(param, dtrain, num_round, watchlist)
|
||||||
|
|
||||||
if xgb.rabit.get_rank() == 0:
|
bst.save_model("test_issue3402.model{}".format(xgb.rabit.get_rank()))
|
||||||
bst.save_model("test_issue3402.model")
|
|
||||||
xgb.rabit.tracker_print("Finished training\n")
|
xgb.rabit.tracker_print("Finished training\n")
|
||||||
|
|
||||||
# Notify the tracker all training has been successful
|
# Notify the tracker all training has been successful
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user