move sync tree to pruner, pruner is now distributed
This commit is contained in:
parent
a68ac8033e
commit
c2fa390181
@ -32,9 +32,8 @@ class DistColMaker : public ColMaker<TStats> {
|
|||||||
utils::Check(trees.size() == 1, "DistColMaker: only support one tree at a time");
|
utils::Check(trees.size() == 1, "DistColMaker: only support one tree at a time");
|
||||||
// build the tree
|
// build the tree
|
||||||
builder.Update(gpair, p_fmat, info, trees[0]);
|
builder.Update(gpair, p_fmat, info, trees[0]);
|
||||||
//// prune the tree
|
//// prune the tree, note that pruner will sync the tree
|
||||||
pruner.Update(gpair, p_fmat, info, trees);
|
pruner.Update(gpair, p_fmat, info, trees);
|
||||||
this->SyncTrees(trees[0]);
|
|
||||||
// update position after the tree is pruned
|
// update position after the tree is pruned
|
||||||
builder.UpdatePosition(p_fmat, *trees[0]);
|
builder.UpdatePosition(p_fmat, *trees[0]);
|
||||||
}
|
}
|
||||||
@ -42,18 +41,6 @@ class DistColMaker : public ColMaker<TStats> {
|
|||||||
return builder.GetLeafPosition();
|
return builder.GetLeafPosition();
|
||||||
}
|
}
|
||||||
private:
|
private:
|
||||||
inline void SyncTrees(RegTree *tree) {
|
|
||||||
std::string s_model;
|
|
||||||
utils::MemoryBufferStream fs(&s_model);
|
|
||||||
int rank = sync::GetRank();
|
|
||||||
if (rank == 0) {
|
|
||||||
tree->SaveModel(fs);
|
|
||||||
sync::Bcast(&s_model, 0);
|
|
||||||
} else {
|
|
||||||
sync::Bcast(&s_model, 0);
|
|
||||||
tree->LoadModel(fs);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
struct Builder : public ColMaker<TStats>::Builder {
|
struct Builder : public ColMaker<TStats>::Builder {
|
||||||
public:
|
public:
|
||||||
Builder(const TrainParam ¶m)
|
Builder(const TrainParam ¶m)
|
||||||
|
|||||||
@ -8,6 +8,7 @@
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
#include "./param.h"
|
#include "./param.h"
|
||||||
#include "./updater.h"
|
#include "./updater.h"
|
||||||
|
#include "../sync/sync.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace tree {
|
namespace tree {
|
||||||
@ -33,9 +34,27 @@ class TreePruner: public IUpdater {
|
|||||||
this->DoPrune(*trees[i]);
|
this->DoPrune(*trees[i]);
|
||||||
}
|
}
|
||||||
param.learning_rate = lr;
|
param.learning_rate = lr;
|
||||||
|
this->SyncTrees(trees);
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
// synchronize the trees in different nodes, take tree from rank 0
|
||||||
|
inline void SyncTrees(const std::vector<RegTree *> &trees) {
|
||||||
|
if (sync::GetWorldSize() == 1) return;
|
||||||
|
std::string s_model;
|
||||||
|
utils::MemoryBufferStream fs(&s_model);
|
||||||
|
int rank = sync::GetRank();
|
||||||
|
if (rank == 0) {
|
||||||
|
for (size_t i = 0; i < trees.size(); ++i) {
|
||||||
|
trees[i]->SaveModel(fs);
|
||||||
|
}
|
||||||
|
sync::Bcast(&s_model, 0);
|
||||||
|
} else {
|
||||||
|
sync::Bcast(&s_model, 0);
|
||||||
|
for (size_t i = 0; i < trees.size(); ++i) {
|
||||||
|
trees[i]->LoadModel(fs);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
// try to prune off current leaf
|
// try to prune off current leaf
|
||||||
inline int TryPruneLeaf(RegTree &tree, int nid, int depth, int npruned) {
|
inline int TryPruneLeaf(RegTree &tree, int nid, int depth, int npruned) {
|
||||||
if (tree[nid].is_root()) return npruned;
|
if (tree[nid].is_root()) return npruned;
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user