move sync tree to pruner, pruner is now distributed
This commit is contained in:
parent
a68ac8033e
commit
c2fa390181
@ -32,9 +32,8 @@ class DistColMaker : public ColMaker<TStats> {
|
||||
utils::Check(trees.size() == 1, "DistColMaker: only support one tree at a time");
|
||||
// build the tree
|
||||
builder.Update(gpair, p_fmat, info, trees[0]);
|
||||
//// prune the tree
|
||||
//// prune the tree, note that pruner will sync the tree
|
||||
pruner.Update(gpair, p_fmat, info, trees);
|
||||
this->SyncTrees(trees[0]);
|
||||
// update position after the tree is pruned
|
||||
builder.UpdatePosition(p_fmat, *trees[0]);
|
||||
}
|
||||
@ -42,18 +41,6 @@ class DistColMaker : public ColMaker<TStats> {
|
||||
return builder.GetLeafPosition();
|
||||
}
|
||||
private:
|
||||
inline void SyncTrees(RegTree *tree) {
|
||||
std::string s_model;
|
||||
utils::MemoryBufferStream fs(&s_model);
|
||||
int rank = sync::GetRank();
|
||||
if (rank == 0) {
|
||||
tree->SaveModel(fs);
|
||||
sync::Bcast(&s_model, 0);
|
||||
} else {
|
||||
sync::Bcast(&s_model, 0);
|
||||
tree->LoadModel(fs);
|
||||
}
|
||||
}
|
||||
struct Builder : public ColMaker<TStats>::Builder {
|
||||
public:
|
||||
Builder(const TrainParam ¶m)
|
||||
|
||||
@ -8,6 +8,7 @@
|
||||
#include <vector>
|
||||
#include "./param.h"
|
||||
#include "./updater.h"
|
||||
#include "../sync/sync.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
@ -33,9 +34,27 @@ class TreePruner: public IUpdater {
|
||||
this->DoPrune(*trees[i]);
|
||||
}
|
||||
param.learning_rate = lr;
|
||||
}
|
||||
|
||||
this->SyncTrees(trees);
|
||||
}
|
||||
private:
|
||||
// synchronize the trees in different nodes, take tree from rank 0
|
||||
inline void SyncTrees(const std::vector<RegTree *> &trees) {
|
||||
if (sync::GetWorldSize() == 1) return;
|
||||
std::string s_model;
|
||||
utils::MemoryBufferStream fs(&s_model);
|
||||
int rank = sync::GetRank();
|
||||
if (rank == 0) {
|
||||
for (size_t i = 0; i < trees.size(); ++i) {
|
||||
trees[i]->SaveModel(fs);
|
||||
}
|
||||
sync::Bcast(&s_model, 0);
|
||||
} else {
|
||||
sync::Bcast(&s_model, 0);
|
||||
for (size_t i = 0; i < trees.size(); ++i) {
|
||||
trees[i]->LoadModel(fs);
|
||||
}
|
||||
}
|
||||
}
|
||||
// try to prune off current leaf
|
||||
inline int TryPruneLeaf(RegTree &tree, int nid, int depth, int npruned) {
|
||||
if (tree[nid].is_root()) return npruned;
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user