move sync tree to pruner, pruner is now distributed

This commit is contained in:
tqchen 2014-10-17 14:53:43 -07:00
parent a68ac8033e
commit c2fa390181
2 changed files with 22 additions and 16 deletions

View File

@ -32,9 +32,8 @@ class DistColMaker : public ColMaker<TStats> {
utils::Check(trees.size() == 1, "DistColMaker: only support one tree at a time");
// build the tree
builder.Update(gpair, p_fmat, info, trees[0]);
//// prune the tree
//// prune the tree, note that pruner will sync the tree
pruner.Update(gpair, p_fmat, info, trees);
this->SyncTrees(trees[0]);
// update position after the tree is pruned
builder.UpdatePosition(p_fmat, *trees[0]);
}
@ -42,18 +41,6 @@ class DistColMaker : public ColMaker<TStats> {
return builder.GetLeafPosition();
}
private:
inline void SyncTrees(RegTree *tree) {
std::string s_model;
utils::MemoryBufferStream fs(&s_model);
int rank = sync::GetRank();
if (rank == 0) {
tree->SaveModel(fs);
sync::Bcast(&s_model, 0);
} else {
sync::Bcast(&s_model, 0);
tree->LoadModel(fs);
}
}
struct Builder : public ColMaker<TStats>::Builder {
public:
Builder(const TrainParam &param)

View File

@ -8,6 +8,7 @@
#include <vector>
#include "./param.h"
#include "./updater.h"
#include "../sync/sync.h"
namespace xgboost {
namespace tree {
@ -33,9 +34,27 @@ class TreePruner: public IUpdater {
this->DoPrune(*trees[i]);
}
param.learning_rate = lr;
}
this->SyncTrees(trees);
}
private:
// synchronize the trees in different nodes, take tree from rank 0
inline void SyncTrees(const std::vector<RegTree *> &trees) {
if (sync::GetWorldSize() == 1) return;
std::string s_model;
utils::MemoryBufferStream fs(&s_model);
int rank = sync::GetRank();
if (rank == 0) {
for (size_t i = 0; i < trees.size(); ++i) {
trees[i]->SaveModel(fs);
}
sync::Bcast(&s_model, 0);
} else {
sync::Bcast(&s_model, 0);
for (size_t i = 0; i < trees.size(); ++i) {
trees[i]->LoadModel(fs);
}
}
}
// try to prune off current leaf
inline int TryPruneLeaf(RegTree &tree, int nid, int depth, int npruned) {
if (tree[nid].is_root()) return npruned;