change hist update to lazy

This commit is contained in:
tqchen 2014-12-20 05:02:38 -08:00
parent deb21351b9
commit 7a35e1a906
3 changed files with 127 additions and 85 deletions

View File

@ -10,6 +10,12 @@ else
CFLAGS += -fopenmp CFLAGS += -fopenmp
endif endif
# by default use c++11
ifeq ($(no_cxx11),1)
else
CFLAGS += -std=c++11
endif
# specify tensor path # specify tensor path
BIN = xgboost BIN = xgboost
OBJ = updater.o gbm.o io.o main.o OBJ = updater.o gbm.o io.o main.o

View File

@ -306,6 +306,12 @@ class CQHistMaker: public HistMaker<TStats> {
} }
// start to work // start to work
this->wspace.Init(this->param, 1); this->wspace.Init(this->param, 1);
// if it is C++11, use lazy evaluation for Allreduce,
// to gain speedup in recovery
#if __cplusplus >= 201103L
auto lazy_get_hist = [&]()
#endif
{
thread_hist.resize(this->get_nthread()); thread_hist.resize(this->get_nthread());
// start accumulating statistics // start accumulating statistics
utils::IIterator<ColBatch> *iter = p_fmat->ColIterator(fset); utils::IIterator<ColBatch> *iter = p_fmat->ColIterator(fset);
@ -330,8 +336,15 @@ class CQHistMaker: public HistMaker<TStats> {
this->wspace.hset[0][fset.size() + wid * (fset.size()+1)] this->wspace.hset[0][fset.size() + wid * (fset.size()+1)]
.data[0] = node_stats[nid]; .data[0] = node_stats[nid];
} }
};
// sync the histogram // sync the histogram
// if it is C++11, use lazy evaluation for Allreduce
#if __cplusplus >= 201103L
this->histred.Allreduce(BeginPtr(this->wspace.hset[0].data),
this->wspace.hset[0].data.size(), lazy_get_hist);
#else
this->histred.Allreduce(BeginPtr(this->wspace.hset[0].data), this->wspace.hset[0].data.size()); this->histred.Allreduce(BeginPtr(this->wspace.hset[0].data), this->wspace.hset[0].data.size());
#endif
} }
virtual void ResetPositionAfterSplit(IFMatrix *p_fmat, virtual void ResetPositionAfterSplit(IFMatrix *p_fmat,
const RegTree &tree) { const RegTree &tree) {
@ -354,13 +367,24 @@ class CQHistMaker: public HistMaker<TStats> {
feat2workindex[fset[i]] = -2; feat2workindex[fset[i]] = -2;
} }
} }
this->GetNodeStats(gpair, *p_fmat, tree, info, this->GetNodeStats(gpair, *p_fmat, tree, info,
&thread_stats, &node_stats); &thread_stats, &node_stats);
sketchs.resize(this->qexpand.size() * freal_set.size()); sketchs.resize(this->qexpand.size() * freal_set.size());
for (size_t i = 0; i < sketchs.size(); ++i) { for (size_t i = 0; i < sketchs.size(); ++i) {
sketchs[i].Init(info.num_row, this->param.sketch_eps); sketchs[i].Init(info.num_row, this->param.sketch_eps);
} }
// intitialize the summary array
summary_array.resize(sketchs.size());
// setup maximum size
unsigned max_size = this->param.max_sketch_size();
for (size_t i = 0; i < sketchs.size(); ++i) {
summary_array[i].Reserve(max_size);
}
// if it is C++11, use lazy evaluation for Allreduce
#if __cplusplus >= 201103L
auto lazy_get_summary = [&]()
#endif
{// get smmary
thread_sketch.resize(this->get_nthread()); thread_sketch.resize(this->get_nthread());
// number of rows in // number of rows in
const size_t nrows = p_fmat->buffered_rowset().size(); const size_t nrows = p_fmat->buffered_rowset().size();
@ -383,19 +407,20 @@ class CQHistMaker: public HistMaker<TStats> {
} }
} }
} }
// setup maximum size
unsigned max_size = this->param.max_sketch_size();
// synchronize sketch
summary_array.resize(sketchs.size());
for (size_t i = 0; i < sketchs.size(); ++i) { for (size_t i = 0; i < sketchs.size(); ++i) {
utils::WXQuantileSketch<bst_float, bst_float>::SummaryContainer out; utils::WXQuantileSketch<bst_float, bst_float>::SummaryContainer out;
sketchs[i].GetSummary(&out); sketchs[i].GetSummary(&out);
summary_array[i].Reserve(max_size);
summary_array[i].SetPrune(out, max_size); summary_array[i].SetPrune(out, max_size);
} }
utils::Assert(summary_array.size() == sketchs.size(), "shape mismatch");
};
if (summary_array.size() != 0) { if (summary_array.size() != 0) {
size_t nbytes = WXQSketch::SummaryContainer::CalcMemCost(max_size); size_t nbytes = WXQSketch::SummaryContainer::CalcMemCost(max_size);
#if __cplusplus >= 201103L
sreducer.Allreduce(BeginPtr(summary_array), nbytes, summary_array.size(), lazy_get_summary);
#else
sreducer.Allreduce(BeginPtr(summary_array), nbytes, summary_array.size()); sreducer.Allreduce(BeginPtr(summary_array), nbytes, summary_array.size());
#endif
} }
// now we get the final result of sketch, setup the cut // now we get the final result of sketch, setup the cut
this->wspace.cut.clear(); this->wspace.cut.clear();
@ -623,6 +648,7 @@ class QuantileHistMaker: public HistMaker<TStats> {
summary_array[i].Reserve(max_size); summary_array[i].Reserve(max_size);
summary_array[i].SetPrune(out, max_size); summary_array[i].SetPrune(out, max_size);
} }
size_t nbytes = WXQSketch::SummaryContainer::CalcMemCost(max_size); size_t nbytes = WXQSketch::SummaryContainer::CalcMemCost(max_size);
sreducer.Allreduce(BeginPtr(summary_array), nbytes, summary_array.size()); sreducer.Allreduce(BeginPtr(summary_array), nbytes, summary_array.size());
// now we get the final result of sketch, setup the cut // now we get the final result of sketch, setup the cut

