change hist update to lazy
This commit is contained in:
parent
deb21351b9
commit
7a35e1a906
6
Makefile
6
Makefile
@ -10,6 +10,12 @@ else
|
||||
CFLAGS += -fopenmp
|
||||
endif
|
||||
|
||||
# by default use c++11
|
||||
ifeq ($(no_cxx11),1)
|
||||
else
|
||||
CFLAGS += -std=c++11
|
||||
endif
|
||||
|
||||
# specify tensor path
|
||||
BIN = xgboost
|
||||
OBJ = updater.o gbm.o io.o main.o
|
||||
|
||||
@ -306,6 +306,12 @@ class CQHistMaker: public HistMaker<TStats> {
|
||||
}
|
||||
// start to work
|
||||
this->wspace.Init(this->param, 1);
|
||||
// if it is C++11, use lazy evaluation for Allreduce,
|
||||
// to gain speedup in recovery
|
||||
#if __cplusplus >= 201103L
|
||||
auto lazy_get_hist = [&]()
|
||||
#endif
|
||||
{
|
||||
thread_hist.resize(this->get_nthread());
|
||||
// start accumulating statistics
|
||||
utils::IIterator<ColBatch> *iter = p_fmat->ColIterator(fset);
|
||||
@ -330,8 +336,15 @@ class CQHistMaker: public HistMaker<TStats> {
|
||||
this->wspace.hset[0][fset.size() + wid * (fset.size()+1)]
|
||||
.data[0] = node_stats[nid];
|
||||
}
|
||||
};
|
||||
// sync the histogram
|
||||
// if it is C++11, use lazy evaluation for Allreduce
|
||||
#if __cplusplus >= 201103L
|
||||
this->histred.Allreduce(BeginPtr(this->wspace.hset[0].data),
|
||||
this->wspace.hset[0].data.size(), lazy_get_hist);
|
||||
#else
|
||||
this->histred.Allreduce(BeginPtr(this->wspace.hset[0].data), this->wspace.hset[0].data.size());
|
||||
#endif
|
||||
}
|
||||
virtual void ResetPositionAfterSplit(IFMatrix *p_fmat,
|
||||
const RegTree &tree) {
|
||||
@ -354,13 +367,24 @@ class CQHistMaker: public HistMaker<TStats> {
|
||||
feat2workindex[fset[i]] = -2;
|
||||
}
|
||||
}
|
||||
|
||||
this->GetNodeStats(gpair, *p_fmat, tree, info,
|
||||
&thread_stats, &node_stats);
|
||||
sketchs.resize(this->qexpand.size() * freal_set.size());
|
||||
for (size_t i = 0; i < sketchs.size(); ++i) {
|
||||
sketchs[i].Init(info.num_row, this->param.sketch_eps);
|
||||
}
|
||||
// intitialize the summary array
|
||||
summary_array.resize(sketchs.size());
|
||||
// setup maximum size
|
||||
unsigned max_size = this->param.max_sketch_size();
|
||||
for (size_t i = 0; i < sketchs.size(); ++i) {
|
||||
summary_array[i].Reserve(max_size);
|
||||
}
|
||||
// if it is C++11, use lazy evaluation for Allreduce
|
||||
#if __cplusplus >= 201103L
|
||||
auto lazy_get_summary = [&]()
|
||||
#endif
|
||||
{// get smmary
|
||||
thread_sketch.resize(this->get_nthread());
|
||||
// number of rows in
|
||||
const size_t nrows = p_fmat->buffered_rowset().size();
|
||||
@ -383,19 +407,20 @@ class CQHistMaker: public HistMaker<TStats> {
|
||||
}
|
||||
}
|
||||
}
|
||||
// setup maximum size
|
||||
unsigned max_size = this->param.max_sketch_size();
|
||||
// synchronize sketch
|
||||
summary_array.resize(sketchs.size());
|
||||
for (size_t i = 0; i < sketchs.size(); ++i) {
|
||||
utils::WXQuantileSketch<bst_float, bst_float>::SummaryContainer out;
|
||||
sketchs[i].GetSummary(&out);
|
||||
summary_array[i].Reserve(max_size);
|
||||
summary_array[i].SetPrune(out, max_size);
|
||||
}
|
||||
utils::Assert(summary_array.size() == sketchs.size(), "shape mismatch");
|
||||
};
|
||||
if (summary_array.size() != 0) {
|
||||
size_t nbytes = WXQSketch::SummaryContainer::CalcMemCost(max_size);
|
||||
#if __cplusplus >= 201103L
|
||||
sreducer.Allreduce(BeginPtr(summary_array), nbytes, summary_array.size(), lazy_get_summary);
|
||||
#else
|
||||
sreducer.Allreduce(BeginPtr(summary_array), nbytes, summary_array.size());
|
||||
#endif
|
||||
}
|
||||
// now we get the final result of sketch, setup the cut
|
||||
this->wspace.cut.clear();
|
||||
@ -623,6 +648,7 @@ class QuantileHistMaker: public HistMaker<TStats> {
|
||||
summary_array[i].Reserve(max_size);
|
||||
summary_array[i].SetPrune(out, max_size);
|
||||
}
|
||||
|
||||
size_t nbytes = WXQSketch::SummaryContainer::CalcMemCost(max_size);
|
||||
sreducer.Allreduce(BeginPtr(summary_array), nbytes, summary_array.size());
|
||||
// now we get the final result of sketch, setup the cut
|
||||
|
||||
@ -52,6 +52,12 @@ class TreeRefresher: public IUpdater {
|
||||
std::fill(stemp[tid].begin(), stemp[tid].end(), TStats(param));
|
||||
fvec_temp[tid].Init(trees[0]->param.num_feature);
|
||||
}
|
||||
// if it is C++11, use lazy evaluation for Allreduce,
|
||||
// to gain speedup in recovery
|
||||
#if __cplusplus >= 201103L
|
||||
auto lazy_get_stats = [&]()
|
||||
#endif
|
||||
{
|
||||
// start accumulating statistics
|
||||
utils::IIterator<RowBatch> *iter = p_fmat->RowIterator();
|
||||
iter->BeforeFirst();
|
||||
@ -84,8 +90,12 @@ class TreeRefresher: public IUpdater {
|
||||
stemp[0][nid].Add(stemp[tid][nid]);
|
||||
}
|
||||
}
|
||||
// AllReduce, add statistics up
|
||||
};
|
||||
#if __cplusplus >= 201103L
|
||||
reducer.Allreduce(BeginPtr(stemp[0]), stemp[0].size(), lazy_get_stats);
|
||||
#else
|
||||
reducer.Allreduce(BeginPtr(stemp[0]), stemp[0].size());
|
||||
#endif
|
||||
// rescale learning rate according to size of trees
|
||||
float lr = param.learning_rate;
|
||||
param.learning_rate = lr / trees.size();
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user