diff --git a/src/tree/updater_distcol-inl.hpp b/src/tree/updater_distcol-inl.hpp index d94cdf409..bce947fe8 100644 --- a/src/tree/updater_distcol-inl.hpp +++ b/src/tree/updater_distcol-inl.hpp @@ -32,9 +32,8 @@ class DistColMaker : public ColMaker { utils::Check(trees.size() == 1, "DistColMaker: only support one tree at a time"); // build the tree builder.Update(gpair, p_fmat, info, trees[0]); - //// prune the tree + //// prune the tree, note that pruner will sync the tree pruner.Update(gpair, p_fmat, info, trees); - this->SyncTrees(trees[0]); // update position after the tree is pruned builder.UpdatePosition(p_fmat, *trees[0]); } @@ -42,18 +41,6 @@ class DistColMaker : public ColMaker { return builder.GetLeafPosition(); } private: - inline void SyncTrees(RegTree *tree) { - std::string s_model; - utils::MemoryBufferStream fs(&s_model); - int rank = sync::GetRank(); - if (rank == 0) { - tree->SaveModel(fs); - sync::Bcast(&s_model, 0); - } else { - sync::Bcast(&s_model, 0); - tree->LoadModel(fs); - } - } struct Builder : public ColMaker::Builder { public: Builder(const TrainParam ¶m) diff --git a/src/tree/updater_prune-inl.hpp b/src/tree/updater_prune-inl.hpp index 726999f55..a68404ba7 100644 --- a/src/tree/updater_prune-inl.hpp +++ b/src/tree/updater_prune-inl.hpp @@ -8,6 +8,7 @@ #include #include "./param.h" #include "./updater.h" +#include "../sync/sync.h" namespace xgboost { namespace tree { @@ -33,9 +34,27 @@ class TreePruner: public IUpdater { this->DoPrune(*trees[i]); } param.learning_rate = lr; - } - + this->SyncTrees(trees); + } private: + // synchronize the trees in different nodes, take tree from rank 0 + inline void SyncTrees(const std::vector &trees) { + if (sync::GetWorldSize() == 1) return; + std::string s_model; + utils::MemoryBufferStream fs(&s_model); + int rank = sync::GetRank(); + if (rank == 0) { + for (size_t i = 0; i < trees.size(); ++i) { + trees[i]->SaveModel(fs); + } + sync::Bcast(&s_model, 0); + } else { + sync::Bcast(&s_model, 0); + for (size_t i = 0; i < trees.size(); ++i) { + trees[i]->LoadModel(fs); + } + } + } // try to prune off current leaf inline int TryPruneLeaf(RegTree &tree, int nid, int depth, int npruned) { if (tree[nid].is_root()) return npruned;