diff --git a/src/tree/param.h b/src/tree/param.h index cf646a76e..8bd855554 100644 --- a/src/tree/param.h +++ b/src/tree/param.h @@ -190,6 +190,10 @@ struct GradStats { inline void Add(const GradStats &b) { this->Add(b.sum_grad, b.sum_hess); } + /*! \brief same as add, reduce is used in All Reduce */ + inline void Reduce(const GradStats &b) { + this->Add(b); + } /*! \brief set current value to a - b */ inline void SetSubstract(const GradStats &a, const GradStats &b) { sum_grad = a.sum_grad - b.sum_grad; @@ -266,6 +270,10 @@ struct CVGradStats : public GradStats { valid[i].Add(b.valid[i]); } } + /*! \brief same as add, reduce is used in All Reduce */ + inline void Reduce(const CVGradStats &b) { + this->Add(b); + } /*! \brief set current value to a - b */ inline void SetSubstract(const CVGradStats &a, const CVGradStats &b) { GradStats::SetSubstract(a, b); diff --git a/src/tree/updater_refresh-inl.hpp b/src/tree/updater_refresh-inl.hpp index a37630333..579ff2bc3 100644 --- a/src/tree/updater_refresh-inl.hpp +++ b/src/tree/updater_refresh-inl.hpp @@ -10,6 +10,7 @@ #include "./param.h" #include "./updater.h" #include "../utils/omp.h" +#include "../sync/sync.h" namespace xgboost { namespace tree { @@ -26,7 +27,7 @@ class TreeRefresher: public IUpdater { virtual void Update(const std::vector &gpair, IFMatrix *p_fmat, const BoosterInfo &info, - const std::vector &trees) { + const std::vector &trees) { if (trees.size() == 0) return; // number of threads // thread temporal space @@ -39,15 +40,16 @@ class TreeRefresher: public IUpdater { nthread = omp_get_num_threads(); } fvec_temp.resize(nthread, RegTree::FVec()); - stemp.resize(trees.size() * nthread, std::vector()); + stemp.resize(nthread, std::vector()); #pragma omp parallel { int tid = omp_get_thread_num(); + int num_nodes = 0; for (size_t i = 0; i < trees.size(); ++i) { - std::vector &vec = stemp[tid * trees.size() + i]; - vec.resize(trees[i]->param.num_nodes, TStats(param)); - std::fill(vec.begin(), vec.end(), TStats(param)); + num_nodes += trees[i]->param.num_nodes; } + stemp[tid].resize(num_nodes, TStats(param)); + std::fill(stemp[tid].begin(), stemp[tid].end(), TStats(param)); fvec_temp[tid].Init(trees[0]->param.num_feature); } // start accumulating statistics @@ -65,28 +67,34 @@ class TreeRefresher: public IUpdater { const bst_uint ridx = static_cast(batch.base_rowid + i); RegTree::FVec &feats = fvec_temp[tid]; feats.Fill(inst); + int offset = 0; for (size_t j = 0; j < trees.size(); ++j) { AddStats(*trees[j], feats, gpair, info, ridx, - &stemp[tid * trees.size() + j]); + BeginPtr(stemp[tid]) + offset); + offset += trees[j]->param.num_nodes; } feats.Drop(inst); } } - // start update the trees using the statistics + // aggregate the statistics + int num_nodes = static_cast(stemp[0].size()); + #pragma omp parallel for schedule(static) + for (int nid = 0; nid < num_nodes; ++nid) { + for (int tid = 1; tid < nthread; ++tid) { + stemp[0][nid].Add(stemp[tid][nid]); + } + } + // AllReduce, add statistics up + reducer.AllReduce(BeginPtr(stemp[0]), stemp[0].size()); // rescale learning rate according to size of trees float lr = param.learning_rate; param.learning_rate = lr / trees.size(); - for (size_t i = 0; i < trees.size(); ++i) { - // aggregate - #pragma omp parallel for schedule(static) - for (int nid = 0; nid < trees[i]->param.num_nodes; ++nid) { - for (int tid = 1; tid < nthread; ++tid) { - stemp[i][nid].Add(stemp[tid * trees.size() + i][nid]); - } - } + int offset = 0; + for (size_t i = 0; i < trees.size(); ++i) { for (int rid = 0; rid < trees[i]->param.num_roots; ++rid) { - this->Refresh(stemp[i], rid, trees[i]); + this->Refresh(BeginPtr(stemp[0]) + offset, rid, trees[i]); } + offset += trees[i]->param.num_nodes; } // set learning rate back param.learning_rate = lr; @@ -98,8 +106,7 @@ class TreeRefresher: public IUpdater { const std::vector &gpair, const BoosterInfo &info, const bst_uint ridx, - std::vector *p_gstats) { - std::vector &gstats = *p_gstats; + TStats *gstats) { // start from groups that belongs to current data int pid = static_cast(info.GetRoot(ridx)); gstats[pid].Add(gpair, info, ridx); @@ -110,7 +117,7 @@ class TreeRefresher: public IUpdater { gstats[pid].Add(gpair, info, ridx); } } - inline void Refresh(const std::vector &gstats, + inline void Refresh(const TStats *gstats, int nid, RegTree *p_tree) { RegTree &tree = *p_tree; tree.stat(nid).base_weight = static_cast(gstats[nid].CalcWeight(param)); @@ -129,6 +136,8 @@ class TreeRefresher: public IUpdater { } // training parameter TrainParam param; + // reducer + sync::Reducer reducer; }; } // namespace tree