diff --git a/src/tree/updater.cpp b/src/tree/updater.cpp index 0d0e66274..f6e669ffa 100644 --- a/src/tree/updater.cpp +++ b/src/tree/updater.cpp @@ -2,6 +2,7 @@ #define _CRT_SECURE_NO_DEPRECATE #include #include "./updater.h" +#include "./updater_sync-inl.hpp" #include "./updater_prune-inl.hpp" #include "./updater_refresh-inl.hpp" #include "./updater_colmaker-inl.hpp" @@ -13,6 +14,7 @@ namespace tree { IUpdater* CreateUpdater(const char *name) { using namespace std; if (!strcmp(name, "prune")) return new TreePruner(); + if (!strcmp(name, "sync")) return new TreeSyncher(); if (!strcmp(name, "refresh")) return new TreeRefresher(); if (!strcmp(name, "grow_colmaker")) return new ColMaker(); if (!strcmp(name, "grow_qhistmaker")) return new QuantileHistMaker(); diff --git a/src/tree/updater_colmaker-inl.hpp b/src/tree/updater_colmaker-inl.hpp index 6db19732e..6f8cb35d3 100644 --- a/src/tree/updater_colmaker-inl.hpp +++ b/src/tree/updater_colmaker-inl.hpp @@ -14,7 +14,7 @@ namespace xgboost { namespace tree { -/*! \brief pruner that prunes a tree after growing finishs */ +/*! \brief colunwise update to construct a tree */ template class ColMaker: public IUpdater { public: diff --git a/src/tree/updater_prune-inl.hpp b/src/tree/updater_prune-inl.hpp index a68404ba7..e7e5f9f0b 100644 --- a/src/tree/updater_prune-inl.hpp +++ b/src/tree/updater_prune-inl.hpp @@ -8,7 +8,7 @@ #include #include "./param.h" #include "./updater.h" -#include "../sync/sync.h" +#include "./updater_sync-inl.hpp" namespace xgboost { namespace tree { @@ -20,6 +20,7 @@ class TreePruner: public IUpdater { virtual void SetParam(const char *name, const char *val) { using namespace std; param.SetParam(name, val); + syncher.SetParam(name, val); if (!strcmp(name, "silent")) silent = atoi(val); } // update the tree, do pruning @@ -34,27 +35,9 @@ class TreePruner: public IUpdater { this->DoPrune(*trees[i]); } param.learning_rate = lr; - this->SyncTrees(trees); - } - private: - // synchronize the trees in different nodes, take tree from rank 0 - inline void SyncTrees(const std::vector &trees) { - if (sync::GetWorldSize() == 1) return; - std::string s_model; - utils::MemoryBufferStream fs(&s_model); - int rank = sync::GetRank(); - if (rank == 0) { - for (size_t i = 0; i < trees.size(); ++i) { - trees[i]->SaveModel(fs); - } - sync::Bcast(&s_model, 0); - } else { - sync::Bcast(&s_model, 0); - for (size_t i = 0; i < trees.size(); ++i) { - trees[i]->LoadModel(fs); - } - } + syncher.Update(gpair, p_fmat, info, trees); } + private: // try to prune off current leaf inline int TryPruneLeaf(RegTree &tree, int nid, int depth, int npruned) { if (tree[nid].is_root()) return npruned; @@ -89,6 +72,8 @@ class TreePruner: public IUpdater { } private: + // synchronizer + TreeSyncher syncher; // shutup int silent; // training parameter diff --git a/src/tree/updater_sync-inl.hpp b/src/tree/updater_sync-inl.hpp new file mode 100644 index 000000000..68a609616 --- /dev/null +++ b/src/tree/updater_sync-inl.hpp @@ -0,0 +1,54 @@ +#ifndef XGBOOST_TREE_UPDATER_SYNC_INL_HPP_ +#define XGBOOST_TREE_UPDATER_SYNC_INL_HPP_ +/*! + * \file updater_sync-inl.hpp + * \brief synchronize the tree in all distributed nodes + * \author Tianqi Chen + */ +#include +#include +#include "./updater.h" +#include "../sync/sync.h" + +namespace xgboost { +namespace tree { +/*! + * \brief syncher that synchronize the tree in all distributed nodes + * can implement various strategies, so far it is always set to node 0's tree + */ +class TreeSyncher: public IUpdater { + public: + virtual ~TreeSyncher(void) {} + virtual void SetParam(const char *name, const char *val) { + } + // update the tree, do pruning + virtual void Update(const std::vector &gpair, + IFMatrix *p_fmat, + const BoosterInfo &info, + const std::vector &trees) { + this->SyncTrees(trees); + } + + private: + // synchronize the trees in different nodes, take tree from rank 0 + inline void SyncTrees(const std::vector &trees) { + if (sync::GetWorldSize() == 1) return; + std::string s_model; + utils::MemoryBufferStream fs(&s_model); + int rank = sync::GetRank(); + if (rank == 0) { + for (size_t i = 0; i < trees.size(); ++i) { + trees[i]->SaveModel(fs); + } + sync::Bcast(&s_model, 0); + } else { + sync::Bcast(&s_model, 0); + for (size_t i = 0; i < trees.size(); ++i) { + trees[i]->LoadModel(fs); + } + } + } +}; +} // namespace tree +} // namespace xgboost +#endif // XGBOOST_TREE_UPDATER_SYNC_INL_HPP_