make clear seperation

This commit is contained in:
tqchen
2014-10-16 13:03:42 -07:00
parent 47145a7fac
commit a21df0770d
6 changed files with 63 additions and 24 deletions

View File

@@ -32,13 +32,13 @@ class DistColMaker : public ColMaker<TStats> {
utils::Check(trees.size() == 1, "DistColMaker: only support one tree at a time");
// build the tree
builder.Update(gpair, p_fmat, info, trees[0]);
// prune the tree
//// prune the tree
pruner.Update(gpair, p_fmat, info, trees);
this->SyncTrees(trees[0]);
// update position after the tree is pruned
builder.UpdatePosition(p_fmat, *trees[0]);
}
private:
inline void SyncTrees(RegTree *tree) {
std::string s_model;
@@ -63,10 +63,12 @@ class DistColMaker : public ColMaker<TStats> {
#pragma omp parallel for schedule(static)
for (bst_omp_uint i = 0; i < ndata; ++i) {
const bst_uint ridx = rowset[i];
int nid = this->position[ridx];
if (nid < 0) {
int nid = this->DecodePosition(ridx);
while (tree[nid].is_deleted()) {
nid = tree[nid].parent();
utils::Assert(nid >=0, "distributed learning error");
}
this->position[ridx] = nid;
}
}
protected:
@@ -111,6 +113,7 @@ class DistColMaker : public ColMaker<TStats> {
}
}
}
// communicate bitmap
sync::AllReduce(BeginPtr(bitmap.data), bitmap.data.size(), sync::kBitwiseOR);
const std::vector<bst_uint> &rowset = p_fmat->buffered_rowset();
@@ -125,7 +128,7 @@ class DistColMaker : public ColMaker<TStats> {
if (tree[nid].default_left()) {
this->SetEncodePosition(ridx, tree[nid].cright());
} else {
this->SetEncodePosition(ridx, tree[nid].cright());
this->SetEncodePosition(ridx, tree[nid].cleft());
}
}
}