add in sync

This commit is contained in:
tqchen 2014-11-16 22:01:22 -08:00
parent 8ed585a7a2
commit d11445e0b1
4 changed files with 63 additions and 22 deletions

View File

@ -2,6 +2,7 @@
#define _CRT_SECURE_NO_DEPRECATE #define _CRT_SECURE_NO_DEPRECATE
#include <cstring> #include <cstring>
#include "./updater.h" #include "./updater.h"
#include "./updater_sync-inl.hpp"
#include "./updater_prune-inl.hpp" #include "./updater_prune-inl.hpp"
#include "./updater_refresh-inl.hpp" #include "./updater_refresh-inl.hpp"
#include "./updater_colmaker-inl.hpp" #include "./updater_colmaker-inl.hpp"
@ -13,6 +14,7 @@ namespace tree {
IUpdater* CreateUpdater(const char *name) { IUpdater* CreateUpdater(const char *name) {
using namespace std; using namespace std;
if (!strcmp(name, "prune")) return new TreePruner(); if (!strcmp(name, "prune")) return new TreePruner();
if (!strcmp(name, "sync")) return new TreeSyncher();
if (!strcmp(name, "refresh")) return new TreeRefresher<GradStats>(); if (!strcmp(name, "refresh")) return new TreeRefresher<GradStats>();
if (!strcmp(name, "grow_colmaker")) return new ColMaker<GradStats>(); if (!strcmp(name, "grow_colmaker")) return new ColMaker<GradStats>();
if (!strcmp(name, "grow_qhistmaker")) return new QuantileHistMaker<GradStats>(); if (!strcmp(name, "grow_qhistmaker")) return new QuantileHistMaker<GradStats>();

View File

@ -14,7 +14,7 @@
namespace xgboost { namespace xgboost {
namespace tree { namespace tree {
/*! \brief pruner that prunes a tree after growing finishs */ /*! \brief colunwise update to construct a tree */
template<typename TStats> template<typename TStats>
class ColMaker: public IUpdater { class ColMaker: public IUpdater {
public: public:

View File

@ -8,7 +8,7 @@
#include <vector> #include <vector>
#include "./param.h" #include "./param.h"
#include "./updater.h" #include "./updater.h"
#include "../sync/sync.h" #include "./updater_sync-inl.hpp"
namespace xgboost { namespace xgboost {
namespace tree { namespace tree {
@ -20,6 +20,7 @@ class TreePruner: public IUpdater {
virtual void SetParam(const char *name, const char *val) { virtual void SetParam(const char *name, const char *val) {
using namespace std; using namespace std;
param.SetParam(name, val); param.SetParam(name, val);
syncher.SetParam(name, val);
if (!strcmp(name, "silent")) silent = atoi(val); if (!strcmp(name, "silent")) silent = atoi(val);
} }
// update the tree, do pruning // update the tree, do pruning
@ -34,27 +35,9 @@ class TreePruner: public IUpdater {
this->DoPrune(*trees[i]); this->DoPrune(*trees[i]);
} }
param.learning_rate = lr; param.learning_rate = lr;
this->SyncTrees(trees); syncher.Update(gpair, p_fmat, info, trees);
} }
private: private:
// synchronize the trees in different nodes, take tree from rank 0
inline void SyncTrees(const std::vector<RegTree *> &trees) {
if (sync::GetWorldSize() == 1) return;
std::string s_model;
utils::MemoryBufferStream fs(&s_model);
int rank = sync::GetRank();
if (rank == 0) {
for (size_t i = 0; i < trees.size(); ++i) {
trees[i]->SaveModel(fs);
}
sync::Bcast(&s_model, 0);
} else {
sync::Bcast(&s_model, 0);
for (size_t i = 0; i < trees.size(); ++i) {
trees[i]->LoadModel(fs);
}
}
}
// try to prune off current leaf // try to prune off current leaf
inline int TryPruneLeaf(RegTree &tree, int nid, int depth, int npruned) { inline int TryPruneLeaf(RegTree &tree, int nid, int depth, int npruned) {
if (tree[nid].is_root()) return npruned; if (tree[nid].is_root()) return npruned;
@ -89,6 +72,8 @@ class TreePruner: public IUpdater {
} }
private: private:
// synchronizer
TreeSyncher syncher;
// shutup // shutup
int silent; int silent;
// training parameter // training parameter

View File

@ -0,0 +1,54 @@
#ifndef XGBOOST_TREE_UPDATER_SYNC_INL_HPP_
#define XGBOOST_TREE_UPDATER_SYNC_INL_HPP_
/*!
* \file updater_sync-inl.hpp
* \brief synchronize the tree in all distributed nodes
* \author Tianqi Chen
*/
#include <vector>
#include <limits>
#include "./updater.h"
#include "../sync/sync.h"
namespace xgboost {
namespace tree {
/*!
* \brief syncher that synchronize the tree in all distributed nodes
* can implement various strategies, so far it is always set to node 0's tree
*/
class TreeSyncher: public IUpdater {
public:
virtual ~TreeSyncher(void) {}
virtual void SetParam(const char *name, const char *val) {
}
// update the tree, do pruning
virtual void Update(const std::vector<bst_gpair> &gpair,
IFMatrix *p_fmat,
const BoosterInfo &info,
const std::vector<RegTree*> &trees) {
this->SyncTrees(trees);
}
private:
// synchronize the trees in different nodes, take tree from rank 0
inline void SyncTrees(const std::vector<RegTree *> &trees) {
if (sync::GetWorldSize() == 1) return;
std::string s_model;
utils::MemoryBufferStream fs(&s_model);
int rank = sync::GetRank();
if (rank == 0) {
for (size_t i = 0; i < trees.size(); ++i) {
trees[i]->SaveModel(fs);
}
sync::Bcast(&s_model, 0);
} else {
sync::Bcast(&s_model, 0);
for (size_t i = 0; i < trees.size(); ++i) {
trees[i]->LoadModel(fs);
}
}
}
};
} // namespace tree
} // namespace xgboost
#endif // XGBOOST_TREE_UPDATER_SYNC_INL_HPP_