refresher is now distributed

This commit is contained in:
tqchen 2014-10-17 14:48:32 -07:00
parent 9df9e07f9b
commit a68ac8033e
2 changed files with 36 additions and 19 deletions

View File

@ -190,6 +190,10 @@ struct GradStats {
inline void Add(const GradStats &b) {
this->Add(b.sum_grad, b.sum_hess);
}
/*! \brief same as add, reduce is used in All Reduce */
inline void Reduce(const GradStats &b) {
this->Add(b);
}
/*! \brief set current value to a - b */
inline void SetSubstract(const GradStats &a, const GradStats &b) {
sum_grad = a.sum_grad - b.sum_grad;
@ -266,6 +270,10 @@ struct CVGradStats : public GradStats {
valid[i].Add(b.valid[i]);
}
}
/*! \brief same as add, reduce is used in All Reduce */
inline void Reduce(const CVGradStats &b) {
this->Add(b);
}
/*! \brief set current value to a - b */
inline void SetSubstract(const CVGradStats &a, const CVGradStats &b) {
GradStats::SetSubstract(a, b);

View File

@ -10,6 +10,7 @@
#include "./param.h"
#include "./updater.h"
#include "../utils/omp.h"
#include "../sync/sync.h"
namespace xgboost {
namespace tree {
@ -26,7 +27,7 @@ class TreeRefresher: public IUpdater {
virtual void Update(const std::vector<bst_gpair> &gpair,
IFMatrix *p_fmat,
const BoosterInfo &info,
const std::vector<RegTree*> &trees) {
const std::vector<RegTree*> &trees) {
if (trees.size() == 0) return;
// number of threads
// thread temporal space
@ -39,15 +40,16 @@ class TreeRefresher: public IUpdater {
nthread = omp_get_num_threads();
}
fvec_temp.resize(nthread, RegTree::FVec());
stemp.resize(trees.size() * nthread, std::vector<TStats>());
stemp.resize(nthread, std::vector<TStats>());
#pragma omp parallel
{
int tid = omp_get_thread_num();
int num_nodes = 0;
for (size_t i = 0; i < trees.size(); ++i) {
std::vector<TStats> &vec = stemp[tid * trees.size() + i];
vec.resize(trees[i]->param.num_nodes, TStats(param));
std::fill(vec.begin(), vec.end(), TStats(param));
num_nodes += trees[i]->param.num_nodes;
}
stemp[tid].resize(num_nodes, TStats(param));
std::fill(stemp[tid].begin(), stemp[tid].end(), TStats(param));
fvec_temp[tid].Init(trees[0]->param.num_feature);
}
// start accumulating statistics
@ -65,28 +67,34 @@ class TreeRefresher: public IUpdater {
const bst_uint ridx = static_cast<bst_uint>(batch.base_rowid + i);
RegTree::FVec &feats = fvec_temp[tid];
feats.Fill(inst);
int offset = 0;
for (size_t j = 0; j < trees.size(); ++j) {
AddStats(*trees[j], feats, gpair, info, ridx,
&stemp[tid * trees.size() + j]);
BeginPtr(stemp[tid]) + offset);
offset += trees[j]->param.num_nodes;
}
feats.Drop(inst);
}
}
// start update the trees using the statistics
// aggregate the statistics
int num_nodes = static_cast<int>(stemp[0].size());
#pragma omp parallel for schedule(static)
for (int nid = 0; nid < num_nodes; ++nid) {
for (int tid = 1; tid < nthread; ++tid) {
stemp[0][nid].Add(stemp[tid][nid]);
}
}
// AllReduce, add statistics up
reducer.AllReduce(BeginPtr(stemp[0]), stemp[0].size());
// rescale learning rate according to size of trees
float lr = param.learning_rate;
param.learning_rate = lr / trees.size();
for (size_t i = 0; i < trees.size(); ++i) {
// aggregate
#pragma omp parallel for schedule(static)
for (int nid = 0; nid < trees[i]->param.num_nodes; ++nid) {
for (int tid = 1; tid < nthread; ++tid) {
stemp[i][nid].Add(stemp[tid * trees.size() + i][nid]);
}
}
int offset = 0;
for (size_t i = 0; i < trees.size(); ++i) {
for (int rid = 0; rid < trees[i]->param.num_roots; ++rid) {
this->Refresh(stemp[i], rid, trees[i]);
this->Refresh(BeginPtr(stemp[0]) + offset, rid, trees[i]);
}
offset += trees[i]->param.num_nodes;
}
// set learning rate back
param.learning_rate = lr;
@ -98,8 +106,7 @@ class TreeRefresher: public IUpdater {
const std::vector<bst_gpair> &gpair,
const BoosterInfo &info,
const bst_uint ridx,
std::vector<TStats> *p_gstats) {
std::vector<TStats> &gstats = *p_gstats;
TStats *gstats) {
// start from groups that belongs to current data
int pid = static_cast<int>(info.GetRoot(ridx));
gstats[pid].Add(gpair, info, ridx);
@ -110,7 +117,7 @@ class TreeRefresher: public IUpdater {
gstats[pid].Add(gpair, info, ridx);
}
}
inline void Refresh(const std::vector<TStats> &gstats,
inline void Refresh(const TStats *gstats,
int nid, RegTree *p_tree) {
RegTree &tree = *p_tree;
tree.stat(nid).base_weight = static_cast<float>(gstats[nid].CalcWeight(param));
@ -129,6 +136,8 @@ class TreeRefresher: public IUpdater {
}
// training parameter
TrainParam param;
// reducer
sync::Reducer<TStats> reducer;
};
} // namespace tree