View File

@ -52,6 +52,12 @@ class TreeRefresher: public IUpdater {
std::fill(stemp[tid].begin(), stemp[tid].end(), TStats(param)); std::fill(stemp[tid].begin(), stemp[tid].end(), TStats(param));
fvec_temp[tid].Init(trees[0]->param.num_feature); fvec_temp[tid].Init(trees[0]->param.num_feature);
} }
// if it is C++11, use lazy evaluation for Allreduce,
// to gain speedup in recovery
#if __cplusplus >= 201103L
auto lazy_get_stats = [&]()
#endif
{
// start accumulating statistics // start accumulating statistics
utils::IIterator<RowBatch> *iter = p_fmat->RowIterator(); utils::IIterator<RowBatch> *iter = p_fmat->RowIterator();
iter->BeforeFirst(); iter->BeforeFirst();
@ -84,8 +90,12 @@ class TreeRefresher: public IUpdater {
stemp[0][nid].Add(stemp[tid][nid]); stemp[0][nid].Add(stemp[tid][nid]);
} }
} }
// AllReduce, add statistics up };
#if __cplusplus >= 201103L
reducer.Allreduce(BeginPtr(stemp[0]), stemp[0].size(), lazy_get_stats);
#else
reducer.Allreduce(BeginPtr(stemp[0]), stemp[0].size()); reducer.Allreduce(BeginPtr(stemp[0]), stemp[0].size());
#endif
// rescale learning rate according to size of trees // rescale learning rate according to size of trees
float lr = param.learning_rate; float lr = param.learning_rate;
param.learning_rate = lr / trees.size(); param.learning_rate = lr / trees.size();