refresher is now distributed
This commit is contained in:
parent
9df9e07f9b
commit
a68ac8033e
@ -190,6 +190,10 @@ struct GradStats {
|
|||||||
inline void Add(const GradStats &b) {
|
inline void Add(const GradStats &b) {
|
||||||
this->Add(b.sum_grad, b.sum_hess);
|
this->Add(b.sum_grad, b.sum_hess);
|
||||||
}
|
}
|
||||||
|
/*! \brief same as add, reduce is used in All Reduce */
|
||||||
|
inline void Reduce(const GradStats &b) {
|
||||||
|
this->Add(b);
|
||||||
|
}
|
||||||
/*! \brief set current value to a - b */
|
/*! \brief set current value to a - b */
|
||||||
inline void SetSubstract(const GradStats &a, const GradStats &b) {
|
inline void SetSubstract(const GradStats &a, const GradStats &b) {
|
||||||
sum_grad = a.sum_grad - b.sum_grad;
|
sum_grad = a.sum_grad - b.sum_grad;
|
||||||
@ -266,6 +270,10 @@ struct CVGradStats : public GradStats {
|
|||||||
valid[i].Add(b.valid[i]);
|
valid[i].Add(b.valid[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
/*! \brief same as add, reduce is used in All Reduce */
|
||||||
|
inline void Reduce(const CVGradStats &b) {
|
||||||
|
this->Add(b);
|
||||||
|
}
|
||||||
/*! \brief set current value to a - b */
|
/*! \brief set current value to a - b */
|
||||||
inline void SetSubstract(const CVGradStats &a, const CVGradStats &b) {
|
inline void SetSubstract(const CVGradStats &a, const CVGradStats &b) {
|
||||||
GradStats::SetSubstract(a, b);
|
GradStats::SetSubstract(a, b);
|
||||||
|
|||||||
@ -10,6 +10,7 @@
|
|||||||
#include "./param.h"
|
#include "./param.h"
|
||||||
#include "./updater.h"
|
#include "./updater.h"
|
||||||
#include "../utils/omp.h"
|
#include "../utils/omp.h"
|
||||||
|
#include "../sync/sync.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace tree {
|
namespace tree {
|
||||||
@ -39,15 +40,16 @@ class TreeRefresher: public IUpdater {
|
|||||||
nthread = omp_get_num_threads();
|
nthread = omp_get_num_threads();
|
||||||
}
|
}
|
||||||
fvec_temp.resize(nthread, RegTree::FVec());
|
fvec_temp.resize(nthread, RegTree::FVec());
|
||||||
stemp.resize(trees.size() * nthread, std::vector<TStats>());
|
stemp.resize(nthread, std::vector<TStats>());
|
||||||
#pragma omp parallel
|
#pragma omp parallel
|
||||||
{
|
{
|
||||||
int tid = omp_get_thread_num();
|
int tid = omp_get_thread_num();
|
||||||
|
int num_nodes = 0;
|
||||||
for (size_t i = 0; i < trees.size(); ++i) {
|
for (size_t i = 0; i < trees.size(); ++i) {
|
||||||
std::vector<TStats> &vec = stemp[tid * trees.size() + i];
|
num_nodes += trees[i]->param.num_nodes;
|
||||||
vec.resize(trees[i]->param.num_nodes, TStats(param));
|
|
||||||
std::fill(vec.begin(), vec.end(), TStats(param));
|
|
||||||
}
|
}
|
||||||
|
stemp[tid].resize(num_nodes, TStats(param));
|
||||||
|
std::fill(stemp[tid].begin(), stemp[tid].end(), TStats(param));
|
||||||
fvec_temp[tid].Init(trees[0]->param.num_feature);
|
fvec_temp[tid].Init(trees[0]->param.num_feature);
|
||||||
}
|
}
|
||||||
// start accumulating statistics
|
// start accumulating statistics
|
||||||
@ -65,28 +67,34 @@ class TreeRefresher: public IUpdater {
|
|||||||
const bst_uint ridx = static_cast<bst_uint>(batch.base_rowid + i);
|
const bst_uint ridx = static_cast<bst_uint>(batch.base_rowid + i);
|
||||||
RegTree::FVec &feats = fvec_temp[tid];
|
RegTree::FVec &feats = fvec_temp[tid];
|
||||||
feats.Fill(inst);
|
feats.Fill(inst);
|
||||||
|
int offset = 0;
|
||||||
for (size_t j = 0; j < trees.size(); ++j) {
|
for (size_t j = 0; j < trees.size(); ++j) {
|
||||||
AddStats(*trees[j], feats, gpair, info, ridx,
|
AddStats(*trees[j], feats, gpair, info, ridx,
|
||||||
&stemp[tid * trees.size() + j]);
|
BeginPtr(stemp[tid]) + offset);
|
||||||
|
offset += trees[j]->param.num_nodes;
|
||||||
}
|
}
|
||||||
feats.Drop(inst);
|
feats.Drop(inst);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// start update the trees using the statistics
|
// aggregate the statistics
|
||||||
|
int num_nodes = static_cast<int>(stemp[0].size());
|
||||||
|
#pragma omp parallel for schedule(static)
|
||||||
|
for (int nid = 0; nid < num_nodes; ++nid) {
|
||||||
|
for (int tid = 1; tid < nthread; ++tid) {
|
||||||
|
stemp[0][nid].Add(stemp[tid][nid]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// AllReduce, add statistics up
|
||||||
|
reducer.AllReduce(BeginPtr(stemp[0]), stemp[0].size());
|
||||||
// rescale learning rate according to size of trees
|
// rescale learning rate according to size of trees
|
||||||
float lr = param.learning_rate;
|
float lr = param.learning_rate;
|
||||||
param.learning_rate = lr / trees.size();
|
param.learning_rate = lr / trees.size();
|
||||||
|
int offset = 0;
|
||||||
for (size_t i = 0; i < trees.size(); ++i) {
|
for (size_t i = 0; i < trees.size(); ++i) {
|
||||||
// aggregate
|
|
||||||
#pragma omp parallel for schedule(static)
|
|
||||||
for (int nid = 0; nid < trees[i]->param.num_nodes; ++nid) {
|
|
||||||
for (int tid = 1; tid < nthread; ++tid) {
|
|
||||||
stemp[i][nid].Add(stemp[tid * trees.size() + i][nid]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for (int rid = 0; rid < trees[i]->param.num_roots; ++rid) {
|
for (int rid = 0; rid < trees[i]->param.num_roots; ++rid) {
|
||||||
this->Refresh(stemp[i], rid, trees[i]);
|
this->Refresh(BeginPtr(stemp[0]) + offset, rid, trees[i]);
|
||||||
}
|
}
|
||||||
|
offset += trees[i]->param.num_nodes;
|
||||||
}
|
}
|
||||||
// set learning rate back
|
// set learning rate back
|
||||||
param.learning_rate = lr;
|
param.learning_rate = lr;
|
||||||
@ -98,8 +106,7 @@ class TreeRefresher: public IUpdater {
|
|||||||
const std::vector<bst_gpair> &gpair,
|
const std::vector<bst_gpair> &gpair,
|
||||||
const BoosterInfo &info,
|
const BoosterInfo &info,
|
||||||
const bst_uint ridx,
|
const bst_uint ridx,
|
||||||
std::vector<TStats> *p_gstats) {
|
TStats *gstats) {
|
||||||
std::vector<TStats> &gstats = *p_gstats;
|
|
||||||
// start from groups that belongs to current data
|
// start from groups that belongs to current data
|
||||||
int pid = static_cast<int>(info.GetRoot(ridx));
|
int pid = static_cast<int>(info.GetRoot(ridx));
|
||||||
gstats[pid].Add(gpair, info, ridx);
|
gstats[pid].Add(gpair, info, ridx);
|
||||||
@ -110,7 +117,7 @@ class TreeRefresher: public IUpdater {
|
|||||||
gstats[pid].Add(gpair, info, ridx);
|
gstats[pid].Add(gpair, info, ridx);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
inline void Refresh(const std::vector<TStats> &gstats,
|
inline void Refresh(const TStats *gstats,
|
||||||
int nid, RegTree *p_tree) {
|
int nid, RegTree *p_tree) {
|
||||||
RegTree &tree = *p_tree;
|
RegTree &tree = *p_tree;
|
||||||
tree.stat(nid).base_weight = static_cast<float>(gstats[nid].CalcWeight(param));
|
tree.stat(nid).base_weight = static_cast<float>(gstats[nid].CalcWeight(param));
|
||||||
@ -129,6 +136,8 @@ class TreeRefresher: public IUpdater {
|
|||||||
}
|
}
|
||||||
// training parameter
|
// training parameter
|
||||||
TrainParam param;
|
TrainParam param;
|
||||||
|
// reducer
|
||||||
|
sync::Reducer<TStats> reducer;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace tree
|
} // namespace tree
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user